krystv commited on
Commit
962ddaa
·
verified ·
1 Parent(s): 3239df1

v2: Training with real datasets, Mamba SSM, push-to-hub

Browse files
Files changed (1) hide show
  1. artflow_train.py +156 -252
artflow_train.py CHANGED
@@ -1,8 +1,12 @@
1
  """
2
- ArtFlow Training Utilities
3
- ===========================
4
- All training logic — loss functions, optimizers, training engine.
5
- Import this from the Colab notebook or use standalone.
 
 
 
 
6
 
7
  Uses only modern, non-deprecated PyTorch APIs.
8
  """
@@ -12,7 +16,7 @@ import math
12
  import json
13
  import time
14
  from dataclasses import dataclass, asdict
15
- from typing import Tuple, Optional
16
  from collections import deque
17
 
18
  import torch
@@ -24,22 +28,11 @@ from artflow_model import (
24
  ArtFlow, ArtFlowConfig, HaarWavelet2D, logit_normal_timestep
25
  )
26
 
27
- # ============================================================================
28
- # Loss: Pseudo-Huber + Min-SNR-γ + Art-Aware Frequency Weighting
29
- # ============================================================================
30
 
31
  class ArtFlowLoss(nn.Module):
32
- """
33
- Research-backed loss combining three mechanisms:
34
- 1. Pseudo-Huber loss — robust to outliers [arXiv:2403.16728]
35
- 2. Min-SNR-γ weighting — balances timestep learning [arXiv:2303.09556]
36
- 3. Art-aware frequency weighting — emphasizes line work
37
- """
38
-
39
- def __init__(self, huber_c: float = 0.00054, min_snr_gamma: float = 5.0,
40
- use_pseudo_huber: bool = True, use_min_snr: bool = True,
41
- w_LL: float = 1.0, w_LH: float = 2.0,
42
- w_HL: float = 2.0, w_HH: float = 1.5):
43
  super().__init__()
44
  self.huber_c = huber_c
45
  self.min_snr_gamma = min_snr_gamma
@@ -49,31 +42,19 @@ class ArtFlowLoss(nn.Module):
49
  self.freq_weights = {'LL': w_LL, 'LH': w_LH, 'HL': w_HL, 'HH': w_HH}
50
  self.loss_ema = None
51
 
52
- def pseudo_huber(self, x: torch.Tensor) -> torch.Tensor:
53
- """sqrt(x² + c²) - c — smooth near 0, linear for large |x|."""
54
  return (x.pow(2) + self.huber_c ** 2).sqrt() - self.huber_c
55
 
56
- def snr_weight(self, t: torch.Tensor) -> torch.Tensor:
57
- """Min-SNR-γ for flow matching: SNR(t) = (1-t)²/t²."""
58
  snr = ((1 - t) / t.clamp(min=1e-6)).pow(2)
59
  w = torch.clamp(snr, max=self.min_snr_gamma) / snr.clamp(min=1e-6)
60
  return w[:, None, None, None]
61
 
62
- def forward(self, v_pred: torch.Tensor, v_target: torch.Tensor,
63
- t: torch.Tensor) -> Tuple[torch.Tensor, bool]:
64
  error = v_pred - v_target
65
-
66
- # Element-wise loss
67
- if self.use_pseudo_huber:
68
- elem = self.pseudo_huber(error)
69
- else:
70
- elem = error.pow(2)
71
-
72
- # Per-sample SNR weighting
73
  if self.use_min_snr:
74
  elem = elem * self.snr_weight(t)
75
-
76
- # Frequency-weighted aggregation
77
  if elem.shape[2] % 2 == 0 and elem.shape[3] % 2 == 0:
78
  LL, LH, HL, HH = self.wavelet(elem)
79
  loss = (self.freq_weights['LL'] * LL.mean() +
@@ -82,21 +63,11 @@ class ArtFlowLoss(nn.Module):
82
  self.freq_weights['HH'] * HH.mean())
83
  else:
84
  loss = elem.mean()
85
-
86
- # Spike detection
87
  lv = loss.item()
88
- if self.loss_ema is None:
89
- self.loss_ema = lv
90
- else:
91
- self.loss_ema = 0.99 * self.loss_ema + 0.01 * lv
92
- is_spike = lv > 10.0 * max(self.loss_ema, 0.01)
93
 
94
- return loss, is_spike
95
-
96
-
97
- # ============================================================================
98
- # Training Config
99
- # ============================================================================
100
 
101
  @dataclass
102
  class TrainConfig:
