File size: 4,045 Bytes
6d7fc1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e22a639
 
 
 
226c7c9
 
 
 
e22a639
226c7c9
 
e22a639
 
 
 
 
226c7c9
 
e22a639
 
226c7c9
 
e22a639
 
 
 
 
226c7c9
 
 
 
e22a639
 
6d7fc1c
 
 
 
 
 
 
 
 
 
 
226c7c9
6d7fc1c
 
 
 
 
 
 
 
 
226c7c9
 
6d7fc1c
 
 
 
 
 
 
 
 
 
 
226c7c9
 
 
6d7fc1c
 
 
 
 
 
 
 
 
226c7c9
6d7fc1c
226c7c9
 
6d7fc1c
 
 
 
226c7c9
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import importlib
import os
import sys


def setup_environment():
    os.system("apt-get update && apt-get install -qqy libmagickwand-dev")

    # install packages
    # os.system(
    #     'export FLASH_ATTENTION_SKIP_CUDA_BUILD=FALSE && \
    #     pip install --timeout=1000000000 --no-build-isolation "flash-attn<=2.7.4.post1"'
    # )
    os.system(
        "pip install --timeout=1000000000 \
        https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.5%2Bcu128torch2.7-cp38-abi3-linux_x86_64.whl"
    )
    os.system('export VLLM_ATTENTION_BACKEND=FLASHINFER && pip install "vllm==0.9.0"')
    os.system('pip install "decord==0.6.0"')

    os.system(
        "export CONDA_PREFIX=/usr/local/cuda && \
        ln -sf $CONDA_PREFIX/lib/python3.10/site-packages/nvidia/*/include/* $CONDA_PREFIX/include/"
    )
    os.system(
        "export CONDA_PREFIX=/usr/local/cuda && \
        ln -sf $CONDA_PREFIX/lib/python3.10/site-packages/nvidia/*/include/* $CONDA_PREFIX/include/python3.10"
    )

    os.system('pip install --timeout=1000000000 --no-build-isolation "transformer-engine[pytorch]"')
    os.system('pip install --timeout=1000000000 "decord==0.6.0"')

    # os.system(
    #     'pip install --timeout=1000000000 \
    #     "git+https://github.com/nvidia-cosmos/cosmos-transfer1@e4055e39ee9c53165e85275bdab84ed20909714a"'
    # )


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--training",
        action="store_true",
        help="Whether to check training-specific dependencies",
    )
    return parser.parse_args()


def check_packages(package_list):
    all_success = True
    for package in package_list:
        try:
            _ = importlib.import_module(package)
        except Exception:
            print(f"\033[91m[ERROR]\033[0m Package not successfully imported: \033[93m{package}\033[0m")
            all_success = False
        else:
            print(f"\033[92m[SUCCESS]\033[0m {package} found")

    return all_success


def main():
    args = parse_args()

    if not (sys.version_info.major == 3 and sys.version_info.minor >= 10):
        detected = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
        print(f"\033[91m[ERROR]\033[0m Python 3.10+ is required. You have: \033[93m{detected}\033[0m")
        sys.exit(1)

    if "CONDA_PREFIX" not in os.environ:
        print(
            "\033[93m[WARNING]\033[0m CONDA_PREFIX is not set. "
            "When manually installed, Cosmos should run under the cosmos-transfer1 conda environment (see INSTALL.md). "
            "This warning can be ignored when running in the container."
        )

    print("Attempting to import critical packages...")

    packages = ["torch", "torchvision", "transformers", "megatron.core", "transformer_engine", "vllm", "pandas"]
    packages_training = [
        "apex.multi_tensor_apply",
    ]

    all_success = check_packages(packages)
    if args.training:
        if not check_packages(packages_training):
            all_success = False

    if all_success:
        print("-----------------------------------------------------------")
        print("\033[92m[SUCCESS]\033[0m Cosmos environment setup is successful!")

    return all_success


if __name__ == "__main__":
    print(f"Enivornment check success ? {main()}")