AbstractPhil commited on
Commit
4dad82a
·
verified ·
1 Parent(s): c63ba55

Introduced new trainer with improved systems and an included timestep

Browse files
Files changed (1) hide show
  1. trainer_v2.py +862 -0
trainer_v2.py ADDED
@@ -0,0 +1,862 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =====================================================================================
2
+ # SD1.5 Flow-Matching Trainer — David-Driven Adaptive Timestep Sampling
3
+ # Quartermaster: Mirel
4
+ # NEW: David-weighted timesteps + SD3 shift + adaptive chaos
5
+ # =====================================================================================
6
+ from __future__ import annotations
7
+ import os, json, math, random, re, shutil
8
+ from dataclasses import dataclass, asdict
9
+ from pathlib import Path
10
+ from typing import Dict, List, Tuple, Optional
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from torch.utils.data import Dataset, DataLoader
16
+ from torch.utils.tensorboard import SummaryWriter
17
+ from tqdm import tqdm
18
+
19
+ # Diffusers
20
+ from diffusers import StableDiffusionPipeline, DDPMScheduler
21
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
22
+
23
+ # Repo deps
24
+ from geovocab2.train.model.core.geo_david_collective import GeoDavidCollective
25
+ from geovocab2.data.prompt.symbolic_tree import SynthesisSystem
26
+
27
+ # HF / safetensors
28
+ from huggingface_hub import snapshot_download, HfApi, create_repo, hf_hub_download
29
+ from safetensors.torch import load_file
30
+
31
+
32
+ # =====================================================================================
33
+ # 1) CONFIG
34
+ # =====================================================================================
35
+ @dataclass
36
+ class BaseConfig:
37
+ run_name: str = "sd15_flowmatch_david_weighted"
38
+ out_dir: str = "./runs/sd15_flowmatch_david_weighted"
39
+ ckpt_dir: str = "./checkpoints_sd15_flow_david_weighted"
40
+ save_every: int = 1
41
+
42
+ # Data
43
+ num_samples: int = 200_000
44
+ batch_size: int = 32
45
+ num_workers: int = 2
46
+ seed: int = 42
47
+
48
+ # Models / Blocks
49
+ model_id: str = "runwayml/stable-diffusion-v1-5"
50
+ active_blocks: Tuple[str, ...] = ("down_0","down_1","down_2","down_3","mid","up_0","up_1","up_2","up_3")
51
+ pooling: str = "mean"
52
+
53
+ # Flow training
54
+ epochs: int = 10
55
+ lr: float = 1e-4
56
+ weight_decay: float = 1e-3
57
+ grad_clip: float = 1.0
58
+ amp: bool = True
59
+
60
+ global_flow_weight: float = 1.0
61
+ block_penalty_weight: float = 0.2
62
+ use_local_flow_heads: bool = False
63
+ local_flow_weight: float = 1.0
64
+
65
+ # KD
66
+ use_kd: bool = True
67
+ kd_weight: float = 0.25
68
+
69
+ # David
70
+ david_repo_id: str = "AbstractPhil/geo-david-collective-sd15-base-e40"
71
+ david_cache_dir: str = "./_hf_david_cache"
72
+ david_state_key: Optional[str] = None
73
+
74
+ # Fusion
75
+ alpha_timestep: float = 0.5
76
+ beta_pattern: float = 0.25
77
+ delta_incoherence: float = 0.25
78
+ lambda_min: float = 0.5
79
+ lambda_max: float = 3.0
80
+
81
+ block_weights: Dict[str, float] = None
82
+
83
+ # Timestep Weighting (David-guided adaptive sampling)
84
+ use_timestep_weighting: bool = True
85
+ use_david_weights: bool = True
86
+ timestep_shift: float = 3.0 # SD3-style shift (higher = bias toward clean)
87
+ base_jitter: int = 5 # Base ±jitter around bin center
88
+ adaptive_chaos: bool = True # Scale jitter by pattern difficulty
89
+ profile_samples: int = 500 # Samples to profile David's difficulty
90
+
91
+ # Scheduler
92
+ num_train_timesteps: int = 1000
93
+
94
+ # Inference
95
+ sample_steps: int = 30
96
+ guidance_scale: float = 7.5
97
+
98
+ # HuggingFace
99
+ hf_repo_id: Optional[str] = "AbstractPhil/sd15-flow-matching"
100
+ upload_every_epoch: bool = True
101
+ continue_training: bool = True
102
+
103
+ def __post_init__(self):
104
+ Path(self.out_dir).mkdir(parents=True, exist_ok=True)
105
+ Path(self.ckpt_dir).mkdir(parents=True, exist_ok=True)
106
+ Path(self.david_cache_dir).mkdir(parents=True, exist_ok=True)
107
+ if self.block_weights is None:
108
+ self.block_weights = {'down_0':0.7,'down_1':0.9,'down_2':1.0,'down_3':1.1,'mid':1.2,'up_0':1.1,'up_1':1.0,'up_2':0.9,'up_3':0.7}
109
+
110
+
111
+ # =====================================================================================
112
+ # 2) DAVID-WEIGHTED TIMESTEP SAMPLER
113
+ # =====================================================================================
114
+ class DavidWeightedTimestepSampler:
115
+ """
116
+ Samples timesteps weighted by David's inherent difficulty + SD3 shift + adaptive chaos.
117
+ """
118
+ def __init__(self, num_timesteps=1000, num_bins=100, shift=3.0, base_jitter=5, adaptive_chaos=True):
119
+ self.num_timesteps = num_timesteps
120
+ self.num_bins = num_bins
121
+ self.shift = shift
122
+ self.base_jitter = base_jitter
123
+ self.adaptive_chaos = adaptive_chaos
124
+
125
+ self.difficulty_weights = None # Timestep difficulty
126
+ self.pattern_difficulty = None # Pattern confusion per bin
127
+
128
+ def _apply_shift(self, t: float) -> float:
129
+ """Apply SD3-style timestep shift (operates on normalized t ∈ [0,1])."""
130
+ if self.shift <= 0:
131
+ return t
132
+ return self.shift * t / (1.0 + (self.shift - 1.0) * t)
133
+
134
+ def compute_difficulty_from_david(self, david, teacher, device, num_samples=500):
135
+ """Profile David's confusion patterns to create difficulty map."""
136
+ print("🔍 Profiling David's timestep & pattern difficulty...")
137
+
138
+ david.eval()
139
+ teacher.eval()
140
+
141
+ # Track David's accuracy and pattern entropy per bin
142
+ correct_per_bin = torch.zeros(self.num_bins)
143
+ total_per_bin = torch.zeros(self.num_bins)
144
+ entropy_per_bin = torch.zeros(self.num_bins)
145
+ entropy_count_per_bin = torch.zeros(self.num_bins)
146
+
147
+ with torch.no_grad():
148
+ for _ in tqdm(range(num_samples // 32), desc="Profiling David", leave=False):
149
+ # Random latents and timesteps
150
+ x = torch.randn(32, 4, 64, 64, device=device, dtype=torch.float16)
151
+ t = torch.randint(0, self.num_timesteps, (32,), device=device)
152
+ t_bins = (t // 10)
153
+
154
+ # Dummy conditioning
155
+ ehs = torch.randn(32, 77, 768, device=device, dtype=torch.float16)
156
+
157
+ # Get teacher features
158
+ teacher.hooks.clear()
159
+ _ = teacher.unet(x, t, encoder_hidden_states=ehs)
160
+ feats = {k: v.float() for k, v in teacher.hooks.bank.items()}
161
+
162
+ # Pool features
163
+ pooled = {name: f.mean(dim=(2, 3)) for name, f in feats.items()}
164
+
165
+ # Get David's outputs
166
+ outputs = david(pooled, t.float())
167
+
168
+ # 1. Timestep difficulty (from classification error)
169
+ ts_key = None
170
+ for key in ["timestep_logits", "logits_timestep", "timestep_head_logits"]:
171
+ if key in outputs:
172
+ ts_key = key
173
+ break
174
+
175
+ if ts_key:
176
+ ts_logits = outputs[ts_key]
177
+ if isinstance(ts_logits, dict):
178
+ ts_logits = torch.stack(list(ts_logits.values())).mean(0)
179
+
180
+ preds = ts_logits.argmax(dim=-1)
181
+ for pred, true_bin in zip(preds, t_bins):
182
+ bin_idx = true_bin.item()
183
+ correct_per_bin[bin_idx] += (pred == true_bin).float().item()
184
+ total_per_bin[bin_idx] += 1
185
+
186
+ # 2. Pattern difficulty (from entropy)
187
+ pt_key = None
188
+ for key in ["pattern_logits", "logits_pattern", "pattern_head_logits"]:
189
+ if key in outputs:
190
+ pt_key = key
191
+ break
192
+
193
+ if pt_key:
194
+ pt_logits = outputs[pt_key]
195
+ if isinstance(pt_logits, dict):
196
+ pt_logits = torch.stack(list(pt_logits.values())).mean(0)
197
+
198
+ P = pt_logits.softmax(-1)
199
+ ent = -(P * P.clamp_min(1e-9).log()).sum(-1)
200
+ norm_ent = ent / math.log(P.shape[-1]) # Normalize by max entropy
201
+
202
+ for i, true_bin in enumerate(t_bins):
203
+ bin_idx = true_bin.item()
204
+ entropy_per_bin[bin_idx] += norm_ent[i].item()
205
+ entropy_count_per_bin[bin_idx] += 1
206
+
207
+ # Compute timestep difficulty (inverse of accuracy)
208
+ accuracy_per_bin = correct_per_bin / (total_per_bin.clamp(min=1))
209
+ timestep_difficulty = (1.0 - accuracy_per_bin) + 0.1 # Higher = harder
210
+ self.difficulty_weights = timestep_difficulty / timestep_difficulty.sum()
211
+
212
+ # Compute pattern difficulty (average entropy per bin)
213
+ self.pattern_difficulty = entropy_per_bin / (entropy_count_per_bin.clamp(min=1))
214
+ self.pattern_difficulty = self.pattern_difficulty.clamp(min=0.1, max=1.0)
215
+
216
+ print(f"✓ David difficulty map computed:")
217
+ print(f" Avg timestep accuracy: {accuracy_per_bin.mean():.2%}")
218
+ print(f" Hardest timestep bin: {accuracy_per_bin.argmin().item()} ({accuracy_per_bin.min():.2%} acc)")
219
+ print(f" Easiest timestep bin: {accuracy_per_bin.argmax().item()} ({accuracy_per_bin.max():.2%} acc)")
220
+ print(f" Avg pattern entropy: {self.pattern_difficulty.mean():.3f}")
221
+
222
+ return self.difficulty_weights
223
+
224
+ def sample(self, batch_size: int) -> List[int]:
225
+ """Sample timesteps with David weighting + shift + adaptive chaos."""
226
+ if self.difficulty_weights is None:
227
+ # Fallback to uniform
228
+ return [random.randint(0, self.num_timesteps - 1) for _ in range(batch_size)]
229
+
230
+ timesteps = []
231
+ for _ in range(batch_size):
232
+ # 1. Sample bin weighted by David's difficulty
233
+ bin_idx = torch.multinomial(self.difficulty_weights, 1).item()
234
+
235
+ # 2. Get bin center as normalized t
236
+ bin_center_raw = bin_idx * (self.num_timesteps // self.num_bins) + (self.num_timesteps // self.num_bins) // 2
237
+ t_normalized = bin_center_raw / self.num_timesteps
238
+
239
+ # 3. Apply SD3 shift
240
+ t_shifted = self._apply_shift(t_normalized)
241
+
242
+ # 4. Add adaptive chaos (jitter scaled by pattern difficulty)
243
+ if self.adaptive_chaos:
244
+ chaos_scale = self.pattern_difficulty[bin_idx].item()
245
+ jitter = int(self.base_jitter * (0.5 + chaos_scale)) # 0.5-1.5x base jitter
246
+ else:
247
+ jitter = self.base_jitter
248
+
249
+ # 5. Convert back to raw timestep with jitter
250
+ t_raw = int(t_shifted * self.num_timesteps)
251
+ t_raw += random.randint(-jitter, jitter)
252
+ t_raw = max(0, min(self.num_timesteps - 1, t_raw))
253
+
254
+ timesteps.append(t_raw)
255
+
256
+ return timesteps
257
+
258
+
259
+ # =====================================================================================
260
+ # 3) DATA
261
+ # =====================================================================================
262
+ class SymbolicPromptDataset(Dataset):
263
+ def __init__(self, n:int, seed:int=42, timestep_sampler=None):
264
+ self.n = n
265
+ self.timestep_sampler = timestep_sampler
266
+ random.seed(seed)
267
+ self.sys = SynthesisSystem(seed=seed)
268
+
269
+ def __len__(self): return self.n
270
+
271
+ def __getitem__(self, idx):
272
+ r = self.sys.synthesize(complexity=random.choice([1,2,3,4,5]))
273
+ prompt = r['text']
274
+
275
+ if self.timestep_sampler:
276
+ t = self.timestep_sampler.sample(1)[0]
277
+ else:
278
+ t = random.randint(0, 999)
279
+
280
+ return {"prompt": prompt, "t": t}
281
+
282
+ def collate(batch: List[dict]):
283
+ prompts = [b["prompt"] for b in batch]
284
+ t = torch.tensor([b["t"] for b in batch], dtype=torch.long)
285
+ t_bins = t // 10
286
+ return {"prompts": prompts, "t": t, "t_bins": t_bins}
287
+
288
+
289
+ # =====================================================================================
290
+ # 4) HOOKS + POOLING
291
+ # =====================================================================================
292
+ class HookBank:
293
+ def __init__(self, unet: UNet2DConditionModel, active: Tuple[str, ...]):
294
+ self.active = set(active)
295
+ self.bank: Dict[str, torch.Tensor] = {}
296
+ self.hooks: List[torch.utils.hooks.RemovableHandle] = []
297
+ self._register(unet)
298
+
299
+ def _register(self, unet: UNet2DConditionModel):
300
+ def mk(name):
301
+ def h(m, i, o):
302
+ out = o[0] if isinstance(o,(tuple,list)) else o
303
+ self.bank[name] = out
304
+ return h
305
+ for i, blk in enumerate(unet.down_blocks):
306
+ nm = f"down_{i}"
307
+ if nm in self.active: self.hooks.append(blk.register_forward_hook(mk(nm)))
308
+ if "mid" in self.active:
309
+ self.hooks.append(unet.mid_block.register_forward_hook(mk("mid")))
310
+ for i, blk in enumerate(unet.up_blocks):
311
+ nm = f"up_{i}"
312
+ if nm in self.active: self.hooks.append(blk.register_forward_hook(mk(nm)))
313
+
314
+ def clear(self): self.bank.clear()
315
+ def close(self):
316
+ for h in self.hooks: h.remove()
317
+ self.hooks.clear()
318
+
319
+ def spatial_pool(x: torch.Tensor, name: str, policy: str) -> torch.Tensor:
320
+ if policy == "mean": return x.mean(dim=(2,3))
321
+ if policy == "max": return x.amax(dim=(2,3))
322
+ if policy == "adaptive":
323
+ return x.mean(dim=(2,3)) if (name.startswith("down") or name=="mid") else x.amax(dim=(2,3))
324
+ raise ValueError(f"Unknown pooling: {policy}")
325
+
326
+
327
+ # =====================================================================================
328
+ # 5) TEACHER
329
+ # =====================================================================================
330
+ class SD15Teacher(nn.Module):
331
+ def __init__(self, cfg: BaseConfig, device: str):
332
+ super().__init__()
333
+ self.pipe = StableDiffusionPipeline.from_pretrained(cfg.model_id, torch_dtype=torch.float16, safety_checker=None).to(device)
334
+ self.unet: UNet2DConditionModel = self.pipe.unet
335
+ self.text_encoder = self.pipe.text_encoder
336
+ self.tokenizer = self.pipe.tokenizer
337
+ self.hooks = HookBank(self.unet, cfg.active_blocks)
338
+ self.sched = DDPMScheduler(num_train_timesteps=cfg.num_train_timesteps)
339
+ self.device = device
340
+ for p in self.parameters(): p.requires_grad_(False)
341
+
342
+ @torch.no_grad()
343
+ def encode(self, prompts: List[str]) -> torch.Tensor:
344
+ tok = self.tokenizer(prompts, padding="max_length", max_length=self.tokenizer.model_max_length,
345
+ truncation=True, return_tensors="pt")
346
+ return self.text_encoder(tok.input_ids.to(self.device))[0]
347
+
348
+ @torch.no_grad()
349
+ def forward_eps_and_feats(self, x_t: torch.Tensor, t: torch.LongTensor, ehs: torch.Tensor):
350
+ self.hooks.clear()
351
+ eps_hat = self.unet(x_t, t, encoder_hidden_states=ehs).sample
352
+ feats = {k: v.detach().float() for k, v in self.hooks.bank.items()}
353
+ return eps_hat.float(), feats
354
+
355
+ def alpha_sigma(self, t: torch.LongTensor) -> Tuple[torch.Tensor, torch.Tensor]:
356
+ ac = self.sched.alphas_cumprod.to(self.device)[t]
357
+ alpha = ac.sqrt().view(-1,1,1,1).float()
358
+ sigma = (1.0 - ac).sqrt().view(-1,1,1,1).float()
359
+ return alpha, sigma
360
+
361
+
362
+ # =====================================================================================
363
+ # 6) STUDENT
364
+ # =====================================================================================
365
+ class StudentUNet(nn.Module):
366
+ def __init__(self, teacher_unet: UNet2DConditionModel, active_blocks: Tuple[str,...], use_local_heads: bool):
367
+ super().__init__()
368
+ self.unet = UNet2DConditionModel.from_config(teacher_unet.config)
369
+ self.unet.load_state_dict(teacher_unet.state_dict(), strict=True)
370
+ self.hooks = HookBank(self.unet, active_blocks)
371
+ self.use_local_heads = use_local_heads
372
+ self.local_heads = nn.ModuleDict()
373
+
374
+ def _ensure_heads(self, feats: Dict[str, torch.Tensor]):
375
+ if not self.use_local_heads: return
376
+ if len(self.local_heads) == len(feats): return
377
+
378
+ target_dtype = next(self.unet.parameters()).dtype
379
+
380
+ for name, f in feats.items():
381
+ c = f.shape[1]
382
+ if name not in self.local_heads:
383
+ head = nn.Conv2d(c, 4, kernel_size=1)
384
+ head = head.to(dtype=target_dtype, device=f.device)
385
+ self.local_heads[name] = head
386
+
387
+ def forward(self, x_t: torch.Tensor, t: torch.LongTensor, ehs: torch.Tensor):
388
+ self.hooks.clear()
389
+ v_hat = self.unet(x_t, t, encoder_hidden_states=ehs).sample
390
+ feats = {k: v for k, v in self.hooks.bank.items()}
391
+ self._ensure_heads(feats)
392
+ return v_hat, feats
393
+
394
+
395
+ # =====================================================================================
396
+ # 7) DAVID + ASSESSOR + FUSION
397
+ # =====================================================================================
398
+ class DavidLoader:
399
+ def __init__(self, cfg: BaseConfig, device: str):
400
+ self.cfg = cfg
401
+ self.device = device
402
+ self.repo_dir = snapshot_download(repo_id=cfg.david_repo_id, local_dir=cfg.david_cache_dir, local_dir_use_symlinks=False)
403
+ self.config_path = os.path.join(self.repo_dir, "config.json")
404
+ self.weights_path = os.path.join(self.repo_dir, "model.safetensors")
405
+ with open(self.config_path, "r") as f:
406
+ self.hf_config = json.load(f)
407
+
408
+ self.gdc = GeoDavidCollective(
409
+ block_configs=self.hf_config["block_configs"],
410
+ num_timestep_bins=int(self.hf_config["num_timestep_bins"]),
411
+ num_patterns_per_bin=int(self.hf_config["num_patterns_per_bin"]),
412
+ block_weights=self.hf_config.get("block_weights", {k:1.0 for k in self.hf_config["block_configs"].keys()}),
413
+ loss_config=self.hf_config.get("loss_config", {})
414
+ ).to(device).eval()
415
+
416
+ state = load_file(self.weights_path)
417
+ self.gdc.load_state_dict(state, strict=False)
418
+ for p in self.gdc.parameters(): p.requires_grad_(False)
419
+
420
+ print(f"✓ David loaded from HF: {self.repo_dir}")
421
+ print(f" blocks={len(self.hf_config['block_configs'])} bins={self.hf_config['num_timestep_bins']} patterns={self.hf_config['num_patterns_per_bin']}")
422
+
423
+ if "block_weights" in self.hf_config:
424
+ cfg.block_weights = self.hf_config["block_weights"]
425
+
426
+ class DavidAssessor(nn.Module):
427
+ def __init__(self, gdc: GeoDavidCollective, pooling: str):
428
+ super().__init__()
429
+ self.gdc = gdc
430
+ self.pooling = pooling
431
+
432
+ def _pool(self, feats: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
433
+ return {k: spatial_pool(v, k, self.pooling) for k, v in feats.items()}
434
+
435
+ @torch.no_grad()
436
+ def forward(self, feats_student: Dict[str, torch.Tensor], t: torch.LongTensor
437
+ ) -> Tuple[Dict[str,float], Dict[str,float], Dict[str,float]]:
438
+ Zs = self._pool(feats_student)
439
+ outs = self.gdc(Zs, t.float())
440
+ e_t, e_p, coh = {}, {}, {}
441
+
442
+ ts_key = None
443
+ for key in ["timestep_logits", "logits_timestep", "timestep_head_logits"]:
444
+ if key in outs: ts_key = key; break
445
+
446
+ pt_key = None
447
+ for key in ["pattern_logits", "logits_pattern", "pattern_head_logits"]:
448
+ if key in outs: pt_key = key; break
449
+
450
+ t_bins = (t // 10).to(next(self.gdc.parameters()).device)
451
+ if ts_key is not None:
452
+ ts_logits = outs[ts_key]
453
+ if isinstance(ts_logits, dict):
454
+ for name, L in ts_logits.items():
455
+ ce = F.cross_entropy(L, t_bins, reduction="mean")
456
+ e_t[name] = float(ce.item())
457
+ else:
458
+ ce = F.cross_entropy(ts_logits, t_bins, reduction="mean")
459
+ for name in Zs.keys():
460
+ e_t[name] = float(ce.item())
461
+ else:
462
+ for name in Zs.keys(): e_t[name] = 0.0
463
+
464
+ if pt_key is not None:
465
+ pt_logits = outs[pt_key]
466
+ if isinstance(pt_logits, dict):
467
+ for name, L in pt_logits.items():
468
+ P = L.softmax(-1)
469
+ ent = -(P * (P.clamp_min(1e-9)).log()).sum(-1).mean()
470
+ e_p[name] = float(ent.item() / math.log(P.shape[-1]))
471
+ else:
472
+ P = pt_logits.softmax(-1)
473
+ ent = -(P * (P.clamp_min(1e-9)).log()).sum(-1).mean()
474
+ for name in Zs.keys():
475
+ e_p[name] = float(ent.item() / math.log(P.shape[-1]))
476
+ else:
477
+ for name in Zs.keys(): e_p[name] = 0.0
478
+
479
+ alphas = {}
480
+ try:
481
+ alphas = self.gdc.get_cantor_alphas()
482
+ except Exception:
483
+ alphas = {}
484
+ avg_alpha = float(sum(alphas.values())/max(len(alphas),1)) if alphas else 1.0
485
+ for name in Zs.keys():
486
+ coh[name] = avg_alpha
487
+
488
+ return e_t, e_p, coh
489
+
490
+ class BlockPenaltyFusion:
491
+ def __init__(self, cfg: BaseConfig): self.cfg = cfg
492
+ def lambdas(self, e_t:Dict[str,float], e_p:Dict[str,float], coh:Dict[str,float]) -> Dict[str,float]:
493
+ lam = {}
494
+ for name, base in self.cfg.block_weights.items():
495
+ val = base * (1.0
496
+ + self.cfg.alpha_timestep * float(e_t.get(name,0.0))
497
+ + self.cfg.beta_pattern * float(e_p.get(name,0.0))
498
+ + self.cfg.delta_incoherence * (1.0 - float(coh.get(name,1.0))))
499
+ lam[name] = float(max(self.cfg.lambda_min, min(self.cfg.lambda_max, val)))
500
+ return lam
501
+
502
+
503
+ # =====================================================================================
504
+ # 8) TRAINER
505
+ # =====================================================================================
506
+ class FlowMatchDavidTrainer:
507
+ def __init__(self, cfg: BaseConfig, device: str = "cuda"):
508
+ self.cfg = cfg
509
+ self.device = device
510
+ self.start_epoch = 0
511
+ self.start_gstep = 0
512
+
513
+ # Initialize David first (needed for timestep sampler)
514
+ self.david_loader = DavidLoader(cfg, device)
515
+ self.david = self.david_loader.gdc
516
+ self.assessor = DavidAssessor(self.david, cfg.pooling)
517
+ self.fusion = BlockPenaltyFusion(cfg)
518
+
519
+ # Initialize teacher (needed for David profiling)
520
+ self.teacher = SD15Teacher(cfg, device).eval()
521
+
522
+ # Initialize timestep sampler
523
+ self.timestep_sampler = None
524
+ if cfg.use_timestep_weighting:
525
+ print("\n" + "="*70)
526
+ print("🎯 ADAPTIVE TIMESTEP SAMPLING ENABLED")
527
+ print(f" David weighting: {cfg.use_david_weights}")
528
+ print(f" SD3 shift: {cfg.timestep_shift}")
529
+ print(f" Base jitter: ±{cfg.base_jitter}")
530
+ print(f" Adaptive chaos: {cfg.adaptive_chaos}")
531
+
532
+ self.timestep_sampler = DavidWeightedTimestepSampler(
533
+ num_timesteps=cfg.num_train_timesteps,
534
+ num_bins=100,
535
+ shift=cfg.timestep_shift if cfg.use_david_weights else 0.0,
536
+ base_jitter=cfg.base_jitter,
537
+ adaptive_chaos=cfg.adaptive_chaos
538
+ )
539
+
540
+ if cfg.use_david_weights:
541
+ self.timestep_sampler.compute_difficulty_from_david(
542
+ david=self.david,
543
+ teacher=self.teacher,
544
+ device=device,
545
+ num_samples=cfg.profile_samples
546
+ )
547
+ print("="*70 + "\n")
548
+
549
+ # Initialize dataset with sampler
550
+ self.dataset = SymbolicPromptDataset(cfg.num_samples, cfg.seed, self.timestep_sampler)
551
+ self.loader = DataLoader(self.dataset, batch_size=cfg.batch_size, shuffle=True,
552
+ num_workers=cfg.num_workers, pin_memory=True, collate_fn=collate)
553
+
554
+ # Initialize student
555
+ self.student = StudentUNet(self.teacher.unet, cfg.active_blocks, cfg.use_local_flow_heads).to(device)
556
+
557
+ self.opt = torch.optim.AdamW(self.student.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
558
+ self.sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=cfg.epochs * len(self.loader))
559
+ self.scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)
560
+
561
+ # Load checkpoints
562
+ emergency_path = Path("./EMERGENCY_SAVE_SUCCESS.pt")
563
+ if not emergency_path.exists():
564
+ print("\n🔍 Emergency checkpoint not found locally, checking HuggingFace...")
565
+ emergency_path = self._download_emergency_checkpoint()
566
+
567
+ if emergency_path and emergency_path.exists():
568
+ self._load_emergency_checkpoint(emergency_path)
569
+ elif cfg.continue_training:
570
+ self._load_latest_from_hf()
571
+
572
+ self.writer = SummaryWriter(log_dir=os.path.join(cfg.out_dir, cfg.run_name))
573
+
574
+ def _download_emergency_checkpoint(self) -> Optional[Path]:
575
+ """Download emergency checkpoint from HuggingFace backup repo."""
576
+ emergency_repo = "AbstractPhil/sd15-flow-emergency-backup"
577
+ emergency_file = "EMERGENCY_SAVE_SUCCESS.pt"
578
+
579
+ try:
580
+ print(f"📥 Downloading emergency checkpoint from {emergency_repo}...")
581
+ local_path = hf_hub_download(
582
+ repo_id=emergency_repo,
583
+ filename=emergency_file,
584
+ repo_type="model",
585
+ cache_dir="./_emergency_cache"
586
+ )
587
+
588
+ target_path = Path("./EMERGENCY_SAVE_SUCCESS.pt")
589
+ shutil.copy(local_path, target_path)
590
+
591
+ size_mb = target_path.stat().st_size / 1e6
592
+ print(f"✅ Downloaded emergency checkpoint ({size_mb:.1f} MB)")
593
+ return target_path
594
+
595
+ except Exception as e:
596
+ print(f"⚠️ Could not download emergency checkpoint: {e}")
597
+ return None
598
+
599
+ def _load_emergency_checkpoint(self, path: Path):
600
+ """Load emergency checkpoint with student_unet structure."""
601
+ try:
602
+ print(f"\n🚨 Found emergency checkpoint: {path}")
603
+ checkpoint = torch.load(path, map_location='cpu')
604
+
605
+ if 'student_unet' in checkpoint:
606
+ print("📦 Loading emergency checkpoint format...")
607
+ missing, unexpected = self.student.unet.load_state_dict(checkpoint['student_unet'], strict=False)
608
+ print(f"✓ Loaded student UNet")
609
+
610
+ if 'opt' in checkpoint:
611
+ self.opt.load_state_dict(checkpoint['opt'])
612
+ print("✓ Loaded optimizer state")
613
+
614
+ if 'sched' in checkpoint:
615
+ self.sched.load_state_dict(checkpoint['sched'])
616
+ print("✓ Loaded scheduler state")
617
+
618
+ if 'gstep' in checkpoint:
619
+ self.start_gstep = checkpoint['gstep']
620
+ self.start_epoch = self.start_gstep // len(self.loader)
621
+ print(f"✓ Resuming from global step {self.start_gstep} (epoch ~{self.start_epoch})")
622
+
623
+ print("✅ Emergency checkpoint loaded successfully!")
624
+
625
+ except Exception as e:
626
+ print(f"⚠️ Failed to load emergency checkpoint: {e}")
627
+
628
+ def _load_latest_from_hf(self):
629
+ if not self.cfg.hf_repo_id:
630
+ return
631
+
632
+ try:
633
+ api = HfApi()
634
+ print(f"\n🔍 Searching for latest checkpoint in {self.cfg.hf_repo_id}...")
635
+
636
+ files = api.list_repo_files(repo_id=self.cfg.hf_repo_id, repo_type="model")
637
+ epochs = []
638
+ for f in files:
639
+ if f.endswith('.pt'):
640
+ match = re.search(r'_e(\d+)\.pt$', f)
641
+ if match:
642
+ epochs.append((int(match.group(1)), f))
643
+
644
+ if not epochs:
645
+ return
646
+
647
+ latest_epoch, latest_file = max(epochs, key=lambda x: x[0])
648
+ print(f"📥 Downloading: {latest_file}")
649
+
650
+ local_path = hf_hub_download(
651
+ repo_id=self.cfg.hf_repo_id,
652
+ filename=latest_file,
653
+ repo_type="model",
654
+ cache_dir=self.cfg.ckpt_dir
655
+ )
656
+
657
+ checkpoint = torch.load(local_path, map_location='cpu')
658
+
659
+ if 'student_unet' in checkpoint:
660
+ self.student.unet.load_state_dict(checkpoint['student_unet'], strict=False)
661
+ elif 'student' in checkpoint:
662
+ self.student.load_state_dict(checkpoint['student'], strict=False)
663
+
664
+ if 'opt' in checkpoint:
665
+ self.opt.load_state_dict(checkpoint['opt'])
666
+ if 'sched' in checkpoint:
667
+ self.sched.load_state_dict(checkpoint['sched'])
668
+
669
+ self.start_epoch = latest_epoch
670
+ self.start_gstep = latest_epoch * len(self.loader)
671
+
672
+ print(f"✅ Resuming from epoch {self.start_epoch + 1}")
673
+ del checkpoint
674
+ torch.cuda.empty_cache()
675
+
676
+ except Exception as e:
677
+ print(f"⚠️ Failed to load from HF: {e}")
678
+
679
+ def _v_star(self, x_t, t, eps_hat):
680
+ alpha, sigma = self.teacher.alpha_sigma(t)
681
+ x0_hat = (x_t - sigma * eps_hat) / (alpha + 1e-8)
682
+ return alpha * eps_hat - sigma * x0_hat
683
+
684
+ def _down_like(self, tgt: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
685
+ return F.interpolate(tgt, size=ref.shape[-2:], mode="bilinear", align_corners=False)
686
+
687
+ def _kd_cos(self, s: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
688
+ s = F.normalize(s, dim=-1); t = F.normalize(t, dim=-1)
689
+ return 1.0 - (s*t).sum(-1).mean()
690
+
691
+ def train(self):
692
+ cfg = self.cfg
693
+ gstep = self.start_gstep
694
+
695
+ for ep in range(self.start_epoch, cfg.epochs):
696
+ self.student.train()
697
+ pbar = tqdm(self.loader, desc=f"Epoch {ep+1}/{cfg.epochs}",
698
+ dynamic_ncols=True, leave=True, position=0)
699
+ acc = {"L":0.0, "Lf":0.0, "Lb":0.0}
700
+
701
+ for it, batch in enumerate(pbar):
702
+ prompts = batch["prompts"]
703
+ t = batch["t"].to(self.device)
704
+
705
+ with torch.no_grad():
706
+ ehs = self.teacher.encode(prompts)
707
+
708
+ x_t = torch.randn(len(prompts), 4, 64, 64, device=self.device, dtype=torch.float16)
709
+
710
+ with torch.no_grad():
711
+ eps_hat, t_feats_spatial = self.teacher.forward_eps_and_feats(x_t.half(), t, ehs)
712
+ v_star = self._v_star(x_t, t, eps_hat)
713
+
714
+ with torch.cuda.amp.autocast(enabled=cfg.amp):
715
+ v_hat, s_feats_spatial = self.student(x_t, t, ehs)
716
+ L_flow = F.mse_loss(v_hat, v_star)
717
+
718
+ e_t, e_p, coh = self.assessor(s_feats_spatial, t)
719
+ lam = self.fusion.lambdas(e_t, e_p, coh)
720
+
721
+ L_blocks = torch.zeros((), device=self.device)
722
+ for name, s_feat in s_feats_spatial.items():
723
+ L_kd = torch.zeros((), device=self.device)
724
+ if cfg.use_kd:
725
+ s_pool = spatial_pool(s_feat, name, cfg.pooling)
726
+ t_pool = spatial_pool(t_feats_spatial[name], name, cfg.pooling)
727
+ L_kd = self._kd_cos(s_pool, t_pool)
728
+
729
+ L_lf = torch.zeros((), device=self.device)
730
+ if cfg.use_local_flow_heads and name in self.student.local_heads:
731
+ v_loc = self.student.local_heads[name](s_feat)
732
+ v_ds = self._down_like(v_star, v_loc)
733
+ L_lf = F.mse_loss(v_loc, v_ds)
734
+ L_blocks = L_blocks + lam.get(name,1.0) * (cfg.kd_weight * L_kd + cfg.local_flow_weight * L_lf)
735
+
736
+ L_total = cfg.global_flow_weight*L_flow + cfg.block_penalty_weight*L_blocks
737
+
738
+ self.opt.zero_grad(set_to_none=True)
739
+ if cfg.amp:
740
+ self.scaler.scale(L_total).backward()
741
+ nn.utils.clip_grad_norm_(self.student.parameters(), cfg.grad_clip)
742
+ self.scaler.step(self.opt); self.scaler.update()
743
+ else:
744
+ L_total.backward()
745
+ nn.utils.clip_grad_norm_(self.student.parameters(), cfg.grad_clip)
746
+ self.opt.step()
747
+ self.sched.step(); gstep += 1
748
+
749
+ acc["L"] += float(L_total.item())
750
+ acc["Lf"] += float(L_flow.item())
751
+ acc["Lb"] += float(L_blocks.item())
752
+
753
+ if it % 50 == 0:
754
+ self.writer.add_scalar("train/total", float(L_total.item()), gstep)
755
+ self.writer.add_scalar("train/flow", float(L_flow.item()), gstep)
756
+ self.writer.add_scalar("train/blocks",float(L_blocks.item()), gstep)
757
+ for k in list(lam.keys())[:4]:
758
+ self.writer.add_scalar(f"lambda/{k}", lam[k], gstep)
759
+
760
+ if it % 10 == 0 or it == len(self.loader) - 1:
761
+ pbar.set_postfix({
762
+ "L": f"{float(L_total.item()):.4f}",
763
+ "Lf": f"{float(L_flow.item()):.4f}",
764
+ "Lb": f"{float(L_blocks.item()):.4f}"
765
+ }, refresh=False)
766
+
767
+ del x_t, eps_hat, v_star, v_hat, s_feats_spatial, t_feats_spatial
768
+
769
+ pbar.close()
770
+
771
+ n = len(self.loader)
772
+ print(f"\n[Epoch {ep+1}] L={acc['L']/n:.4f} | L_flow={acc['Lf']/n:.4f} | L_blocks={acc['Lb']/n:.4f}")
773
+ self.writer.add_scalar("epoch/total", acc['L']/n, ep+1)
774
+ self.writer.add_scalar("epoch/flow", acc['Lf']/n, ep+1)
775
+ self.writer.add_scalar("epoch/blocks",acc['Lb']/n, ep+1)
776
+
777
+ if (ep+1) % cfg.save_every == 0:
778
+ self._save(ep+1, gstep)
779
+
780
+ self._save("final", gstep)
781
+ self.writer.close()
782
+
783
+ def _save(self, tag, gstep):
784
+ """Save checkpoint and upload to HuggingFace."""
785
+ pt_path = Path(self.cfg.ckpt_dir) / f"{self.cfg.run_name}_e{tag}.pt"
786
+ torch.save({
787
+ "cfg": asdict(self.cfg),
788
+ "student": self.student.state_dict(),
789
+ "opt": self.opt.state_dict(),
790
+ "sched": self.sched.state_dict(),
791
+ "gstep": gstep
792
+ }, pt_path)
793
+
794
+ size_mb = pt_path.stat().st_size / 1e6
795
+ print(f"✓ Saved checkpoint: {pt_path.name} ({size_mb:.1f} MB)")
796
+
797
+ if self.cfg.upload_every_epoch and self.cfg.hf_repo_id:
798
+ self._upload_to_hf(pt_path, tag)
799
+
800
+ def _upload_to_hf(self, path: Path, tag):
801
+ """Upload checkpoint to HuggingFace."""
802
+ try:
803
+ api = HfApi()
804
+ create_repo(self.cfg.hf_repo_id, exist_ok=True, private=False, repo_type="model")
805
+
806
+ print(f"📤 Uploading {path.name} to {self.cfg.hf_repo_id}...")
807
+ api.upload_file(
808
+ path_or_fileobj=str(path),
809
+ path_in_repo=path.name,
810
+ repo_id=self.cfg.hf_repo_id,
811
+ repo_type="model",
812
+ commit_message=f"Epoch {tag}"
813
+ )
814
+ print(f"✅ Uploaded: https://huggingface.co/{self.cfg.hf_repo_id}/{path.name}")
815
+
816
+ except Exception as e:
817
+ print(f"⚠️ Upload failed: {e}")
818
+
819
+ @torch.no_grad()
820
+ def sample(self, prompts: List[str], steps: Optional[int]=None, guidance: Optional[float]=None) -> torch.Tensor:
821
+ steps = steps or self.cfg.sample_steps
822
+ guidance = guidance if guidance is not None else self.cfg.guidance_scale
823
+ cond_e = self.teacher.encode(prompts)
824
+ uncond_e = self.teacher.encode([""]*len(prompts))
825
+ sched = self.teacher.sched
826
+ sched.set_timesteps(steps, device=self.device)
827
+ x_t = torch.randn(len(prompts), 4, 64, 64, device=self.device)
828
+
829
+ for t_scalar in sched.timesteps:
830
+ t = torch.full((x_t.shape[0],), t_scalar, device=self.device, dtype=torch.long)
831
+ v_u, _ = self.student(x_t, t, uncond_e)
832
+ v_c, _ = self.student(x_t, t, cond_e)
833
+ v_hat = v_u + guidance*(v_c - v_u)
834
+
835
+ alpha, sigma = self.teacher.alpha_sigma(t)
836
+ denom = (alpha**2 + sigma**2)
837
+ x0_hat = (alpha * x_t - sigma * v_hat) / (denom + 1e-8)
838
+ eps_hat = (x_t - alpha * x0_hat) / (sigma + 1e-8)
839
+
840
+ step = sched.step(model_output=eps_hat, timestep=t_scalar, sample=x_t)
841
+ x_t = step.prev_sample
842
+
843
+ imgs = self.teacher.pipe.vae.decode(x_t / 0.18215).sample
844
+ return imgs.clamp(-1,1)
845
+
846
+
847
+ # =====================================================================================
848
+ # 9) MAIN
849
+ # =====================================================================================
850
+ def main():
851
+ cfg = BaseConfig()
852
+ print(json.dumps(asdict(cfg), indent=2))
853
+ device = "cuda" if torch.cuda.is_available() else "cpu"
854
+ if device != "cuda":
855
+ print("⚠️ A100 strongly recommended.")
856
+ trainer = FlowMatchDavidTrainer(cfg, device=device)
857
+ trainer.train()
858
+ _ = trainer.sample(["a castle at sunset"], steps=10, guidance=7.0)
859
+ print("✓ Training complete.")
860
+
861
+ if __name__ == "__main__":
862
+ main()