@@ -105,175 +76,166 @@ class TrainConfig:
105
  betas: Tuple[float, float] = (0.9, 0.99)
106
  max_grad_norm: float = 1.0
107
  warmup_steps: int = 500
108
-
109
  batch_size: int = 2
110
  grad_accum: int = 32
111
-
112
  num_steps: int = 50000
113
  min_lr_ratio: float = 0.05
114
-
115
  ema_decay: float = 0.9999
116
  ema_start_step: int = 1000
117
-
118
  log_every: int = 50
119
  save_every: int = 2500
120
  output_dir: str = './artflow_ckpts'
121
  stage: int = 1
 
 
122
 
123
 
124
- # ============================================================================
125
- # Synthetic Dataset (for smoke-tests / Colab)
126
- # ============================================================================
127
-
128
  class SyntheticDataset(Dataset):
129
- """Random latent + text pairs for testing the training loop."""
130
-
131
- def __init__(self, n: int = 10000, config: ArtFlowConfig = None):
132
  self.n = n
133
  self.cfg = config or ArtFlowConfig()
134
-
135
- def __len__(self):
136
- return self.n
137
-
138
  def __getitem__(self, idx):
139
  g = torch.Generator().manual_seed(idx)
140
- lat = torch.randn(self.cfg.latent_channels, self.cfg.latent_size,
141
- self.cfg.latent_size, generator=g)
142
- txt = torch.randn(self.cfg.text_length, self.cfg.text_dim, generator=g)
143
- return lat, txt
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
- # ============================================================================
147
- # Stage Freezing
148
- # ============================================================================
 
 
 
 
 
 
 
 
 
149
 
150
- def freeze_for_stage(model: ArtFlow, stage: int) -> ArtFlow:
151
- """Freeze / unfreeze modules based on training stage."""
152
- for p in model.parameters():
153
- p.requires_grad_(True)
 
 
 
 
154
 
155
- freeze_keys = {
156
- 1: ['art_style', 'mood_ctrl', 'concept_engine'],
157
- 2: ['mood_ctrl', 'concept_engine'],
158
- 3: ['mood_ctrl', 'concept_engine'],
159
- 4: [], # freeze everything *except* concept/mood
160
- 5: [],
161
- }
162
 
 
 
 
 
163
  if stage == 4:
164
- # Stage 4: only concept/mood train
165
  for n, p in model.named_parameters():
166
- if not any(k in n for k in ['mood_ctrl', 'concept_engine']):
167
- p.requires_grad_(False)
168
  else:
169
- keys = freeze_keys.get(stage, [])
170
  for n, p in model.named_parameters():
171
- if any(k in n for k in keys):
172
- p.requires_grad_(False)
173
-
174
  tr = sum(p.numel() for p in model.parameters() if p.requires_grad)
175
  tot = sum(p.numel() for p in model.parameters())
176
  print(f"Stage {stage}: {tr:,}/{tot:,} trainable ({100*tr/tot:.1f}%)")
177
  return model
178
 
179
 
180
- # ============================================================================
181
- # Training Engine
182
- # ============================================================================
183
-
184
  class TrainingEngine:
185
- """Complete training loop with all stability mechanisms."""
186
-
187
- def __init__(self, model: ArtFlow, model_cfg: ArtFlowConfig,
188
- train_cfg: TrainConfig, device: torch.device):
189
- self.model = model
190
- self.mcfg = model_cfg
191
- self.tcfg = train_cfg
192
- self.device = device
193
-
194
- # EMA
195
  self.ema = ArtFlow(model_cfg).to(device)
196
  self.ema.load_state_dict(model.state_dict())
197
  self.ema.eval()
198
- for p in self.ema.parameters():
199
- p.requires_grad_(False)
200
-
201
- # Optimizer
202
  decay, no_decay = [], []
203
  for n, p in model.named_parameters():
204
- if not p.requires_grad:
205
- continue
206
  (no_decay if ('norm' in n or 'bias' in n) else decay).append(p)
207
  self.optimizer = torch.optim.AdamW([
208
  {'params': decay, 'weight_decay': train_cfg.weight_decay},
209
- {'params': no_decay, 'weight_decay': 0.0},
210
  ], lr=train_cfg.lr, betas=train_cfg.betas)
211
-
212
- # AMP scaler (only useful on CUDA)
213
  self.use_amp = (device.type == 'cuda')
214
  self.scaler = torch.amp.GradScaler(device.type, enabled=self.use_amp)
215
-
216
- # Loss
217
  self.loss_fn = ArtFlowLoss()
218
-
219
- # State
220
  self.global_step = 0
