import os import torch from torch.optim import AdamW import torch.nn.functional as F from diffusers_sv3d.pipelines.stable_video_diffusion.pipeline_stable_video_3d_diffusion import ( StableVideo3DDiffusionPipeline, ) # Configuration BATCH_SIZE = 1 LR = 1e-5 NUM_EPOCHS = 10 SAVE_DIR = "checkpoints" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") SV3D_PATH = os.path.abspath("/home/hubert/projects/sv3d-pbr/sv3d_diffusers/pretrained_sv3d") def train(): # Create directories os.makedirs(SAVE_DIR, exist_ok=True) # Create pipeline pipeline = StableVideo3DDiffusionPipeline.from_pretrained( SV3D_PATH, revision="fp16", torch_dtype=torch.float16, ) pipeline.to(DEVICE) # freeze unet parts - freeze everything first for param in pipeline.unet.parameters(): param.requires_grad = False # unfreeze only one specific layer (for example, the last output block) for name, param in pipeline.unet.named_parameters(): if "down_blocks.2.resnets.0.spatial_res_block.conv1" in name: param.requires_grad = True print(f"Unfreezing: {name}") # Count trainable parameters trainable_params = sum(p.numel() for p in pipeline.unet.parameters() if p.requires_grad) total_params = sum(p.numel() for p in pipeline.unet.parameters()) print(f"Trainable parameters: {trainable_params:,} / {total_params:,} ({trainable_params/total_params:.2%})") # Setup optimizer - only train unfrozen parameters optimizer = AdamW([p for p in pipeline.unet.parameters() if p.requires_grad], lr=LR) # Training loop for epoch in range(NUM_EPOCHS): pipeline.unet.train() # Prepare for backward pass optimizer.zero_grad() latents = torch.randn((6,21,8,72,72), dtype=torch.float16).to(DEVICE) t = 0.123 encoder_hidden_states = torch.randn((126,1,1024), dtype=torch.float16).to(DEVICE) added_tim_ids = torch.randn((6,21), dtype=torch.float16).to(DEVICE) target_noise = torch.randn((6,21,8,72,72), dtype=torch.float16).to(DEVICE) noise_pred = pipeline.unet( latents, t, encoder_hidden_states=encoder_hidden_states, added_time_ids=[added_tim_ids], ) print(noise_pred.shape) # loss = F.mse_loss(noise_pred, target_noise) # Backward pass and optimizer step # loss.backward() # optimizer.step() if __name__ == "__main__": train()