Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from __future__ import annotations | |
| import argparse | |
| import os | |
| import signal | |
| from loguru import logger as log | |
| from v2v_utils import move_to_device, clone_tensors | |
| TORCHRUN_DEFAULT_MASTER_ADDR = 'localhost' | |
| TORCHRUN_DEFAULT_MASTER_PORT = 12355 | |
| def _get_inference_class(cosmos_variant: str): | |
| if cosmos_variant == 'predict1': | |
| from cosmos_predict1.diffusion.inference.gen3c_persistent import Gen3cPersistentModel | |
| from cosmos_predict1.utils.distributed import is_rank0 | |
| return Gen3cPersistentModel, is_rank0 | |
| else: | |
| raise ValueError(f"Unsupported cosmos variant: {cosmos_variant}") | |
| def _inference_worker(rank: int, args: argparse.Namespace, | |
| gpu_count: int, | |
| cosmos_variant: str, | |
| input_queues: 'list[torch.multiprocessing.Queue]', | |
| result_queue: 'torch.multiprocessing.Queue', | |
| attrs_queue: 'torch.multiprocessing.Queue'): | |
| """ | |
| One such function will run, in a separate process, for each GPU. | |
| Each process loads the model and keeps it in memory. | |
| """ | |
| log.debug(f'inference_worker for rank {rank} starting, doing imports now') | |
| import torch | |
| import torch.distributed as dist | |
| InferenceAR, is_tp_cp_pp_rank0 = _get_inference_class(cosmos_variant) | |
| log.debug(f'inference_worker for rank {rank} done with imports.') | |
| # The FQDN of the host that is running worker with rank 0; used to initialize the Torch Distributed backend. | |
| os.environ.setdefault("MASTER_ADDR", TORCHRUN_DEFAULT_MASTER_ADDR) | |
| # The port on the MASTER_ADDR that can be used to host the C10d TCP store. | |
| os.environ.setdefault("MASTER_PORT", str(TORCHRUN_DEFAULT_MASTER_PORT)) | |
| # The local rank. | |
| os.environ["LOCAL_RANK"] = str(rank) | |
| # The global rank. | |
| os.environ["RANK"] = str(rank) | |
| # The rank of the worker group. A number between 0 and max_nnodes. When running a single worker group per node, this is the rank of the node. | |
| os.environ["GROUP_RANK"] = str(rank) | |
| # The rank of the worker across all the workers that have the same role. The role of the worker is specified in the WorkerSpec. | |
| os.environ["ROLE_RANK"] = str(rank) | |
| # The local world size (e.g. number of workers running locally); equals to --nproc-per-node specified on torchrun. | |
| os.environ["LOCAL_WORLD_SIZE"] = str(gpu_count) | |
| # The world size (total number of workers in the job). | |
| os.environ["WORLD_SIZE"] = str(gpu_count) | |
| # The total number of workers that was launched with the same role specified in WorkerSpec. | |
| os.environ["ROLE_WORLD_SIZE"] = str(gpu_count) | |
| # # The number of worker group restarts so far. | |
| # os.environ["TORCHELASTIC_RESTART_COUNT"] = TODO | |
| # # The configured maximum number of restarts. | |
| # os.environ["TORCHELASTIC_MAX_RESTARTS"] = TODO | |
| # # Equal to the rendezvous run_id (e.g. unique job id). | |
| # os.environ["TORCHELASTIC_RUN_ID"] = TODO | |
| # # System executable override. If provided, the python user script will use the value of PYTHON_EXEC as executable. The sys.executable is used by default. | |
| # os.environ["PYTHON_EXEC"] = TODO | |
| # We're already parallelizing over the context, so we can't also parallelize inside the tokenizers (?) | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| device = f"cuda:{rank}" | |
| torch.cuda.set_device(rank) | |
| input_queue = input_queues[rank] | |
| del input_queues | |
| # Load model once | |
| log.debug(f'inference_worker for rank {rank} creating the model object now') | |
| local_model = InferenceAR(args) | |
| del args | |
| log.debug(f'inference_worker for rank {rank} ready, pushing a "ready" message to the queue') | |
| result_queue.put((rank, "ready")) | |
| # Install interrupt signal handler so that we can shut down gracefully. | |
| should_quit = False | |
| def signal_handler(signum, frame): | |
| nonlocal should_quit | |
| log.info(f"[RANK{rank}] Received signal {signum}, shutting down") | |
| should_quit = True | |
| try: | |
| input_queue.put(None) | |
| except ValueError: | |
| pass | |
| signal.signal(signal.SIGINT, signal_handler) | |
| while not should_quit: | |
| try: | |
| inputs_task = input_queue.get() | |
| except ValueError: | |
| # Queue was closed, we can exit. | |
| log.debug(f"[RANK{rank}] Input queue was closed, exiting.") | |
| break | |
| if inputs_task is None: | |
| # Special sentinel value to indicate that we are done and can exit. | |
| log.debug(f"[RANK{rank}] Got input {inputs_task}, exiting.") | |
| break | |
| # Note: we don't need to chunk the inputs for this rank / process, this is done | |
| # automatically in the model. | |
| # Note: we don't need to move the inputs to a specific device either since the | |
| # Gen3C API expects NumPy arrays. | |
| if False: | |
| log.debug(f"[RANK{rank}] Moving task to {device=}") | |
| inputs_task = move_to_device(inputs_task, device) | |
| # Run the requested task | |
| with torch.no_grad(): | |
| task_type, args, kwargs = inputs_task | |
| log.debug(f"[RANK{rank}] Got task: {task_type=}") | |
| if task_type == 'inference': | |
| log.debug(f"[RANK{rank}] Running `inference_on_cameras()`...") | |
| output = local_model.inference_on_cameras(*args, **kwargs) | |
| log.debug(f"[RANK{rank}] Done `inference_on_cameras()`!") | |
| if is_tp_cp_pp_rank0(): | |
| log.debug(f"[RANK{rank}] Moving outputs of `inference_on_cameras()` to the CPU") | |
| output = move_to_device(output, device='cpu') | |
| log.debug(f"[RANK{rank}] Pushing outputs of `inference_on_cameras()` to the results queue") | |
| result_queue.put(output) | |
| elif task_type == 'seeding': | |
| log.debug(f"[RANK{rank}] Calling `seed_model_from_values()...`") | |
| if cosmos_variant == 'predict1': | |
| output = local_model.seed_model_from_values(*args, **kwargs) | |
| else: | |
| raise NotImplementedError(f"Unsupported cosmos variant: {cosmos_variant}") | |
| output = move_to_device(output, device='cpu') | |
| result_queue.put((rank, "seed_model_from_values_done", output)) | |
| log.debug(f"[RANK{rank}] Done with `seed_model_from_values()`") | |
| elif task_type == 'clear_cache': | |
| log.debug(f"[RANK{rank}] Calling `clear_cache()...`") | |
| local_model.clear_cache() | |
| result_queue.put((rank, "clear_cache_done")) | |
| log.debug(f"[RANK{rank}] Done with `clear_cache()`") | |
| elif task_type == 'get_cache_input_depths': | |
| log.debug(f"[RANK{rank}] Calling `get_cache_input_depths()...`") | |
| input_depths = local_model.get_cache_input_depths() | |
| attrs_queue.put(('cache_input_depths', input_depths.cpu(), True)) | |
| log.debug(f"[RANK{rank}] Done with `get_cache_input_depths()`") | |
| elif task_type == 'getattr': | |
| assert kwargs is None | |
| assert len(args) == 1 | |
| attr_name = args[0] | |
| assert isinstance(attr_name, str) | |
| has_attr = hasattr(local_model, attr_name) | |
| attr_value_or_none = getattr(local_model, attr_name) | |
| if has_attr and (attr_value_or_none is not None) and torch.is_tensor(attr_value_or_none): | |
| log.debug(f"[RANK{rank}] Attribute {attr_name=} is a torch tensor on " | |
| f"device {attr_value_or_none.device}, cloning it before sending it through the queue") | |
| attr_value_or_none = attr_value_or_none.clone() | |
| log.debug(f"[RANK{rank}] Pushing attribute value for {attr_name=}") | |
| attrs_queue.put((attr_name, attr_value_or_none, has_attr)) | |
| else: | |
| raise NotImplementedError(f"Unsupported task type for Cosmos inference worker: {task_type}") | |
| # Cleanup before exiting | |
| local_model.cleanup() | |
| del local_model | |
| def inference_worker(*args, **kwargs): | |
| try: | |
| _inference_worker(*args, **kwargs) | |
| except Exception as e: | |
| import traceback | |
| rank = os.environ.get("LOCAL_RANK", "(unknown)") | |
| log.error(f"[RANK{rank}] encountered exception: {e}. Will re-raise after cleanup." | |
| f" Stack trace:\n{traceback.format_exc()}") | |
| try: | |
| import torch.distributed as dist | |
| dist.destroy_process_group() | |
| log.info(f"[RANK{rank}] Destroyed model parallel group after catching exception." | |
| " Will re-raise now.") | |
| except Exception as _: | |
| pass | |
| raise e | |
| class MultiGPUInferenceAR(): | |
| """ | |
| Adapter class to run multi-GPU Cosmos inference in the context of the FastAPI inference server. | |
| This class implements the same interface as `InferenceAR`, but spawns one process per GPU and | |
| forwards inference requests to the multiple processes via a work queue. | |
| The worker processes wait for work from the queue, perform inference, and gather all results | |
| on the rank 0 process. That process then pushes results to the result queue. | |
| """ | |
| def __init__(self, gpu_count: int, cosmos_variant: str, args: argparse.Namespace): | |
| import torch | |
| import torch.multiprocessing as mp | |
| self.gpu_count = gpu_count | |
| assert self.gpu_count <= torch.cuda.device_count(), \ | |
| f"Requested {self.gpu_count} GPUs, but only {torch.cuda.device_count()} are available." | |
| ctx = mp.get_context('spawn') | |
| manager = ctx.Manager() | |
| self.input_queues: list[mp.Queue] = [ctx.Queue() for _ in range(self.gpu_count)] | |
| self.result_queue = manager.Queue() | |
| self.attrs_queue = manager.Queue() | |
| log.info(f"Spawning {self.gpu_count} processes (one per GPU)") | |
| self.ctx = mp.spawn( | |
| inference_worker, | |
| args=(args, self.gpu_count, cosmos_variant, | |
| self.input_queues, self.result_queue, self.attrs_queue), | |
| nprocs=self.gpu_count, | |
| join=False | |
| ) | |
| log.info(f"Waiting for {self.gpu_count} processes to load the model...") | |
| for _ in range(self.gpu_count): | |
| v = self.result_queue.get() | |
| if not isinstance(v, tuple) or len(v) != 2 or v[1] != "ready": | |
| raise ValueError(f"Expected a 'ready' message from each process, but received: {v}") | |
| log.info(f"Process {v[0]} is ready.") | |
| def inference_on_cameras(self, *args, **kwargs): | |
| log.debug(f"inference_on_cameras(): submitting request to {len(self.input_queues)} inference processes.") | |
| for iq in self.input_queues: | |
| # Send the same input to each process | |
| task = ('inference', args, kwargs) | |
| iq.put(task) | |
| # Wait on the result queue to produce the result (this could take a while). | |
| log.debug(f"inference_on_cameras(): waiting for result...") | |
| outputs = self.result_queue.get() | |
| log.debug(f"inference_on_cameras(): got inference results! Cloning and returning.") | |
| return clone_tensors(outputs) | |
| def seed_model_from_values(self, *args, **kwargs): | |
| log.debug(f"seed_model_from_values(): submitting request to {len(self.input_queues)} inference processes.") | |
| for iq in self.input_queues: | |
| task = ('seeding', args, kwargs) | |
| iq.put(task) | |
| # TODO: refactor this, and maybe use some events or another primitive | |
| log.info(f"Waiting for {self.gpu_count} processes to be done with seeding...") | |
| for i in range(self.gpu_count): | |
| v = self.result_queue.get() | |
| if not isinstance(v, tuple) or len(v) != 3 or v[1] != "seed_model_from_values_done": | |
| raise ValueError(f"Expected a 'seed_model_from_values_done' message from each process, but received: {v}") | |
| log.info(f"Process {v[0]} is done with `seed_model_from_values()`.") | |
| # Arbitrarily pick the output from the first process | |
| if i == 0: | |
| outputs = v[2] | |
| return clone_tensors(outputs) | |
| def clear_cache(self): | |
| for iq in self.input_queues: | |
| task = ('clear_cache', None, None) | |
| iq.put(task) | |
| # TODO: refactor this, and maybe use some events or another primitive | |
| log.info(f"Waiting for {self.gpu_count} processes to be done with clear_cache...") | |
| for _ in range(self.gpu_count): | |
| v = self.result_queue.get() | |
| if not isinstance(v, tuple) or len(v) != 2 or v[1] != "clear_cache_done": | |
| raise ValueError(f"Expected a 'clear_cache_done' message from each process, but received: {v}") | |
| log.info(f"Process {v[0]} is done with `clear_cache()`.") | |
| def get_cache_input_depths(self): | |
| name = 'cache_input_depths' | |
| task = ('get_cache_input_depths', None, None) | |
| self.input_queues[0].put(task) | |
| # TODO: refactor this, and maybe use some events or another primitive | |
| looked_up_name, value, exists = self.attrs_queue.get() | |
| if looked_up_name != name: | |
| # TODO: this could be handled better (retry or enforce some ordering maybe). | |
| raise ValueError(f"Queried model for attribute '{name}' but got attribute '{looked_up_name}'," | |
| " there was likely a race condition.") | |
| log.debug(f"Got a valid response, returning value for `get_cache_input_depths()`") | |
| return value | |
| def __getattr__(self, name: str): | |
| log.debug(f"__getattr__({name=}) called") | |
| # Note: this will not be called for methods we implement here, or attributes | |
| # that actually exist in this object. | |
| # Query the attribute from rank 0 (arbitrarily) | |
| task = ('getattr', (name,), None) | |
| self.input_queues[0].put(task) | |
| # Get result (blocking) | |
| log.debug(f"Waiting for response on `attrs_queue`...") | |
| looked_up_name, value, exists = self.attrs_queue.get() | |
| if looked_up_name != name: | |
| # TODO: this could be handled better (retry or enforce some ordering maybe). | |
| raise ValueError(f"Queried model for attribute '{name}' but got attribute '{looked_up_name}'," | |
| " there was likely a race condition.") | |
| if not exists: | |
| raise AttributeError(f"Model has no attribute named '{name}'") | |
| log.debug(f"Got a valid response, returning {name} == {value}") | |
| return value | |
| def cleanup(self): | |
| """ | |
| Clean up resources before shutting down. | |
| """ | |
| log.info(f"MultiGPUInferenceAR winding down, asking {len(self.input_queues)} processes to clean up.") | |
| # "Close" all queues (there's no actual `close` method in PyTorch MP queues) | |
| for iq in self.input_queues: | |
| iq.put(None) | |
| # Wait for all processes to finish | |
| log.info(f"Waiting for {len(self.input_queues)} processes to finish (join).") | |
| self.ctx.join() | |
| log.info(f"{len(self.input_queues)} processes have finished.") | |