AbstractPhil commited on
Commit
609e926
·
verified ·
1 Parent(s): 6ea73f8

Create trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +547 -0
trainer.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =====================================================================================
2
+ # SD1.5 Flow-Matching Trainer — David-Driven Block Penalties (HF-loaded)
3
+ # Quartermaster: Mirel
4
+ # - BaseConfig at top
5
+ # - Functionality (teacher/student/david/assessor/fusion/trainer)
6
+ # - Activations at bottom
7
+ # =====================================================================================
8
+ from __future__ import annotations
9
+ import os, json, math, random
10
+ from dataclasses import dataclass, asdict
11
+ from pathlib import Path
12
+ from typing import Dict, List, Tuple, Optional
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from torch.utils.data import Dataset, DataLoader
18
+ from torch.utils.tensorboard import SummaryWriter
19
+ from tqdm import tqdm
20
+
21
+ # Diffusers
22
+ from diffusers import StableDiffusionPipeline, DDPMScheduler
23
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
24
+
25
+ # Repo deps (present in your repo)
26
+ from geovocab2.train.model.core.geo_david_collective import GeoDavidCollective
27
+ from geovocab2.data.prompt.symbolic_tree import SynthesisSystem
28
+
29
+ # HF / safetensors
30
+ from huggingface_hub import snapshot_download
31
+ from safetensors.torch import load_file
32
+
33
+
34
+ # =====================================================================================
35
+ # 1) CONFIG (BaseConfig)
36
+ # =====================================================================================
37
+ @dataclass
38
+ class BaseConfig:
39
+ run_name: str = "sd15_flowmatch_david_hf"
40
+ out_dir: str = "./runs/sd15_flowmatch_david_hf"
41
+ ckpt_dir: str = "./checkpoints_sd15_flow_david_hf"
42
+ save_every: int = 1
43
+
44
+ # Data
45
+ num_samples: int = 200_000
46
+ batch_size: int = 32
47
+ num_workers: int = 2
48
+ seed: int = 42
49
+
50
+ # Models / Blocks
51
+ model_id: str = "runwayml/stable-diffusion-v1-5"
52
+ active_blocks: Tuple[str, ...] = ("down_0","down_1","down_2","down_3","mid","up_0","up_1","up_2","up_3")
53
+ pooling: str = "mean" # mean | max | adaptive
54
+
55
+ # Flow training
56
+ epochs: int = 10
57
+ lr: float = 1e-4
58
+ weight_decay: float = 1e-3
59
+ grad_clip: float = 1.0
60
+ amp: bool = True
61
+
62
+ global_flow_weight: float = 1.0
63
+ block_penalty_weight: float = 0.125 # ← NEW: Start very low!
64
+ use_local_flow_heads: bool = False
65
+ local_flow_weight: float = 1.0
66
+
67
+ # KD (optional)
68
+ use_kd: bool = True
69
+ kd_weight: float = 0.25
70
+
71
+ # David (ALWAYS used, HF)
72
+ david_repo_id: str = "AbstractPhil/geo-david-collective-sd15-base-e40"
73
+ david_cache_dir: str = "./_hf_david_cache"
74
+ david_state_key: Optional[str] = None # None→raw state; or "model_state_dict" if ckpt-style
75
+
76
+ # Fusion: λ_b = w_b * (1 + α·e_t + β·e_p + δ·(1−coh))
77
+ alpha_timestep: float = 0.5
78
+ beta_pattern: float = 0.25
79
+ delta_incoherence: float = 0.25
80
+ lambda_min: float = 0.5
81
+ lambda_max: float = 3.0
82
+
83
+ # Block weights (overridden by HF config if present)
84
+ block_weights: Dict[str, float] = None
85
+
86
+ # Scheduler
87
+ num_train_timesteps: int = 1000
88
+
89
+ # Inference
90
+ sample_steps: int = 30
91
+ guidance_scale: float = 7.5
92
+
93
+ def __post_init__(self):
94
+ Path(self.out_dir).mkdir(parents=True, exist_ok=True)
95
+ Path(self.ckpt_dir).mkdir(parents=True, exist_ok=True)
96
+ Path(self.david_cache_dir).mkdir(parents=True, exist_ok=True)
97
+ if self.block_weights is None:
98
+ self.block_weights = {'down_0':0.7,'down_1':0.9,'down_2':1.0,'down_3':1.1,'mid':1.2,'up_0':1.1,'up_1':1.0,'up_2':0.9,'up_3':0.7}
99
+
100
+
101
+ # =====================================================================================
102
+ # 2) DATA
103
+ # =====================================================================================
104
+ class SymbolicPromptDataset(Dataset):
105
+ def __init__(self, n:int, seed:int=42):
106
+ self.n = n
107
+ random.seed(seed)
108
+ self.sys = SynthesisSystem(seed=seed)
109
+
110
+ def __len__(self): return self.n
111
+
112
+ def __getitem__(self, idx):
113
+ r = self.sys.synthesize(complexity=random.choice([1,2,3,4,5]))
114
+ prompt = r['text']
115
+ t = random.randint(0, 999)
116
+ return {"prompt": prompt, "t": t}
117
+
118
+ def collate(batch: List[dict]):
119
+ prompts = [b["prompt"] for b in batch]
120
+ t = torch.tensor([b["t"] for b in batch], dtype=torch.long)
121
+ t_bins = t // 10
122
+ return {"prompts": prompts, "t": t, "t_bins": t_bins}
123
+
124
+
125
+ # =====================================================================================
126
+ # 3) HOOKS + POOLING
127
+ # =====================================================================================
128
+ class HookBank:
129
+ def __init__(self, unet: UNet2DConditionModel, active: Tuple[str, ...]):
130
+ self.active = set(active)
131
+ self.bank: Dict[str, torch.Tensor] = {}
132
+ self.hooks: List[torch.utils.hooks.RemovableHandle] = []
133
+ self._register(unet)
134
+
135
+ def _register(self, unet: UNet2DConditionModel):
136
+ def mk(name):
137
+ def h(m, i, o):
138
+ out = o[0] if isinstance(o,(tuple,list)) else o
139
+ self.bank[name] = out
140
+ return h
141
+ for i, blk in enumerate(unet.down_blocks):
142
+ nm = f"down_{i}"
143
+ if nm in self.active: self.hooks.append(blk.register_forward_hook(mk(nm)))
144
+ if "mid" in self.active:
145
+ self.hooks.append(unet.mid_block.register_forward_hook(mk("mid")))
146
+ for i, blk in enumerate(unet.up_blocks):
147
+ nm = f"up_{i}"
148
+ if nm in self.active: self.hooks.append(blk.register_forward_hook(mk(nm)))
149
+
150
+ def clear(self): self.bank.clear()
151
+ def close(self):
152
+ for h in self.hooks: h.remove()
153
+ self.hooks.clear()
154
+
155
+ def spatial_pool(x: torch.Tensor, name: str, policy: str) -> torch.Tensor:
156
+ if policy == "mean": return x.mean(dim=(2,3))
157
+ if policy == "max": return x.amax(dim=(2,3))
158
+ if policy == "adaptive":
159
+ return x.mean(dim=(2,3)) if (name.startswith("down") or name=="mid") else x.amax(dim=(2,3))
160
+ raise ValueError(f"Unknown pooling: {policy}")
161
+
162
+
163
+ # =====================================================================================
164
+ # 4) TEACHER (SD1.5)
165
+ # =====================================================================================
166
+ class SD15Teacher(nn.Module):
167
+ def __init__(self, cfg: BaseConfig, device: str):
168
+ super().__init__()
169
+ self.pipe = StableDiffusionPipeline.from_pretrained(cfg.model_id, torch_dtype=torch.float16, safety_checker=None).to(device)
170
+ self.unet: UNet2DConditionModel = self.pipe.unet
171
+ self.text_encoder = self.pipe.text_encoder
172
+ self.tokenizer = self.pipe.tokenizer
173
+ self.hooks = HookBank(self.unet, cfg.active_blocks)
174
+ self.sched = DDPMScheduler(num_train_timesteps=cfg.num_train_timesteps)
175
+ self.device = device
176
+ for p in self.parameters(): p.requires_grad_(False)
177
+
178
+ @torch.no_grad()
179
+ def encode(self, prompts: List[str]) -> torch.Tensor:
180
+ tok = self.tokenizer(prompts, padding="max_length", max_length=self.tokenizer.model_max_length,
181
+ truncation=True, return_tensors="pt")
182
+ return self.text_encoder(tok.input_ids.to(self.device))[0]
183
+
184
+ @torch.no_grad()
185
+ def forward_eps_and_feats(self, x_t: torch.Tensor, t: torch.LongTensor, ehs: torch.Tensor):
186
+ self.hooks.clear()
187
+ eps_hat = self.unet(x_t, t, encoder_hidden_states=ehs).sample
188
+ feats = {k: v.detach().float() for k, v in self.hooks.bank.items()}
189
+ return eps_hat.float(), feats
190
+
191
+ def alpha_sigma(self, t: torch.LongTensor) -> Tuple[torch.Tensor, torch.Tensor]:
192
+ ac = self.sched.alphas_cumprod.to(self.device)[t]
193
+ alpha = ac.sqrt().view(-1,1,1,1).float()
194
+ sigma = (1.0 - ac).sqrt().view(-1,1,1,1).float()
195
+ return alpha, sigma
196
+
197
+
198
+ # =====================================================================================
199
+ # 5) STUDENT (v-pred) + LOCAL FLOW HEADS
200
+ # =====================================================================================
201
+ class StudentUNet(nn.Module):
202
+ def __init__(self, teacher_unet: UNet2DConditionModel, active_blocks: Tuple[str,...], use_local_heads: bool):
203
+ super().__init__()
204
+ self.unet = UNet2DConditionModel.from_config(teacher_unet.config)
205
+ self.unet.load_state_dict(teacher_unet.state_dict(), strict=True)
206
+ self.hooks = HookBank(self.unet, active_blocks)
207
+ self.use_local_heads = use_local_heads
208
+ self.local_heads = nn.ModuleDict()
209
+
210
+ def _ensure_heads(self, feats: Dict[str, torch.Tensor]):
211
+ if not self.use_local_heads: return
212
+ if len(self.local_heads) == len(feats): return
213
+
214
+ # Get dtype from main UNet
215
+ target_dtype = next(self.unet.parameters()).dtype
216
+
217
+ for name, f in feats.items():
218
+ c = f.shape[1]
219
+ if name not in self.local_heads:
220
+ head = nn.Conv2d(c, 4, kernel_size=1)
221
+ # Convert head to match UNet dtype
222
+ head = head.to(dtype=target_dtype, device=f.device)
223
+ self.local_heads[name] = head
224
+
225
+ def forward(self, x_t: torch.Tensor, t: torch.LongTensor, ehs: torch.Tensor):
226
+ self.hooks.clear()
227
+ v_hat = self.unet(x_t, t, encoder_hidden_states=ehs).sample
228
+ feats = {k: v for k, v in self.hooks.bank.items()} # Keep original dtype
229
+ self._ensure_heads(feats)
230
+ return v_hat, feats
231
+
232
+ # =====================================================================================
233
+ # 6) DAVID LOADER (HF) + ASSESSOR + FUSION
234
+ # =====================================================================================
235
+ class DavidLoader:
236
+ """
237
+ Downloads HF repo (config + safetensors), instantiates GeoDavidCollective with HF config,
238
+ loads weights, returns a frozen model + the parsed HF config.
239
+ """
240
+ def __init__(self, cfg: BaseConfig, device: str):
241
+ self.cfg = cfg
242
+ self.device = device
243
+ self.repo_dir = snapshot_download(repo_id=cfg.david_repo_id, local_dir=cfg.david_cache_dir, local_dir_use_symlinks=False)
244
+ self.config_path = os.path.join(self.repo_dir, "config.json")
245
+ self.weights_path = os.path.join(self.repo_dir, "model.safetensors")
246
+ with open(self.config_path, "r") as f:
247
+ self.hf_config = json.load(f)
248
+ # Instantiate GeoDavidCollective from HF config
249
+ self.gdc = GeoDavidCollective(
250
+ block_configs=self.hf_config["block_configs"],
251
+ num_timestep_bins=int(self.hf_config["num_timestep_bins"]),
252
+ num_patterns_per_bin=int(self.hf_config["num_patterns_per_bin"]),
253
+ block_weights=self.hf_config.get("block_weights", {k:1.0 for k in self.hf_config["block_configs"].keys()}),
254
+ loss_config=self.hf_config.get("loss_config", {})
255
+ ).to(device).eval()
256
+ # Load weights
257
+ state = load_file(self.weights_path)
258
+ self.gdc.load_state_dict(state, strict=False)
259
+ for p in self.gdc.parameters(): p.requires_grad_(False)
260
+ # Report
261
+ print(f"✓ David loaded from HF: {self.repo_dir}")
262
+ print(f" blocks={len(self.hf_config['block_configs'])} bins={self.hf_config['num_timestep_bins']} patterns={self.hf_config['num_patterns_per_bin']}")
263
+ # Override block weights in main cfg if provided
264
+ if "block_weights" in self.hf_config:
265
+ cfg.block_weights = self.hf_config["block_weights"]
266
+
267
+ class DavidAssessor(nn.Module):
268
+ """
269
+ Uses David to score STUDENT pooled features (per block) and timesteps.
270
+ Produces:
271
+ e_t[name] : timestep CE error proxy (scalar)
272
+ e_p[name] : pattern CE error proxy if logits present, else 0
273
+ coh[name] : coherence proxy (avg Cantor alpha if provided, else 1)
274
+ """
275
+ def __init__(self, gdc: GeoDavidCollective, pooling: str):
276
+ super().__init__()
277
+ self.gdc = gdc
278
+ self.pooling = pooling
279
+
280
+ def _pool(self, feats: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
281
+ return {k: spatial_pool(v, k, self.pooling) for k, v in feats.items()}
282
+
283
+ @torch.no_grad()
284
+ def forward(self, feats_student: Dict[str, torch.Tensor], t: torch.LongTensor
285
+ ) -> Tuple[Dict[str,float], Dict[str,float], Dict[str,float]]:
286
+ Zs = self._pool(feats_student) # [B,C] per block
287
+ outs = self.gdc(Zs, t.float()) # forward for predictions/logits
288
+ e_t, e_p, coh = {}, {}, {}
289
+
290
+ # timestep logits
291
+ ts_key = None
292
+ for key in ["timestep_logits", "logits_timestep", "timestep_head_logits"]:
293
+ if key in outs: ts_key = key; break
294
+ # pattern logits (optional)
295
+ pt_key = None
296
+ for key in ["pattern_logits", "logits_pattern", "pattern_head_logits"]:
297
+ if key in outs: pt_key = key; break
298
+
299
+ t_bins = (t // 10).to(next(self.gdc.parameters()).device)
300
+ if ts_key is not None:
301
+ # Expect dict per block or a tensor across blocks; support both
302
+ ts_logits = outs[ts_key]
303
+ if isinstance(ts_logits, dict):
304
+ for name, L in ts_logits.items():
305
+ ce = F.cross_entropy(L, t_bins, reduction="mean")
306
+ e_t[name] = float(ce.item())
307
+ else:
308
+ # single head: broadcast same CE to all blocks
309
+ ce = F.cross_entropy(ts_logits, t_bins, reduction="mean")
310
+ for name in Zs.keys():
311
+ e_t[name] = float(ce.item())
312
+ else:
313
+ for name in Zs.keys(): e_t[name] = 0.0
314
+
315
+ if pt_key is not None:
316
+ pt_logits = outs[pt_key]
317
+ # If no labels for pattern, use entropy as "error" proxy
318
+ if isinstance(pt_logits, dict):
319
+ for name, L in pt_logits.items():
320
+ P = L.softmax(-1)
321
+ ent = -(P * (P.clamp_min(1e-9)).log()).sum(-1).mean()
322
+ e_p[name] = float(ent.item() / math.log(P.shape[-1]))
323
+ else:
324
+ P = pt_logits.softmax(-1)
325
+ ent = -(P * (P.clamp_min(1e-9)).log()).sum(-1).mean()
326
+ for name in Zs.keys():
327
+ e_p[name] = float(ent.item() / math.log(P.shape[-1]))
328
+ else:
329
+ for name in Zs.keys(): e_p[name] = 0.0
330
+
331
+ # Cantor alphas / coherence
332
+ alphas = {}
333
+ try:
334
+ alphas = self.gdc.get_cantor_alphas() # dict of scalars
335
+ except Exception:
336
+ alphas = {}
337
+ avg_alpha = float(sum(alphas.values())/max(len(alphas),1)) if alphas else 1.0
338
+ for name in Zs.keys():
339
+ coh[name] = avg_alpha # higher=more coherent
340
+
341
+ return e_t, e_p, coh
342
+
343
+ class BlockPenaltyFusion:
344
+ def __init__(self, cfg: BaseConfig): self.cfg = cfg
345
+ def lambdas(self, e_t:Dict[str,float], e_p:Dict[str,float], coh:Dict[str,float]) -> Dict[str,float]:
346
+ lam = {}
347
+ for name, base in self.cfg.block_weights.items():
348
+ val = base * (1.0
349
+ + self.cfg.alpha_timestep * float(e_t.get(name,0.0))
350
+ + self.cfg.beta_pattern * float(e_p.get(name,0.0))
351
+ + self.cfg.delta_incoherence * (1.0 - float(coh.get(name,1.0))))
352
+ lam[name] = float(max(self.cfg.lambda_min, min(self.cfg.lambda_max, val)))
353
+ return lam
354
+
355
+
356
+ # =====================================================================================
357
+ # 7) TRAINER + INFERENCE
358
+ # =====================================================================================
359
+ class FlowMatchDavidTrainer:
360
+ def __init__(self, cfg: BaseConfig, device: str = "cuda"):
361
+ self.cfg = cfg
362
+ self.device = device
363
+
364
+ # Data
365
+ self.dataset = SymbolicPromptDataset(cfg.num_samples, cfg.seed)
366
+ self.loader = DataLoader(self.dataset, batch_size=cfg.batch_size, shuffle=True,
367
+ num_workers=cfg.num_workers, pin_memory=True, collate_fn=collate)
368
+
369
+ # Teacher & Student
370
+ self.teacher = SD15Teacher(cfg, device).eval()
371
+ self.student = StudentUNet(self.teacher.unet, cfg.active_blocks, cfg.use_local_flow_heads).to(device)
372
+
373
+ # David
374
+ self.david_loader = DavidLoader(cfg, device)
375
+ self.david = self.david_loader.gdc
376
+ # Assessor + Fusion
377
+ self.assessor = DavidAssessor(self.david, cfg.pooling)
378
+ self.fusion = BlockPenaltyFusion(cfg)
379
+
380
+ # Opt/Sched/AMP
381
+ self.opt = torch.optim.AdamW(self.student.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
382
+ self.sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=cfg.epochs * len(self.loader))
383
+ self.scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)
384
+
385
+ # Logs
386
+ self.writer = SummaryWriter(log_dir=os.path.join(cfg.out_dir, cfg.run_name))
387
+
388
+ # math helpers
389
+ def _v_star(self, x_t, t, eps_hat):
390
+ alpha, sigma = self.teacher.alpha_sigma(t)
391
+ x0_hat = (x_t - sigma * eps_hat) / (alpha + 1e-8)
392
+ return alpha * eps_hat - sigma * x0_hat
393
+
394
+ def _down_like(self, tgt: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
395
+ return F.interpolate(tgt, size=ref.shape[-2:], mode="bilinear", align_corners=False)
396
+
397
+ def _kd_cos(self, s: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
398
+ s = F.normalize(s, dim=-1); t = F.normalize(t, dim=-1)
399
+ return 1.0 - (s*t).sum(-1).mean()
400
+
401
+ # training
402
+ def train(self):
403
+ cfg = self.cfg
404
+ gstep = 0
405
+ for ep in range(cfg.epochs):
406
+ self.student.train()
407
+ pbar = tqdm(self.loader, desc=f"Epoch {ep+1}/{cfg.epochs}")
408
+ acc = {"L":0.0, "Lf":0.0, "Lb":0.0}
409
+
410
+ for it, batch in enumerate(pbar):
411
+ prompts = batch["prompts"]
412
+ t = batch["t"].to(self.device)
413
+
414
+ with torch.no_grad():
415
+ ehs = self.teacher.encode(prompts)
416
+
417
+ # Latents
418
+ x_t = torch.randn(len(prompts), 4, 64, 64, device=self.device, dtype=torch.float16)
419
+
420
+ # Teacher targets
421
+ with torch.no_grad():
422
+ eps_hat, t_feats_spatial = self.teacher.forward_eps_and_feats(x_t.half(), t, ehs)
423
+ v_star = self._v_star(x_t, t, eps_hat)
424
+
425
+ with torch.cuda.amp.autocast(enabled=cfg.amp):
426
+ # Student
427
+ v_hat, s_feats_spatial = self.student(x_t, t, ehs)
428
+ L_flow = F.mse_loss(v_hat, v_star)
429
+
430
+ # David assessor on STUDENT pooled features
431
+ e_t, e_p, coh = self.assessor(s_feats_spatial, t)
432
+ lam = self.fusion.lambdas(e_t, e_p, coh)
433
+
434
+ # Per-block KD + Local flow
435
+ L_blocks = torch.zeros((), device=self.device)
436
+ for name, s_feat in s_feats_spatial.items():
437
+ # KD (pooled)
438
+ L_kd = torch.zeros((), device=self.device)
439
+ if cfg.use_kd:
440
+ s_pool = spatial_pool(s_feat, name, cfg.pooling)
441
+ t_pool = spatial_pool(t_feats_spatial[name], name, cfg.pooling)
442
+ L_kd = self._kd_cos(s_pool, t_pool)
443
+ # Local flow
444
+ L_lf = torch.zeros((), device=self.device)
445
+ if cfg.use_local_flow_heads and name in self.student.local_heads:
446
+ v_loc = self.student.local_heads[name](s_feat)
447
+ v_ds = self._down_like(v_star, v_loc)
448
+ L_lf = F.mse_loss(v_loc, v_ds)
449
+ L_blocks = L_blocks + lam.get(name,1.0) * (cfg.kd_weight * L_kd + cfg.local_flow_weight * L_lf)
450
+
451
+ L_total = cfg.global_flow_weight*L_flow + cfg.block_penalty_weight*L_blocks
452
+
453
+ self.opt.zero_grad(set_to_none=True)
454
+ if cfg.amp:
455
+ self.scaler.scale(L_total).backward()
456
+ nn.utils.clip_grad_norm_(self.student.parameters(), cfg.grad_clip)
457
+ self.scaler.step(self.opt); self.scaler.update()
458
+ else:
459
+ L_total.backward()
460
+ nn.utils.clip_grad_norm_(self.student.parameters(), cfg.grad_clip)
461
+ self.opt.step()
462
+ self.sched.step(); gstep += 1
463
+
464
+ acc["L"] += float(L_total.item())
465
+ acc["Lf"] += float(L_flow.item())
466
+ acc["Lb"] += float(L_blocks.item())
467
+
468
+ if it % 50 == 0:
469
+ self.writer.add_scalar("train/total", float(L_total.item()), gstep)
470
+ self.writer.add_scalar("train/flow", float(L_flow.item()), gstep)
471
+ self.writer.add_scalar("train/blocks",float(L_blocks.item()), gstep)
472
+ # log a few lambdas
473
+ for k in list(lam.keys())[:4]:
474
+ self.writer.add_scalar(f"lambda/{k}", lam[k], gstep)
475
+
476
+ pbar.set_postfix({"L": f"{float(L_total.item()):.4f}", "Lf": f"{float(L_flow.item()):.4f}", "Lb": f"{float(L_blocks.item()):.4f}"})
477
+ del x_t, eps_hat, v_star, v_hat, s_feats_spatial, t_feats_spatial
478
+
479
+ n = len(self.loader)
480
+ print(f"\n[Epoch {ep+1}] L={acc['L']/n:.4f} | L_flow={acc['Lf']/n:.4f} | L_blocks={acc['Lb']/n:.4f}")
481
+ self.writer.add_scalar("epoch/total", acc['L']/n, ep+1)
482
+ self.writer.add_scalar("epoch/flow", acc['Lf']/n, ep+1)
483
+ self.writer.add_scalar("epoch/blocks",acc['Lb']/n, ep+1)
484
+
485
+ if (ep+1) % cfg.save_every == 0:
486
+ self._save(ep+1, gstep)
487
+
488
+ self._save("final", gstep)
489
+ self.writer.close()
490
+
491
+ def _save(self, tag, gstep):
492
+ path = Path(self.cfg.ckpt_dir) / f"{self.cfg.run_name}_{tag}.pt"
493
+ torch.save({
494
+ "cfg": asdict(self.cfg),
495
+ "student": self.student.state_dict(),
496
+ "opt": self.opt.state_dict(),
497
+ "sched": self.sched.state_dict(),
498
+ "gstep": gstep
499
+ }, path)
500
+ print(f"✓ Saved: {path}")
501
+
502
+ # ---------- Inference (v-pred sampling; use teacher VAE for decode) ----------
503
+ @torch.no_grad()
504
+ def sample(self, prompts: List[str], steps: Optional[int]=None, guidance: Optional[float]=None) -> torch.Tensor:
505
+ steps = steps or self.cfg.sample_steps
506
+ guidance = guidance if guidance is not None else self.cfg.guidance_scale
507
+ cond_e = self.teacher.encode(prompts)
508
+ uncond_e = self.teacher.encode([""]*len(prompts))
509
+ sched = self.teacher.sched
510
+ sched.set_timesteps(steps, device=self.device)
511
+ x_t = torch.randn(len(prompts), 4, 64, 64, device=self.device)
512
+
513
+ for t_scalar in sched.timesteps:
514
+ t = torch.full((x_t.shape[0],), t_scalar, device=self.device, dtype=torch.long)
515
+ v_u, _ = self.student(x_t, t, uncond_e)
516
+ v_c, _ = self.student(x_t, t, cond_e)
517
+ v_hat = v_u + guidance*(v_c - v_u)
518
+
519
+ alpha, sigma = self.teacher.alpha_sigma(t)
520
+ denom = (alpha**2 + sigma**2)
521
+ x0_hat = (alpha * x_t - sigma * v_hat) / (denom + 1e-8)
522
+ eps_hat = (x_t - alpha * x0_hat) / (sigma + 1e-8)
523
+
524
+ step = sched.step(model_output=eps_hat, timestep=t_scalar, sample=x_t)
525
+ x_t = step.prev_sample
526
+
527
+ imgs = self.teacher.pipe.vae.decode(x_t / 0.18215).sample
528
+ return imgs.clamp(-1,1)
529
+
530
+
531
+ # =====================================================================================
532
+ # 8) ACTIVATION
533
+ # =====================================================================================
534
+ def main():
535
+ cfg = BaseConfig()
536
+ print(json.dumps(asdict(cfg), indent=2))
537
+ device = "cuda" if torch.cuda.is_available() else "cpu"
538
+ if device != "cuda":
539
+ print("⚠️ A100 strongly recommended.")
540
+ trainer = FlowMatchDavidTrainer(cfg, device=device)
541
+ trainer.train()
542
+ # quick sanity
543
+ _ = trainer.sample(["a castle at sunset"], steps=10, guidance=7.0)
544
+ print("✓ Inference sanity done.")
545
+
546
+ if __name__ == "__main__":
547
+ main()