AbstractPhil commited on
Commit
4adc3d8
·
verified ·
1 Parent(s): 55886a5

Update trainer_colab.py

Browse files
Files changed (1) hide show
  1. trainer_colab.py +280 -111
trainer_colab.py CHANGED
@@ -36,9 +36,9 @@ DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
36
 
37
  # HuggingFace Hub
38
  HF_REPO = "AbstractPhil/tiny-flux"
39
- SAVE_EVERY = 500 # steps - local save
40
- UPLOAD_EVERY = 500 # steps - hub upload
41
- SAMPLE_EVERY = 250 # steps - generate samples
42
  LOG_EVERY = 10 # steps - tensorboard
43
 
44
  # Checkpoint loading target
@@ -47,10 +47,14 @@ LOG_EVERY = 10 # steps - tensorboard
47
  # "best" - load best model
48
  # int (e.g. 1500) - load specific step
49
  # "hub:step_1000" - load specific checkpoint from hub
50
- # "local:path/to/checkpoint.safetensors" - load specific local file
51
  # "none" - start fresh, ignore existing checkpoints
52
  LOAD_TARGET = "latest"
53
 
 
 
 
 
54
  # Local paths
55
  CHECKPOINT_DIR = "./tiny_flux_checkpoints"
56
  LOG_DIR = "./tiny_flux_logs"
@@ -127,14 +131,37 @@ def encode_prompt(prompt):
127
  # ============================================================================
128
  # FLOW MATCHING HELPERS
129
  # ============================================================================
 
 
 
 
 
 
 
 
 
 
 
130
  def flux_shift(t, s=SHIFT):
 
 
 
 
 
 
 
131
  return s * t / (1 + (s - 1) * t)
132
 
133
  def flux_shift_inverse(t_shifted, s=SHIFT):
134
- """Inverse of flux_shift for sampling."""
135
  return t_shifted / (s - (s - 1) * t_shifted)
136
 
137
  def min_snr_weight(t, gamma=MIN_SNR):
 
 
 
 
 
138
  snr = (t / (1 - t).clamp(min=1e-5)).pow(2)
139
  return torch.clamp(snr, max=gamma) / snr.clamp(min=1e-5)
140
 
@@ -143,7 +170,12 @@ def min_snr_weight(t, gamma=MIN_SNR):
143
  # ============================================================================
144
  @torch.no_grad()
145
  def generate_samples(model, prompts, num_steps=20, guidance_scale=3.5, H=64, W=64):
146
- """Generate sample images using Euler sampling."""
 
 
 
 
 
147
  model.eval()
148
  B = len(prompts)
149
  C = 16 # VAE channels
@@ -157,18 +189,21 @@ def generate_samples(model, prompts, num_steps=20, guidance_scale=3.5, H=64, W=6
157
  t5_embeds = torch.stack(t5_embeds)
158
  clip_pooleds = torch.stack(clip_pooleds)
159
 
160
- # Start from noise
161
  x = torch.randn(B, H * W, C, device=DEVICE, dtype=DTYPE)
162
 
163
  # Create image IDs
164
  img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE)
165
 
166
- # Euler sampling with uniform timesteps
167
- timesteps = torch.linspace(1, 0, num_steps + 1, device=DEVICE)[:-1]
168
 
169
- for i, t in enumerate(timesteps):
170
- t_batch = t.expand(B)
171
- dt = 1.0 / num_steps
 
 
 
172
 
173
  # Conditional prediction
174
  guidance = torch.full((B,), guidance_scale, device=DEVICE, dtype=DTYPE)
@@ -181,15 +216,15 @@ def generate_samples(model, prompts, num_steps=20, guidance_scale=3.5, H=64, W=6
181
  guidance=guidance,
182
  )
183
 
184
- # Euler step: x = x + v * dt (going from noise to data)
185
  x = x + v_cond * dt
186
 
187
  # Reshape to image format: (B, H*W, C) -> (B, C, H, W)
188
  latents = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
189
 
190
- # Decode with VAE
191
  latents = latents / vae.config.scaling_factor
192
- images = vae.decode(latents.float()).sample
193
  images = (images / 2 + 0.5).clamp(0, 1)
194
 
195
  model.train()
