| | import os |
| |
|
| | import torch |
| | from torch.optim import AdamW |
| | import torch.nn.functional as F |
| |
|
| | from diffusers_sv3d.pipelines.stable_video_diffusion.pipeline_stable_video_3d_diffusion import ( |
| | StableVideo3DDiffusionPipeline, |
| | ) |
| |
|
| | |
| | BATCH_SIZE = 1 |
| | LR = 1e-5 |
| | NUM_EPOCHS = 10 |
| | SAVE_DIR = "checkpoints" |
| | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | SV3D_PATH = os.path.abspath("/home/hubert/projects/sv3d-pbr/sv3d_diffusers/pretrained_sv3d") |
| |
|
| |
|
| | def train(): |
| | |
| | os.makedirs(SAVE_DIR, exist_ok=True) |
| |
|
| | |
| | pipeline = StableVideo3DDiffusionPipeline.from_pretrained( |
| | SV3D_PATH, |
| | revision="fp16", |
| | torch_dtype=torch.float16, |
| | ) |
| | pipeline.to(DEVICE) |
| |
|
| | |
| | for param in pipeline.unet.parameters(): |
| | param.requires_grad = False |
| | |
| | |
| | for name, param in pipeline.unet.named_parameters(): |
| | if "down_blocks.2.resnets.0.spatial_res_block.conv1" in name: |
| | param.requires_grad = True |
| | print(f"Unfreezing: {name}") |
| | |
| | |
| | trainable_params = sum(p.numel() for p in pipeline.unet.parameters() if p.requires_grad) |
| | total_params = sum(p.numel() for p in pipeline.unet.parameters()) |
| | print(f"Trainable parameters: {trainable_params:,} / {total_params:,} ({trainable_params/total_params:.2%})") |
| |
|
| | |
| | optimizer = AdamW([p for p in pipeline.unet.parameters() if p.requires_grad], lr=LR) |
| |
|
| | |
| | for epoch in range(NUM_EPOCHS): |
| | pipeline.unet.train() |
| |
|
| | |
| | optimizer.zero_grad() |
| | |
| | latents = torch.randn((6,21,8,72,72), dtype=torch.float16).to(DEVICE) |
| | t = 0.123 |
| | encoder_hidden_states = torch.randn((126,1,1024), dtype=torch.float16).to(DEVICE) |
| | added_tim_ids = torch.randn((6,21), dtype=torch.float16).to(DEVICE) |
| | target_noise = torch.randn((6,21,8,72,72), dtype=torch.float16).to(DEVICE) |
| |
|
| | noise_pred = pipeline.unet( |
| | latents, |
| | t, |
| | encoder_hidden_states=encoder_hidden_states, |
| | added_time_ids=[added_tim_ids], |
| | ) |
| |
|
| | print(noise_pred.shape) |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | if __name__ == "__main__": |
| | train() |
| |
|