stevee00 commited on
Commit
44963e7
·
verified ·
1 Parent(s): 7c0e853

Upload scripts/train_vae.py

Browse files
Files changed (1) hide show
  1. scripts/train_vae.py +160 -0
scripts/train_vae.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stage 1: SLAT-Interior VAE Pre-training."""
2
+
3
+ import os
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.utils.data import DataLoader
11
+ from accelerate import Accelerator
12
+ from omegaconf import OmegaConf
13
+ from tqdm import tqdm
14
+
15
+
16
+ def main():
17
+ # Load config
18
+ config_path = sys.argv[1] if len(sys.argv) > 1 else "configs/vae_pretrain.yaml"
19
+ config = OmegaConf.load(config_path)
20
+
21
+ # Initialize accelerator
22
+ accelerator = Accelerator(
23
+ mixed_precision="bf16",
24
+ gradient_accumulation_steps=config.training.gradient_accumulation,
25
+ )
26
+
27
+ device = accelerator.device
28
+
29
+ # Build model
30
+ from interiorfusion.models.slat_vae import SLATInteriorVAE
31
+ model = SLATInteriorVAE(
32
+ latent_dim=config.model.latent_dim,
33
+ base_resolution=config.model.base_resolution,
34
+ )
35
+
36
+ # Optimizer
37
+ optimizer = torch.optim.AdamW(
38
+ model.parameters(),
39
+ lr=config.optimizer.lr,
40
+ weight_decay=config.optimizer.weight_decay,
41
+ betas=tuple(config.optimizer.betas),
42
+ )
43
+
44
+ # Scheduler
45
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
46
+ optimizer,
47
+ T_0=config.scheduler.warmup_steps,
48
+ T_mult=2,
49
+ )
50
+
51
+ # Data loader
52
+ from interiorfusion.data.dataset import InteriorFusionDataset
53
+ dataset = InteriorFusionDataset(
54
+ root=config.data.dataset,
55
+ split="train",
56
+ resolution=config.model.base_resolution,
57
+ )
58
+ dataloader = DataLoader(
59
+ dataset,
60
+ batch_size=config.training.batch_size,
61
+ shuffle=True,
62
+ num_workers=config.data.num_workers,
63
+ pin_memory=config.data.pin_memory,
64
+ )
65
+
66
+ # Prepare with accelerator
67
+ model, optimizer, dataloader, scheduler = accelerator.prepare(
68
+ model, optimizer, dataloader, scheduler
69
+ )
70
+
71
+ # Training loop
72
+ global_step = 0
73
+ for epoch in range(1000):
74
+ model.train()
75
+ epoch_loss = 0.0
76
+
77
+ for batch in tqdm(dataloader, desc=f"Epoch {epoch}"):
78
+ with accelerator.accumulate(model):
79
+ # Forward
80
+ occupancy = batch["occupancy"] # [B, 1, N, N, N]
81
+ materials = batch["materials"] # [B, 4, N, N, N]
82
+ depth = batch["depth"] # [B, 1, N, N, N]
83
+ normal = batch["normal"] # [B, 3, N, N, N]
84
+
85
+ # Encode
86
+ z, mu, logvar = model.encode(occupancy, materials)
87
+
88
+ # Decode
89
+ pred_shape, pred_material = model.decode(z)
90
+
91
+ # Decode depth and normal from shape
92
+ pred_depth = model.predict_depth(pred_shape)
93
+ pred_normal = model.predict_normal(pred_shape)
94
+
95
+ # Losses
96
+ loss_recon = F.l1_loss(pred_shape, occupancy) + \
97
+ F.l1_loss(pred_material, materials)
98
+
99
+ loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
100
+ loss_kl = loss_kl * config.loss.kl_divergence.weight
101
+
102
+ loss_depth = F.l1_loss(pred_depth, depth) * config.loss.depth_consistency.weight
103
+
104
+ loss_normal = (1 - F.cosine_similarity(
105
+ pred_normal, normal, dim=1
106
+ ).mean()) * config.loss.normal_consistency.weight
107
+
108
+ loss = loss_recon + loss_kl + loss_depth + loss_normal
109
+
110
+ # Backward
111
+ accelerator.backward(loss)
112
+
113
+ if accelerator.sync_gradients:
114
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
115
+
116
+ optimizer.step()
117
+ scheduler.step()
118
+ optimizer.zero_grad()
119
+
120
+ global_step += 1
121
+ epoch_loss += loss.item()
122
+
123
+ # Logging
124
+ if global_step % 100 == 0:
125
+ accelerator.print(
126
+ f"Step {global_step}: "
127
+ f"loss={loss.item():.4f} "
128
+ f"recon={loss_recon.item():.4f} "
129
+ f"kl={loss_kl.item():.4f} "
130
+ f"depth={loss_depth.item():.4f} "
131
+ f"normal={loss_normal.item():.4f}"
132
+ )
133
+
134
+ # Checkpoint
135
+ if global_step % 5000 == 0:
136
+ accelerator.wait_for_everyone()
137
+ if accelerator.is_main_process:
138
+ unwrapped_model = accelerator.unwrap_model(model)
139
+ checkpoint_path = f"checkpoints/vae_step{global_step}.pt"
140
+ os.makedirs("checkpoints", exist_ok=True)
141
+ torch.save({
142
+ "model": unwrapped_model.state_dict(),
143
+ "optimizer": optimizer.state_dict(),
144
+ "scheduler": scheduler.state_dict(),
145
+ "step": global_step,
146
+ "config": OmegaConf.to_container(config),
147
+ }, checkpoint_path)
148
+ print(f"Saved checkpoint: {checkpoint_path}")
149
+
150
+ # Early stopping on step limit
151
+ if global_step >= config.training.max_steps:
152
+ accelerator.print("Reached max steps. Training complete.")
153
+ return
154
+
155
+ avg_loss = epoch_loss / len(dataloader)
156
+ accelerator.print(f"Epoch {epoch} complete. Avg loss: {avg_loss:.4f}")
157
+
158
+
159
+ if __name__ == "__main__":
160
+ main()