@@ -235,6 +270,32 @@ def collate(batch):
235
  # ============================================================================
236
  # CHECKPOINT FUNCTIONS
237
  # ============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  def save_checkpoint(model, optimizer, scheduler, step, epoch, loss, path):
239
  """Save checkpoint locally."""
240
  os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
@@ -306,7 +367,7 @@ def load_checkpoint(model, optimizer, scheduler, target):
306
  "best" - best model
307
  int (1500) - specific step
308
  "hub:step_1000" - specific hub checkpoint
309
- "local:/path/to/file.safetensors" - specific local file
310
  "none" - skip loading, start fresh
311
  """
312
  if target == "none":
@@ -342,102 +403,163 @@ def load_checkpoint(model, optimizer, scheduler, target):
342
 
343
  # Load based on mode
344
  if load_mode == "local":
345
- # Direct local file
346
  if os.path.exists(load_path):
347
- weights = load_file(load_path)
348
  model.load_state_dict(weights)
349
- # Try to find associated state file
350
- state_path = load_path.replace(".safetensors", ".pt")
351
- if os.path.exists(state_path):
352
- state = torch.load(state_path, weights_only=False)
353
- optimizer.load_state_dict(state["optimizer"])
354
- scheduler.load_state_dict(state["scheduler"])
355
- start_step = state.get("step", 0)
356
- start_epoch = state.get("epoch", 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  print(f"✓ Loaded local: {load_path} (step {start_step})")
358
  return start_step, start_epoch
359
  else:
360
  print(f"⚠ Local file not found: {load_path}")
361
 
362
  elif load_mode == "hub":
363
- # Specific hub checkpoint
364
- try:
365
- filename = f"checkpoints/{load_path}.safetensors" if not load_path.endswith(".safetensors") else load_path
366
- local_path = hf_hub_download(repo_id=HF_REPO, filename=filename)
367
- weights = load_file(local_path)
368
- model.load_state_dict(weights)
369
- # Extract step from filename
370
- if "step_" in load_path:
371
- start_step = int(load_path.split("step_")[-1].replace(".safetensors", ""))
372
- print(f"✓ Loaded from Hub: {filename} (step {start_step})")
373
- return start_step, start_epoch
374
- except Exception as e:
375
- print(f" Hub load failed: {e}")
 
 
 
 
 
376
 
377
  elif load_mode == "best":
378
- # Try hub best first
379
- try:
380
- local_path = hf_hub_download(repo_id=HF_REPO, filename="model.safetensors")
381
- weights = load_file(local_path)
382
- model.load_state_dict(weights)
383
- print(f"✓ Loaded best model from Hub")
384
- return start_step, start_epoch
385
- except:
386
- pass
387
- # Try local best
388
- best_path = os.path.join(CHECKPOINT_DIR, "best.safetensors")
389
- if os.path.exists(best_path):
390
- weights = load_file(best_path)
391
- model.load_state_dict(weights)
392
- state_path = best_path.replace(".safetensors", ".pt")
393
- if os.path.exists(state_path):
394
- state = torch.load(state_path, weights_only=False)
395
- start_step = state.get("step", 0)
396
- start_epoch = state.get("epoch", 0)
397
- print(f" Loaded local best (step {start_step})")
398
- return start_step, start_epoch
 
 
 
 
 
 
399
 
400
  elif load_mode == "step":
401
  # Specific step number
402
  step_num = load_path
403
- # Try hub
404
- try:
405
- filename = f"checkpoints/step_{step_num}.safetensors"
406
- local_path = hf_hub_download(repo_id=HF_REPO, filename=filename)
407
- weights = load_file(local_path)
408
- model.load_state_dict(weights)
409
- start_step = step_num
410
- print(f"✓ Loaded step {step_num} from Hub")
411
- return start_step, start_epoch
412
- except:
413
- pass
414
- # Try local
415
- local_path = os.path.join(CHECKPOINT_DIR, f"step_{step_num}.safetensors")
416
- if os.path.exists(local_path):
417
- weights = load_file(local_path)
418
- model.load_state_dict(weights)
419
- state_path = local_path.replace(".safetensors", ".pt")
420
- if os.path.exists(state_path):
421
- state = torch.load(state_path, weights_only=False)
422
- optimizer.load_state_dict(state["optimizer"])
423
- scheduler.load_state_dict(state["scheduler"])
424
- start_epoch = state.get("epoch", 0)
425
- start_step = step_num
426
- print(f"✓ Loaded local step {step_num}")
427
- return start_step, start_epoch
 
 
 
 
 
 
 
 
 
428
  print(f"⚠ Step {step_num} not found")
429
 
430
  # Default: latest
431
- # Try Hub first
432
  try:
433
  files = api.list_repo_files(repo_id=HF_REPO)
434
- checkpoints = [f for f in files if f.startswith("checkpoints/step_") and f.endswith(".safetensors")]
435
  if checkpoints:
436
- checkpoints.sort(key=lambda x: int(x.split("step_")[-1].replace(".safetensors", "")))
 
 
 
437
  latest = checkpoints[-1]
438
- step = int(latest.split("step_")[-1].replace(".safetensors", ""))
439
  local_path = hf_hub_download(repo_id=HF_REPO, filename=latest)
440
- weights = load_file(local_path)
441
  model.load_state_dict(weights)
442
  start_step = step
443
  print(f"✓ Loaded latest from Hub: step {step}")
@@ -445,22 +567,33 @@ def load_checkpoint(model, optimizer, scheduler, target):
445
  except Exception as e:
446
  print(f"Hub check: {e}")
447
 
448
- # Try local
449
  if os.path.exists(CHECKPOINT_DIR):
450
- local_ckpts = [f for f in os.listdir(CHECKPOINT_DIR) if f.startswith("step_") and f.endswith(".safetensors")]
 
 
451
  if local_ckpts:
452
- local_ckpts.sort(key=lambda x: int(x.split("step_")[-1].replace(".safetensors", "")))
 
 
453
  latest = local_ckpts[-1]
454
- step = int(latest.split("step_")[-1].replace(".safetensors", ""))
455
  weights_path = os.path.join(CHECKPOINT_DIR, latest)
456
- weights = load_file(weights_path)
457
  model.load_state_dict(weights)
458
- state_path = weights_path.replace(".safetensors", ".pt")
 
459
  if os.path.exists(state_path):
460
- state = torch.load(state_path, weights_only=False)
461
- optimizer.load_state_dict(state["optimizer"])
462
- scheduler.load_state_dict(state["scheduler"])
463
- start_epoch = state.get("epoch", 0)
 
 
 
 
 
 
464
  start_step = step
465
  print(f"✓ Loaded latest local: step {step}")
466
  return start_step, start_epoch
@@ -479,6 +612,7 @@ loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate,
479
  config = TinyFluxConfig()
480
  model = TinyFlux(config).to(DEVICE).to(DTYPE)
481
  print(f"\nParams: {sum(p.numel() for p in model.parameters()):,}")
 
482
 
483
  # ============================================================================
484
  # OPTIMIZER & SCHEDULER
@@ -499,6 +633,11 @@ sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_fn)
499
  print(f"\nLoad target: {LOAD_TARGET}")
500
  start_step, start_epoch = load_checkpoint(model, opt, sched, LOAD_TARGET)
501
 
 
 
 
 
 
502
  # Log config to tensorboard
503
  writer.add_text("config", json.dumps(config.__dict__, indent=2), 0)
504
  writer.add_text("training_config", json.dumps({
@@ -537,26 +676,53 @@ for ep in range(start_epoch, EPOCHS):
537
  pbar = tqdm(loader, desc=f"E{ep+1}")
538
 
539
  for i, batch in enumerate(pbar):
540
- lat = batch["latents"]
541
  t5 = batch["t5_embeds"]
542
  clip = batch["clip_pooled"]
543
 
544
- B, C, H, W = lat.shape
545
- x1 = lat.permute(0, 2, 3, 1).reshape(B, H*W, C)
546
- x0 = torch.randn_like(x1)
547
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
548
  t = torch.sigmoid(torch.randn(B, device=DEVICE))
549
- t = flux_shift(t).to(DTYPE).clamp(1e-4, 1-1e-4)
 
 
 
 
550
 
551
- t_exp = t.view(B, 1, 1)
552
- x_t = (1 - t_exp) * x0 + t_exp * x1
553
- v_target = x1 - x0
554
 
 
555
  img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE)
556
- guidance = torch.rand(B, device=DEVICE, dtype=DTYPE) * 4 + 1
557
 
 
 
 
 
558
  with torch.autocast("cuda", dtype=DTYPE):
559
- pred = model(
560
  hidden_states=x_t,
561
  encoder_hidden_states=t5,
562
  pooled_projections=clip,
@@ -565,7 +731,10 @@ for ep in range(start_epoch, EPOCHS):
565
  guidance=guidance,
566
  )
567
 
568
- loss_raw = F.mse_loss(pred, v_target, reduction="none").mean(dim=[1,2])
 
 
 
569
  snr_weights = min_snr_weight(t)
570
  loss = (loss_raw * snr_weights).mean() / GRAD_ACCUM
571
  loss.backward()
 
36
 
37
  # HuggingFace Hub
38
  HF_REPO = "AbstractPhil/tiny-flux"
39
+ SAVE_EVERY = 1000 # steps - local save
40
+ UPLOAD_EVERY = 1000 # steps - hub upload
41
+ SAMPLE_EVERY = 500 # steps - generate samples
42
  LOG_EVERY = 10 # steps - tensorboard
43
 
44
  # Checkpoint loading target
 
47
  # "best" - load best model
48
  # int (e.g. 1500) - load specific step
49
  # "hub:step_1000" - load specific checkpoint from hub
50
+ # "local:path/to/checkpoint.safetensors" or "local:path/to/checkpoint.pt"
51
  # "none" - start fresh, ignore existing checkpoints
52
  LOAD_TARGET = "latest"
53
 
54
+ # Manual resume step (set to override step from checkpoint, or None to use checkpoint's step)
55
+ # Useful when checkpoint doesn't contain step info
56
+ RESUME_STEP = None # e.g., 5000 to resume from step 5000
57
+
58
  # Local paths
59
  CHECKPOINT_DIR = "./tiny_flux_checkpoints"
60
  LOG_DIR = "./tiny_flux_logs"
 
131
  # ============================================================================
132
  # FLOW MATCHING HELPERS
133
  # ============================================================================
134
+ # Rectified Flow / Flow Matching formulation:
135
+ # x_t = (1-t) * x_0 + t * x_1
136
+ # where x_0 = noise, x_1 = data
137
+ # t=0: pure noise, t=1: pure data
138
+ # velocity v = x_1 - x_0 = data - noise
139
+ #
140
+ # Training: model learns to predict v given (x_t, t)
141
+ # Inference: start from noise (t=0), integrate to data (t=1)
142
+ # x_{t+dt} = x_t + v_pred * dt
143
+ # ============================================================================
144
+
145
  def flux_shift(t, s=SHIFT):
146
+ """Flux timestep shift for training distribution.
147
+
148
+ Shifts timesteps towards higher values (closer to data),
149
+ making training focus more on refining details.
150
+
151
+ s=3.0 (default): flux_shift(0.5) ≈ 0.75
152
+ """
153
  return s * t / (1 + (s - 1) * t)
154
 
155
  def flux_shift_inverse(t_shifted, s=SHIFT):
156
+ """Inverse of flux_shift."""
157
  return t_shifted / (s - (s - 1) * t_shifted)
158
 
159
  def min_snr_weight(t, gamma=MIN_SNR):
160
+ """Min-SNR weighting to balance loss across timesteps.
161
+
162
+ Downweights very easy timesteps (near t=0 or t=1).
163
+ gamma=5.0 is typical.
164
+ """
165
  snr = (t / (1 - t).clamp(min=1e-5)).pow(2)
166
  return torch.clamp(snr, max=gamma) / snr.clamp(min=1e-5)
167
 
 
170
  # ============================================================================
171
  @torch.no_grad()
172
  def generate_samples(model, prompts, num_steps=20, guidance_scale=3.5, H=64, W=64):
173
+ """Generate sample images using Euler sampling.
174
+
175
+ Flow matching: x_t = (1-t)*noise + t*data, v = data - noise
176
+ At t=0: pure noise. At t=1: pure data.
177
+ We integrate from t=0 to t=1.
178
+ """
179
  model.eval()
180
  B = len(prompts)
181
  C = 16 # VAE channels
 
189
  t5_embeds = torch.stack(t5_embeds)
190
  clip_pooleds = torch.stack(clip_pooleds)
191
 
192
+ # Start from pure noise (t=0)
193
  x = torch.randn(B, H * W, C, device=DEVICE, dtype=DTYPE)
194
 
195
  # Create image IDs
196
  img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE)
197
 
198
+ # Euler sampling: t goes from 0 (noise) to 1 (data)
199
+ timesteps = torch.linspace(0, 1, num_steps + 1, device=DEVICE, dtype=DTYPE)
200
 
201
+ for i in range(num_steps):
202
+ t_curr = timesteps[i]
203
+ t_next = timesteps[i + 1]
204
+ dt = t_next - t_curr # positive
205
+
206
+ t_batch = t_curr.expand(B)
207
 
208
  # Conditional prediction
209
  guidance = torch.full((B,), guidance_scale, device=DEVICE, dtype=DTYPE)
 
216
  guidance=guidance,
217
  )
218
 
219
+ # Euler step: x_{t+dt} = x_t + v * dt
220
  x = x + v_cond * dt
221
 
222
  # Reshape to image format: (B, H*W, C) -> (B, C, H, W)
223
  latents = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
224
 
225
+ # Decode with VAE (match VAE dtype)
226
  latents = latents / vae.config.scaling_factor
227
+ images = vae.decode(latents.to(vae.dtype)).sample
228
  images = (images / 2 + 0.5).clamp(0, 1)
229
 
230
  model.train()
 
270
  # ============================================================================
271
  # CHECKPOINT FUNCTIONS
272
  # ============================================================================
273
+ def load_weights(path):
274
+ """Load weights from .safetensors or .pt file."""
275
+ if path.endswith(".safetensors"):
276
+ return load_file(path)
277
+ elif path.endswith(".pt"):
278
+ ckpt = torch.load(path, map_location=DEVICE, weights_only=False)
279
+ if isinstance(ckpt, dict):
280
+ if "model" in ckpt:
281
+ return ckpt["model"]
282
+ elif "state_dict" in ckpt:
283
+ return ckpt["state_dict"]
284
+ else:
285
+ # Check if it looks like a state dict (has tensor values)
286
+ first_val = next(iter(ckpt.values()), None)
287
+ if isinstance(first_val, torch.Tensor):
288
+ return ckpt
289
+ # Otherwise might have optimizer etc, look for model keys
290
+ return ckpt
291
+ return ckpt
292
+ else:
293
+ # Try safetensors first, then pt
294
+ try:
295
+ return load_file(path)
296
+ except:
297
+ return torch.load(path, map_location=DEVICE, weights_only=False)
298
+
299
  def save_checkpoint(model, optimizer, scheduler, step, epoch, loss, path):
300
  """Save checkpoint locally."""
301
  os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
 
367
  "best" - best model
368
  int (1500) - specific step
369
  "hub:step_1000" - specific hub checkpoint
370
+ "local:/path/to/file.safetensors" or "local:/path/to/file.pt" - specific local file
371
  "none" - skip loading, start fresh
372
  """
