|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import signal |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import torch.utils.data |
|
|
from megatron.core import parallel_state |
|
|
|
|
|
from cosmos_predict1.checkpointer.tp import Checkpointer as TensorParallelCheckpointer |
|
|
from cosmos_predict1.utils import distributed, ema, log, misc |
|
|
from cosmos_predict1.utils.checkpointer import Checkpointer |
|
|
from cosmos_predict1.utils.fsdp_checkpointer import FSDPCheckpointer |
|
|
from cosmos_predict1.utils.model import Model |
|
|
from cosmos_predict1.utils.trainer import Trainer |
|
|
|
|
|
|
|
|
class Trainer(Trainer): |
|
|
def __init__(self, config): |
|
|
super(Trainer, self).__init__(config) |
|
|
if config.trainer.distributed_parallelism == "ddp": |
|
|
if parallel_state.get_tensor_model_parallel_world_size() > 1: |
|
|
self.checkpointer = TensorParallelCheckpointer(config.checkpoint, config.job, callbacks=self.callbacks) |
|
|
log.critical("Using Tensor Parallelism Checkpointer") |
|
|
else: |
|
|
self.checkpointer = Checkpointer(config.checkpoint, config.job, callbacks=self.callbacks) |
|
|
|
|
|
elif config.trainer.distributed_parallelism == "fsdp": |
|
|
self.checkpointer = FSDPCheckpointer(config.checkpoint, config.job, callbacks=self.callbacks) |
|
|
else: |
|
|
raise ValueError(f"Unsupported distributed parallelism: {config.trainer.distributed_parallelism}") |
|
|
|
|
|
""" |
|
|
Modify the original trainer to log average loss (averaging across all devices and gradient accumulation) |
|
|
""" |
|
|
|
|
|
def train( |
|
|
self, |
|
|
model: Model, |
|
|
dataloader_train: torch.utils.data.DataLoader, |
|
|
dataloader_val: torch.utils.data.DataLoader, |
|
|
) -> None: |
|
|
"""The training function. |
|
|
|
|
|
Args: |
|
|
model (Model): The PyTorch model. |
|
|
dataloader_train (torch.utils.data.DataLoader): The training data loader. |
|
|
dataloader_val (torch.utils.data.DataLoader): The validation data loader. |
|
|
""" |
|
|
|
|
|
model = model.to("cuda", memory_format=self.config.trainer.memory_format) |
|
|
log.info(f"Model Architecture:\n {model}") |
|
|
model.on_train_start(self.config.trainer.memory_format) |
|
|
|
|
|
self.callbacks.on_optimizer_init_start() |
|
|
|
|
|
optimizer, scheduler = model.init_optimizer_scheduler(self.config.optimizer, self.config.scheduler) |
|
|
|
|
|
grad_scaler = torch.amp.GradScaler("cuda", **self.config.trainer.grad_scaler_args) |
|
|
self.callbacks.on_optimizer_init_end() |
|
|
|
|
|
iteration = self.checkpointer.load(model, optimizer, scheduler, grad_scaler) |
|
|
|
|
|
scheduler.last_epoch = iteration |
|
|
scheduler._step_count = iteration + 1 |
|
|
|
|
|
grad_accum_iter = 0 |
|
|
log.critical(f"Distributed parallelism mode: {self.config.trainer.distributed_parallelism}") |
|
|
if self.config.trainer.distributed_parallelism == "ddp": |
|
|
|
|
|
model_ddp = distributed.parallel_model_wrapper(self.config.trainer.ddp, model) |
|
|
elif self.config.trainer.distributed_parallelism == "fsdp": |
|
|
model_ddp = model |
|
|
else: |
|
|
raise ValueError(f"Unknown distributed parallelism mode: {self.config.trainer.distributed_parallelism}") |
|
|
log.info("Starting training...") |
|
|
self.callbacks.on_train_start(model, iteration=iteration) |
|
|
|
|
|
if self.config.trainer.run_validation and iteration == 0: |
|
|
self.validate(model, dataloader_val, iteration=iteration) |
|
|
_end_training = False |
|
|
self.callbacks.on_before_dataloading(iteration) |
|
|
accumulated_loss = 0.0 |
|
|
|
|
|
while True: |
|
|
dataloader_train_iter = iter(dataloader_train) |
|
|
while True: |
|
|
self.callbacks.on_before_dataloading(iteration) |
|
|
try: |
|
|
data_batch = next(dataloader_train_iter) |
|
|
except StopIteration: |
|
|
break |
|
|
self.callbacks.on_after_dataloading(iteration) |
|
|
|
|
|
if iteration >= self.config.trainer.max_iter: |
|
|
_end_training = True |
|
|
break |
|
|
|
|
|
|
|
|
data_batch = misc.to(data_batch, device="cuda") |
|
|
|
|
|
self.callbacks.on_training_step_start(model, data_batch, iteration=iteration) |
|
|
model_ddp.train() |
|
|
output_batch, loss, grad_accum_iter = self.training_step( |
|
|
model_ddp, |
|
|
optimizer, |
|
|
scheduler, |
|
|
grad_scaler, |
|
|
data_batch, |
|
|
iteration=iteration, |
|
|
grad_accum_iter=grad_accum_iter, |
|
|
) |
|
|
|
|
|
|
|
|
accumulated_loss += loss.detach() |
|
|
|
|
|
|
|
|
if grad_accum_iter != 0: |
|
|
if self.enable_one_logger: |
|
|
|
|
|
self.one_logger.on_train_batch_end(set_barrier=False) |
|
|
continue |
|
|
|
|
|
iteration += 1 |
|
|
|
|
|
|
|
|
grad_accum_avg_loss = accumulated_loss / self.config.trainer.grad_accum_iter |
|
|
|
|
|
device_avg_loss = grad_accum_avg_loss.clone() |
|
|
dist.all_reduce(device_avg_loss, op=dist.ReduceOp.SUM) |
|
|
device_avg_loss /= dist.get_world_size() |
|
|
|
|
|
accumulated_loss = 0.0 |
|
|
|
|
|
self.callbacks.on_training_step_end( |
|
|
model, data_batch, output_batch, device_avg_loss, iteration=iteration |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.config.trainer.run_validation and iteration % self.config.trainer.validation_iter == 0: |
|
|
self.validate(model, dataloader_val, iteration=iteration) |
|
|
|
|
|
if iteration % self.config.checkpoint.save_iter == 0: |
|
|
self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration) |
|
|
|
|
|
signal.alarm(self.config.trainer.timeout_period) |
|
|
if _end_training: |
|
|
break |
|
|
log.success("Done with training.") |
|
|
self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration) |
|
|
self.callbacks.on_train_end(model, iteration=iteration) |
|
|
self.checkpointer.finalize() |
|
|
distributed.barrier() |
|
|
self.callbacks.on_app_end() |
|
|
|
|
|
def training_step( |
|
|
self, |
|
|
model_ddp: torch.nn.Module | distributed.DistributedDataParallel, |
|
|
optimizer: torch.optim.Optimizer, |
|
|
scheduler: torch.optim.lr_scheduler.LRScheduler, |
|
|
grad_scaler: torch.amp.GradScaler, |
|
|
data: dict[str, torch.Tensor], |
|
|
iteration: int = 0, |
|
|
grad_accum_iter: int = 0, |
|
|
) -> tuple[dict[str, torch.Tensor], torch.Tensor, int]: |
|
|
"""The training step. |
|
|
|
|
|
Args: |
|
|
model_ddp (torch.nn.Module | distributed.DistributedDataParallel): The model with a DDP wrapper or, the bare |
|
|
module, depending on whether distributed training is enabled or not. |
|
|
optimizer (torch.optim.Optimizer): The model optimizer. |
|
|
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. |
|
|
grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). |
|
|
data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). |
|
|
iteration (int): Current iteration number. |
|
|
grad_accum_iter (int): Number of gradient accumulation iterations. |
|
|
|
|
|
Returns: |
|
|
output (dict[str, torch.Tensor]): The model output from the training data batch (dictionary of tensors). |
|
|
loss (torch.Tensor): The total loss of the training data batch. |
|
|
""" |
|
|
|
|
|
with distributed.ddp_sync_grad(model_ddp, grad_accum_iter == self.config.trainer.grad_accum_iter - 1): |
|
|
with self.training_timer("forward"): |
|
|
output_batch, loss = model_ddp.training_step(data, iteration) |
|
|
self.callbacks.on_before_backward(model_ddp, loss, iteration=iteration) |
|
|
with self.training_timer("backward"): |
|
|
loss_scaled = grad_scaler.scale(loss / self.config.trainer.grad_accum_iter) |
|
|
loss_scaled.backward() |
|
|
if self.config.trainer.distributed_parallelism == "ddp": |
|
|
model_ddp.module.on_after_backward() |
|
|
else: |
|
|
model_ddp.on_after_backward() |
|
|
self.callbacks.on_after_backward(model_ddp, iteration=iteration) |
|
|
grad_accum_iter += 1 |
|
|
if grad_accum_iter == self.config.trainer.grad_accum_iter: |
|
|
with self.training_timer("optimizer_step"): |
|
|
self.callbacks.on_before_optimizer_step( |
|
|
model_ddp, optimizer, scheduler, grad_scaler, iteration=iteration |
|
|
) |
|
|
grad_scaler.step(optimizer) |
|
|
grad_scaler.update() |
|
|
scheduler.step() |
|
|
self.callbacks.on_before_zero_grad(model_ddp, optimizer, scheduler, iteration=iteration) |
|
|
if self.config.trainer.distributed_parallelism == "ddp": |
|
|
model_ddp.module.on_before_zero_grad(optimizer, scheduler, iteration=iteration) |
|
|
else: |
|
|
model_ddp.on_before_zero_grad(optimizer, scheduler, iteration=iteration) |
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
grad_accum_iter = 0 |
|
|
return output_batch, loss, grad_accum_iter |
|
|
|
|
|
@torch.no_grad() |
|
|
def validate(self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0) -> None: |
|
|
"""Validate on the full validation dataset. |
|
|
|
|
|
Args: |
|
|
model (Model): The PyTorch model. |
|
|
dataloader_val (torch.utils.data.DataLoader): The validation data loader. |
|
|
iteration (int): Current iteration number. |
|
|
""" |
|
|
self.callbacks.on_validation_start(model, dataloader_val, iteration=iteration) |
|
|
model.eval() |
|
|
|
|
|
with ema.ema_scope(model, enabled=getattr(model.config.ema, "enabled", False)): |
|
|
for val_iter, data_batch in enumerate(dataloader_val): |
|
|
if self.config.trainer.max_val_iter is not None and val_iter >= self.config.trainer.max_val_iter: |
|
|
break |
|
|
data_batch = misc.to(data_batch, device="cuda") |
|
|
self.callbacks.on_validation_step_start(model, data_batch, iteration=iteration) |
|
|
output_batch, loss = model.validation_step(data_batch, iteration) |
|
|
self.callbacks.on_validation_step_end(model, data_batch, output_batch, loss, iteration=iteration) |
|
|
self.callbacks.on_validation_end(model, iteration=iteration) |
|
|
|