|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Implements the forward op for training, validation, and inference.""" |
|
|
|
|
|
from typing import Any |
|
|
|
|
|
import torch |
|
|
|
|
|
from cosmos_predict1.tokenizer.training.datasets.utils import IMAGE_KEY, INPUT_KEY, MASK_KEY, RECON_KEY, VIDEO_KEY |
|
|
from cosmos_predict1.tokenizer.training.losses.continuous import RECON_CONSISTENCY_KEY, VIDEO_CONSISTENCY_LOSS |
|
|
from cosmos_predict1.utils import ema |
|
|
from cosmos_predict1.utils.lazy_config import LazyDict, instantiate |
|
|
from cosmos_predict1.utils.model import Model |
|
|
|
|
|
PREDICTION = "prediction" |
|
|
EMA_PREDICTION = "ema_prediction" |
|
|
|
|
|
|
|
|
class TokenizerModel(Model): |
|
|
def __init__(self, config) -> None: |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.network = instantiate(config.network) |
|
|
self.loss = instantiate(config.loss) |
|
|
self.metric = instantiate(config.metric) |
|
|
self.precision = getattr(torch, config.precision) |
|
|
if self.config.ema.enabled: |
|
|
self.ema = ema.EMAModelTracker( |
|
|
self, |
|
|
beta=self.config.ema.beta, |
|
|
torch_compile_buffer_renaming=self.config.ema.torch_compile_buffer_renaming, |
|
|
) |
|
|
self.init_input_keys() |
|
|
|
|
|
def init_input_keys(self): |
|
|
self.image_key = IMAGE_KEY |
|
|
self.video_key = VIDEO_KEY |
|
|
|
|
|
def get_input_key(self, data_batch: dict[str, torch.Tensor]) -> str: |
|
|
if self.image_key in data_batch: |
|
|
return self.image_key |
|
|
elif self.video_key in data_batch: |
|
|
return self.video_key |
|
|
else: |
|
|
raise ValueError("Input key not found in data_batch.") |
|
|
|
|
|
def init_optimizer_scheduler( |
|
|
self, optimizer_config: LazyDict, scheduler_config: LazyDict |
|
|
) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: |
|
|
"""Creates the optimizer and scheduler for the network. |
|
|
|
|
|
Args: |
|
|
optimizer_config: The optimizer config for the net. |
|
|
scheduler_config: The scheduler config for the net. |
|
|
|
|
|
Returns: |
|
|
optimizer (torch.optim.Optimizer): The net optimizer. |
|
|
scheduler (torch.optim.lr_scheduler.LRScheduler): The net optimization scheduler. |
|
|
""" |
|
|
optimizer_config.params = self.network.parameters() |
|
|
optimizer = instantiate(optimizer_config) |
|
|
scheduler_config.optimizer = optimizer |
|
|
scheduler = instantiate(scheduler_config) |
|
|
|
|
|
return optimizer, scheduler |
|
|
|
|
|
def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: |
|
|
if self.config.ema.enabled: |
|
|
self.ema.to(dtype=torch.float32) |
|
|
self.network = self.network.to(dtype=self.precision, memory_format=memory_format) |
|
|
self.loss = self.loss.to(dtype=self.precision, memory_format=memory_format) |
|
|
|
|
|
def state_dict( |
|
|
self, destination: dict[str, Any] = None, prefix: str = "", keep_vars: bool = False |
|
|
) -> dict[str, Any]: |
|
|
original_state_dict = super(TokenizerModel, self).state_dict(destination, prefix, keep_vars) |
|
|
|
|
|
|
|
|
filtered_state_dict = {k: v for k, v in original_state_dict.items() if not k.startswith("loss.")} |
|
|
filtered_state_dict = {k: v for k, v in filtered_state_dict.items() if not k.startswith("ema.loss-")} |
|
|
filtered_state_dict = { |
|
|
k: v for k, v in filtered_state_dict.items() if not k.startswith("network.encoder.patcher") |
|
|
} |
|
|
filtered_state_dict = { |
|
|
k: v for k, v in filtered_state_dict.items() if not k.startswith("network.decoder.unpatcher") |
|
|
} |
|
|
|
|
|
return filtered_state_dict |
|
|
|
|
|
def load_state_dict(self, state_dict: Any, strict: bool = True) -> None: |
|
|
own_state = self.state_dict() |
|
|
filtered_state_dict = {k: v for k, v in state_dict.items() if k in own_state} |
|
|
|
|
|
|
|
|
super(TokenizerModel, self).load_state_dict(filtered_state_dict, strict=False) |
|
|
|
|
|
|
|
|
missing_keys = set(own_state.keys()) - set(filtered_state_dict.keys()) |
|
|
if missing_keys and strict: |
|
|
raise KeyError(f"Missing keys in state_dict: {missing_keys}") |
|
|
|
|
|
def _on_before_network_forward(self, data_batch: dict[str, torch.Tensor]) -> None: |
|
|
consistency_loss = self.loss.loss_modules[VIDEO_CONSISTENCY_LOSS] |
|
|
if hasattr(consistency_loss, "enabled") and consistency_loss.enabled: |
|
|
_input_key = self.get_input_key(data_batch) |
|
|
if _input_key is self.video_key: |
|
|
data_batch[_input_key] = consistency_loss.shuffle(data_batch[_input_key]) |
|
|
return |
|
|
|
|
|
def _on_after_network_forward( |
|
|
self, data_batch: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor] |
|
|
) -> None: |
|
|
consistency_loss = self.loss.loss_modules[VIDEO_CONSISTENCY_LOSS] |
|
|
if hasattr(consistency_loss, "enabled") and consistency_loss.enabled: |
|
|
_input_key = self.get_input_key(data_batch) |
|
|
if _input_key is self.video_key: |
|
|
data_batch[_input_key] = consistency_loss.unshuffle(data_batch[_input_key]) |
|
|
output_batch[RECON_CONSISTENCY_KEY] = torch.ones_like(output_batch[RECON_KEY]) * output_batch[RECON_KEY] |
|
|
output_batch[RECON_KEY] = consistency_loss.unshuffle(output_batch[RECON_KEY]) |
|
|
return |
|
|
|
|
|
def _network_forward(self, data_batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: |
|
|
|
|
|
self._on_before_network_forward(data_batch) |
|
|
|
|
|
|
|
|
tensor_batch = data_batch[self.get_input_key(data_batch)] |
|
|
output_batch = self.network(tensor_batch) |
|
|
output_batch = output_batch if self.network.training else output_batch._asdict() |
|
|
|
|
|
|
|
|
self._on_after_network_forward(data_batch, output_batch) |
|
|
return output_batch |
|
|
|
|
|
def training_step( |
|
|
self, |
|
|
data_batch: dict[str, torch.Tensor], |
|
|
iteration: int, |
|
|
) -> tuple[dict[str, torch.Tensor], torch.Tensor]: |
|
|
_input_key = self.get_input_key(data_batch) |
|
|
output_dict = self._network_forward(data_batch) |
|
|
input_images, recon_images = data_batch[_input_key], output_dict[RECON_KEY] |
|
|
|
|
|
|
|
|
inputs = {INPUT_KEY: input_images, MASK_KEY: data_batch.get("loss_mask", torch.ones_like(input_images))} |
|
|
|
|
|
loss_dict, loss_value = self.loss(inputs, output_dict, iteration) |
|
|
return dict({PREDICTION: recon_images, **loss_dict}), loss_value |
|
|
|
|
|
@torch.no_grad() |
|
|
def validation_step( |
|
|
self, |
|
|
data_batch: dict[str, torch.Tensor], |
|
|
iteration: int, |
|
|
ema_model: bool = False, |
|
|
) -> tuple[dict[str, torch.Tensor], torch.Tensor]: |
|
|
_input_key = self.get_input_key(data_batch) |
|
|
output_dict = self._network_forward(data_batch) |
|
|
input_images, recon_images = data_batch[_input_key], output_dict[RECON_KEY] |
|
|
|
|
|
|
|
|
inputs = {INPUT_KEY: input_images, MASK_KEY: data_batch.get("loss_mask", torch.ones_like(input_images))} |
|
|
|
|
|
loss_dict, loss_value = self.loss(inputs, output_dict, iteration) |
|
|
metric_dict = self.metric(input_images, output_dict, iteration) |
|
|
loss_dict.update(metric_dict) |
|
|
prediction_key = EMA_PREDICTION if ema_model else PREDICTION |
|
|
return dict({prediction_key: recon_images, **loss_dict}), loss_value |
|
|
|
|
|
@torch.inference_mode() |
|
|
def forward(self, data_batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: |
|
|
_input_key = self.get_input_key(data_batch) |
|
|
output_dict = self._network_forward(data_batch) |
|
|
return dict({PREDICTION: output_dict[RECON_KEY]}) |
|
|
|