373
  if target == "none":
 
403
 
404
  # Load based on mode
405
  if load_mode == "local":
406
+ # Direct local file (.pt or .safetensors)
407
  if os.path.exists(load_path):
408
+ weights = load_weights(load_path)
409
  model.load_state_dict(weights)
410
+
411
+ # Try to find associated state file for optimizer/scheduler
412
+ if load_path.endswith(".safetensors"):
413
+ state_path = load_path.replace(".safetensors", ".pt")
414
+ elif load_path.endswith(".pt"):
415
+ # The .pt file might contain everything
416
+ ckpt = torch.load(load_path, map_location=DEVICE, weights_only=False)
417
+ if isinstance(ckpt, dict):
418
+ # Debug: show what keys are in the checkpoint
419
+ non_tensor_keys = [k for k in ckpt.keys() if not isinstance(ckpt.get(k), torch.Tensor)]
420
+ if non_tensor_keys:
421
+ print(f" Checkpoint keys: {non_tensor_keys}")
422
+
423
+ # Extract step/epoch - try multiple common key names
424
+ start_step = ckpt.get("step", ckpt.get("global_step", ckpt.get("iteration", 0)))
425
+ start_epoch = ckpt.get("epoch", 0)
426
+
427
+ # Also check for nested state dict
428
+ if "state" in ckpt and isinstance(ckpt["state"], dict):
429
+ start_step = ckpt["state"].get("step", start_step)
430
+ start_epoch = ckpt["state"].get("epoch", start_epoch)
431
+
432
+ # Try to load optimizer/scheduler if present
433
+ if "optimizer" in ckpt:
434
+ try:
435
+ optimizer.load_state_dict(ckpt["optimizer"])
436
+ if "scheduler" in ckpt:
437
+ scheduler.load_state_dict(ckpt["scheduler"])
438
+ except Exception as e:
439
+ print(f" Note: Could not load optimizer state: {e}")
440
+ state_path = None
441
+ else:
442
+ state_path = load_path + ".pt"
443
+
444
+ if state_path and os.path.exists(state_path):
445
+ state = torch.load(state_path, map_location=DEVICE, weights_only=False)
446
+ try:
447
+ start_step = state.get("step", start_step)
448
+ start_epoch = state.get("epoch", start_epoch)
449
+ if "optimizer" in state:
450
+ optimizer.load_state_dict(state["optimizer"])
451
+ if "scheduler" in state:
452
+ scheduler.load_state_dict(state["scheduler"])
453
+ except Exception as e:
454
+ print(f" Note: Could not load optimizer state: {e}")
455
+
456
  print(f"✓ Loaded local: {load_path} (step {start_step})")
