# InteriorFusion Training Guide ## Hardware Requirements | Stage | GPUs | VRAM Each | Duration | Cost (Cloud) | |-------|------|-----------|----------|-------------| | VAE Pre-training | 8× A100 (80GB) | 80GB | 7 days | ~$15K | | Structure DiT | 32× A100 (80GB) | 80GB | 14 days | ~$30K | | Material DiT | 16× A100 (80GB) | 80GB | 7 days | ~$15K | | Fine-tuning | 8× A100 (80GB) | 80GB | 3 days | ~$5K | | **Total** | **Variable** | — | **~4 weeks** | **~$65K** | Minimum viable: 8× A100 (all stages, longer duration) Budget option: 8× RTX 4090 (48GB) — requires gradient accumulation, ~2× longer ## Stage 1: SLAT-Interior VAE Pre-training ### Architecture - **Encoder**: Sparse 3D convolutional U-Net - Input: Dense occupancy grid O ∈ {0,1}^N³ where N=256/512/1024 - Sparse convolution layers with channel-to-space shortcuts - 16× spatial compression (1024³ → 64³ latent) - **Decoder**: - Sparse conv upsampler with skip connections - Early-pruning: predict binary mask for active children before upsampling - Outputs: per-voxel shape features + material features ### Training Configuration ```yaml # configs/vae_pretrain.yaml model: latent_dim: 64 base_resolution: 256 target_resolution: 1024 optimizer: type: AdamW lr: 1.0e-4 weight_decay: 0.01 betas: [0.9, 0.999] scheduler: type: cosine_with_restarts warmup_steps: 10000 training: batch_size: 8 # per GPU num_gpus: 8 effective_batch_size: 64 max_steps: 200000 gradient_accumulation: 1 mixed_precision: bf16 curriculum: - resolution: 256 steps: 50000 lr: 1.0e-4 - resolution: 512 steps: 100000 lr: 1.0e-4 - resolution: 1024 steps: 50000 lr: 5.0e-5 data: dataset: InteriorFusion-Train num_workers: 8 pin_memory: true loss: reconstruction: weight: 1.0 type: l1 kl_divergence: weight: 1.0e-3 depth_consistency: weight: 0.5 type: l1 normal_consistency: weight: 0.3 type: cosine edge_preservation: weight: 0.2 type: l1 ``` ### Loss Functions ```python def vae_loss(pred_shape, pred_material, target_shape, target_material, pred_depth, target_depth, pred_normal, target_normal, mu, logvar): # Reconstruction loss_recon = F.l1_loss(pred_shape, target_shape) + \ F.l1_loss(pred_material, target_material) # KL divergence loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) loss_kl = loss_kl * 1e-3 # Depth consistency loss_depth = F.l1_loss(pred_depth, target_depth) # Normal consistency loss_normal = 1 - F.cosine_similarity(pred_normal, target_normal, dim=-1).mean() return loss_recon + loss_kl + 0.5 * loss_depth + 0.3 * loss_normal ``` ## Stage 2: Structure DiT (Rectified Flow) ### Architecture - **DiT model**: Flow-matching transformer - Width: 1536 - Depth: 30 blocks - Heads: 12 - MLP ratio: 8192 - Parameters: ~1.3B - **Conditioning encoders**: - Image: DINOv3-L (frozen, 1024-dim features) - Depth: Custom CNN encoder (256-dim) - Layout: Transformer encoder on SpatialLM tokens (512-dim) - Semantic: Mask2Former feature pyramid (256-dim) - **Conditioning fusion**: Cross-attention + AdaLN-single modulation ### Training Configuration ```yaml # configs/dit_structure.yaml model: width: 1536 depth: 30 num_heads: 12 mlp_ratio: 8192 optimizer: type: AdamW lr: 1.0e-4 weight_decay: 0.01 scheduler: type: linear_warmup_cosine warmup_steps: 10000 training: batch_size: 8 # per GPU num_gpus: 32 effective_batch_size: 256 max_steps: 400000 mixed_precision: bf16 curriculum: - resolution: 256 steps: 100000 lr: 1.0e-4 - resolution: 512 steps: 200000 lr: 1.0e-4 - resolution: 1024 steps: 100000 lr: 2.0e-5 data: dataset: InteriorFusion-Train num_workers: 8 flow_matching: sigma_min: 0.001 sigma_max: 80.0 p_mean: -1.2 p_std: 1.2 loss: flow_matching: weight: 1.0 depth_guidance: weight: 0.3 ``` ### Flow Matching Loss ```python def flow_matching_loss(model, x_1, cond_img, cond_depth, cond_layout, cond_semantic): """ Rectified flow matching for 3D generation. x_1: target structured latent (from VAE encoder) """ # Sample noise x_0 = torch.randn_like(x_1) # Sample timestep t = torch.rand(x_1.shape[0], device=x_1.device) # Interpolate x_t = (1 - t[:, None, None, None]) * x_0 + t[:, None, None, None] * x_1 # Model predicts velocity v_pred = model(x_t, t, cond_img, cond_depth, cond_layout, cond_semantic) # Target velocity v_target = x_1 - x_0 # MSE loss loss = F.mse_loss(v_pred, v_target) return loss ``` ## Stage 3: Material DiT ### Architecture - Same DiT backbone as Stage 2 - Additional conditioning: generated geometry latent - Output: per-voxel material features (albedo RGB, metallic, roughness, normal XYZ) ### Training ```yaml # configs/dit_material.yaml training: batch_size: 16 # per GPU num_gpus: 16 effective_batch_size: 256 max_steps: 200000 loss: albedo: weight: 1.0 type: l1 metallic_roughness: weight: 0.5 type: l1 normal: weight: 0.5 type: cosine perceptual: weight: 0.3 type: lpips network: vgg rendering: weight: 0.5 type: mse # rendered vs ground truth ``` ## Stage 4: Real-World Fine-tuning ### LoRA Configuration ```yaml # configs/finetune_lora.yaml lora: rank: 32 alpha: 32 target_modules: - "attention.qkv" - "attention.proj" - "mlp.fc1" - "mlp.fc2" dropout: 0.0 training: batch_size: 4 num_gpus: 8 max_steps: 50000 lr: 1.0e-5 data: dataset: InteriorFusion-Real # ScanNet + HM3D weight: 1.0 ``` ### RL Fine-tuning (Optional) ```yaml # configs/rl_finetune.yaml rl: algorithm: GRPO group_size: 8 reward_weights: depth_consistency: 0.25 point_cloud_consistency: 0.25 pose_stability: 0.25 edit_quality: 0.25 vggt_model: "microsoft/VGGT-1B" # For geometric rewards training: num_iterations: 10000 lr: 1.0e-6 kl_penalty: 0.01 ``` ## Distributed Training ### Using Accelerate / DeepSpeed ```bash # Launch with DeepSpeed ZeRO-3 accelerate launch --config_file configs/accelerate_deepspeed.yaml \ scripts/train_vae.py --config configs/vae_pretrain.yaml ``` ```yaml # configs/accelerate_deepspeed.yaml deep_speed_config: zero_stage: 3 offload_optimizer_device: none offload_param_device: none gradient_accumulation_steps: 1 gradient_clipping: 1.0 train_batch_size: auto train_micro_batch_size_per_gpu: auto ``` ### LR Scaling for Distributed Training Following Grendel-GS: ```python def scale_lr_for_distributed(base_lr, batch_size): """Square root scaling for distributed training.""" return base_lr * math.sqrt(batch_size) def scale_adam_betas_for_distributed(beta1, beta2, batch_size): """Exponential momentum scaling.""" return beta1 ** batch_size, beta2 ** batch_size ``` ## Checkpointing & Resumption ```python checkpoint = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'step': step, 'epoch': epoch, 'best_val_loss': best_val_loss, 'config': OmegaConf.to_container(config), } torch.save(checkpoint, f'checkpoints/stage1_step{step}.pt') ``` ## Validation Metrics | Metric | Target | How to Compute | |--------|--------|---------------| | Chamfer Distance | < 0.01 | Point cloud comparison | | F-Score @ 0.1 | > 0.80 | Precision/recall on surface | | LPIPS | < 0.06 | Perceptual similarity | | PSNR | > 28 | Rendering quality | | SSIM | > 0.90 | Structural similarity | | Layout IoU | > 0.85 | Room layout accuracy | | Object Detection mAP | > 0.70 | Furniture detection | | Scale Error | < 5% | Metric depth consistency |