Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from __future__ import annotations | |
| import collections | |
| import collections.abc | |
| import functools | |
| import json | |
| import os | |
| import random | |
| import time | |
| from contextlib import ContextDecorator | |
| from pathlib import Path | |
| from typing import Any, Callable, List, Optional, Tuple, TypeVar | |
| from urllib.parse import urlparse | |
| import boto3 | |
| import numpy as np | |
| import termcolor | |
| import torch | |
| from torch import nn | |
| from torch.distributed._functional_collectives import AsyncCollectiveTensor | |
| from torch.distributed._tensor.api import DTensor | |
| from cosmos_predict1.utils import distributed, log | |
| from cosmos_predict1.utils.easy_io import easy_io | |
| def to( | |
| data: Any, | |
| device: str | torch.device | None = None, | |
| dtype: torch.dtype | None = None, | |
| memory_format: torch.memory_format = torch.preserve_format, | |
| ) -> Any: | |
| """Recursively cast data into the specified device, dtype, and/or memory_format. | |
| The input data can be a tensor, a list of tensors, a dict of tensors. | |
| See the documentation for torch.Tensor.to() for details. | |
| Args: | |
| data (Any): Input data. | |
| device (str | torch.device): GPU device (default: None). | |
| dtype (torch.dtype): data type (default: None). | |
| memory_format (torch.memory_format): memory organization format (default: torch.preserve_format). | |
| Returns: | |
| data (Any): Data cast to the specified device, dtype, and/or memory_format. | |
| """ | |
| assert ( | |
| device is not None or dtype is not None or memory_format is not None | |
| ), "at least one of device, dtype, memory_format should be specified" | |
| if isinstance(data, torch.Tensor): | |
| is_cpu = (isinstance(device, str) and device == "cpu") or ( | |
| isinstance(device, torch.device) and device.type == "cpu" | |
| ) | |
| data = data.to( | |
| device=device, | |
| dtype=dtype, | |
| memory_format=memory_format, | |
| non_blocking=(not is_cpu), | |
| ) | |
| return data | |
| elif isinstance(data, collections.abc.Mapping): | |
| return type(data)({key: to(data[key], device=device, dtype=dtype, memory_format=memory_format) for key in data}) | |
| elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)): | |
| return type(data)([to(elem, device=device, dtype=dtype, memory_format=memory_format) for elem in data]) | |
| else: | |
| return data | |
| def serialize(data: Any) -> Any: | |
| """Serialize data by hierarchically traversing through iterables. | |
| Args: | |
| data (Any): Input data. | |
| Returns: | |
| data (Any): Serialized data. | |
| """ | |
| if isinstance(data, collections.abc.Mapping): | |
| return type(data)({key: serialize(data[key]) for key in data}) | |
| elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)): | |
| return type(data)([serialize(elem) for elem in data]) | |
| else: | |
| try: | |
| json.dumps(data) | |
| except TypeError: | |
| data = str(data) | |
| return data | |
| def print_environ_variables(env_vars: list[str]) -> None: | |
| """Print a specific list of environment variables. | |
| Args: | |
| env_vars (list[str]): List of specified environment variables. | |
| """ | |
| for env_var in env_vars: | |
| if env_var in os.environ: | |
| log.info(f"Environment variable {Color.green(env_var)}: {Color.yellow(os.environ[env_var])}") | |
| else: | |
| log.warning(f"Environment variable {Color.green(env_var)} not set!") | |
| def set_random_seed(seed: int, by_rank: bool = False) -> None: | |
| """Set random seed. This includes random, numpy, Pytorch. | |
| Args: | |
| seed (int): Random seed. | |
| by_rank (bool): if true, each GPU will use a different random seed. | |
| """ | |
| if by_rank: | |
| seed += distributed.get_rank() | |
| log.info(f"Using random seed {seed}.") | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) # sets seed on the current CPU & all GPUs | |
| def arch_invariant_rand( | |
| shape: List[int] | Tuple[int], dtype: torch.dtype, device: str | torch.device, seed: int | None = None | |
| ): | |
| """Produce a GPU-architecture-invariant randomized Torch tensor. | |
| Args: | |
| shape (list or tuple of ints): Output tensor shape. | |
| dtype (torch.dtype): Output tensor type. | |
| device (torch.device): Device holding the output. | |
| seed (int): Optional randomization seed. | |
| Returns: | |
| tensor (torch.tensor): Randomly-generated tensor. | |
| """ | |
| # Create a random number generator, optionally seeded | |
| rng = np.random.RandomState(seed) | |
| # # Generate random numbers using the generator | |
| random_array = rng.standard_normal(shape).astype(np.float32) # Use standard_normal for normal distribution | |
| # Convert to torch tensor and return | |
| return torch.from_numpy(random_array).to(dtype=dtype, device=device) | |
| T = TypeVar("T", bound=Callable[..., Any]) | |
| class timer(ContextDecorator): # noqa: N801 | |
| """Simple timer for timing the execution of code. | |
| It can be used as either a context manager or a function decorator. The timing result will be logged upon exit. | |
| Example: | |
| def func_a(): | |
| time.sleep(1) | |
| with timer("func_a"): | |
| func_a() | |
| @timer("func_b) | |
| def func_b(): | |
| time.sleep(1) | |
| func_b() | |
| """ | |
| def __init__(self, context: str, debug: bool = False): | |
| self.context = context | |
| self.debug = debug | |
| def __enter__(self) -> None: | |
| self.tic = time.time() | |
| def __exit__(self, exc_type, exc_value, traceback) -> None: # noqa: ANN001 | |
| time_spent = time.time() - self.tic | |
| if self.debug: | |
| log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") | |
| else: | |
| log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") | |
| def __call__(self, func: T) -> T: | |
| def wrapper(*args, **kwargs): # noqa: ANN202 | |
| tic = time.time() | |
| result = func(*args, **kwargs) | |
| time_spent = time.time() - tic | |
| if self.debug: | |
| log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") | |
| else: | |
| log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") | |
| return result | |
| return wrapper # type: ignore | |
| class TrainingTimer: | |
| """Timer for timing the execution of code, aggregating over multiple training iterations. | |
| It is used as a context manager to measure the execution time of code and store the timing results | |
| for each function. The context managers can be nested. | |
| Attributes: | |
| results (dict): A dictionary to store timing results for various code. | |
| Example: | |
| timer = Timer() | |
| for i in range(100): | |
| with timer("func_a"): | |
| func_a() | |
| avg_time = sum(timer.results["func_a"]) / len(timer.results["func_a"]) | |
| print(f"func_a() took {avg_time} seconds.") | |
| """ | |
| def __init__(self) -> None: | |
| self.results = dict() | |
| self.average_results = dict() | |
| self.start_time = [] | |
| self.func_stack = [] | |
| self.reset() | |
| def reset(self) -> None: | |
| self.results = {key: [] for key in self.results} | |
| def __enter__(self) -> TrainingTimer: | |
| self.start_time.append(time.time()) | |
| return self | |
| def __exit__(self, exc_type, exc_value, traceback) -> None: # noqa: ANN001 | |
| end_time = time.time() | |
| result = end_time - self.start_time.pop() | |
| key = self.func_stack.pop() | |
| self.results.setdefault(key, []) | |
| self.results[key].append(result) | |
| def __call__(self, func_name: str) -> TrainingTimer: | |
| self.func_stack.append(func_name) | |
| return self | |
| def __getattr__(self, func_name: str) -> TrainingTimer: | |
| return self.__call__(func_name) | |
| def nested(self, func_name: str) -> TrainingTimer: | |
| return self.__call__(func_name) | |
| def compute_average_results(self) -> dict[str, float]: | |
| results = dict() | |
| for key, value_list in self.results.items(): | |
| results[key] = sum(value_list) / len(value_list) | |
| return results | |
| def timeout_handler(timeout_period: float, signum: int, frame: int) -> None: | |
| # What to do when the process gets stuck. For now, we simply end the process. | |
| error_message = f"Timeout error: more than {timeout_period} seconds passed since the last iteration." | |
| raise TimeoutError(error_message) | |
| class Color: | |
| """A convenience class to colorize strings in the console. | |
| Example: | |
| import | |
| print("This is {Color.red('important')}.") | |
| """ | |
| def red(x: str) -> str: | |
| return termcolor.colored(str(x), color="red") | |
| def green(x: str) -> str: | |
| return termcolor.colored(str(x), color="green") | |
| def cyan(x: str) -> str: | |
| return termcolor.colored(str(x), color="cyan") | |
| def yellow(x: str) -> str: | |
| return termcolor.colored(str(x), color="yellow") | |
| class BufferCnt: | |
| """ | |
| Buffer counter which keeps track of the condition when called and returns True when the condition in met "thres" | |
| amount of times, otherwise returns False. | |
| Example usage: | |
| buf = BufferCnt(thres=3) | |
| for _ in range(5): | |
| if buf(random.random() > 0.5): | |
| print("We got lucky 3 times out of 5.") | |
| Args: | |
| thres (int): The amount of times the expression needs to be True before returning True. | |
| reset_over_thres (bool): Whether to reset the buffer after returning True. | |
| """ | |
| def __init__(self, thres=10, reset_over_thres=False): | |
| self._cnt = 0 | |
| self.thres = thres | |
| self.reset_over_thres = reset_over_thres | |
| def __call__(self, expre, thres=None): | |
| if expre is True: | |
| self._cnt += 1 | |
| else: | |
| self._cnt = 0 | |
| if thres is None: | |
| thres = self.thres | |
| if self._cnt >= thres: | |
| if self.reset_over_thres: | |
| self.reset() | |
| return True | |
| return False | |
| def cnt(self): | |
| return self._cnt | |
| def reset(self): | |
| self._cnt = 0 | |
| def get_local_tensor_if_DTensor(tensor: torch.Tensor | DTensor) -> torch.tensor: | |
| if isinstance(tensor, DTensor): | |
| local = tensor.to_local() | |
| # As per PyTorch documentation, if the communication is not finished yet, we need to wait for it to finish | |
| # https://pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.DTensor.to_local | |
| if isinstance(local, AsyncCollectiveTensor): | |
| return local.wait() | |
| else: | |
| return local | |
| return tensor | |
| def disabled_train(self: Any, mode: bool = True) -> Any: | |
| """Overwrite model.train with this function to make sure train/eval mode | |
| does not change anymore.""" | |
| return self | |
| def count_params(model: nn.Module, verbose=False) -> int: | |
| total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| if verbose: | |
| print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") | |
| return total_params | |
| def expand_dims_like(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
| while x.dim() != y.dim(): | |
| x = x.unsqueeze(-1) | |
| return x | |
| def download_from_s3_with_cache( | |
| s3_path: str, | |
| cache_fp: Optional[str] = None, | |
| cache_dir: Optional[str] = None, | |
| rank_sync: bool = True, | |
| backend_args: Optional[dict] = None, | |
| backend_key: Optional[str] = None, | |
| ) -> str: | |
| """download data from S3 with optional caching. | |
| This function first attempts to load the data from a local cache file. If | |
| the cache file doesn't exist, it downloads the data from S3 to the cache | |
| location. Caching is performed in a rank-aware manner | |
| using `distributed.barrier()` to ensure only one download occurs across | |
| distributed workers (if `rank_sync` is True). | |
| Args: | |
| s3_path (str): The S3 path of the data to load. | |
| cache_fp (str, optional): The path to the local cache file. If None, | |
| a filename will be generated based on `s3_path` within `cache_dir`. | |
| cache_dir (str, optional): The directory to store the cache file. If | |
| None, the environment variable `COSMOS_CACHE_DIR` (defaulting | |
| to "/tmp") will be used. | |
| rank_sync (bool, optional): Whether to synchronize download across | |
| distributed workers using `distributed.barrier()`. Defaults to True. | |
| backend_args (dict, optional): The backend arguments passed to easy_io to construct the backend. | |
| backend_key (str, optional): The backend key passed to easy_io to registry the backend or retrieve the backend if it is already registered. | |
| Returns: | |
| cache_fp (str): The path to the local cache file. | |
| Raises: | |
| FileNotFoundError: If the data cannot be found in S3 or the cache. | |
| """ | |
| cache_dir = os.environ.get("TORCH_HOME") if cache_dir is None else cache_dir | |
| cache_dir = ( | |
| os.environ.get("COSMOS_CACHE_DIR", os.path.expanduser("~/.cache/cosmos")) if cache_dir is None else cache_dir | |
| ) | |
| cache_dir = os.path.expanduser(cache_dir) | |
| if cache_fp is None: | |
| cache_fp = os.path.join(cache_dir, s3_path.replace("s3://", "")) | |
| if not cache_fp.startswith("/"): | |
| cache_fp = os.path.join(cache_dir, cache_fp) | |
| if distributed.get_rank() == 0: | |
| if os.path.exists(cache_fp): | |
| # check the size of cache_fp | |
| if os.path.getsize(cache_fp) < 1: | |
| os.remove(cache_fp) | |
| log.warning(f"Removed empty cache file {cache_fp}.") | |
| if rank_sync: | |
| if not os.path.exists(cache_fp): | |
| log.critical(f"Local cache {cache_fp} Not exist! Downloading {s3_path} to {cache_fp}.") | |
| log.info(f"backend_args: {backend_args}") | |
| log.info(f"backend_key: {backend_key}") | |
| easy_io.copyfile_to_local( | |
| s3_path, cache_fp, dst_type="file", backend_args=backend_args, backend_key=backend_key | |
| ) | |
| log.info(f"Downloaded {s3_path} to {cache_fp}.") | |
| else: | |
| log.info(f"Local cache {cache_fp} already exist! {s3_path} -> {cache_fp}.") | |
| distributed.barrier() | |
| else: | |
| if not os.path.exists(cache_fp): | |
| easy_io.copyfile_to_local( | |
| s3_path, cache_fp, dst_type="file", backend_args=backend_args, backend_key=backend_key | |
| ) | |
| log.info(f"Downloaded {s3_path} to {cache_fp}.") | |
| return cache_fp | |
| def load_from_s3_with_cache( | |
| s3_path: str, | |
| cache_fp: Optional[str] = None, | |
| cache_dir: Optional[str] = None, | |
| rank_sync: bool = True, | |
| backend_args: Optional[dict] = None, | |
| backend_key: Optional[str] = None, | |
| easy_io_kwargs: Optional[dict] = None, | |
| ) -> Any: | |
| """Loads data from S3 with optional caching. | |
| This function first attempts to load the data from a local cache file. If | |
| the cache file doesn't exist, it downloads the data from S3 to the cache | |
| location and then loads it. Caching is performed in a rank-aware manner | |
| using `distributed.barrier()` to ensure only one download occurs across | |
| distributed workers (if `rank_sync` is True). | |
| Args: | |
| s3_path (str): The S3 path of the data to load. | |
| cache_fp (str, optional): The path to the local cache file. If None, | |
| a filename will be generated based on `s3_path` within `cache_dir`. | |
| cache_dir (str, optional): The directory to store the cache file. If | |
| None, the environment variable `COSMOS_CACHE_DIR` (defaulting | |
| to "/tmp") will be used. | |
| rank_sync (bool, optional): Whether to synchronize download across | |
| distributed workers using `distributed.barrier()`. Defaults to True. | |
| backend_args (dict, optional): The backend arguments passed to easy_io to construct the backend. | |
| backend_key (str, optional): The backend key passed to easy_io to registry the backend or retrieve the backend if it is already registered. | |
| Returns: | |
| Any: The loaded data from the S3 path or cache file. | |
| Raises: | |
| FileNotFoundError: If the data cannot be found in S3 or the cache. | |
| """ | |
| cache_fp = download_from_s3_with_cache(s3_path, cache_fp, cache_dir, rank_sync, backend_args, backend_key) | |
| if easy_io_kwargs is None: | |
| easy_io_kwargs = {} | |
| return easy_io.load(cache_fp, **easy_io_kwargs) | |
| def sync_s3_dir_to_local( | |
| s3_dir: str, | |
| s3_credential_path: str, | |
| cache_dir: Optional[str] = None, | |
| rank_sync: bool = True, | |
| ) -> str: | |
| """ | |
| Download an entire directory from S3 to the local cache directory. | |
| Args: | |
| s3_dir (str): The AWS S3 directory to download. | |
| s3_credential_path (str): The path to the AWS S3 credentials file. | |
| rank_sync (bool, optional): Whether to synchronize download across | |
| distributed workers using `distributed.barrier()`. Defaults to True. | |
| cache_dir (str, optional): The cache folder to sync the S3 directory to. | |
| If None, the environment variable `COSMOS_CACHE_DIR` (defaulting | |
| to "~/.cache/cosmos") will be used. | |
| Returns: | |
| local_dir (str): The path to the local directory. | |
| """ | |
| if not s3_dir.startswith("s3://"): | |
| # If the directory exists locally, return the local path | |
| assert os.path.exists(s3_dir), f"{s3_dir} is not a S3 path or a local path." | |
| return s3_dir | |
| # Load AWS credentials from the file | |
| with open(s3_credential_path, "r") as f: | |
| credentials = json.load(f) | |
| # Create an S3 client | |
| s3 = boto3.client( | |
| "s3", | |
| **credentials, | |
| ) | |
| # Parse the S3 URL | |
| parsed_url = urlparse(s3_dir) | |
| source_bucket = parsed_url.netloc | |
| source_prefix = parsed_url.path.lstrip("/") | |
| # If the local directory is not specified, use the default cache directory | |
| cache_dir = ( | |
| os.environ.get("COSMOS_CACHE_DIR", os.path.expanduser("~/.cache/cosmos")) if cache_dir is None else cache_dir | |
| ) | |
| cache_dir = os.path.expanduser(cache_dir) | |
| Path(cache_dir).mkdir(parents=True, exist_ok=True) | |
| # List objects in the bucket with the given prefix | |
| response = s3.list_objects_v2(Bucket=source_bucket, Prefix=source_prefix) | |
| # Download each matching object | |
| for obj in response.get("Contents", []): | |
| if obj["Key"].startswith(source_prefix): | |
| # Create the full path for the destination file, preserving the directory structure | |
| rel_path = os.path.relpath(obj["Key"], source_prefix) | |
| dest_path = os.path.join(cache_dir, source_prefix, rel_path) | |
| # Ensure the directory exists | |
| os.makedirs(os.path.dirname(dest_path), exist_ok=True) | |
| # Check if the file already exists | |
| if os.path.exists(dest_path): | |
| continue | |
| else: | |
| log.info(f"Downloading {obj['Key']} to {dest_path}") | |
| # Download the file | |
| if not rank_sync or distributed.get_rank() == 0: | |
| s3.download_file(source_bucket, obj["Key"], dest_path) | |
| if rank_sync: | |
| distributed.barrier() | |
| local_dir = os.path.join(cache_dir, source_prefix) | |
| return local_dir | |