sv3d_diff / train.py
hkomp's picture
Add model and code
a13d12f
import os
import torch
from torch.optim import AdamW
import torch.nn.functional as F
from diffusers_sv3d.pipelines.stable_video_diffusion.pipeline_stable_video_3d_diffusion import (
StableVideo3DDiffusionPipeline,
)
# Configuration
BATCH_SIZE = 1
LR = 1e-5
NUM_EPOCHS = 10
SAVE_DIR = "checkpoints"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SV3D_PATH = os.path.abspath("/home/hubert/projects/sv3d-pbr/sv3d_diffusers/pretrained_sv3d")
def train():
# Create directories
os.makedirs(SAVE_DIR, exist_ok=True)
# Create pipeline
pipeline = StableVideo3DDiffusionPipeline.from_pretrained(
SV3D_PATH,
revision="fp16",
torch_dtype=torch.float16,
)
pipeline.to(DEVICE)
# freeze unet parts - freeze everything first
for param in pipeline.unet.parameters():
param.requires_grad = False
# unfreeze only one specific layer (for example, the last output block)
for name, param in pipeline.unet.named_parameters():
if "down_blocks.2.resnets.0.spatial_res_block.conv1" in name:
param.requires_grad = True
print(f"Unfreezing: {name}")
# Count trainable parameters
trainable_params = sum(p.numel() for p in pipeline.unet.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in pipeline.unet.parameters())
print(f"Trainable parameters: {trainable_params:,} / {total_params:,} ({trainable_params/total_params:.2%})")
# Setup optimizer - only train unfrozen parameters
optimizer = AdamW([p for p in pipeline.unet.parameters() if p.requires_grad], lr=LR)
# Training loop
for epoch in range(NUM_EPOCHS):
pipeline.unet.train()
# Prepare for backward pass
optimizer.zero_grad()
latents = torch.randn((6,21,8,72,72), dtype=torch.float16).to(DEVICE)
t = 0.123
encoder_hidden_states = torch.randn((126,1,1024), dtype=torch.float16).to(DEVICE)
added_tim_ids = torch.randn((6,21), dtype=torch.float16).to(DEVICE)
target_noise = torch.randn((6,21,8,72,72), dtype=torch.float16).to(DEVICE)
noise_pred = pipeline.unet(
latents,
t,
encoder_hidden_states=encoder_hidden_states,
added_time_ids=[added_tim_ids],
)
print(noise_pred.shape)
# loss = F.mse_loss(noise_pred, target_noise)
# Backward pass and optimizer step
# loss.backward()
# optimizer.step()
if __name__ == "__main__":
train()