File size: 940 Bytes
3316641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class VJEPATrainer:
    def __init__(self, model, lr=1e-4):
        self.model = model
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
        self.criterion = nn.MSELoss()

    def train_step(self, video, text, future_frames):
        # Mask future video segments
        masked_video = video[:, :, :-1]  # Remove last segment
        context_emb = self.model(masked_video, text)
        
        # Predict future frames with diffusion
        noise = torch.randn_like(future_frames)
        timesteps = torch.randint(0, 1000, (video.shape[0],))
        noisy_frames = self.add_noise(future_frames, noise, timesteps)
        
        pred = self.model.diffusion_decoder(
            noisy_frames, 
            timesteps, 
            encoder_hidden_states=context_emb
        ).sample
        
        loss = self.criterion(pred, noise)
        loss.backward()
        self.optimizer.step()
        return loss.item()