Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import threading | |
| from collections import namedtuple | |
| from typing import Any, Dict, Optional, Set, Tuple, Union | |
| import torch | |
| import torch.distributed | |
| from megatron.core import parallel_state | |
| from torch.distributed import ProcessGroup, get_process_group_ranks | |
| from cosmos_predict1.utils import distributed, log, misc | |
| from cosmos_predict1.utils.base import AbstractCheckpointer | |
| from cosmos_predict1.utils.checkpointer.safe_broadcast import broadcast_object | |
| from cosmos_predict1.utils.easy_io import easy_io | |
| from cosmos_predict1.utils.model import Model | |
| StateDictItemPath = namedtuple("StateDictItemPath", ["state_dict", "save_path"]) | |
| class Checkpointer(AbstractCheckpointer): | |
| """ | |
| Checkpointer for DDP. | |
| Note: This implementation only supports local filesystem. | |
| """ | |
| KEYS_TO_SAVE = ["model", "optim", "scheduler", "trainer"] | |
| KEYS_TO_POSTFIX = { | |
| "model": "model", | |
| "optim": "optim", | |
| "scheduler": "scheduler", | |
| "trainer": "", | |
| } | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| pp_world_size = parallel_state.get_pipeline_model_parallel_world_size() | |
| ep_world_size = parallel_state.get_expert_model_parallel_world_size() | |
| assert pp_world_size < 2, "Pipeline Parallelism (PP) is not tested yet." | |
| assert ep_world_size < 2, "Expert Parallelism (EP) is not tested yet." | |
| self.mp_world_size = parallel_state.get_model_parallel_group().size() | |
| if self.mp_world_size > 1 and self.__class__ == Checkpointer: | |
| raise NotImplementedError( | |
| "Model Parallelism (MP) is enabled - " | |
| "you should use TensorParallel Checkpointer instead of DDP Checkpointer." | |
| ) | |
| # DDP rank (with context parallelism considered) | |
| self.rank_dp_w_cp = parallel_state.get_data_parallel_rank(with_context_parallel=True) | |
| # Context parallelism rank | |
| self.cp_rank = parallel_state.get_context_parallel_rank() | |
| # Model parallelism rank (including Tensor+Pipeline+Expert Parallelisms) | |
| self.mp_rank = parallel_state.get_model_parallel_group().rank() | |
| # self.mp_rank = parallel_state.get_model_parallel_group(with_expert_parallel=ep_world_size > 1).rank() | |
| if self.broadcast_via_filesystem: | |
| log.info("Broadcasting checkpoint data via the local filesystem.") | |
| if not self.strict_resume: | |
| log.warning("Strict resume mode is off. Some model parameters may not be loaded.") | |
| # collect ranks of all model parallel groups | |
| all_ranks = [None for _ in range(distributed.get_world_size())] | |
| torch.distributed.all_gather_object( | |
| all_ranks, get_process_group_ranks(parallel_state.get_model_parallel_group()) | |
| ) | |
| all_ranks = list(set(tuple(rank) if isinstance(rank, list) else rank for rank in all_ranks)) | |
| for ranks in all_ranks: | |
| group = torch.distributed.new_group(list(ranks), backend="gloo") | |
| if distributed.get_rank() in ranks: | |
| self.mp_gloo_pg = group | |
| self.print("Checkpointer Initialized.") | |
| def print(self, message: str): | |
| """ | |
| Print message to the console. Include the parallelism rank information when verbose is set to True. | |
| """ | |
| if self.verbose: | |
| log.info( | |
| f"[Parallelism Rank: DP-{self.rank_dp_w_cp}, TP-{self.mp_rank}, CP-{self.cp_rank}]: {message}", | |
| rank0_only=False, | |
| ) | |
| else: | |
| log.info(message, rank0_only=True) | |
| def add_type_postfix_to_checkpoint_path(self, key: str, checkpoint_path: str, model: Model) -> str: | |
| del model | |
| assert key in self.KEYS_TO_SAVE | |
| post_fix = self.KEYS_TO_POSTFIX[key] | |
| if post_fix: | |
| _ckpt_path = checkpoint_path.replace(".pt", f"_{post_fix}.pt") | |
| else: | |
| _ckpt_path = checkpoint_path | |
| return _ckpt_path | |
| def save( | |
| self, | |
| model: Model, | |
| optimizer: torch.optim.Optimizer, | |
| scheduler: torch.optim.lr_scheduler.LRScheduler, | |
| grad_scaler: torch.amp.GradScaler, | |
| iteration: int, | |
| ) -> None: | |
| """Save network weights, optimizer parameters, scheduler parameters to a checkpoint. | |
| Args: | |
| model (Model): The PyTorch model. | |
| optimizer (torch.optim.Optimizer): The model optimizer. | |
| scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. | |
| grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). | |
| iteration (int): Current iteration number. | |
| """ | |
| self.callbacks.on_save_checkpoint_start(model, iteration) | |
| checkpoint_file = self.format_checkpoint_filename(model, iteration) | |
| state_dict = self.generate_save_state_dict(model, optimizer, scheduler, grad_scaler, iteration) | |
| state_dict = self._map_state_dict_path_during_save(state_dict, checkpoint_file, model) | |
| if state_dict: | |
| # Wait for previous saver thread to end. | |
| if self.save_thread: | |
| self.save_thread.join() | |
| # Run the checkpoint saver in a separate thread. | |
| self.save_thread = threading.Thread( | |
| target=self._save_worker, | |
| daemon=False, | |
| args=(state_dict, checkpoint_file, distributed.get_rank()), | |
| ) | |
| self.save_thread.start() | |
| # Note: Checkpoints are saved on a separate thread and this callback is not accurate. | |
| # Please check logs from on_save_checkpoint_success() for better accuracy | |
| self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration) | |
| def _map_state_dict_path_during_save(self, state_dict, checkpoint_file, model) -> dict[str, StateDictItemPath]: | |
| new_dict = {} | |
| for key, _state_dict in state_dict.items(): | |
| _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_file, model) | |
| checkpoint_path = os.path.join(self.save_dirname, _ckpt_path) | |
| new_dict[key] = StateDictItemPath(_state_dict, checkpoint_path) | |
| return new_dict | |
| def _save_worker(self, state_dict: dict[str, StateDictItemPath], checkpoint_file: str, rank: int = 0) -> None: | |
| """Worker to save checkpoint to disk, spawned with a child thread (in parallel with the training). | |
| Args: | |
| state_dict (dict[str, StateDictItemPath]): The state dict of the model/optimizer/scheduler. | |
| checkpoint_file (str): The file name of the model checkpoint. | |
| rank (int): GPU device (default: 0). | |
| """ | |
| try: | |
| for key, item in state_dict.items(): | |
| self.print(f"Saving {key} to {item.save_path}") | |
| try: | |
| easy_io.dump( | |
| item.state_dict, | |
| item.save_path, | |
| fast_backend=True, # optional for fast backend, cpu heavy | |
| ) | |
| self.print(f"Saved {key} to {item.save_path}") | |
| except Exception as e: | |
| self.print(f"Failed to save {key} to {item.save_path}: {str(e)}") | |
| raise # Re-raise the exception after logging | |
| # Synchronize only rank 0 of each model parallel group | |
| if self.mp_world_size > 1: | |
| torch.distributed.barrier(group=self.mp_gloo_pg) | |
| # Only rank 0 of MP group and rank 0 of DP with CP updates latest_checkpoint.txt | |
| if self.mp_rank == 0 and self.rank_dp_w_cp == 0: | |
| self._write_latest_checkpoint_file(checkpoint_file) | |
| if distributed.get_rank() == 0: # only rank 0 saves trained_data_record | |
| if "trained_data_record" in state_dict["model"].state_dict: | |
| self._write_trained_data_record( | |
| checkpoint_file, state_dict["model"].state_dict["trained_data_record"] | |
| ) | |
| iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", "")) | |
| self.callbacks.on_save_checkpoint_success(iteration=iteration) | |
| except Exception as e: # noqa: BLE001 | |
| log.exception(f"Checkpoint failed to upload: {e}", rank0_only=not self.verbose) | |
| def format_checkpoint_filename(self, model: Model, iteration: int) -> str: | |
| """Generate the checkpoint file name. | |
| Args: | |
| iteration (int): The current iteration number. | |
| Returns: | |
| checkpoint_file (str): The checkpoint file name. | |
| """ | |
| del self, model | |
| return f"iter_{iteration:09}.pt" | |
| def generate_save_state_dict( | |
| self, | |
| model: Model, | |
| optimizer: torch.optim.Optimizer, | |
| scheduler: torch.optim.lr_scheduler.LRScheduler, | |
| grad_scaler: torch.amp.GradScaler, | |
| iteration: int, | |
| ) -> Optional[Dict[str, Any]]: | |
| state_dict = {} | |
| if self.rank_dp_w_cp == 0: | |
| trainer_state = dict( | |
| grad_scaler=grad_scaler.state_dict(), | |
| iteration=iteration, | |
| ) | |
| model_state = model.state_dict() | |
| optim_state = optimizer.state_dict() | |
| scheduler_state = scheduler.state_dict() | |
| self.callbacks.on_save_checkpoint(model, state_dict=trainer_state) | |
| trainer_state, model_state, optim_state, scheduler_state = misc.to( | |
| [trainer_state, model_state, optim_state, scheduler_state], device="cpu" | |
| ) | |
| state_dict = { | |
| "model": model_state, | |
| "optim": optim_state, | |
| "scheduler": scheduler_state, | |
| } | |
| if distributed.get_rank() == 0: # only rank 0 saves trainer state | |
| state_dict["trainer"] = trainer_state | |
| return state_dict | |
| return state_dict | |
| def load_broadcast_state_dict(self, checkpoint_path: str, model: Model, resume_keys: Set) -> dict[str, Any]: | |
| """ | |
| Load state_dict and broadcast. | |
| The main steps are: | |
| 1. Download TP-rank-specific checkpoints for every GPU of DDP-rank 0 and CP-rank 0. | |
| 2. Each rank loads its corresponding checkpoint from the local cache or receives it via broadcast. | |
| This approach ensures that each MP rank loads its specific part of the model, which is | |
| crucial for Model Parallelism where different parts of the model are distributed across | |
| multiple GPUs. | |
| When using Model Parallelism (e.g., Tensor Parallelism), the `broadcast_via_filesystem` option can | |
| be set to True. This allows each rank to load its specific checkpoint from the local filesystem | |
| instead of receiving it via network broadcast, which could be more efficient in some cases. | |
| For standard DDP without TP, `broadcast_via_filesystem` should remain False (default). | |
| Args: | |
| checkpoint_path (str): The base path of the checkpoint. | |
| model (Model): The model being loaded. | |
| resume_keys (Set): Set of keys to resume from the checkpoint. | |
| Returns: | |
| dict[str, Any]: A dictionary containing the loaded state for each resumed key. | |
| """ | |
| state_dict = {} | |
| sorted_resume_keys = sorted(resume_keys) | |
| # Step 1: Download TP-rank-specific checkpoints for every GPU of DDP-rank 0 and CP-rank 0. | |
| if self.rank_dp_w_cp == 0: | |
| for key in sorted_resume_keys: | |
| _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) | |
| local_cache_path = os.path.join(self.load_dirname, os.path.basename(_ckpt_path)) | |
| if os.path.exists(local_cache_path): | |
| # If the local checkpoint exists, we can directly load it | |
| self.print(f"Checkpoint is already in local cache: {local_cache_path}. Loading...") | |
| _state_dict = easy_io.load(local_cache_path, fast_backend=True) | |
| else: | |
| _state_dict = easy_io.load(_ckpt_path, fast_backend=True) | |
| self.print(f"Downloading checkpoint from: {_ckpt_path}") | |
| if self.broadcast_via_filesystem: | |
| # Save the checkpoint to the local filesystem | |
| easy_io.dump(_state_dict, local_cache_path, fast_backend=True) | |
| state_dict[key] = _state_dict | |
| # Ensure all ranks wait for the download to complete | |
| distributed.barrier() | |
| # Step 2: Broadcast checkpoint data | |
| log.info( | |
| "Start broadcasting checkpoint from the source rank to all other ranks in the same DDP group.", | |
| rank0_only=True, | |
| ) | |
| for key in sorted_resume_keys: | |
| if self.broadcast_via_filesystem: | |
| # Load the checkpoint from the local filesystem for other ranks | |
| if self.rank_dp_w_cp != 0: | |
| _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) | |
| local_cache_path = os.path.join(self.load_dirname, os.path.basename(_ckpt_path)) | |
| self.print(f"Loading checkpoint from: {local_cache_path}") | |
| state_dict[key] = easy_io.load(local_cache_path, fast_backend=True) | |
| else: | |
| # Broadcast the checkpoint to all GPUs of the current DDP rank | |
| group: ProcessGroup = parallel_state.get_data_parallel_group(with_context_parallel=True) | |
| min_rank = min(get_process_group_ranks(group)) | |
| _state_dict = broadcast_object( | |
| state_dict[key] if self.rank_dp_w_cp == 0 else None, | |
| min_rank, | |
| group=group, | |
| device=torch.device(torch.cuda.current_device()), | |
| ) | |
| if self.rank_dp_w_cp == 0: | |
| self.print(f'Broadcasted checkpoint["{key}"] to all other ranks in the same DDP group.') | |
| else: | |
| state_dict[key] = _state_dict | |
| self.print(f'Received checkpoint["{key}"] from source rank {min_rank}.') | |
| return state_dict | |
| def keys_to_resume_during_load(self) -> Tuple[Set, Union[str, None]]: | |
| latest_checkpoint_file = self._read_latest_checkpoint_file() | |
| resume_keys = [] | |
| if latest_checkpoint_file is not None: | |
| # 1. Resume training from latest_checkpoint.txt under the same name. | |
| checkpoint_path = os.path.join(self.load_dirname, latest_checkpoint_file) | |
| resume_keys.extend(self.KEYS_TO_SAVE) | |
| else: | |
| if self.load_path: | |
| # 2. Load the module weights specified by config_checkpoint.path. | |
| checkpoint_path = self.load_path | |
| if self.load_training_state: | |
| resume_keys.extend(self.KEYS_TO_SAVE) | |
| else: | |
| resume_keys.append("model") | |
| if self.only_load_scheduler_state: | |
| resume_keys.append("scheduler") | |
| else: | |
| checkpoint_path = None | |
| if len(self.keys_not_to_resume) > 0: | |
| for key in self.keys_not_to_resume: | |
| assert key in self.KEYS_TO_SAVE, f"Invalid key to resume: {key} not in {self.KEYS_TO_SAVE}" | |
| resume_keys = [key for key in resume_keys if key not in self.keys_not_to_resume] | |
| return set(resume_keys), checkpoint_path | |
| def load( | |
| self, | |
| model: Model, | |
| optimizer: torch.optim.Optimizer | None = None, | |
| scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, | |
| grad_scaler: torch.amp.GradScaler | None = None, | |
| ) -> int: | |
| """Load network weights and optimizer states from a checkpoint in a single process. | |
| The priority of the checkpoint loading logic is: | |
| 1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name. | |
| 2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path. | |
| - This is typically used for inference mode. | |
| - If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states. | |
| 3. If none of the above, randomly initialize the model parameters and train from scratch. | |
| Args: | |
| model (Model): The PyTorch model. | |
| optimizer (torch.optim.Optimizer | None): The model optimizer (default: None). | |
| scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None). | |
| grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training). | |
| Returns: | |
| iteration (int): the iteration number to start/resume from. | |
| """ | |
| self.callbacks.on_load_checkpoint_start(model) | |
| resume_keys, checkpoint_path = self.keys_to_resume_during_load() | |
| iteration = 0 | |
| # Load checkpoint. | |
| if checkpoint_path is not None: | |
| self._check_checkpoint_exists(checkpoint_path) | |
| state_dict = self.load_broadcast_state_dict(checkpoint_path, model, set(resume_keys)) | |
| if "trainer" in state_dict: | |
| trainer_state = state_dict["trainer"] | |
| log.critical(state_dict.keys(), rank0_only=False) | |
| log.critical(trainer_state, rank0_only=False) | |
| log.info("- Loading the gradient scaler...") | |
| grad_scaler.load_state_dict(trainer_state["grad_scaler"]) | |
| self.callbacks.on_load_checkpoint(model, state_dict=trainer_state) | |
| iteration = trainer_state["iteration"] | |
| if "optim" in state_dict: | |
| assert optimizer | |
| optimizer_state = state_dict["optim"] | |
| log.info("- Loading the optimizer...") | |
| optimizer.load_state_dict(optimizer_state) | |
| if "scheduler" in state_dict: | |
| assert scheduler | |
| scheduler_state = state_dict["scheduler"] | |
| log.info("- Loading the scheduler...") | |
| scheduler.load_state_dict(scheduler_state) | |
| scheduler.last_epoch = iteration | |
| if "model" in state_dict: | |
| model_state = state_dict["model"] | |
| log.info("- Loading the model...") | |
| # model.load_state_dict(model_state) | |
| if self.strict_resume: | |
| log.info("\t Strict resume mode is on.") | |
| else: | |
| log.info("\t Strict resume mode is off.") | |
| model_load_info = model.load_state_dict(model_state, strict=self.strict_resume) | |
| log.info(f"\t {model_load_info}") | |
| self.print(f"Loaded checkpoint from {checkpoint_path} in iteration {iteration}") | |
| else: | |
| log.info("Training from scratch.") | |
| torch.cuda.empty_cache() | |
| self.callbacks.on_load_checkpoint_end(model) | |
| return iteration | |
| def _write_trained_data_record(self, checkpoint_file: str, trained_data_record: dict[str, int]) -> None: | |
| """Write json file to save number of seen samples and number of iterations. | |
| Args: | |
| checkpoint_file (str): iteration number for the saved checkpoint | |
| trained_data_record (dict[str, int]): example {"image": 0, "video": 0, "iteration": 0}. | |
| """ | |
| # filename: iter_xxxxxxxxx_trained_data_record.json | |
| checkpoint_path = os.path.join( | |
| self.save_dirname, f"{checkpoint_file.replace('.pt', '')}_trained_data_record.json" | |
| ) | |
| easy_io.dump(trained_data_record, checkpoint_path) | |