AlexWortega commited on
Commit
ccdcfe1
·
verified ·
1 Parent(s): 889bf64

Upload train_diffusion.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_diffusion.py +287 -0
train_diffusion.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Training script for conditional DDPM on The Well datasets.
4
+ Includes periodic evaluation with WandB video logging.
5
+
6
+ Usage:
7
+ python train_diffusion.py --dataset turbulent_radiative_layer_2D --wandb
8
+ python train_diffusion.py --dataset active_matter --batch_size 4 --wandb
9
+ """
10
+ import argparse
11
+ import logging
12
+ import math
13
+ import os
14
+ import time
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from torch.amp import GradScaler, autocast
19
+ from tqdm import tqdm
20
+
21
+ from data_pipeline import create_dataloader, prepare_batch, get_channel_info
22
+ from unet import UNet
23
+ from diffusion import GaussianDiffusion
24
+
25
+ # --- logging setup (suppress noisy library logs) ---
26
+ logging.basicConfig(level=logging.WARNING)
27
+ logger = logging.getLogger("train_diffusion")
28
+ logger.setLevel(logging.INFO)
29
+ _h = logging.StreamHandler()
30
+ _h.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S"))
31
+ logger.addHandler(_h)
32
+ logger.propagate = False
33
+
34
+ # Also let eval_utils log through us
35
+ logging.getLogger("eval_utils").setLevel(logging.INFO)
36
+ logging.getLogger("eval_utils").addHandler(_h)
37
+ logging.getLogger("eval_utils").propagate = False
38
+
39
+
40
+ def cosine_lr(step, warmup, total, base_lr, min_lr=1e-6):
41
+ if step < warmup:
42
+ return base_lr * step / max(warmup, 1)
43
+ progress = (step - warmup) / max(total - warmup, 1)
44
+ return min_lr + 0.5 * (base_lr - min_lr) * (1 + math.cos(progress * math.pi))
45
+
46
+
47
+ def train(args):
48
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+ logger.info(f"Device: {device}")
50
+
51
+ # ---- WandB ----
52
+ wandb_run = None
53
+ if args.wandb:
54
+ import wandb
55
+
56
+ wandb_run = wandb.init(
57
+ project="the-well-diffusion",
58
+ name=f"{args.dataset}_bs{args.batch_size}_lr{args.lr}",
59
+ config=vars(args),
60
+ )
61
+ logger.info(f"WandB run: {wandb_run.url}")
62
+
63
+ # ---- Data: train ----
64
+ logger.info(f"Loading training data: {args.dataset} (streaming={args.streaming})")
65
+ train_loader, train_dataset = create_dataloader(
66
+ dataset_name=args.dataset,
67
+ split="train",
68
+ batch_size=args.batch_size,
69
+ n_steps_input=args.n_input,
70
+ n_steps_output=args.n_output,
71
+ num_workers=args.workers,
72
+ streaming=args.streaming,
73
+ local_path=args.local_path,
74
+ )
75
+
76
+ ch_info = get_channel_info(train_dataset)
77
+ logger.info(f"Channel info: {ch_info}")
78
+
79
+ c_in = ch_info["input_channels"]
80
+ c_out = ch_info["output_channels"]
81
+
82
+ # ---- Data: validation (single-step) ----
83
+ logger.info("Loading validation data...")
84
+ val_loader, _ = create_dataloader(
85
+ dataset_name=args.dataset,
86
+ split="valid",
87
+ batch_size=args.batch_size,
88
+ n_steps_input=args.n_input,
89
+ n_steps_output=args.n_output,
90
+ num_workers=0,
91
+ streaming=args.streaming,
92
+ local_path=args.local_path,
93
+ )
94
+
95
+ # ---- Data: rollout validation (multi-step output for GT comparison) ----
96
+ logger.info(f"Loading rollout data (n_steps_output={args.n_rollout})...")
97
+ rollout_loader, _ = create_dataloader(
98
+ dataset_name=args.dataset,
99
+ split="valid",
100
+ batch_size=1,
101
+ n_steps_input=args.n_input,
102
+ n_steps_output=args.n_rollout,
103
+ num_workers=0,
104
+ streaming=args.streaming,
105
+ local_path=args.local_path,
106
+ )
107
+
108
+ # ---- Model ----
109
+ unet = UNet(
110
+ in_channels=c_out + c_in,
111
+ out_channels=c_out,
112
+ base_ch=args.base_ch,
113
+ ch_mults=tuple(args.ch_mults),
114
+ n_res=args.n_res,
115
+ attn_levels=tuple(args.attn_levels),
116
+ dropout=args.dropout,
117
+ )
118
+ diffusion = GaussianDiffusion(unet, timesteps=args.timesteps).to(device)
119
+
120
+ n_params = sum(p.numel() for p in diffusion.parameters() if p.requires_grad)
121
+ logger.info(f"Model parameters: {n_params:,}")
122
+
123
+ if wandb_run:
124
+ wandb_run.summary["n_params"] = n_params
125
+
126
+ # ---- Optimizer ----
127
+ optimizer = torch.optim.AdamW(diffusion.parameters(), lr=args.lr, weight_decay=args.wd)
128
+ scaler = GradScaler("cuda", enabled=args.amp)
129
+
130
+ # ---- Checkpoint resume ----
131
+ start_epoch = 0
132
+ global_step = 0
133
+ if args.resume and os.path.exists(args.resume):
134
+ ckpt = torch.load(args.resume, map_location=device, weights_only=False)
135
+ diffusion.load_state_dict(ckpt["model"])
136
+ optimizer.load_state_dict(ckpt["optimizer"])
137
+ scaler.load_state_dict(ckpt["scaler"])
138
+ start_epoch = ckpt["epoch"] + 1
139
+ global_step = ckpt["global_step"]
140
+ logger.info(f"Resumed from epoch {start_epoch}, step {global_step}")
141
+
142
+ # ---- Training loop ----
143
+ os.makedirs(args.ckpt_dir, exist_ok=True)
144
+ total_steps = args.epochs * len(train_loader)
145
+
146
+ logger.info(f"Starting training: {args.epochs} epochs, ~{total_steps} steps")
147
+ logger.info(f"Eval every {args.eval_every} epochs, rollout {args.n_rollout} steps")
148
+
149
+ for epoch in range(start_epoch, args.epochs):
150
+ diffusion.train()
151
+ epoch_loss = 0.0
152
+ n_batches = 0
153
+ t0 = time.time()
154
+
155
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False)
156
+ for batch in pbar:
157
+ try:
158
+ x_cond, x_target = prepare_batch(batch, device)
159
+ except Exception as e:
160
+ logger.warning(f"Batch error: {e}, skipping")
161
+ continue
162
+
163
+ lr = cosine_lr(global_step, args.warmup, total_steps, args.lr)
164
+ for pg in optimizer.param_groups:
165
+ pg["lr"] = lr
166
+
167
+ optimizer.zero_grad(set_to_none=True)
168
+
169
+ with autocast(device_type="cuda", dtype=torch.bfloat16, enabled=args.amp):
170
+ loss = diffusion.training_loss(x_target, x_cond)
171
+
172
+ scaler.scale(loss).backward()
173
+ scaler.unscale_(optimizer)
174
+ nn.utils.clip_grad_norm_(diffusion.parameters(), args.grad_clip)
175
+ scaler.step(optimizer)
176
+ scaler.update()
177
+
178
+ epoch_loss += loss.item()
179
+ n_batches += 1
180
+ global_step += 1
181
+
182
+ pbar.set_postfix(loss=f"{loss.item():.4f}", lr=f"{lr:.2e}")
183
+
184
+ if wandb_run and global_step % 20 == 0:
185
+ wandb_run.log({"train/loss": loss.item(), "train/lr": lr}, step=global_step)
186
+
187
+ avg_loss = epoch_loss / max(n_batches, 1)
188
+ elapsed = time.time() - t0
189
+ logger.info(
190
+ f"Epoch {epoch}: loss={avg_loss:.4f}, batches={n_batches}, "
191
+ f"time={elapsed:.1f}s, lr={lr:.2e}"
192
+ )
193
+ if wandb_run:
194
+ wandb_run.log({"train/epoch_loss": avg_loss, "epoch": epoch}, step=global_step)
195
+
196
+ # ---- Evaluation with video logging ----
197
+ if (epoch + 1) % args.eval_every == 0:
198
+ from eval_utils import run_evaluation
199
+
200
+ logger.info("=" * 40)
201
+ logger.info(f"EVALUATION at epoch {epoch}")
202
+ logger.info("=" * 40)
203
+
204
+ eval_metrics = run_evaluation(
205
+ model=diffusion,
206
+ val_loader=val_loader,
207
+ rollout_loader=rollout_loader,
208
+ device=device,
209
+ global_step=global_step,
210
+ wandb_run=wandb_run,
211
+ n_val_batches=args.eval_batches,
212
+ n_rollout=args.n_rollout,
213
+ ddim_steps=args.ddim_steps,
214
+ )
215
+
216
+ logger.info(
217
+ f" val/mse={eval_metrics['val/mse']:.6f}, "
218
+ f"rollout_mse_mean={eval_metrics['val/rollout_mse_mean']:.6f}"
219
+ )
220
+ logger.info("=" * 40)
221
+
222
+ # ---- Checkpoint ----
223
+ if (epoch + 1) % args.save_every == 0 or epoch == args.epochs - 1:
224
+ ckpt_path = os.path.join(args.ckpt_dir, f"diffusion_ep{epoch:04d}.pt")
225
+ torch.save(
226
+ {
227
+ "epoch": epoch,
228
+ "global_step": global_step,
229
+ "model": diffusion.state_dict(),
230
+ "optimizer": optimizer.state_dict(),
231
+ "scaler": scaler.state_dict(),
232
+ "args": vars(args),
233
+ "ch_info": ch_info,
234
+ },
235
+ ckpt_path,
236
+ )
237
+ logger.info(f"Saved {ckpt_path}")
238
+
239
+ if wandb_run:
240
+ wandb_run.finish()
241
+ logger.info("Training complete.")
242
+
243
+
244
+ def main():
245
+ p = argparse.ArgumentParser(description="Train conditional DDPM on The Well")
246
+ # Data
247
+ p.add_argument("--dataset", default="turbulent_radiative_layer_2D")
248
+ p.add_argument("--streaming", action="store_true", default=True)
249
+ p.add_argument("--no-streaming", dest="streaming", action="store_false")
250
+ p.add_argument("--local_path", default=None)
251
+ p.add_argument("--batch_size", type=int, default=8)
252
+ p.add_argument("--workers", type=int, default=0)
253
+ p.add_argument("--n_input", type=int, default=1)
254
+ p.add_argument("--n_output", type=int, default=1)
255
+ # Model
256
+ p.add_argument("--base_ch", type=int, default=64)
257
+ p.add_argument("--ch_mults", type=int, nargs="+", default=[1, 2, 4, 8])
258
+ p.add_argument("--n_res", type=int, default=2)
259
+ p.add_argument("--attn_levels", type=int, nargs="+", default=[3])
260
+ p.add_argument("--dropout", type=float, default=0.1)
261
+ p.add_argument("--timesteps", type=int, default=1000)
262
+ # Optimization
263
+ p.add_argument("--lr", type=float, default=1e-4)
264
+ p.add_argument("--wd", type=float, default=0.01)
265
+ p.add_argument("--warmup", type=int, default=1000)
266
+ p.add_argument("--grad_clip", type=float, default=1.0)
267
+ p.add_argument("--amp", action="store_true", default=True)
268
+ p.add_argument("--no-amp", dest="amp", action="store_false")
269
+ p.add_argument("--epochs", type=int, default=100)
270
+ # Evaluation
271
+ p.add_argument("--eval_every", type=int, default=5, help="Eval every N epochs")
272
+ p.add_argument("--eval_batches", type=int, default=4, help="Val batches for MSE")
273
+ p.add_argument("--n_rollout", type=int, default=20, help="Rollout steps for video")
274
+ p.add_argument("--ddim_steps", type=int, default=50, help="DDIM steps for eval sampling")
275
+ # Checkpointing
276
+ p.add_argument("--ckpt_dir", default="checkpoints/diffusion")
277
+ p.add_argument("--save_every", type=int, default=5)
278
+ p.add_argument("--resume", default=None)
279
+ # Logging
280
+ p.add_argument("--wandb", action="store_true", default=False)
281
+
282
+ args = p.parse_args()
283
+ train(args)
284
+
285
+
286
+ if __name__ == "__main__":
287
+ main()