class VJEPATrainer: def __init__(self, model, lr=1e-4): self.model = model self.optimizer = torch.optim.AdamW(model.parameters(), lr=lr) self.criterion = nn.MSELoss() def train_step(self, video, text, future_frames): # Mask future video segments masked_video = video[:, :, :-1] # Remove last segment context_emb = self.model(masked_video, text) # Predict future frames with diffusion noise = torch.randn_like(future_frames) timesteps = torch.randint(0, 1000, (video.shape[0],)) noisy_frames = self.add_noise(future_frames, noise, timesteps) pred = self.model.diffusion_decoder( noisy_frames, timesteps, encoder_hidden_states=context_emb ).sample loss = self.criterion(pred, noise) loss.backward() self.optimizer.step() return loss.item()