haiphamcse commited on
Commit
be5a273
·
verified ·
1 Parent(s): f2c29a3

Upload train_jit.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_jit.py +334 -0
train_jit.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ CFM training with unconditional JiT (jit_model_unconditional.JiT).
4
+ Mirrors train_cfm_unet.py (data, TensorBoard, checkpoints); model + YAML differ.
5
+
6
+ JiT expects forward(x, t); torchdyn NeuralODE calls f(t, x) — use CFMFlowWrapper.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import argparse
12
+ import os
13
+ from dataclasses import dataclass
14
+ from pathlib import Path
15
+ from typing import Any
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torchvision
20
+ from torch.utils.data import DataLoader
21
+ from torch.utils.tensorboard import SummaryWriter
22
+ from torchvision.transforms import v2
23
+ from torchdyn.core import NeuralODE
24
+
25
+ from torchcfm.conditional_flow_matching import ConditionalFlowMatcher
26
+
27
+ from jit import JiT
28
+
29
+ try:
30
+ import yaml # type: ignore[import-untyped]
31
+ except ImportError as e: # pragma: no cover
32
+ raise ImportError("Please `pip install pyyaml` to use --config.") from e
33
+
34
+ # Reuse dataset helpers from UNet trainer (same CLI for data)
35
+ from train_unet import load_training_dataset
36
+
37
+
38
+ def parse_args() -> argparse.Namespace:
39
+ p = argparse.ArgumentParser(description="Train unconditional JiT with Conditional Flow Matching")
40
+
41
+ p.add_argument(
42
+ "--dataset",
43
+ type=str,
44
+ default="imagenette",
45
+ choices=["cifar10", "imagenette"],
46
+ help="Training dataset",
47
+ )
48
+ p.add_argument("--data-root", type=str, default=".", help="Root for dataset download/cache")
49
+ p.add_argument("--cifar-split", type=str, default="train", choices=["train", "test"])
50
+ p.add_argument("--imagenette-split", type=str, default="train", choices=["train", "val"])
51
+ p.add_argument("--imagenette-size", type=str, default="160px", choices=["160px", "320px", "full"])
52
+ p.add_argument("--single-class", action="store_true")
53
+ p.add_argument("--class-id", type=int, default=0)
54
+ p.add_argument("--batch-size", type=int, default=64)
55
+ p.add_argument("--num-workers", type=int, default=4)
56
+
57
+ p.add_argument("--epochs", type=int, default=30)
58
+ p.add_argument("--device", type=str, default=None, help="cuda | cpu (default: auto)")
59
+ p.add_argument("--log-interval", type=int, default=100)
60
+ p.add_argument("--seed", type=int, default=0)
61
+
62
+ p.add_argument("--save-dir", type=str, default="./runs/cfm_jit/checkpoints")
63
+ p.add_argument(
64
+ "--log-dir",
65
+ type=str,
66
+ default="./runs/cfm_jit/tensorboard",
67
+ )
68
+ p.add_argument("--run-name", type=str, default=None)
69
+
70
+ p.add_argument(
71
+ "--config",
72
+ type=str,
73
+ default=None,
74
+ help="YAML with JiT + CFM hyperparameters (default: jit_config.yaml next to this script)",
75
+ )
76
+
77
+ return p.parse_args()
78
+
79
+
80
+ def _dim_from_yaml(value: Any) -> tuple[int, int, int]:
81
+ if isinstance(value, (list, tuple)) and len(value) == 3:
82
+ return (int(value[0]), int(value[1]), int(value[2]))
83
+ raise ValueError("YAML 'dim' must be [C, H, W]")
84
+
85
+
86
+ @dataclass
87
+ class JiTTrainConfig:
88
+ sigma: float
89
+ dim: tuple[int, int, int]
90
+ lr: float
91
+ weight_decay: float
92
+ inference_steps: int
93
+ vis_batch_size: int
94
+ input_size: int
95
+ patch_size: int
96
+ hidden_size: int
97
+ depth: int
98
+ num_heads: int
99
+ mlp_ratio: float
100
+ attn_drop: float
101
+ proj_drop: float
102
+ bottleneck_dim: int
103
+ in_context_len: int
104
+ in_context_start: int
105
+
106
+
107
+ REQUIRED_JIT_YAML_KEYS = (
108
+ "sigma",
109
+ "dim",
110
+ "lr",
111
+ "weight_decay",
112
+ "inference_steps",
113
+ "vis_batch_size",
114
+ "input_size",
115
+ "patch_size",
116
+ "hidden_size",
117
+ "depth",
118
+ "num_heads",
119
+ "mlp_ratio",
120
+ "attn_drop",
121
+ "proj_drop",
122
+ "bottleneck_dim",
123
+ "in_context_len",
124
+ "in_context_start",
125
+ )
126
+
127
+
128
+ def load_jit_config_yaml(path: str | os.PathLike[str]) -> JiTTrainConfig:
129
+ path = Path(path)
130
+ if not path.is_file():
131
+ raise FileNotFoundError(f"Config file not found: {path.resolve()}")
132
+
133
+ with open(path, encoding="utf-8") as f:
134
+ raw = yaml.safe_load(f)
135
+ if raw is None or not isinstance(raw, dict):
136
+ raise ValueError(f"Config must be a YAML mapping: {path}")
137
+
138
+ missing = [k for k in REQUIRED_JIT_YAML_KEYS if k not in raw]
139
+ if missing:
140
+ raise ValueError(f"Missing keys in {path}: {missing}")
141
+
142
+ dim = _dim_from_yaml(raw["dim"])
143
+ input_size = int(raw["input_size"])
144
+ if dim[1] != input_size or dim[2] != input_size:
145
+ raise ValueError(f"dim {dim} must match input_size×input_size ({input_size})")
146
+
147
+ return JiTTrainConfig(
148
+ sigma=float(raw["sigma"]),
149
+ dim=dim,
150
+ lr=float(raw["lr"]),
151
+ weight_decay=float(raw["weight_decay"]),
152
+ inference_steps=int(raw["inference_steps"]),
153
+ vis_batch_size=int(raw["vis_batch_size"]),
154
+ input_size=input_size,
155
+ patch_size=int(raw["patch_size"]),
156
+ hidden_size=int(raw["hidden_size"]),
157
+ depth=int(raw["depth"]),
158
+ num_heads=int(raw["num_heads"]),
159
+ mlp_ratio=float(raw["mlp_ratio"]),
160
+ attn_drop=float(raw["attn_drop"]),
161
+ proj_drop=float(raw["proj_drop"]),
162
+ bottleneck_dim=int(raw["bottleneck_dim"]),
163
+ in_context_len=int(raw["in_context_len"]),
164
+ in_context_start=int(raw["in_context_start"]),
165
+ )
166
+
167
+
168
+ def build_jit(cfg: JiTTrainConfig) -> JiT:
169
+ c = cfg.dim[0]
170
+ return JiT(
171
+ input_size=cfg.input_size,
172
+ patch_size=cfg.patch_size,
173
+ in_channels=c,
174
+ hidden_size=cfg.hidden_size,
175
+ depth=cfg.depth,
176
+ num_heads=cfg.num_heads,
177
+ mlp_ratio=cfg.mlp_ratio,
178
+ attn_drop=cfg.attn_drop,
179
+ proj_drop=cfg.proj_drop,
180
+ bottleneck_dim=cfg.bottleneck_dim,
181
+ in_context_len=cfg.in_context_len,
182
+ in_context_start=cfg.in_context_start,
183
+ )
184
+
185
+
186
+ class CFMFlowWrapper(nn.Module):
187
+ """
188
+ torchdyn NeuralODE expects f(t, x) with dx/dt returned.
189
+ - velocity mode: JiT predicts v directly → return model(x, t).
190
+ - x_pred mode (JiT denoiser): JiT predicts x1 (clean); v = (x_pred - x) / (1-t).
191
+ """
192
+
193
+ def __init__(self, model: JiT, prediction_mode: str = "velocity", t_eps: float = 1e-5):
194
+ super().__init__()
195
+ self.model = model
196
+ self.prediction_mode = "x_pred"
197
+ self.t_eps = t_eps
198
+
199
+ def forward(self, t: torch.Tensor, x: torch.Tensor, y=None, *args, **kwargs) -> torch.Tensor:
200
+ batch = x.shape[0]
201
+ t_flat = torch.as_tensor(t, device=x.device, dtype=torch.float32).reshape(-1)
202
+ if t_flat.numel() == 1:
203
+ t_flat = t_flat.expand(batch)
204
+ elif t_flat.shape[0] != batch:
205
+ t_flat = t_flat[:batch]
206
+
207
+ if self.prediction_mode == "x_pred":
208
+ x_pred = self.model(x, t_flat)
209
+ one_minus_t = (1.0 - t_flat).clamp(min=self.t_eps)
210
+ t_bc = one_minus_t.reshape(-1, *([1] * (x.dim() - 1)))
211
+ return (x_pred - x) / t_bc
212
+ return self.model(x, t_flat)
213
+
214
+
215
+ def main() -> None:
216
+ args = parse_args()
217
+ default_cfg = Path(__file__).resolve().parent / "jit_config.yaml"
218
+ config_path = Path(args.config).resolve() if args.config else default_cfg
219
+ cfg = load_jit_config_yaml(config_path)
220
+ print(f"Loaded JiT config from: {config_path}")
221
+
222
+ torch.manual_seed(args.seed)
223
+ if torch.cuda.is_available():
224
+ torch.cuda.manual_seed_all(args.seed)
225
+
226
+ device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu"))
227
+ print(f"Using device: {device}")
228
+
229
+ os.makedirs(args.save_dir, exist_ok=True)
230
+
231
+ tb_dir = os.path.join(args.log_dir, args.run_name) if args.run_name else args.log_dir
232
+ os.makedirs(tb_dir, exist_ok=True)
233
+ writer = SummaryWriter(log_dir=tb_dir)
234
+ writer.add_text("config/args", str(vars(args)), 0)
235
+ writer.add_text("config/jit_yaml", config_path.read_text(encoding="utf-8"), 0)
236
+
237
+ transforms = v2.Compose(
238
+ [
239
+ v2.ToTensor(),
240
+ v2.ToDtype(torch.float32, scale=True),
241
+ v2.Resize((cfg.input_size, cfg.input_size)),
242
+ v2.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
243
+ ]
244
+ )
245
+ train_dataset = load_training_dataset(args, transforms)
246
+ print(f"Dataset: {args.dataset}, size={len(train_dataset)}")
247
+
248
+ dummy_dataloader = DataLoader(
249
+ train_dataset,
250
+ batch_size=args.batch_size,
251
+ shuffle=True,
252
+ num_workers=args.num_workers,
253
+ pin_memory=device.type == "cuda",
254
+ )
255
+
256
+ total_optimizer_steps = len(dummy_dataloader) * args.epochs
257
+
258
+ fm = ConditionalFlowMatcher(sigma=cfg.sigma)
259
+ net_model = build_jit(cfg).to(device)
260
+ ode_net = CFMFlowWrapper(net_model)
261
+
262
+ optim = torch.optim.AdamW(net_model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
263
+ scheduler = torch.optim.lr_scheduler.LinearLR(optim, total_iters=max(total_optimizer_steps, 1))
264
+ t_span = torch.linspace(0, 1, cfg.inference_steps + 1, device=device)
265
+
266
+ c, h, w = cfg.dim
267
+ global_step = 0
268
+ best_loss = float("inf")
269
+
270
+ for ep in range(args.epochs):
271
+ net_model.train()
272
+ epoch_loss = 0.0
273
+ num_batches = 0
274
+
275
+ for data in dummy_dataloader:
276
+ x1 = data[0].to(device, non_blocking=True)
277
+ x0 = torch.randn_like(x1)
278
+ t, xt, ut = fm.sample_location_and_conditional_flow(x0, x1)
279
+ t_b = t.reshape(-1).float()
280
+ # vt = net_model(xt, t_b)
281
+ # loss = torch.mean((vt - ut) ** 2)
282
+ x_pred = net_model(xt, t_b)
283
+ one_minus_t = (1.0 - t_b).clamp(min=1.0e-5)
284
+ t_bc = one_minus_t.reshape(-1, *([1] * (xt.dim() - 1)))
285
+ v_pred = (x_pred - xt) / t_bc
286
+ v_target = (x1 - xt) / t_bc
287
+ loss = torch.mean((v_target - v_pred) ** 2)
288
+
289
+ optim.zero_grad(set_to_none=True)
290
+ loss.backward()
291
+ optim.step()
292
+ scheduler.step()
293
+
294
+ epoch_loss += loss.item()
295
+ num_batches += 1
296
+
297
+ writer.add_scalar("train/loss_step", loss.item(), global_step)
298
+ writer.add_scalar("train/lr", scheduler.get_last_lr()[0], global_step)
299
+
300
+ if global_step % args.log_interval == 0:
301
+ print(f"[step {global_step}] loss = {loss.item():.6f}")
302
+
303
+ global_step += 1
304
+
305
+ avg_epoch_loss = epoch_loss / max(num_batches, 1)
306
+ writer.add_scalar("train/loss_epoch", avg_epoch_loss, ep)
307
+ print(f"[epoch {ep}] avg loss = {avg_epoch_loss:.6f}")
308
+
309
+ net_model.eval()
310
+ node = NeuralODE(ode_net, solver="euler")
311
+ with torch.no_grad():
312
+ x_vis = torch.randn(cfg.vis_batch_size, c, h, w, device=device)
313
+ traj = node.trajectory(x_vis, t_span=t_span)
314
+ x_final = traj[-1]
315
+ x_final = x_final.clamp(0.0, 1.0).cpu()
316
+ grid = torchvision.utils.make_grid(x_final, nrow=4, padding=2, normalize=False)
317
+ writer.add_image("samples/neural_ode_final", grid, ep)
318
+
319
+
320
+ if ep % 30 == 0:
321
+ ckpt_path = os.path.join(args.save_dir, f"model_epoch_{ep}.pt")
322
+ torch.save(net_model.state_dict(), ckpt_path)
323
+
324
+ if ep == 0 or avg_epoch_loss < best_loss:
325
+ best_loss = avg_epoch_loss
326
+ torch.save(net_model.state_dict(), os.path.join(args.save_dir, "model_best.pt"))
327
+
328
+ writer.close()
329
+ print(f"Done. Checkpoints: {args.save_dir}")
330
+ print(f"TensorBoard: tensorboard --logdir {tb_dir}")
331
+
332
+
333
+ if __name__ == "__main__":
334
+ main()