| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = UNet3D().to(device) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) | |
| for epoch in range(250): | |
| for batch in tqdm(dataloader): | |
| video = batch["video"].to(device) | |
| text = batch["text"].to(device) | |
| t = torch.randint(0, 1000, (video.shape[0], 1)).to(device) | |
| noise = torch.randn_like(video) | |
| alpha_t = (1 - t/1000).view(-1, 1, 1, 1, 1) | |
| noisy_video = torch.sqrt(alpha_t) * video + torch.sqrt(1 - alpha_t) * noise | |
| pred_noise = model(noisy_video, t/1000, text) | |
| loss = F.mse_loss(pred_noise, noise) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| print(f"Epoch {epoch}, Loss: {loss.item():.4f}") |