InteriorFusion / scripts /train_vae.py
stevee00's picture
Upload scripts/train_vae.py
44963e7 verified
"""Stage 1: SLAT-Interior VAE Pre-training."""
import os
import sys
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from accelerate import Accelerator
from omegaconf import OmegaConf
from tqdm import tqdm
def main():
# Load config
config_path = sys.argv[1] if len(sys.argv) > 1 else "configs/vae_pretrain.yaml"
config = OmegaConf.load(config_path)
# Initialize accelerator
accelerator = Accelerator(
mixed_precision="bf16",
gradient_accumulation_steps=config.training.gradient_accumulation,
)
device = accelerator.device
# Build model
from interiorfusion.models.slat_vae import SLATInteriorVAE
model = SLATInteriorVAE(
latent_dim=config.model.latent_dim,
base_resolution=config.model.base_resolution,
)
# Optimizer
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.optimizer.lr,
weight_decay=config.optimizer.weight_decay,
betas=tuple(config.optimizer.betas),
)
# Scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=config.scheduler.warmup_steps,
T_mult=2,
)
# Data loader
from interiorfusion.data.dataset import InteriorFusionDataset
dataset = InteriorFusionDataset(
root=config.data.dataset,
split="train",
resolution=config.model.base_resolution,
)
dataloader = DataLoader(
dataset,
batch_size=config.training.batch_size,
shuffle=True,
num_workers=config.data.num_workers,
pin_memory=config.data.pin_memory,
)
# Prepare with accelerator
model, optimizer, dataloader, scheduler = accelerator.prepare(
model, optimizer, dataloader, scheduler
)
# Training loop
global_step = 0
for epoch in range(1000):
model.train()
epoch_loss = 0.0
for batch in tqdm(dataloader, desc=f"Epoch {epoch}"):
with accelerator.accumulate(model):
# Forward
occupancy = batch["occupancy"] # [B, 1, N, N, N]
materials = batch["materials"] # [B, 4, N, N, N]
depth = batch["depth"] # [B, 1, N, N, N]
normal = batch["normal"] # [B, 3, N, N, N]
# Encode
z, mu, logvar = model.encode(occupancy, materials)
# Decode
pred_shape, pred_material = model.decode(z)
# Decode depth and normal from shape
pred_depth = model.predict_depth(pred_shape)
pred_normal = model.predict_normal(pred_shape)
# Losses
loss_recon = F.l1_loss(pred_shape, occupancy) + \
F.l1_loss(pred_material, materials)
loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
loss_kl = loss_kl * config.loss.kl_divergence.weight
loss_depth = F.l1_loss(pred_depth, depth) * config.loss.depth_consistency.weight
loss_normal = (1 - F.cosine_similarity(
pred_normal, normal, dim=1
).mean()) * config.loss.normal_consistency.weight
loss = loss_recon + loss_kl + loss_depth + loss_normal
# Backward
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
global_step += 1
epoch_loss += loss.item()
# Logging
if global_step % 100 == 0:
accelerator.print(
f"Step {global_step}: "
f"loss={loss.item():.4f} "
f"recon={loss_recon.item():.4f} "
f"kl={loss_kl.item():.4f} "
f"depth={loss_depth.item():.4f} "
f"normal={loss_normal.item():.4f}"
)
# Checkpoint
if global_step % 5000 == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unwrapped_model = accelerator.unwrap_model(model)
checkpoint_path = f"checkpoints/vae_step{global_step}.pt"
os.makedirs("checkpoints", exist_ok=True)
torch.save({
"model": unwrapped_model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"step": global_step,
"config": OmegaConf.to_container(config),
}, checkpoint_path)
print(f"Saved checkpoint: {checkpoint_path}")
# Early stopping on step limit
if global_step >= config.training.max_steps:
accelerator.print("Reached max steps. Training complete.")
return
avg_loss = epoch_loss / len(dataloader)
accelerator.print(f"Epoch {epoch} complete. Avg loss: {avg_loss:.4f}")
if __name__ == "__main__":
main()