|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.utils.data |
|
|
|
|
|
from cosmos_predict1.tokenizer.training.checkpointer import TokenizerCheckpointer |
|
|
from cosmos_predict1.utils import ema, misc |
|
|
from cosmos_predict1.utils.model import Model |
|
|
from cosmos_predict1.utils.trainer import Trainer |
|
|
|
|
|
|
|
|
class TokenizerTrainer(Trainer): |
|
|
"""The tokenizers traine, extended from Trainer. |
|
|
|
|
|
It extends model training functionality. |
|
|
|
|
|
Attributes: |
|
|
checkpointer (Checkpointer): checkpointer object to save/load model weights and optimizer states. |
|
|
training_timer (misc.Timer): Timer object to time code blocks and functions. |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super(TokenizerTrainer, self).__init__(config) |
|
|
self.model_config = config.model.config |
|
|
self.checkpointer = TokenizerCheckpointer(config.checkpoint, config.job, callbacks=self.callbacks) |
|
|
|
|
|
@torch.no_grad() |
|
|
def validate(self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0) -> None: |
|
|
"""Validate on the full validation dataset. |
|
|
|
|
|
Args: |
|
|
model (Model): The PyTorch model. |
|
|
dataloader_val (torch.utils.data.DataLoader): The validation data loader. |
|
|
iteration (int): Current iteration number. |
|
|
""" |
|
|
self.callbacks.on_validation_start(model, dataloader_val, iteration=iteration) |
|
|
model.eval() |
|
|
|
|
|
for val_iter, data_batch in enumerate(dataloader_val): |
|
|
if self.config.trainer.max_val_iter is not None and val_iter >= self.config.trainer.max_val_iter: |
|
|
break |
|
|
data_batch = misc.to(data_batch, device="cuda") |
|
|
self.callbacks.on_validation_step_start(model, data_batch, iteration=iteration) |
|
|
output_batch, _ = model.validation_step(data_batch, iteration) |
|
|
with ema.ema_scope(model, enabled=model.config.ema.enabled): |
|
|
ema_output_batch, loss = model.validation_step(data_batch, iteration, ema_model=True) |
|
|
output_batch.update(ema_output_batch) |
|
|
self.callbacks.on_validation_step_end(model, data_batch, output_batch, loss, iteration=iteration) |
|
|
self.callbacks.on_validation_end(model, iteration=iteration) |
|
|
|