457
  return start_step, start_epoch
458
  else:
459
  print(f"⚠ Local file not found: {load_path}")
460
 
461
  elif load_mode == "hub":
462
+ # Specific hub checkpoint - try both extensions
463
+ for ext in [".safetensors", ".pt", ""]:
464
+ try:
465
+ if load_path.endswith((".safetensors", ".pt")):
466
+ filename = load_path if "/" in load_path else f"checkpoints/{load_path}"
467
+ else:
468
+ filename = f"checkpoints/{load_path}{ext}"
469
+ local_path = hf_hub_download(repo_id=HF_REPO, filename=filename)
470
+ weights = load_weights(local_path)
471
+ model.load_state_dict(weights)
472
+ # Extract step from filename
473
+ if "step_" in load_path:
474
+ start_step = int(load_path.split("step_")[-1].replace(".safetensors", "").replace(".pt", ""))
475
+ print(f"✓ Loaded from Hub: {filename} (step {start_step})")
476
+ return start_step, start_epoch
477
+ except Exception as e:
478
+ continue
479
+ print(f"⚠ Could not load from hub: {load_path}")
480
 
481
  elif load_mode == "best":
482
+ # Try hub best first (try both extensions)
483
+ for ext in [".safetensors", ".pt"]:
484
+ try:
485
+ filename = f"model{ext}" if ext else "model.safetensors"
486
+ local_path = hf_hub_download(repo_id=HF_REPO, filename=filename)
487
+ weights = load_weights(local_path)
488
+ model.load_state_dict(weights)
489
+ print(f"✓ Loaded best model from Hub")
490
+ return start_step, start_epoch
491
+ except:
492
+ continue
493
+
494
+ # Try local best (both extensions)
495
+ for ext in [".safetensors", ".pt"]:
496
+ best_path = os.path.join(CHECKPOINT_DIR, f"best{ext}")
497
+ if os.path.exists(best_path):
498
+ weights = load_weights(best_path)
499
+ model.load_state_dict(weights)
500
+ # Try to load optimizer state
501
+ state_path = best_path.replace(ext, ".pt") if ext == ".safetensors" else best_path
502
+ if os.path.exists(state_path):
503
+ state = torch.load(state_path, map_location=DEVICE, weights_only=False)
504
+ if isinstance(state, dict) and "step" in state:
505
+ start_step = state.get("step", 0)
506
+ start_epoch = state.get("epoch", 0)
507
+ print(f"✓ Loaded local best (step {start_step})")
508
+ return start_step, start_epoch
509
 