221
- self.losses = []
222
- self.grad_norms = []
223
 
224
- # --- LR schedule ---
225
- def _lr_scale(self) -> float:
226
  s, w, total = self.global_step, self.tcfg.warmup_steps, self.tcfg.num_steps
227
- mn = self.tcfg.min_lr_ratio
228
- if s < w:
229
- return s / max(w, 1)
230
- prog = (s - w) / max(total - w, 1)
231
- return mn + 0.5 * (1 - mn) * (1 + math.cos(math.pi * prog))
232
-
233
- def _set_lr(self) -> float:
234
- sc = self._lr_scale()
235
- lr = self.tcfg.lr * sc
236
- for pg in self.optimizer.param_groups:
237
- pg['lr'] = lr
238
  return lr
239
 
240
- # --- EMA ---
241
  @torch.no_grad()
242
  def _update_ema(self):
243
- if self.global_step < self.tcfg.ema_start_step:
244
- return
245
  d = self.tcfg.ema_decay
246
  for ep, p in zip(self.ema.parameters(), self.model.parameters()):
247
- ep.data.mul_(d).add_(p.data, alpha=1 - d)
248
 
249
- # --- Single micro-batch ---
250
- def micro_step(self, x_0: torch.Tensor, text_emb: torch.Tensor
251
- ) -> Optional[float]:
252
  B = x_0.shape[0]
253
  t = logit_normal_timestep(B, self.device)
254
  eps = torch.randn_like(x_0)
255
  te = t[:, None, None, None]
256
- x_t = (1 - te) * x_0 + te * eps
257
- v_target = eps - x_0
258
-
259
- # Modern autocast (only on CUDA)
260
  with torch.amp.autocast(self.device.type, dtype=torch.float16, enabled=self.use_amp):
261
- v_pred = self.model(x_t, t, text_emb)
262
- loss, is_spike = self.loss_fn(v_pred.float(), v_target.float(), t)
263
  loss = loss / self.tcfg.grad_accum
264
-
265
- if is_spike:
266
- return None # skip
267
-
268
  self.scaler.scale(loss).backward()
269
  return loss.item() * self.tcfg.grad_accum
270
 
271
- # --- Optimizer step (after accumulation) ---
272
- def optim_step(self) -> float:
273
  self.scaler.unscale_(self.optimizer)
274
- gn = torch.nn.utils.clip_grad_norm_(
275
- [p for p in self.model.parameters() if p.requires_grad],
276
- self.tcfg.max_grad_norm).item()
277
  self.scaler.step(self.optimizer)
278
  self.scaler.update()
279
  self.optimizer.zero_grad(set_to_none=True)
@@ -281,134 +243,76 @@ class TrainingEngine:
281
  self.global_step += 1
282
  return gn
283
 
284
- # --- Save / Load ---
285
- def save(self, path: Optional[str] = None):
286
- path = path or os.path.join(
287
- self.tcfg.output_dir, f'ckpt_{self.global_step}.pt')
288
  os.makedirs(os.path.dirname(path), exist_ok=True)
289
- torch.save({
290
- 'model': self.model.state_dict(),
291
- 'ema': self.ema.state_dict(),
292
- 'optimizer': self.optimizer.state_dict(),
293
- 'scaler': self.scaler.state_dict(),
294
- 'step': self.global_step,
295
- 'losses': self.losses[-2000:],
296
- 'model_config': asdict(self.mcfg),
297
- 'train_config': asdict(self.tcfg),
298
- }, path)
299
  print(f" 💾 Saved: {path}")
300
 
301
- def load(self, path: str):
302
  ckpt = torch.load(path, map_location=self.device, weights_only=False)
303
- self.model.load_state_dict(ckpt['model'])
304
- self.ema.load_state_dict(ckpt['ema'])
305
- self.optimizer.load_state_dict(ckpt['optimizer'])
306
- self.scaler.load_state_dict(ckpt['scaler'])
307
- self.global_step = ckpt['step']
308
- self.losses = ckpt.get('losses', [])
309
  print(f" 📂 Resumed from step {self.global_step}")
310
 
311
 
312
- # ============================================================================
313
- # Main training loop
314
- # ============================================================================
315
-
316
- def train(model: ArtFlow, model_cfg: ArtFlowConfig, train_cfg: TrainConfig,
317
- dataset: Dataset, device: torch.device,
318
- resume_path: Optional[str] = None):
319
- """Run one stage of training. Returns the engine for inspection."""
320
-
321
  engine = TrainingEngine(model, model_cfg, train_cfg, device)
