| class Text2VideoTrainer: | |
| def __init__(self, model, optimizer, device): | |
| self.model = model | |
| self.optimizer = optimizer | |
| self.device = device | |
| def train_step(self, text_batch, video_batch): | |
| self.optimizer.zero_grad() | |
| generated_video = self.model(text_batch) | |
| loss = F.mse_loss(generated_video, video_batch) | |
| loss.backward() | |
| self.optimizer.step() | |
| return loss.item() | |