510
  elif load_mode == "step":
511
  # Specific step number
512
  step_num = load_path
513
+ # Try hub (both extensions)
514
+ for ext in [".safetensors", ".pt"]:
515
+ try:
516
+ filename = f"checkpoints/step_{step_num}{ext}"
517
+ local_path = hf_hub_download(repo_id=HF_REPO, filename=filename)
518
+ weights = load_weights(local_path)
519
+ model.load_state_dict(weights)
520
+ start_step = step_num
521
+ print(f"✓ Loaded step {step_num} from Hub")
522
+ return start_step, start_epoch
523
+ except:
524
+ continue
525
+
526
+ # Try local (both extensions)
527
+ for ext in [".safetensors", ".pt"]:
528
+ local_path = os.path.join(CHECKPOINT_DIR, f"step_{step_num}{ext}")
529
+ if os.path.exists(local_path):
530
+ weights = load_weights(local_path)
531
+ model.load_state_dict(weights)
532
+ state_path = local_path.replace(".safetensors", ".pt") if ext == ".safetensors" else local_path
533
+ if os.path.exists(state_path):
534
+ state = torch.load(state_path, map_location=DEVICE, weights_only=False)
535
+ if isinstance(state, dict):
536
+ try:
537
+ if "optimizer" in state:
538
+ optimizer.load_state_dict(state["optimizer"])
539
+ if "scheduler" in state:
540
+ scheduler.load_state_dict(state["scheduler"])
541
+ start_epoch = state.get("epoch", 0)
542
+ except:
543
+ pass
544
+ start_step = step_num
545
+ print(f"✓ Loaded local step {step_num}")
546
+ return start_step, start_epoch
547
  print(f"⚠ Step {step_num} not found")