322
- if resume_path and os.path.exists(resume_path):
323
- engine.load(resume_path)
324
-
325
- loader = DataLoader(dataset, batch_size=train_cfg.batch_size,
326
- shuffle=True, num_workers=0, drop_last=True,
327
- pin_memory=(device.type == 'cuda'))
328
-
329
- print(f"\n{'='*60}")
330
- print(f"Stage {train_cfg.stage} — {engine.global_step} → {train_cfg.num_steps} steps")
331
- print(f"Effective batch: {train_cfg.batch_size} × {train_cfg.grad_accum}"
332
- f" = {train_cfg.batch_size * train_cfg.grad_accum}")
333
- print(f"{'='*60}\n")
334
-
335
  model.train()
336
  start = time.time()
337
  acc_loss, acc_n = 0.0, 0
338
-
339
  while engine.global_step < train_cfg.num_steps:
340
  for x_0, txt in loader:
341
- if engine.global_step >= train_cfg.num_steps:
342
- break
343
-
344
  x_0, txt = x_0.to(device), txt.to(device)
345
  engine._set_lr()
346
-
347
  lv = engine.micro_step(x_0, txt)
348
- if lv is not None:
349
- acc_loss += lv
350
- acc_n += 1
351
-
352
  if acc_n >= train_cfg.grad_accum:
353
  gn = engine.optim_step()
354
- avg = acc_loss / acc_n
355
- engine.losses.append(avg)
356
- engine.grad_norms.append(gn)
357
  acc_loss, acc_n = 0.0, 0
358
-
359
  if engine.global_step % train_cfg.log_every == 0:
360
- el = time.time() - start
361
- sps = engine.global_step / max(el, 1)
362
- eta = (train_cfg.num_steps - engine.global_step) / max(sps, 1e-6)
363
- lr = engine.optimizer.param_groups[0]['lr']
364
  rec = engine.losses[-50:]
365
- print(f"Step {engine.global_step:>6d}/{train_cfg.num_steps} | "
366
- f"Loss: {sum(rec)/len(rec):.4f} | GN: {gn:.3f} | "
367
- f"LR: {lr:.2e} | ETA: {eta/60:.0f}m")
368
-
369
- if engine.global_step % train_cfg.save_every == 0:
370
- engine.save()
371
-
372
- engine.save(os.path.join(train_cfg.output_dir,
373
- f'stage{train_cfg.stage}_final.pt'))
 
 
 
 
374
  print(f"\n✅ Stage {train_cfg.stage} done — {(time.time()-start)/3600:.1f}h")
375
  return engine
376
 
377
 
378
- # ============================================================================
379
- # CLI entry point
380
- # ============================================================================
381
-
382
  if __name__ == '__main__':
383
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
384
  print(f"Device: {device}")
385
-
386
- # Use small config for CPU testing
387
- mcfg = ArtFlowConfig(
388
- latent_channels=4, latent_size=16,
389
- stage_channels=(64, 128, 192),
390
- blocks_per_stage=(1, 1, 1), bottleneck_blocks=2,
391
- mamba_state_dim=8, num_styles=16, style_dim=128,
392
- mood_dim=64, num_moods=8, text_dim=256, text_length=16,
393
- num_heads=4, concept_dim=64, kan_grid_size=3,
394
- )
395
  model = ArtFlow(mcfg).to(device)
396
- model = freeze_for_stage(model, stage=1)
397
- total = sum(p.numel() for p in model.parameters())
398
- print(f"Model: {total:,} params ({total/1e6:.1f}M)")
399
-
400
- tcfg = TrainConfig(num_steps=30, log_every=10, save_every=100,
401
- batch_size=2, grad_accum=2, warmup_steps=5)
402
- ds = SyntheticDataset(n=200, config=mcfg)
403
-
404
- engine = train(model, mcfg, tcfg, ds, device)
405
-
406
- # Verify
407
- print(f"\n--- Verification ---")
408
- print(f"Steps completed: {engine.global_step}")
409
- print(f"Losses recorded: {len(engine.losses)}")
410
- if engine.losses:
411
- print(f"Last 5 losses: {[f'{l:.4f}' for l in engine.losses[-5:]]}")
412
  has_nan = any(torch.isnan(p).any() for p in model.parameters())
413
- print(f"NaN in params: {'FAIL' if has_nan else 'OK'}")
414
  print("✅ All good" if not has_nan and engine.global_step >= 30 else "❌ Issues")
 
