| """Stage 1: SLAT-Interior VAE Pre-training.""" |
|
|
| import os |
| import sys |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
| from accelerate import Accelerator |
| from omegaconf import OmegaConf |
| from tqdm import tqdm |
|
|
|
|
| def main(): |
| |
| config_path = sys.argv[1] if len(sys.argv) > 1 else "configs/vae_pretrain.yaml" |
| config = OmegaConf.load(config_path) |
| |
| |
| accelerator = Accelerator( |
| mixed_precision="bf16", |
| gradient_accumulation_steps=config.training.gradient_accumulation, |
| ) |
| |
| device = accelerator.device |
| |
| |
| from interiorfusion.models.slat_vae import SLATInteriorVAE |
| model = SLATInteriorVAE( |
| latent_dim=config.model.latent_dim, |
| base_resolution=config.model.base_resolution, |
| ) |
| |
| |
| optimizer = torch.optim.AdamW( |
| model.parameters(), |
| lr=config.optimizer.lr, |
| weight_decay=config.optimizer.weight_decay, |
| betas=tuple(config.optimizer.betas), |
| ) |
| |
| |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( |
| optimizer, |
| T_0=config.scheduler.warmup_steps, |
| T_mult=2, |
| ) |
| |
| |
| from interiorfusion.data.dataset import InteriorFusionDataset |
| dataset = InteriorFusionDataset( |
| root=config.data.dataset, |
| split="train", |
| resolution=config.model.base_resolution, |
| ) |
| dataloader = DataLoader( |
| dataset, |
| batch_size=config.training.batch_size, |
| shuffle=True, |
| num_workers=config.data.num_workers, |
| pin_memory=config.data.pin_memory, |
| ) |
| |
| |
| model, optimizer, dataloader, scheduler = accelerator.prepare( |
| model, optimizer, dataloader, scheduler |
| ) |
| |
| |
| global_step = 0 |
| for epoch in range(1000): |
| model.train() |
| epoch_loss = 0.0 |
| |
| for batch in tqdm(dataloader, desc=f"Epoch {epoch}"): |
| with accelerator.accumulate(model): |
| |
| occupancy = batch["occupancy"] |
| materials = batch["materials"] |
| depth = batch["depth"] |
| normal = batch["normal"] |
| |
| |
| z, mu, logvar = model.encode(occupancy, materials) |
| |
| |
| pred_shape, pred_material = model.decode(z) |
| |
| |
| pred_depth = model.predict_depth(pred_shape) |
| pred_normal = model.predict_normal(pred_shape) |
| |
| |
| loss_recon = F.l1_loss(pred_shape, occupancy) + \ |
| F.l1_loss(pred_material, materials) |
| |
| loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) |
| loss_kl = loss_kl * config.loss.kl_divergence.weight |
| |
| loss_depth = F.l1_loss(pred_depth, depth) * config.loss.depth_consistency.weight |
| |
| loss_normal = (1 - F.cosine_similarity( |
| pred_normal, normal, dim=1 |
| ).mean()) * config.loss.normal_consistency.weight |
| |
| loss = loss_recon + loss_kl + loss_depth + loss_normal |
| |
| |
| accelerator.backward(loss) |
| |
| if accelerator.sync_gradients: |
| accelerator.clip_grad_norm_(model.parameters(), 1.0) |
| |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| |
| global_step += 1 |
| epoch_loss += loss.item() |
| |
| |
| if global_step % 100 == 0: |
| accelerator.print( |
| f"Step {global_step}: " |
| f"loss={loss.item():.4f} " |
| f"recon={loss_recon.item():.4f} " |
| f"kl={loss_kl.item():.4f} " |
| f"depth={loss_depth.item():.4f} " |
| f"normal={loss_normal.item():.4f}" |
| ) |
| |
| |
| if global_step % 5000 == 0: |
| accelerator.wait_for_everyone() |
| if accelerator.is_main_process: |
| unwrapped_model = accelerator.unwrap_model(model) |
| checkpoint_path = f"checkpoints/vae_step{global_step}.pt" |
| os.makedirs("checkpoints", exist_ok=True) |
| torch.save({ |
| "model": unwrapped_model.state_dict(), |
| "optimizer": optimizer.state_dict(), |
| "scheduler": scheduler.state_dict(), |
| "step": global_step, |
| "config": OmegaConf.to_container(config), |
| }, checkpoint_path) |
| print(f"Saved checkpoint: {checkpoint_path}") |
| |
| |
| if global_step >= config.training.max_steps: |
| accelerator.print("Reached max steps. Training complete.") |
| return |
| |
| avg_loss = epoch_loss / len(dataloader) |
| accelerator.print(f"Epoch {epoch} complete. Avg loss: {avg_loss:.4f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|