548
 
549
  # Default: latest
550
+ # Try Hub first (both extensions)
551
  try:
552
  files = api.list_repo_files(repo_id=HF_REPO)
553
+ checkpoints = [f for f in files if f.startswith("checkpoints/step_") and (f.endswith(".safetensors") or f.endswith(".pt"))]
554
  if checkpoints:
555
+ # Sort by step number
556
+ def get_step(f):
557
+ return int(f.split("step_")[-1].replace(".safetensors", "").replace(".pt", ""))
558
+ checkpoints.sort(key=get_step)
559
  latest = checkpoints[-1]
560
+ step = get_step(latest)
561
  local_path = hf_hub_download(repo_id=HF_REPO, filename=latest)
562
+ weights = load_weights(local_path)
563
  model.load_state_dict(weights)
564
  start_step = step
565
  print(f"✓ Loaded latest from Hub: step {step}")
 
567
  except Exception as e:
568
  print(f"Hub check: {e}")
569
 
570
+ # Try local (both extensions)
571
  if os.path.exists(CHECKPOINT_DIR):
572
+ local_ckpts = [f for f in os.listdir(CHECKPOINT_DIR) if f.startswith("step_") and (f.endswith(".safetensors") or f.endswith(".pt"))]
573
+ # Filter to just weights files (not state .pt files that pair with .safetensors)
574
+ local_ckpts = [f for f in local_ckpts if not (f.endswith(".pt") and f.replace(".pt", ".safetensors") in local_ckpts)]
575
  if local_ckpts:
576
+ def get_step(f):
577
+ return int(f.split("step_")[-1].replace(".safetensors", "").replace(".pt", ""))
578
+ local_ckpts.sort(key=get_step)
579
  latest = local_ckpts[-1]
580
+ step = get_step(latest)
581
  weights_path = os.path.join(CHECKPOINT_DIR, latest)
582
+ weights = load_weights(weights_path)
583
  model.load_state_dict(weights)
584
+ # Try to load optimizer state
585
+ state_path = weights_path.replace(".safetensors", ".pt") if weights_path.endswith(".safetensors") else weights_path
586
  if os.path.exists(state_path):
587
+ state = torch.load(state_path, map_location=DEVICE, weights_only=False)
588
+ if isinstance(state, dict):
589
+ try:
590
+ if "optimizer" in state:
591
+ optimizer.load_state_dict(state["optimizer"])
592
+ if "scheduler" in state:
593
+ scheduler.load_state_dict(state["scheduler"])
594
+ start_epoch = state.get("epoch", 0)
595
+ except:
596
+ pass
597
  start_step = step
598
  print(f"✓ Loaded latest local: step {step}")
599
  return start_step, start_epoch
 
612
  config = TinyFluxConfig()
613
  model = TinyFlux(config).to(DEVICE).to(DTYPE)