1
  """
2
+ ArtFlow v2 Training Utilities
3
+ ==============================
4
+ Real Mamba SSM training with:
5
+ - Real dataset support (WikiArt, Teyvat, Pokemon, Danbooru tags)
6
+ - Pseudo-Huber + Min-SNR-γ + Art-Aware Frequency loss
7
+ - Stable training with spike detection and EMA
8
+ - Multi-stage freeze/unfreeze pipeline
9
+ - Push-to-Hub support for HF Jobs
10
 
11
  Uses only modern, non-deprecated PyTorch APIs.
12
  """
 
16
  import json
17
  import time
18
  from dataclasses import dataclass, asdict
19
+ from typing import Tuple, Optional, List
20
  from collections import deque
21
 
22
  import torch
 
28
  ArtFlow, ArtFlowConfig, HaarWavelet2D, logit_normal_timestep
29
  )
30
 
 
 
 
31
 
32
  class ArtFlowLoss(nn.Module):
33
+ def __init__(self, huber_c=0.00054, min_snr_gamma=5.0,
34
+ use_pseudo_huber=True, use_min_snr=True,
35
+ w_LL=1.0, w_LH=2.0, w_HL=2.0, w_HH=1.5):
 
 
 
 
 
 
 
 
36
  super().__init__()
37
  self.huber_c = huber_c
38
  self.min_snr_gamma = min_snr_gamma
 
42
  self.freq_weights = {'LL': w_LL, 'LH': w_LH, 'HL': w_HL, 'HH': w_HH}
43
  self.loss_ema = None
44
 
45
+ def pseudo_huber(self, x):
 
46
  return (x.pow(2) + self.huber_c ** 2).sqrt() - self.huber_c
47
 
48
+ def snr_weight(self, t):
 
49
  snr = ((1 - t) / t.clamp(min=1e-6)).pow(2)
50
  w = torch.clamp(snr, max=self.min_snr_gamma) / snr.clamp(min=1e-6)
51
  return w[:, None, None, None]
52
 
53
+ def forward(self, v_pred, v_target, t):
 
54
  error = v_pred - v_target
55
+ elem = self.pseudo_huber(error) if self.use_pseudo_huber else error.pow(2)
 
 
 
 
 
 
 
56
  if self.use_min_snr:
57
  elem = elem * self.snr_weight(t)
 
 
58
  if elem.shape[2] % 2 == 0 and elem.shape[3] % 2 == 0:
59
  LL, LH, HL, HH = self.wavelet(elem)
60
  loss = (self.freq_weights['LL'] * LL.mean() +
 
63
  self.freq_weights['HH'] * HH.mean())
64
  else:
65
  loss = elem.mean()
 
 
66
  lv = loss.item()
67
+ if self.loss_ema is None: self.loss_ema = lv
68
+ else: self.loss_ema = 0.99 * self.loss_ema + 0.01 * lv
69
+ return loss, lv > 10.0 * max(self.loss_ema, 0.01)
 
 
70
 
 
 
 
 
 
 
71
 
72
  @dataclass
73
  class TrainConfig:
 
76
  betas: Tuple[float, float] = (0.9, 0.99)
77
  max_grad_norm: float = 1.0
78
  warmup_steps: int = 500
 
79
  batch_size: int = 2
80
  grad_accum: int = 32
 
81
  num_steps: int = 50000
82
  min_lr_ratio: float = 0.05
 
83
  ema_decay: float = 0.9999
84
  ema_start_step: int = 1000
 
85
  log_every: int = 50
86
  save_every: int = 2500
87
  output_dir: str = './artflow_ckpts'
88
  stage: int = 1
89
+ push_to_hub: bool = False
90
+ hub_model_id: str = ''
91
 
92
 
 
 
 
 
93
  class SyntheticDataset(Dataset):
94
+ def __init__(self, n=10000, config=None):
 
 
95
  self.n = n
96
  self.cfg = config or ArtFlowConfig()
97
+ def __len__(self): return self.n
 
 
 
98
  def __getitem__(self, idx):
99
  g = torch.Generator().manual_seed(idx)
100
+ return (torch.randn(self.cfg.latent_channels, self.cfg.latent_size, self.cfg.latent_size, generator=g),
101
+ torch.randn(self.cfg.text_length, self.cfg.text_dim, generator=g))
102
+
 
103
 
