|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tokenizer callbacks extended from base callbacks.""" |
|
|
|
|
|
import math |
|
|
from typing import Any, Optional |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from torch._dynamo.eval_frame import OptimizedModule as torch_OptimizedModule |
|
|
|
|
|
from cosmos_predict1.utils import callback, distributed, log |
|
|
from cosmos_predict1.utils.config import Config |
|
|
from cosmos_predict1.utils.model import Model |
|
|
from cosmos_predict1.utils.trainer import Trainer |
|
|
|
|
|
_UINT8_MAX_F = float(np.iinfo(np.uint8).max) |
|
|
_VIDEO_CONSISTENCY_LOSS = "video_consistency" |
|
|
|
|
|
|
|
|
def make_video_grid(video, nrow=None, padding=1): |
|
|
r"""Make a grid of videos for visualization. |
|
|
Args: |
|
|
video (tensor): video of size B x C x T x H x W. |
|
|
nrow (int): number of rows in the grid. |
|
|
padding (int): size of paddings between videos. |
|
|
""" |
|
|
b, c, t, h, w = video.shape |
|
|
video = video.permute(0, 2, 3, 4, 1) |
|
|
video = (video.cpu().detach().numpy() * _UINT8_MAX_F).astype("uint8") |
|
|
if nrow is None: |
|
|
nrow = math.ceil(math.sqrt(b)) |
|
|
ncol = math.ceil(b / nrow) |
|
|
video_grid = np.zeros((t, (padding + h) * nrow + padding, (padding + w) * ncol + padding, c), dtype="uint8") |
|
|
|
|
|
for i in range(b): |
|
|
r = i // ncol |
|
|
c = i % ncol |
|
|
start_r = (padding + h) * r |
|
|
start_c = (padding + w) * c |
|
|
video_grid[:, start_r : start_r + h, start_c : start_c + w] = video[i] |
|
|
video = [] |
|
|
for i in range(t): |
|
|
video.append(video_grid[i]) |
|
|
return video |
|
|
|
|
|
|
|
|
def compute_weight_norm(model): |
|
|
weight_norm = dict() |
|
|
for layer_name, param in model.named_parameters(): |
|
|
if torch.isnan(param).any(): |
|
|
raise ValueError(f"[weight] {layer_name} NaN detected in gradients") |
|
|
weight_norm[f"{layer_name}"] = torch.norm(param, p=2).item() |
|
|
return weight_norm |
|
|
|
|
|
|
|
|
def compute_grad_norm(model): |
|
|
grad_norm = dict() |
|
|
for layer_name, param in model.named_parameters(): |
|
|
if param.grad is not None: |
|
|
if torch.isnan(param.grad).any(): |
|
|
raise ValueError(f"[grad] {layer_name} NaN detected in gradients") |
|
|
grad_norm[f"{layer_name}"] = torch.norm(param.grad, p=2).item() |
|
|
return grad_norm |
|
|
|
|
|
|
|
|
class AdaptCkptStateDict(callback.Callback): |
|
|
def __init__(self, config: Config, trainer: Trainer): |
|
|
super().__init__(config, trainer) |
|
|
|
|
|
def on_save_checkpoint(self, model: Model, state_dict: dict[Any, Any]) -> None: |
|
|
"""Adapt the state dict should the model be a compiled one.""" |
|
|
if not isinstance(model.network, torch_OptimizedModule): |
|
|
return |
|
|
|
|
|
def _uncompiled_key(k): |
|
|
if k.startswith("network._orig_mod"): |
|
|
return k.replace("network._orig_mod", "network") |
|
|
elif k.startswith("ema.network-_orig_mod"): |
|
|
return k.replace("ema.network-_orig_mod", "ema.network") |
|
|
return k |
|
|
|
|
|
fixed_keys_state_dict = {} |
|
|
|
|
|
for k, v in state_dict["model"].items(): |
|
|
fixed_keys_state_dict[_uncompiled_key(k)] = v |
|
|
|
|
|
state_dict["model"] = fixed_keys_state_dict |
|
|
|
|
|
def on_load_checkpoint(self, model: Model, state_dict: dict[Any, Any]) -> None: |
|
|
"""Adapt the state dict should the model be a compiled one.""" |
|
|
if not isinstance(model.network, torch_OptimizedModule): |
|
|
return |
|
|
|
|
|
def _compiled_key(k): |
|
|
if k.startswith("network."): |
|
|
return k.replace("network", "network._orig_mod") |
|
|
elif k.startswith("ema.network-"): |
|
|
return k.replace("ema.network", "ema.network-_orig_mod") |
|
|
return k |
|
|
|
|
|
fixed_keys_state_dict = {} |
|
|
|
|
|
for k, v in state_dict["model"].items(): |
|
|
fixed_keys_state_dict[_compiled_key(k)] = v |
|
|
|
|
|
state_dict["model"] = fixed_keys_state_dict |
|
|
|
|
|
|
|
|
class GradClipCallback(callback.GradClipCallback): |
|
|
"""The verbose tokenizer callback for gradient clipping.""" |
|
|
|
|
|
def __init__(self, grad_clip_norm: float, config: Config, trainer: Trainer, verbose: bool): |
|
|
super().__init__(config, trainer, grad_clip_norm) |
|
|
self.verbose = verbose |
|
|
|
|
|
def on_before_optimizer_step( |
|
|
self, |
|
|
model_ddp: distributed.DistributedDataParallel, |
|
|
optimizer: torch.optim.Optimizer, |
|
|
scheduler: torch.optim.lr_scheduler.LRScheduler, |
|
|
grad_scaler: torch.amp.GradScaler, |
|
|
iteration: int = 0, |
|
|
) -> None: |
|
|
grad_scaler.unscale_(optimizer) |
|
|
total_norm = torch.nn.utils.clip_grad_norm_(model_ddp.module.parameters(), max_norm=self.grad_clip_norm) |
|
|
if torch.isnan(total_norm): |
|
|
raise ValueError("[gradient clipping] NaN detected in gradient norms") |
|
|
if torch.isfinite(total_norm) and total_norm > self.grad_clip_norm and self.verbose: |
|
|
if model_ddp.module.network.training: |
|
|
log.warning( |
|
|
f"[net:{iteration:07d}] Gradient norm {total_norm} > {self.grad_clip_norm}. Clipping gradients." |
|
|
) |
|
|
else: |
|
|
log.warning( |
|
|
f"[unknown:{iteration:07d}] Gradient norm {total_norm} > {self.grad_clip_norm}. Clipping gradients." |
|
|
) |
|
|
|
|
|
|
|
|
class ExpandLossMask(callback.Callback): |
|
|
def __init__(self, kernel_size: int, config: Config, trainer: Trainer): |
|
|
super().__init__(config, trainer) |
|
|
self.kernel_size = kernel_size |
|
|
|
|
|
def on_training_step_start(self, model: Model, data: dict[str, Any], iteration: int = 0) -> None: |
|
|
"""Expand loss_mask with max pooling (to cover some partial human regions)""" |
|
|
|
|
|
if "loss_mask" not in data.keys(): |
|
|
return |
|
|
|
|
|
assert data["loss_mask"].ndim == 4 or data["loss_mask"].ndim == 5, "ndim of loss_mask must be 4 or 5" |
|
|
|
|
|
kernel_size = self.kernel_size |
|
|
if data["loss_mask"].ndim == 4: |
|
|
data["loss_mask"] = torch.nn.functional.max_pool2d( |
|
|
data["loss_mask"], kernel_size, stride=1, padding=kernel_size // 2 |
|
|
) |
|
|
else: |
|
|
data["loss_mask"] = torch.nn.functional.max_pool3d( |
|
|
data["loss_mask"], |
|
|
(1, kernel_size, kernel_size), |
|
|
stride=1, |
|
|
padding=(0, kernel_size // 2, kernel_size // 2), |
|
|
) |
|
|
|
|
|
|
|
|
class TorchCompile(callback.Callback): |
|
|
""" |
|
|
Callback to use torch.compile() on network or modules in losses(FlowLoss and PerceptualLoss) or both. |
|
|
We compile them at later iteration as it prevents NCCL timeouts when times are very unstable during first iterations |
|
|
""" |
|
|
|
|
|
_TORCH_DYNAMO_CACHE_SIZE = 128 |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
compile_after_iterations: int = 8, |
|
|
compile_network: bool = False, |
|
|
compile_loss: bool = False, |
|
|
compile_loss_keys: list[str] = ["flow", "perceptual"], |
|
|
): |
|
|
self.initial_iteration: Optional[int] = None |
|
|
self.compile_after_iterations: int = compile_after_iterations |
|
|
|
|
|
self.compile_network: bool = compile_network |
|
|
self.compile_loss: bool = compile_loss |
|
|
|
|
|
self.compile_loss_keys: list[str] = compile_loss_keys |
|
|
|
|
|
if self.compile_network or self.compile_loss: |
|
|
torch._dynamo.config.cache_size_limit = TorchCompile._TORCH_DYNAMO_CACHE_SIZE |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_OptimizedModule.training = property( |
|
|
lambda self: self._orig_mod.training, lambda self, value: None, lambda self: None |
|
|
) |
|
|
|
|
|
def on_training_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: |
|
|
if not (self.compile_network or self.compile_loss): |
|
|
return |
|
|
|
|
|
if self.initial_iteration is None: |
|
|
log.info(f"Compilation will done on iteration {iteration + self.compile_after_iterations}") |
|
|
self.initial_iteration = iteration |
|
|
|
|
|
if self.compile_network: |
|
|
if model.config.ema.enabled is True and model.config.ema.torch_compile_buffer_renaming is False: |
|
|
log.warning( |
|
|
'"model.config.ema.torch_compile_buffer_renaming" should be turned on for the EMA to work with torch.compile(), network will not be compiled' |
|
|
) |
|
|
|
|
|
if iteration - self.initial_iteration == self.compile_after_iterations: |
|
|
if self.compile_network: |
|
|
if model.config.ema.enabled is True and model.config.ema.torch_compile_buffer_renaming is False: |
|
|
log.warning( |
|
|
'"model.config.ema.torch_compile_buffer_renaming" should be turned on for the EMA to work with torch.compile(), skipping network compilation' |
|
|
) |
|
|
else: |
|
|
log.info("Compiling network") |
|
|
model.network = torch.compile(model.network, dynamic=False) |
|
|
|
|
|
if self.compile_loss: |
|
|
for key in self.compile_loss_keys: |
|
|
if key not in model.loss.loss_modules: |
|
|
log.warning(f"Loss module for compilation with key: {key} not found") |
|
|
else: |
|
|
if ( |
|
|
hasattr(model.loss.loss_modules[key], "checkpoint_activations") |
|
|
and getattr(model.loss.loss_modules[key], "checkpoint_activations") is True |
|
|
): |
|
|
log.warning( |
|
|
f"torch.compile() doesn't work with activation checkpointing, skipping compilation for loss with key: {key}" |
|
|
) |
|
|
else: |
|
|
log.info(f"Compiling loss with key: {key}") |
|
|
model.loss.loss_modules[key].torch_compile() |
|
|
|