Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import signal | |
| import torch | |
| import torch.distributed as dist | |
| import torch.utils.data | |
| from megatron.core import parallel_state | |
| from cosmos_predict1.checkpointer.tp import Checkpointer as TensorParallelCheckpointer | |
| from cosmos_predict1.utils import distributed, ema, log, misc | |
| from cosmos_predict1.utils.checkpointer import Checkpointer | |
| from cosmos_predict1.utils.fsdp_checkpointer import FSDPCheckpointer | |
| from cosmos_predict1.utils.model import Model | |
| from cosmos_predict1.utils.trainer import Trainer | |
| class Trainer(Trainer): | |
| def __init__(self, config): | |
| super(Trainer, self).__init__(config) | |
| if config.trainer.distributed_parallelism == "ddp": | |
| if parallel_state.get_tensor_model_parallel_world_size() > 1: | |
| self.checkpointer = TensorParallelCheckpointer(config.checkpoint, config.job, callbacks=self.callbacks) | |
| log.critical("Using Tensor Parallelism Checkpointer") | |
| else: | |
| self.checkpointer = Checkpointer(config.checkpoint, config.job, callbacks=self.callbacks) | |
| elif config.trainer.distributed_parallelism == "fsdp": | |
| self.checkpointer = FSDPCheckpointer(config.checkpoint, config.job, callbacks=self.callbacks) | |
| else: | |
| raise ValueError(f"Unsupported distributed parallelism: {config.trainer.distributed_parallelism}") | |
| """ | |
| Modify the original trainer to log average loss (averaging across all devices and gradient accumulation) | |
| """ | |
| def train( | |
| self, | |
| model: Model, | |
| dataloader_train: torch.utils.data.DataLoader, | |
| dataloader_val: torch.utils.data.DataLoader, | |
| ) -> None: | |
| """The training function. | |
| Args: | |
| model (Model): The PyTorch model. | |
| dataloader_train (torch.utils.data.DataLoader): The training data loader. | |
| dataloader_val (torch.utils.data.DataLoader): The validation data loader. | |
| """ | |
| # Leaving this for backward compability for now, but we can think about moving this to model.on_train_start for all models. | |
| model = model.to("cuda", memory_format=self.config.trainer.memory_format) # type: ignore | |
| log.info(f"Model Architecture:\n {model}") | |
| model.on_train_start(self.config.trainer.memory_format) | |
| # Initialize the optimizer and scheduler. | |
| self.callbacks.on_optimizer_init_start() | |
| optimizer, scheduler = model.init_optimizer_scheduler(self.config.optimizer, self.config.scheduler) | |
| grad_scaler = torch.amp.GradScaler("cuda", **self.config.trainer.grad_scaler_args) | |
| self.callbacks.on_optimizer_init_end() | |
| # Load the model checkpoint and get the starting iteration number. | |
| iteration = self.checkpointer.load(model, optimizer, scheduler, grad_scaler) | |
| # Set the scheduler to the current iteration. | |
| scheduler.last_epoch = iteration | |
| scheduler._step_count = iteration + 1 | |
| grad_accum_iter = 0 | |
| log.critical(f"Distributed parallelism mode: {self.config.trainer.distributed_parallelism}") | |
| if self.config.trainer.distributed_parallelism == "ddp": | |
| # Create a DDP model wrapper. | |
| model_ddp = distributed.parallel_model_wrapper(self.config.trainer.ddp, model) | |
| elif self.config.trainer.distributed_parallelism == "fsdp": | |
| model_ddp = model | |
| else: | |
| raise ValueError(f"Unknown distributed parallelism mode: {self.config.trainer.distributed_parallelism}") | |
| log.info("Starting training...") | |
| self.callbacks.on_train_start(model, iteration=iteration) | |
| # Initial validation. | |
| if self.config.trainer.run_validation and iteration == 0: | |
| self.validate(model, dataloader_val, iteration=iteration) | |
| _end_training = False | |
| self.callbacks.on_before_dataloading(iteration) | |
| accumulated_loss = 0.0 | |
| while True: | |
| dataloader_train_iter = iter(dataloader_train) | |
| while True: | |
| self.callbacks.on_before_dataloading(iteration) | |
| try: | |
| data_batch = next(dataloader_train_iter) | |
| except StopIteration: | |
| break | |
| self.callbacks.on_after_dataloading(iteration) | |
| # If max_iter is reached, exit the training loop. | |
| if iteration >= self.config.trainer.max_iter: | |
| _end_training = True | |
| break | |
| # Move all tensors in the data batch to GPU device. | |
| data_batch = misc.to(data_batch, device="cuda") | |
| # The actual training step. | |
| self.callbacks.on_training_step_start(model, data_batch, iteration=iteration) | |
| model_ddp.train() | |
| output_batch, loss, grad_accum_iter = self.training_step( | |
| model_ddp, | |
| optimizer, | |
| scheduler, | |
| grad_scaler, | |
| data_batch, | |
| iteration=iteration, | |
| grad_accum_iter=grad_accum_iter, | |
| ) | |
| # Accumulate loss | |
| accumulated_loss += loss.detach() | |
| # If the gradients are still being accumulated, continue to load the next training batch. | |
| if grad_accum_iter != 0: | |
| if self.enable_one_logger: | |
| # Callback for skipped OneLoggerCallback.on_training_step_end() | |
| self.one_logger.on_train_batch_end(set_barrier=False) | |
| continue | |
| # Do the following when an actual optimizer (update) step has been made. | |
| iteration += 1 | |
| # Average loss over accumulation steps | |
| grad_accum_avg_loss = accumulated_loss / self.config.trainer.grad_accum_iter | |
| # Average loss across all devices | |
| device_avg_loss = grad_accum_avg_loss.clone() | |
| dist.all_reduce(device_avg_loss, op=dist.ReduceOp.SUM) | |
| device_avg_loss /= dist.get_world_size() | |
| # Reset accumulation variables | |
| accumulated_loss = 0.0 | |
| self.callbacks.on_training_step_end( | |
| model, data_batch, output_batch, device_avg_loss, iteration=iteration | |
| ) | |
| # self.callbacks.on_training_step_end(model, data_batch, output_batch, loss, iteration=iteration) | |
| # Validation. | |
| if self.config.trainer.run_validation and iteration % self.config.trainer.validation_iter == 0: | |
| self.validate(model, dataloader_val, iteration=iteration) | |
| # Save checkpoint. | |
| if iteration % self.config.checkpoint.save_iter == 0: | |
| self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration) | |
| # This iteration is successful; reset the timeout signal. | |
| signal.alarm(self.config.trainer.timeout_period) | |
| if _end_training: | |
| break | |
| log.success("Done with training.") | |
| self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration) | |
| self.callbacks.on_train_end(model, iteration=iteration) | |
| self.checkpointer.finalize() | |
| distributed.barrier() | |
| self.callbacks.on_app_end() | |
| def training_step( | |
| self, | |
| model_ddp: torch.nn.Module | distributed.DistributedDataParallel, | |
| optimizer: torch.optim.Optimizer, | |
| scheduler: torch.optim.lr_scheduler.LRScheduler, | |
| grad_scaler: torch.amp.GradScaler, | |
| data: dict[str, torch.Tensor], | |
| iteration: int = 0, | |
| grad_accum_iter: int = 0, | |
| ) -> tuple[dict[str, torch.Tensor], torch.Tensor, int]: | |
| """The training step. | |
| Args: | |
| model_ddp (torch.nn.Module | distributed.DistributedDataParallel): The model with a DDP wrapper or, the bare | |
| module, depending on whether distributed training is enabled or not. | |
| optimizer (torch.optim.Optimizer): The model optimizer. | |
| scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. | |
| grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). | |
| data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). | |
| iteration (int): Current iteration number. | |
| grad_accum_iter (int): Number of gradient accumulation iterations. | |
| Returns: | |
| output (dict[str, torch.Tensor]): The model output from the training data batch (dictionary of tensors). | |
| loss (torch.Tensor): The total loss of the training data batch. | |
| """ | |
| # Only let DDP sync gradient at the last iteration of the gradient accumulation window | |
| with distributed.ddp_sync_grad(model_ddp, grad_accum_iter == self.config.trainer.grad_accum_iter - 1): | |
| with self.training_timer("forward"): | |
| output_batch, loss = model_ddp.training_step(data, iteration) | |
| self.callbacks.on_before_backward(model_ddp, loss, iteration=iteration) | |
| with self.training_timer("backward"): | |
| loss_scaled = grad_scaler.scale(loss / self.config.trainer.grad_accum_iter) | |
| loss_scaled.backward() | |
| if self.config.trainer.distributed_parallelism == "ddp": | |
| model_ddp.module.on_after_backward() | |
| else: | |
| model_ddp.on_after_backward() | |
| self.callbacks.on_after_backward(model_ddp, iteration=iteration) | |
| grad_accum_iter += 1 | |
| if grad_accum_iter == self.config.trainer.grad_accum_iter: | |
| with self.training_timer("optimizer_step"): | |
| self.callbacks.on_before_optimizer_step( | |
| model_ddp, optimizer, scheduler, grad_scaler, iteration=iteration | |
| ) | |
| grad_scaler.step(optimizer) | |
| grad_scaler.update() | |
| scheduler.step() | |
| self.callbacks.on_before_zero_grad(model_ddp, optimizer, scheduler, iteration=iteration) | |
| if self.config.trainer.distributed_parallelism == "ddp": | |
| model_ddp.module.on_before_zero_grad(optimizer, scheduler, iteration=iteration) | |
| else: | |
| model_ddp.on_before_zero_grad(optimizer, scheduler, iteration=iteration) | |
| optimizer.zero_grad(set_to_none=True) | |
| grad_accum_iter = 0 | |
| return output_batch, loss, grad_accum_iter | |
| def validate(self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0) -> None: | |
| """Validate on the full validation dataset. | |
| Args: | |
| model (Model): The PyTorch model. | |
| dataloader_val (torch.utils.data.DataLoader): The validation data loader. | |
| iteration (int): Current iteration number. | |
| """ | |
| self.callbacks.on_validation_start(model, dataloader_val, iteration=iteration) | |
| model.eval() | |
| # Evaluate on the full validation set. | |
| with ema.ema_scope(model, enabled=getattr(model.config.ema, "enabled", False)): | |
| for val_iter, data_batch in enumerate(dataloader_val): | |
| if self.config.trainer.max_val_iter is not None and val_iter >= self.config.trainer.max_val_iter: | |
| break | |
| data_batch = misc.to(data_batch, device="cuda") | |
| self.callbacks.on_validation_step_start(model, data_batch, iteration=iteration) | |
| output_batch, loss = model.validation_step(data_batch, iteration) | |
| self.callbacks.on_validation_step_end(model, data_batch, output_batch, loss, iteration=iteration) | |
| self.callbacks.on_validation_end(model, iteration=iteration) | |