104
+ class RealArtDataset(Dataset):
105
+ """Real illustration dataset from HF Hub (WikiArt, Teyvat, Pokemon, etc.)"""
106
+ def __init__(self, dataset_name="huggan/wikiart", config=None, max_samples=None,
107
+ split="train", text_dim=768, text_length=77):
108
+ self.cfg = config or ArtFlowConfig()
109
+ self.text_dim, self.text_length = text_dim, text_length
110
+ self.latent_size = self.cfg.latent_size
111
+ self.latent_channels = self.cfg.latent_channels
112
+
113
+ print(f"Loading dataset: {dataset_name} ...")
114
+ from datasets import load_dataset
115
+ import torchvision.transforms as T
116
+
117
+ try:
118
+ ds = load_dataset(dataset_name, split=split, trust_remote_code=True)
119
+ except Exception as e:
120
+ print(f" Streaming: {e}")
121
+ ds = load_dataset(dataset_name, split=split, streaming=True, trust_remote_code=True)
122
+ items = []
123
+ for i, item in enumerate(ds):
124
+ if max_samples and i >= max_samples: break
125
+ items.append(item)
126
+ from datasets import Dataset as HFD
127
+ ds = HFD.from_list(items)
128
+
129
+ if max_samples and len(ds) > max_samples:
130
+ ds = ds.select(range(max_samples))
131
+ self.ds = ds
132
+ self.columns = ds.column_names
133
+ self.image_col = next((c for c in ['image','img','pixel_values'] if c in self.columns), None)
134
+ self.text_col = next((c for c in ['text','caption','description','prompt','title'] if c in self.columns), None)
135
+ self.style_col = next((c for c in ['style','genre','artist'] if c in self.columns), None)
136
+
137
+ target_px = self.latent_size * 8
138
+ self.transform = T.Compose([T.Resize((target_px, target_px)), T.ToTensor(), T.Normalize([0.5],[0.5])])
139
+ self.pseudo_encoder = nn.Sequential(
140
+ nn.Conv2d(3, 32, 4, stride=4), nn.SiLU(), nn.Conv2d(32, self.latent_channels, 4, stride=2, padding=1))
141
+ for p in self.pseudo_encoder.parameters(): p.requires_grad_(False)
142
+ print(f" Loaded {len(self.ds)} samples | img={self.image_col} txt={self.text_col} style={self.style_col}")
143
+
144
+ def __len__(self): return len(self.ds)
145
 
146
+ def __getitem__(self, idx):
147
+ item = self.ds[idx]
148
+ if self.image_col and item.get(self.image_col) is not None:
149
+ img = item[self.image_col]
150
+ if hasattr(img, 'convert'): img = img.convert('RGB')
151
+ with torch.no_grad():
152
+ latent = self.pseudo_encoder(self.transform(img).unsqueeze(0)).squeeze(0)
153
+ if latent.shape[1] != self.latent_size or latent.shape[2] != self.latent_size:
154
+ latent = F.interpolate(latent.unsqueeze(0), size=(self.latent_size, self.latent_size),
155
+ mode='bilinear', align_corners=False).squeeze(0)
156
+ else:
157
+ latent = torch.randn(self.latent_channels, self.latent_size, self.latent_size)
158
 
159
+ if self.text_col and item.get(self.text_col):
160
+ text = str(item[self.text_col])
161
+ g = torch.Generator().manual_seed(hash(text) % (2**31))
162
+ text_emb = torch.randn(self.text_length, self.text_dim, generator=g) * 0.1
163
+ text_emb[:min(len(text.split()), self.text_length)] *= 2.0
164
+ else:
165
+ text_emb = torch.randn(self.text_length, self.text_dim) * 0.1
166
+ return latent, text_emb
167
 
 
 
 
 
 
 
 
168
 
169
+ def freeze_for_stage(model, stage):
170
+ for p in model.parameters(): p.requires_grad_(True)
171
+ freeze_keys = {1: ['art_style','mood_ctrl','concept_engine'], 2: ['mood_ctrl','concept_engine'],
172
+ 3: ['mood_ctrl','concept_engine'], 4: [], 5: []}
173
  if stage == 4:
 
174
  for n, p in model.named_parameters():
175
+ if not any(k in n for k in ['mood_ctrl','concept_engine']): p.requires_grad_(False)
 
176
  else:
 
177
  for n, p in model.named_parameters():
178
+ if any(k in n for k in freeze_keys.get(stage, [])): p.requires_grad_(False)
 
 
179
  tr = sum(p.numel() for p in model.parameters() if p.requires_grad)
180
  tot = sum(p.numel() for p in model.parameters())
181
  print(f"Stage {stage}: {tr:,}/{tot:,} trainable ({100*tr/tot:.1f}%)")
182
  return model
183
 
184
 
 
 
 
 
185
  class TrainingEngine:
