"""Stage 1: SLAT-Interior VAE Pre-training.""" import os import sys from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from accelerate import Accelerator from omegaconf import OmegaConf from tqdm import tqdm def main(): # Load config config_path = sys.argv[1] if len(sys.argv) > 1 else "configs/vae_pretrain.yaml" config = OmegaConf.load(config_path) # Initialize accelerator accelerator = Accelerator( mixed_precision="bf16", gradient_accumulation_steps=config.training.gradient_accumulation, ) device = accelerator.device # Build model from interiorfusion.models.slat_vae import SLATInteriorVAE model = SLATInteriorVAE( latent_dim=config.model.latent_dim, base_resolution=config.model.base_resolution, ) # Optimizer optimizer = torch.optim.AdamW( model.parameters(), lr=config.optimizer.lr, weight_decay=config.optimizer.weight_decay, betas=tuple(config.optimizer.betas), ) # Scheduler scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=config.scheduler.warmup_steps, T_mult=2, ) # Data loader from interiorfusion.data.dataset import InteriorFusionDataset dataset = InteriorFusionDataset( root=config.data.dataset, split="train", resolution=config.model.base_resolution, ) dataloader = DataLoader( dataset, batch_size=config.training.batch_size, shuffle=True, num_workers=config.data.num_workers, pin_memory=config.data.pin_memory, ) # Prepare with accelerator model, optimizer, dataloader, scheduler = accelerator.prepare( model, optimizer, dataloader, scheduler ) # Training loop global_step = 0 for epoch in range(1000): model.train() epoch_loss = 0.0 for batch in tqdm(dataloader, desc=f"Epoch {epoch}"): with accelerator.accumulate(model): # Forward occupancy = batch["occupancy"] # [B, 1, N, N, N] materials = batch["materials"] # [B, 4, N, N, N] depth = batch["depth"] # [B, 1, N, N, N] normal = batch["normal"] # [B, 3, N, N, N] # Encode z, mu, logvar = model.encode(occupancy, materials) # Decode pred_shape, pred_material = model.decode(z) # Decode depth and normal from shape pred_depth = model.predict_depth(pred_shape) pred_normal = model.predict_normal(pred_shape) # Losses loss_recon = F.l1_loss(pred_shape, occupancy) + \ F.l1_loss(pred_material, materials) loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) loss_kl = loss_kl * config.loss.kl_divergence.weight loss_depth = F.l1_loss(pred_depth, depth) * config.loss.depth_consistency.weight loss_normal = (1 - F.cosine_similarity( pred_normal, normal, dim=1 ).mean()) * config.loss.normal_consistency.weight loss = loss_recon + loss_kl + loss_depth + loss_normal # Backward accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad() global_step += 1 epoch_loss += loss.item() # Logging if global_step % 100 == 0: accelerator.print( f"Step {global_step}: " f"loss={loss.item():.4f} " f"recon={loss_recon.item():.4f} " f"kl={loss_kl.item():.4f} " f"depth={loss_depth.item():.4f} " f"normal={loss_normal.item():.4f}" ) # Checkpoint if global_step % 5000 == 0: accelerator.wait_for_everyone() if accelerator.is_main_process: unwrapped_model = accelerator.unwrap_model(model) checkpoint_path = f"checkpoints/vae_step{global_step}.pt" os.makedirs("checkpoints", exist_ok=True) torch.save({ "model": unwrapped_model.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), "step": global_step, "config": OmegaConf.to_container(config), }, checkpoint_path) print(f"Saved checkpoint: {checkpoint_path}") # Early stopping on step limit if global_step >= config.training.max_steps: accelerator.print("Reached max steps. Training complete.") return avg_loss = epoch_loss / len(dataloader) accelerator.print(f"Epoch {epoch} complete. Avg loss: {avg_loss:.4f}") if __name__ == "__main__": main()