class Text2VideoTrainer: def __init__(self, model, optimizer, device): self.model = model self.optimizer = optimizer self.device = device def train_step(self, text_batch, video_batch): self.optimizer.zero_grad() generated_video = self.model(text_batch) loss = F.mse_loss(generated_video, video_batch) loss.backward() self.optimizer.step() return loss.item()