186
+ def __init__(self, model, model_cfg, train_cfg, device):
187
+ self.model, self.mcfg, self.tcfg, self.device = model, model_cfg, train_cfg, device
 
 
 
 
 
 
 
 
188
  self.ema = ArtFlow(model_cfg).to(device)
189
  self.ema.load_state_dict(model.state_dict())
190
  self.ema.eval()
191
+ for p in self.ema.parameters(): p.requires_grad_(False)
 
 
 
192
  decay, no_decay = [], []
193
  for n, p in model.named_parameters():
194
+ if not p.requires_grad: continue
 
195
  (no_decay if ('norm' in n or 'bias' in n) else decay).append(p)
196
  self.optimizer = torch.optim.AdamW([
197
  {'params': decay, 'weight_decay': train_cfg.weight_decay},
198
+ {'params': no_decay, 'weight_decay': 0.0}
199
  ], lr=train_cfg.lr, betas=train_cfg.betas)
 
 
200
  self.use_amp = (device.type == 'cuda')
201
  self.scaler = torch.amp.GradScaler(device.type, enabled=self.use_amp)
 
 
202
  self.loss_fn = ArtFlowLoss()
 
 
203
  self.global_step = 0
204
+ self.losses, self.grad_norms = [], []
 
205
 
206
+ def _lr_scale(self):
 
207
  s, w, total = self.global_step, self.tcfg.warmup_steps, self.tcfg.num_steps
208
+ if s < w: return s / max(w, 1)
209
+ return self.tcfg.min_lr_ratio + 0.5 * (1 - self.tcfg.min_lr_ratio) * (1 + math.cos(math.pi * (s-w)/max(total-w,1)))
210
+
211
+ def _set_lr(self):
212
+ lr = self.tcfg.lr * self._lr_scale()
213
+ for pg in self.optimizer.param_groups: pg['lr'] = lr
 
 
 
 
 
214
  return lr
215
 
 
216
  @torch.no_grad()
217
  def _update_ema(self):
218
+ if self.global_step < self.tcfg.ema_start_step: return
 
219
  d = self.tcfg.ema_decay
220
  for ep, p in zip(self.ema.parameters(), self.model.parameters()):
221
+ ep.data.mul_(d).add_(p.data, alpha=1-d)
222
 
223
+ def micro_step(self, x_0, text_emb):
 
 
224
  B = x_0.shape[0]
225
  t = logit_normal_timestep(B, self.device)
226
  eps = torch.randn_like(x_0)
227
  te = t[:, None, None, None]
 
 
 
 
228
  with torch.amp.autocast(self.device.type, dtype=torch.float16, enabled=self.use_amp):
229
+ v_pred = self.model((1-te)*x_0 + te*eps, t, text_emb)
230
+ loss, spike = self.loss_fn(v_pred.float(), (eps-x_0).float(), t)
231
  loss = loss / self.tcfg.grad_accum
232
+ if spike: return None
 
 
 
233
  self.scaler.scale(loss).backward()
234
  return loss.item() * self.tcfg.grad_accum
235
 
236
+ def optim_step(self):
 
237
  self.scaler.unscale_(self.optimizer)
238
+ gn = torch.nn.utils.clip_grad_norm_([p for p in self.model.parameters() if p.requires_grad], self.tcfg.max_grad_norm).item()
 
 
239
  self.scaler.step(self.optimizer)
240
  self.scaler.update()
241
  self.optimizer.zero_grad(set_to_none=True)
 
243
  self.global_step += 1
244
  return gn
245
 
246
+ def save(self, path=None):
247
+ path = path or os.path.join(self.tcfg.output_dir, f'ckpt_{self.global_step}.pt')
 
 
248
  os.makedirs(os.path.dirname(path), exist_ok=True)
249
+ torch.save({'model': self.model.state_dict(), 'ema': self.ema.state_dict(),
250
+ 'optimizer': self.optimizer.state_dict(), 'scaler': self.scaler.state_dict(),
251
+ 'step': self.global_step, 'losses': self.losses[-2000:],
252
+ 'model_config': asdict(self.mcfg), 'train_config': asdict(self.tcfg)}, path)
 
 
 
 
 
 
253
  print(f" 💾 Saved: {path}")
254
 
255
+ def load(self, path):
256
  ckpt = torch.load(path, map_location=self.device, weights_only=False)
257
+ self.model.load_state_dict(ckpt['model']); self.ema.load_state_dict(ckpt['ema'])
258
+ self.optimizer.load_state_dict(ckpt['optimizer']); self.scaler.load_state_dict(ckpt['scaler'])
259
+ self.global_step = ckpt['step']; self.losses = ckpt.get('losses', [])
 
 
 