614
  print(f"\nParams: {sum(p.numel() for p in model.parameters()):,}")
615
+ model = torch.compile(model, mode="default")
616
 
617
  # ============================================================================
618
  # OPTIMIZER & SCHEDULER
 
633
  print(f"\nLoad target: {LOAD_TARGET}")
634
  start_step, start_epoch = load_checkpoint(model, opt, sched, LOAD_TARGET)
635
 
636
+ # Override start_step if RESUME_STEP is set
637
+ if RESUME_STEP is not None:
638
+ print(f"Overriding start_step: {start_step} -> {RESUME_STEP}")
639
+ start_step = RESUME_STEP
640
+
641
  # Log config to tensorboard
642
  writer.add_text("config", json.dumps(config.__dict__, indent=2), 0)
643
  writer.add_text("training_config", json.dumps({
 
676
  pbar = tqdm(loader, desc=f"E{ep+1}")
677
 
678
  for i, batch in enumerate(pbar):
679
+ latents = batch["latents"] # Ground truth data (VAE encoded images)
680
  t5 = batch["t5_embeds"]
681
  clip = batch["clip_pooled"]
682
 
683
+ B, C, H, W = latents.shape
 
 
684
 
685
+ # ================================================================
686
+ # FLOW MATCHING FORMULATION
687
+ # ================================================================
688
+ # x_1 = data (what we want to generate)
689
+ # x_0 = noise (where we start at inference)
690
+ # x_t = (1-t)*x_0 + t*x_1 (linear interpolation)
691
+ #
692
+ # At t=0: x_t = x_0 (pure noise)
693
+ # At t=1: x_t = x_1 (pure data)
694
+ #
695
+ # Velocity field: v = dx/dt = x_1 - x_0
696
+ # Model learns to predict v given (x_t, t)
697
+ #
698
+ # At inference: start from noise, integrate v from t=0 to t=1
699
+ # ================================================================
700
+
701
+ # Reshape data to sequence format: (B, C, H, W) -> (B, H*W, C)
702
+ data = latents.permute(0, 2, 3, 1).reshape(B, H*W, C) # x_1
703
+ noise = torch.randn_like(data) # x_0
704
+
705
+ # Sample timesteps with logit-normal distribution + Flux shift
706
+ # This biases training towards higher t (closer to data)
707
  t = torch.sigmoid(torch.randn(B, device=DEVICE))
708
+ t = flux_shift(t, s=SHIFT).to(DTYPE).clamp(1e-4, 1-1e-4)
709
+
710
+ # Create noisy samples via linear interpolation
711
+ t_expanded = t.view(B, 1, 1)
712
+ x_t = (1 - t_expanded) * noise + t_expanded * data # Noisy sample at time t
713
 
714
+ # Target velocity: direction from noise to data
715
+ v_target = data - noise
 
716
 
717
+ # Create position IDs for RoPE
718
  img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE)
 
719
 
720
+ # Random guidance scale (for CFG training)
721
+ guidance = torch.rand(B, device=DEVICE, dtype=DTYPE) * 4 + 1 # [1, 5]
722
+
723
+ # Forward pass: predict velocity
724
  with torch.autocast("cuda", dtype=DTYPE):
725
+ v_pred = model(
726
  hidden_states=x_t,
727
  encoder_hidden_states=t5,
728
  pooled_projections=clip,
 
731
  guidance=guidance,
732
  )
733
 
734
+ # Loss: MSE between predicted and target velocity
735
+ loss_raw = F.mse_loss(v_pred, v_target, reduction="none").mean(dim=[1, 2])
736
+
737
+ # Min-SNR weighting: downweight easy timesteps (near t=0 or t=1)
738
  snr_weights = min_snr_weight(t)
739
  loss = (loss_raw * snr_weights).mean() / GRAD_ACCUM
740
  loss.backward()