|
|
import os |
|
|
|
|
|
import torch |
|
|
from torch.optim import AdamW |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from diffusers_sv3d.pipelines.stable_video_diffusion.pipeline_stable_video_3d_diffusion import ( |
|
|
StableVideo3DDiffusionPipeline, |
|
|
) |
|
|
|
|
|
|
|
|
BATCH_SIZE = 1 |
|
|
LR = 1e-5 |
|
|
NUM_EPOCHS = 10 |
|
|
SAVE_DIR = "checkpoints" |
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
SV3D_PATH = os.path.abspath("/home/hubert/projects/sv3d-pbr/sv3d_diffusers/pretrained_sv3d") |
|
|
|
|
|
|
|
|
def train(): |
|
|
|
|
|
os.makedirs(SAVE_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
pipeline = StableVideo3DDiffusionPipeline.from_pretrained( |
|
|
SV3D_PATH, |
|
|
revision="fp16", |
|
|
torch_dtype=torch.float16, |
|
|
) |
|
|
pipeline.to(DEVICE) |
|
|
|
|
|
|
|
|
for param in pipeline.unet.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
for name, param in pipeline.unet.named_parameters(): |
|
|
if "down_blocks.2.resnets.0.spatial_res_block.conv1" in name: |
|
|
param.requires_grad = True |
|
|
print(f"Unfreezing: {name}") |
|
|
|
|
|
|
|
|
trainable_params = sum(p.numel() for p in pipeline.unet.parameters() if p.requires_grad) |
|
|
total_params = sum(p.numel() for p in pipeline.unet.parameters()) |
|
|
print(f"Trainable parameters: {trainable_params:,} / {total_params:,} ({trainable_params/total_params:.2%})") |
|
|
|
|
|
|
|
|
optimizer = AdamW([p for p in pipeline.unet.parameters() if p.requires_grad], lr=LR) |
|
|
|
|
|
|
|
|
for epoch in range(NUM_EPOCHS): |
|
|
pipeline.unet.train() |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
latents = torch.randn((6,21,8,72,72), dtype=torch.float16).to(DEVICE) |
|
|
t = 0.123 |
|
|
encoder_hidden_states = torch.randn((126,1,1024), dtype=torch.float16).to(DEVICE) |
|
|
added_tim_ids = torch.randn((6,21), dtype=torch.float16).to(DEVICE) |
|
|
target_noise = torch.randn((6,21,8,72,72), dtype=torch.float16).to(DEVICE) |
|
|
|
|
|
noise_pred = pipeline.unet( |
|
|
latents, |
|
|
t, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
added_time_ids=[added_tim_ids], |
|
|
) |
|
|
|
|
|
print(noise_pred.shape) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
train() |
|
|
|