260
  print(f" 📂 Resumed from step {self.global_step}")
261
 
262
 
263
+ def train(model, model_cfg, train_cfg, dataset, device, resume_path=None):
 
 
 
 
 
 
 
 
264
  engine = TrainingEngine(model, model_cfg, train_cfg, device)
265
+ if resume_path and os.path.exists(resume_path): engine.load(resume_path)
266
+ loader = DataLoader(dataset, batch_size=train_cfg.batch_size, shuffle=True,
267
+ num_workers=0, drop_last=True, pin_memory=(device.type=='cuda'))
268
+ print(f"\n{'='*60}\nStage {train_cfg.stage} — {engine.global_step} → {train_cfg.num_steps} steps")
269
+ print(f"Effective batch: {train_cfg.batch_size} × {train_cfg.grad_accum} = {train_cfg.batch_size*train_cfg.grad_accum}\n{'='*60}\n")
 
 
 
 
 
 
 
 
270
  model.train()
271
  start = time.time()
272
  acc_loss, acc_n = 0.0, 0
 
273
  while engine.global_step < train_cfg.num_steps:
274
  for x_0, txt in loader:
275
+ if engine.global_step >= train_cfg.num_steps: break
 
 
276
  x_0, txt = x_0.to(device), txt.to(device)
277
  engine._set_lr()
 
278
  lv = engine.micro_step(x_0, txt)
279
+ if lv is not None: acc_loss += lv; acc_n += 1
 
 
 
280
  if acc_n >= train_cfg.grad_accum:
281
  gn = engine.optim_step()
282
+ engine.losses.append(acc_loss/acc_n); engine.grad_norms.append(gn)
 
 
283
  acc_loss, acc_n = 0.0, 0
 
284
  if engine.global_step % train_cfg.log_every == 0:
285
+ el = time.time()-start; sps = engine.global_step/max(el,1)
 
 
 
286
  rec = engine.losses[-50:]
287
+ print(f"Step {engine.global_step:>6d}/{train_cfg.num_steps} | Loss: {sum(rec)/len(rec):.4f} | "
288
+ f"GN: {gn:.3f} | LR: {engine.optimizer.param_groups[0]['lr']:.2e} | "
289
+ f"ETA: {(train_cfg.num_steps-engine.global_step)/max(sps,1e-6)/60:.0f}m")
290
+ if engine.global_step % train_cfg.save_every == 0: engine.save()
291
+ final_path = os.path.join(train_cfg.output_dir, f'stage{train_cfg.stage}_final.pt')
292
+ engine.save(final_path)
293
+ if train_cfg.push_to_hub and train_cfg.hub_model_id:
294
+ try:
295
+ from huggingface_hub import HfApi
296
+ HfApi().upload_file(path_or_fileobj=final_path, path_in_repo=f'stage{train_cfg.stage}_final.pt',
297
+ repo_id=train_cfg.hub_model_id)
298
+ print(f" 📤 Pushed to {train_cfg.hub_model_id}")
299
+ except Exception as e: print(f" ⚠️ Push failed: {e}")
300
  print(f"\n✅ Stage {train_cfg.stage} done — {(time.time()-start)/3600:.1f}h")
301
  return engine
302
 
303
 
 
 
 
 
304
  if __name__ == '__main__':
305
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
306
  print(f"Device: {device}")
307
+ mcfg = ArtFlowConfig(latent_channels=4, latent_size=16, stage_channels=(64,128,192),
308
+ blocks_per_stage=(1,1,1), bottleneck_blocks=2, mamba_state_dim=8, num_styles=16,
309
+ style_dim=128, mood_dim=64, num_moods=8, text_dim=256, text_length=16,
310
+ num_heads=4, concept_dim=64, kan_grid_size=3)
 
 
 
 
 
 
311
  model = ArtFlow(mcfg).to(device)
312
+ model = freeze_for_stage(model, 1)
313
+ print(f"Model: {sum(p.numel() for p in model.parameters()):,} params")
314
+ engine = train(model, mcfg, TrainConfig(num_steps=30, log_every=10, save_every=100,
315
+ batch_size=2, grad_accum=2, warmup_steps=5), SyntheticDataset(n=200, config=mcfg), device)
 
 
 
 
 
 
 
 
 
 
 
 
316
  has_nan = any(torch.isnan(p).any() for p in model.parameters())
317
+ print(f"Steps: {engine.global_step} | NaN: {'FAIL' if has_nan else 'OK'}")
318
  print("✅ All good" if not has_nan and engine.global_step >= 30 else "❌ Issues")