| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from dataclasses import dataclass |
| from typing import Tuple |
|
|
| import torch |
| from megatron.core import parallel_state |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
|
|
| from cosmos_transfer1.utils import distributed |
| from cosmos_transfer1.utils.callback import GradClip as GradClipImage |
| from cosmos_transfer1.utils.callback import _fused_nan_to_num |
| from cosmos_transfer1.utils.model import Model |
|
|
|
|
| @dataclass |
| class _MagnitudeRecord: |
| state: float = 0 |
| iter_count: int = 0 |
|
|
| def reset(self) -> None: |
| self.state = 0 |
| self.iter_count = 0 |
|
|
| def update(self, cur_state: torch.Tensor) -> None: |
| self.state += cur_state |
| self.iter_count += 1 |
|
|
| def get_stat(self) -> Tuple[float, float]: |
| if self.iter_count > 0: |
| avg_state = self.state / self.iter_count |
| avg_state = avg_state.item() |
| else: |
| avg_state = 0 |
| self.reset() |
| return avg_state |
|
|
|
|
| class GradClip(GradClipImage): |
| """ |
| adds support for TP |
| """ |
|
|
| def __init__(self, *args, **kwargs) -> None: |
| super().__init__(*args, **kwargs) |
| self.img_mag_log = _MagnitudeRecord() |
| self.video_mag_log = _MagnitudeRecord() |
| self._cur_state = None |
|
|
| def on_training_step_start(self, model: Model, data_batch: dict[str, torch.Tensor], iteration: int = 0) -> None: |
| if model.is_image_batch(data_batch): |
| self._cur_state = self.img_mag_log |
| else: |
| self._cur_state = self.video_mag_log |
|
|
| def on_before_optimizer_step( |
| self, |
| model_ddp: distributed.DistributedDataParallel, |
| optimizer: torch.optim.Optimizer, |
| scheduler: torch.optim.lr_scheduler.LRScheduler, |
| grad_scaler: torch.amp.GradScaler, |
| iteration: int = 0, |
| ) -> None: |
| del optimizer, scheduler |
| if isinstance(model_ddp, distributed.DistributedDataParallel): |
| model = model_ddp.module |
| else: |
| model = model_ddp |
| params = [] |
| if self.model_key is not None: |
| items = self.model_key.split(".") |
| for item in items: |
| model = getattr(model, item) |
| if self.force_finite: |
| for param in model.parameters(): |
| if param.grad is not None: |
| params.append(param.grad) |
| |
| _fused_nan_to_num(params) |
|
|
| if isinstance(model, FSDP) and self.fsdp_enabled: |
| total_norm = model.clip_grad_norm_(self.clip_norm) |
| else: |
| if parallel_state.is_initialized() and parallel_state.get_tensor_model_parallel_world_size() > 1: |
| total_norm = model_ddp.module.clip_grad_norm_(self.clip_norm) |
| else: |
| total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip_norm, foreach=True) |
|
|
| self._cur_state.update(total_norm) |
|
|