BetterModel / deepseek_python_20250816_58a3a6.py
atanu2531's picture
Upload deepseek_python_20250816_58a3a6.py
3316641 verified
class VJEPATrainer:
def __init__(self, model, lr=1e-4):
self.model = model
self.optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
self.criterion = nn.MSELoss()
def train_step(self, video, text, future_frames):
# Mask future video segments
masked_video = video[:, :, :-1] # Remove last segment
context_emb = self.model(masked_video, text)
# Predict future frames with diffusion
noise = torch.randn_like(future_frames)
timesteps = torch.randint(0, 1000, (video.shape[0],))
noisy_frames = self.add_noise(future_frames, noise, timesteps)
pred = self.model.diffusion_decoder(
noisy_frames,
timesteps,
encoder_hidden_states=context_emb
).sample
loss = self.criterion(pred, noise)
loss.backward()
self.optimizer.step()
return loss.item()