| # InteriorFusion Training Guide |
|
|
| ## Hardware Requirements |
|
|
| | Stage | GPUs | VRAM Each | Duration | Cost (Cloud) | |
| |-------|------|-----------|----------|-------------| |
| | VAE Pre-training | 8× A100 (80GB) | 80GB | 7 days | ~$15K | |
| | Structure DiT | 32× A100 (80GB) | 80GB | 14 days | ~$30K | |
| | Material DiT | 16× A100 (80GB) | 80GB | 7 days | ~$15K | |
| | Fine-tuning | 8× A100 (80GB) | 80GB | 3 days | ~$5K | |
| | **Total** | **Variable** | — | **~4 weeks** | **~$65K** | |
|
|
| Minimum viable: 8× A100 (all stages, longer duration) |
| Budget option: 8× RTX 4090 (48GB) — requires gradient accumulation, ~2× longer |
|
|
| ## Stage 1: SLAT-Interior VAE Pre-training |
|
|
| ### Architecture |
| - **Encoder**: Sparse 3D convolutional U-Net |
| - Input: Dense occupancy grid O ∈ {0,1}^N³ where N=256/512/1024 |
| - Sparse convolution layers with channel-to-space shortcuts |
| - 16× spatial compression (1024³ → 64³ latent) |
| |
| - **Decoder**: |
| - Sparse conv upsampler with skip connections |
| - Early-pruning: predict binary mask for active children before upsampling |
| - Outputs: per-voxel shape features + material features |
|
|
| ### Training Configuration |
| ```yaml |
| # configs/vae_pretrain.yaml |
| model: |
| latent_dim: 64 |
| base_resolution: 256 |
| target_resolution: 1024 |
| |
| optimizer: |
| type: AdamW |
| lr: 1.0e-4 |
| weight_decay: 0.01 |
| betas: [0.9, 0.999] |
| |
| scheduler: |
| type: cosine_with_restarts |
| warmup_steps: 10000 |
| |
| training: |
| batch_size: 8 # per GPU |
| num_gpus: 8 |
| effective_batch_size: 64 |
| max_steps: 200000 |
| gradient_accumulation: 1 |
| mixed_precision: bf16 |
| |
| curriculum: |
| - resolution: 256 |
| steps: 50000 |
| lr: 1.0e-4 |
| - resolution: 512 |
| steps: 100000 |
| lr: 1.0e-4 |
| - resolution: 1024 |
| steps: 50000 |
| lr: 5.0e-5 |
| |
| data: |
| dataset: InteriorFusion-Train |
| num_workers: 8 |
| pin_memory: true |
| |
| loss: |
| reconstruction: |
| weight: 1.0 |
| type: l1 |
| kl_divergence: |
| weight: 1.0e-3 |
| depth_consistency: |
| weight: 0.5 |
| type: l1 |
| normal_consistency: |
| weight: 0.3 |
| type: cosine |
| edge_preservation: |
| weight: 0.2 |
| type: l1 |
| ``` |
|
|
| ### Loss Functions |
|
|
| ```python |
| def vae_loss(pred_shape, pred_material, target_shape, target_material, |
| pred_depth, target_depth, pred_normal, target_normal, mu, logvar): |
| |
| # Reconstruction |
| loss_recon = F.l1_loss(pred_shape, target_shape) + \ |
| F.l1_loss(pred_material, target_material) |
| |
| # KL divergence |
| loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) |
| loss_kl = loss_kl * 1e-3 |
| |
| # Depth consistency |
| loss_depth = F.l1_loss(pred_depth, target_depth) |
| |
| # Normal consistency |
| loss_normal = 1 - F.cosine_similarity(pred_normal, target_normal, dim=-1).mean() |
| |
| return loss_recon + loss_kl + 0.5 * loss_depth + 0.3 * loss_normal |
| ``` |
|
|
| ## Stage 2: Structure DiT (Rectified Flow) |
|
|
| ### Architecture |
| - **DiT model**: Flow-matching transformer |
| - Width: 1536 |
| - Depth: 30 blocks |
| - Heads: 12 |
| - MLP ratio: 8192 |
| - Parameters: ~1.3B |
| |
| - **Conditioning encoders**: |
| - Image: DINOv3-L (frozen, 1024-dim features) |
| - Depth: Custom CNN encoder (256-dim) |
| - Layout: Transformer encoder on SpatialLM tokens (512-dim) |
| - Semantic: Mask2Former feature pyramid (256-dim) |
| |
| - **Conditioning fusion**: Cross-attention + AdaLN-single modulation |
|
|
| ### Training Configuration |
| ```yaml |
| # configs/dit_structure.yaml |
| model: |
| width: 1536 |
| depth: 30 |
| num_heads: 12 |
| mlp_ratio: 8192 |
| |
| optimizer: |
| type: AdamW |
| lr: 1.0e-4 |
| weight_decay: 0.01 |
| |
| scheduler: |
| type: linear_warmup_cosine |
| warmup_steps: 10000 |
| |
| training: |
| batch_size: 8 # per GPU |
| num_gpus: 32 |
| effective_batch_size: 256 |
| max_steps: 400000 |
| mixed_precision: bf16 |
| |
| curriculum: |
| - resolution: 256 |
| steps: 100000 |
| lr: 1.0e-4 |
| - resolution: 512 |
| steps: 200000 |
| lr: 1.0e-4 |
| - resolution: 1024 |
| steps: 100000 |
| lr: 2.0e-5 |
| |
| data: |
| dataset: InteriorFusion-Train |
| num_workers: 8 |
| |
| flow_matching: |
| sigma_min: 0.001 |
| sigma_max: 80.0 |
| p_mean: -1.2 |
| p_std: 1.2 |
| |
| loss: |
| flow_matching: |
| weight: 1.0 |
| depth_guidance: |
| weight: 0.3 |
| ``` |
|
|
| ### Flow Matching Loss |
|
|
| ```python |
| def flow_matching_loss(model, x_1, cond_img, cond_depth, cond_layout, cond_semantic): |
| """ |
| Rectified flow matching for 3D generation. |
| x_1: target structured latent (from VAE encoder) |
| """ |
| # Sample noise |
| x_0 = torch.randn_like(x_1) |
| |
| # Sample timestep |
| t = torch.rand(x_1.shape[0], device=x_1.device) |
| |
| # Interpolate |
| x_t = (1 - t[:, None, None, None]) * x_0 + t[:, None, None, None] * x_1 |
| |
| # Model predicts velocity |
| v_pred = model(x_t, t, cond_img, cond_depth, cond_layout, cond_semantic) |
| |
| # Target velocity |
| v_target = x_1 - x_0 |
| |
| # MSE loss |
| loss = F.mse_loss(v_pred, v_target) |
| |
| return loss |
| ``` |
|
|
| ## Stage 3: Material DiT |
|
|
| ### Architecture |
| - Same DiT backbone as Stage 2 |
| - Additional conditioning: generated geometry latent |
| - Output: per-voxel material features (albedo RGB, metallic, roughness, normal XYZ) |
|
|
| ### Training |
| ```yaml |
| # configs/dit_material.yaml |
| training: |
| batch_size: 16 # per GPU |
| num_gpus: 16 |
| effective_batch_size: 256 |
| max_steps: 200000 |
| |
| loss: |
| albedo: |
| weight: 1.0 |
| type: l1 |
| metallic_roughness: |
| weight: 0.5 |
| type: l1 |
| normal: |
| weight: 0.5 |
| type: cosine |
| perceptual: |
| weight: 0.3 |
| type: lpips |
| network: vgg |
| rendering: |
| weight: 0.5 |
| type: mse # rendered vs ground truth |
| ``` |
|
|
| ## Stage 4: Real-World Fine-tuning |
|
|
| ### LoRA Configuration |
| ```yaml |
| # configs/finetune_lora.yaml |
| lora: |
| rank: 32 |
| alpha: 32 |
| target_modules: |
| - "attention.qkv" |
| - "attention.proj" |
| - "mlp.fc1" |
| - "mlp.fc2" |
| dropout: 0.0 |
| |
| training: |
| batch_size: 4 |
| num_gpus: 8 |
| max_steps: 50000 |
| lr: 1.0e-5 |
| |
| data: |
| dataset: InteriorFusion-Real # ScanNet + HM3D |
| weight: 1.0 |
| ``` |
|
|
| ### RL Fine-tuning (Optional) |
| ```yaml |
| # configs/rl_finetune.yaml |
| rl: |
| algorithm: GRPO |
| group_size: 8 |
| reward_weights: |
| depth_consistency: 0.25 |
| point_cloud_consistency: 0.25 |
| pose_stability: 0.25 |
| edit_quality: 0.25 |
| |
| vggt_model: "microsoft/VGGT-1B" # For geometric rewards |
| |
| training: |
| num_iterations: 10000 |
| lr: 1.0e-6 |
| kl_penalty: 0.01 |
| ``` |
|
|
| ## Distributed Training |
|
|
| ### Using Accelerate / DeepSpeed |
| ```bash |
| # Launch with DeepSpeed ZeRO-3 |
| accelerate launch --config_file configs/accelerate_deepspeed.yaml \ |
| scripts/train_vae.py --config configs/vae_pretrain.yaml |
| ``` |
|
|
| ```yaml |
| # configs/accelerate_deepspeed.yaml |
| deep_speed_config: |
| zero_stage: 3 |
| offload_optimizer_device: none |
| offload_param_device: none |
| gradient_accumulation_steps: 1 |
| gradient_clipping: 1.0 |
| train_batch_size: auto |
| train_micro_batch_size_per_gpu: auto |
| ``` |
|
|
| ### LR Scaling for Distributed Training |
| Following Grendel-GS: |
| ```python |
| def scale_lr_for_distributed(base_lr, batch_size): |
| """Square root scaling for distributed training.""" |
| return base_lr * math.sqrt(batch_size) |
| |
| def scale_adam_betas_for_distributed(beta1, beta2, batch_size): |
| """Exponential momentum scaling.""" |
| return beta1 ** batch_size, beta2 ** batch_size |
| ``` |
|
|
| ## Checkpointing & Resumption |
|
|
| ```python |
| checkpoint = { |
| 'model': model.state_dict(), |
| 'optimizer': optimizer.state_dict(), |
| 'scheduler': scheduler.state_dict(), |
| 'step': step, |
| 'epoch': epoch, |
| 'best_val_loss': best_val_loss, |
| 'config': OmegaConf.to_container(config), |
| } |
| |
| torch.save(checkpoint, f'checkpoints/stage1_step{step}.pt') |
| ``` |
|
|
| ## Validation Metrics |
|
|
| | Metric | Target | How to Compute | |
| |--------|--------|---------------| |
| | Chamfer Distance | < 0.01 | Point cloud comparison | |
| | F-Score @ 0.1 | > 0.80 | Precision/recall on surface | |
| | LPIPS | < 0.06 | Perceptual similarity | |
| | PSNR | > 28 | Rendering quality | |
| | SSIM | > 0.90 | Structural similarity | |
| | Layout IoU | > 0.85 | Room layout accuracy | |
| | Object Detection mAP | > 0.70 | Furniture detection | |
| | Scale Error | < 5% | Metric depth consistency | |
|
|