| class VJEPATrainer: | |
| def __init__(self, model, lr=1e-4): | |
| self.model = model | |
| self.optimizer = torch.optim.AdamW(model.parameters(), lr=lr) | |
| self.criterion = nn.MSELoss() | |
| def train_step(self, video, text, future_frames): | |
| # Mask future video segments | |
| masked_video = video[:, :, :-1] # Remove last segment | |
| context_emb = self.model(masked_video, text) | |
| # Predict future frames with diffusion | |
| noise = torch.randn_like(future_frames) | |
| timesteps = torch.randint(0, 1000, (video.shape[0],)) | |
| noisy_frames = self.add_noise(future_frames, noise, timesteps) | |
| pred = self.model.diffusion_decoder( | |
| noisy_frames, | |
| timesteps, | |
| encoder_hidden_states=context_emb | |
| ).sample | |
| loss = self.criterion(pred, noise) | |
| loss.backward() | |
| self.optimizer.step() | |
| return loss.item() |