AbstractPhil commited on
Commit
14c3b94
·
verified ·
1 Parent(s): e4ba6b1

Create trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +653 -0
trainer.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================================
2
+ # TinyFlux Training Cell - Full Featured
3
+ # ============================================================================
4
+ # Run the model cell before this one (defines TinyFlux, TinyFluxConfig)
5
+ # Dataset: AbstractPhil/flux-schnell-teacher-latents
6
+ # Uploads checkpoints to: AbstractPhil/tiny-flux
7
+ # ============================================================================
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch.utils.data import DataLoader
12
+ from datasets import load_dataset
13
+ from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer
14
+ from huggingface_hub import HfApi, hf_hub_download
15
+ from safetensors.torch import save_file, load_file
16
+ from torch.utils.tensorboard import SummaryWriter
17
+ from tqdm.auto import tqdm
18
+ import numpy as np
19
+ import math
20
+ import os
21
+ import json
22
+ from datetime import datetime
23
+
24
+ # ============================================================================
25
+ # CONFIG
26
+ # ============================================================================
27
+ BATCH_SIZE = 4
28
+ GRAD_ACCUM = 2
29
+ LR = 1e-4
30
+ EPOCHS = 10
31
+ MAX_SEQ = 128
32
+ MIN_SNR = 5.0
33
+ SHIFT = 3.0
34
+ DEVICE = "cuda"
35
+ DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
36
+
37
+ # HuggingFace Hub
38
+ HF_REPO = "AbstractPhil/tiny-flux"
39
+ SAVE_EVERY = 500 # steps - local save
40
+ UPLOAD_EVERY = 500 # steps - hub upload
41
+ SAMPLE_EVERY = 250 # steps - generate samples
42
+ LOG_EVERY = 10 # steps - tensorboard
43
+
44
+ # Checkpoint loading target
45
+ # Options:
46
+ # None or "latest" - load most recent checkpoint
47
+ # "best" - load best model
48
+ # int (e.g. 1500) - load specific step
49
+ # "hub:step_1000" - load specific checkpoint from hub
50
+ # "local:path/to/checkpoint.safetensors" - load specific local file
51
+ # "none" - start fresh, ignore existing checkpoints
52
+ LOAD_TARGET = "latest"
53
+
54
+ # Local paths
55
+ CHECKPOINT_DIR = "./tiny_flux_checkpoints"
56
+ LOG_DIR = "./tiny_flux_logs"
57
+ SAMPLE_DIR = "./tiny_flux_samples"
58
+
59
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
60
+ os.makedirs(LOG_DIR, exist_ok=True)
61
+ os.makedirs(SAMPLE_DIR, exist_ok=True)
62
+
63
+ # ============================================================================
64
+ # HF HUB SETUP
65
+ # ============================================================================
66
+ print("Setting up HuggingFace Hub...")
67
+ api = HfApi()
68
+
69
+ try:
70
+ api.create_repo(repo_id=HF_REPO, exist_ok=True, repo_type="model")
71
+ print(f"✓ Repo ready: {HF_REPO}")
72
+ except Exception as e:
73
+ print(f"Note: {e}")
74
+
75
+ # ============================================================================
76
+ # TENSORBOARD
77
+ # ============================================================================
78
+ run_name = datetime.now().strftime("%Y%m%d_%H%M%S")
79
+ writer = SummaryWriter(log_dir=os.path.join(LOG_DIR, run_name))
80
+ print(f"✓ Tensorboard: {LOG_DIR}/{run_name}")
81
+
82
+ # ============================================================================
83
+ # LOAD DATASET
84
+ # ============================================================================
85
+ print("\nLoading dataset...")
86
+ ds = load_dataset("AbstractPhil/flux-schnell-teacher-latents", split="train")
87
+ print(f"Samples: {len(ds)}")
88
+
89
+ # ============================================================================
90
+ # LOAD TEXT ENCODERS
91
+ # ============================================================================
92
+ print("\nLoading flan-t5-base (768 dim)...")
93
+ t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
94
+ t5_enc = T5EncoderModel.from_pretrained("google/flan-t5-base", torch_dtype=DTYPE).to(DEVICE).eval()
95
+
96
+ print("Loading CLIP-L...")
97
+ clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
98
+ clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval()
99
+
100
+ for p in t5_enc.parameters(): p.requires_grad = False
101
+ for p in clip_enc.parameters(): p.requires_grad = False
102
+
103
+ # ============================================================================
104
+ # LOAD VAE FOR SAMPLE GENERATION
105
+ # ============================================================================
106
+ print("Loading Flux VAE for samples...")
107
+ from diffusers import AutoencoderKL
108
+ vae = AutoencoderKL.from_pretrained(
109
+ "black-forest-labs/FLUX.1-schnell",
110
+ subfolder="vae",
111
+ torch_dtype=DTYPE
112
+ ).to(DEVICE).eval()
113
+ for p in vae.parameters(): p.requires_grad = False
114
+
115
+ # ============================================================================
116
+ # ENCODING HELPERS
117
+ # ============================================================================
118
+ @torch.no_grad()
119
+ def encode_prompt(prompt):
120
+ t5_in = t5_tok(prompt, max_length=MAX_SEQ, padding="max_length", truncation=True, return_tensors="pt").to(DEVICE)
121
+ t5_out = t5_enc(input_ids=t5_in.input_ids, attention_mask=t5_in.attention_mask).last_hidden_state
122
+
123
+ clip_in = clip_tok(prompt, max_length=77, padding="max_length", truncation=True, return_tensors="pt").to(DEVICE)
124
+ clip_out = clip_enc(input_ids=clip_in.input_ids, attention_mask=clip_in.attention_mask)
125
+ return t5_out, clip_out.pooler_output
126
+
127
+ # ============================================================================
128
+ # FLOW MATCHING HELPERS
129
+ # ============================================================================
130
+ def flux_shift(t, s=SHIFT):
131
+ return s * t / (1 + (s - 1) * t)
132
+
133
+ def flux_shift_inverse(t_shifted, s=SHIFT):
134
+ """Inverse of flux_shift for sampling."""
135
+ return t_shifted / (s - (s - 1) * t_shifted)
136
+
137
+ def min_snr_weight(t, gamma=MIN_SNR):
138
+ snr = (t / (1 - t).clamp(min=1e-5)).pow(2)
139
+ return torch.clamp(snr, max=gamma) / snr.clamp(min=1e-5)
140
+
141
+ # ============================================================================
142
+ # SAMPLING FUNCTION
143
+ # ============================================================================
144
+ @torch.no_grad()
145
+ def generate_samples(model, prompts, num_steps=20, guidance_scale=3.5, H=64, W=64):
146
+ """Generate sample images using Euler sampling."""
147
+ model.eval()
148
+ B = len(prompts)
149
+ C = 16 # VAE channels
150
+
151
+ # Encode prompts
152
+ t5_embeds, clip_pooleds = [], []
153
+ for p in prompts:
154
+ t5_out, clip_pooled = encode_prompt(p)
155
+ t5_embeds.append(t5_out.squeeze(0))
156
+ clip_pooleds.append(clip_pooled.squeeze(0))
157
+ t5_embeds = torch.stack(t5_embeds)
158
+ clip_pooleds = torch.stack(clip_pooleds)
159
+
160
+ # Start from noise
161
+ x = torch.randn(B, H * W, C, device=DEVICE, dtype=DTYPE)
162
+
163
+ # Create image IDs
164
+ img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE)
165
+
166
+ # Euler sampling with uniform timesteps
167
+ timesteps = torch.linspace(1, 0, num_steps + 1, device=DEVICE)[:-1]
168
+
169
+ for i, t in enumerate(timesteps):
170
+ t_batch = t.expand(B)
171
+ dt = 1.0 / num_steps
172
+
173
+ # Conditional prediction
174
+ guidance = torch.full((B,), guidance_scale, device=DEVICE, dtype=DTYPE)
175
+ v_cond = model(
176
+ hidden_states=x,
177
+ encoder_hidden_states=t5_embeds,
178
+ pooled_projections=clip_pooleds,
179
+ timestep=t_batch,
180
+ img_ids=img_ids,
181
+ guidance=guidance,
182
+ )
183
+
184
+ # Euler step: x = x + v * dt (going from noise to data)
185
+ x = x + v_cond * dt
186
+
187
+ # Reshape to image format: (B, H*W, C) -> (B, C, H, W)
188
+ latents = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
189
+
190
+ # Decode with VAE
191
+ latents = latents / vae.config.scaling_factor
192
+ images = vae.decode(latents.float()).sample
193
+ images = (images / 2 + 0.5).clamp(0, 1)
194
+
195
+ model.train()
196
+ return images
197
+
198
+ def save_samples(images, prompts, step, save_dir):
199
+ """Save sample images and log to tensorboard."""
200
+ from torchvision.utils import make_grid, save_image
201
+
202
+ # Save individual images
203
+ for i, (img, prompt) in enumerate(zip(images, prompts)):
204
+ safe_prompt = prompt[:50].replace(" ", "_").replace("/", "-")
205
+ path = os.path.join(save_dir, f"step{step}_{i}_{safe_prompt}.png")
206
+ save_image(img, path)
207
+
208
+ # Log grid to tensorboard
209
+ grid = make_grid(images, nrow=2, normalize=False)
210
+ writer.add_image("samples", grid, step)
211
+
212
+ # Log prompts
213
+ writer.add_text("sample_prompts", "\n".join(prompts), step)
214
+
215
+ print(f" ✓ Saved {len(images)} samples")
216
+
217
+ # ============================================================================
218
+ # COLLATE
219
+ # ============================================================================
220
+ def collate(batch):
221
+ latents, t5_embeds, clip_embeds, prompts = [], [], [], []
222
+ for b in batch:
223
+ latents.append(torch.tensor(np.array(b["latent"]), dtype=DTYPE))
224
+ t5_out, clip_pooled = encode_prompt(b["prompt"])
225
+ t5_embeds.append(t5_out.squeeze(0))
226
+ clip_embeds.append(clip_pooled.squeeze(0))
227
+ prompts.append(b["prompt"])
228
+ return {
229
+ "latents": torch.stack(latents).to(DEVICE),
230
+ "t5_embeds": torch.stack(t5_embeds),
231
+ "clip_pooled": torch.stack(clip_embeds),
232
+ "prompts": prompts,
233
+ }
234
+
235
+ # ============================================================================
236
+ # CHECKPOINT FUNCTIONS
237
+ # ============================================================================
238
+ def save_checkpoint(model, optimizer, scheduler, step, epoch, loss, path):
239
+ """Save checkpoint locally."""
240
+ os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
241
+
242
+ weights_path = path.replace(".pt", ".safetensors")
243
+ save_file(model.state_dict(), weights_path)
244
+
245
+ state = {
246
+ "step": step,
247
+ "epoch": epoch,
248
+ "loss": loss,
249
+ "optimizer": optimizer.state_dict(),
250
+ "scheduler": scheduler.state_dict(),
251
+ }
252
+ torch.save(state, path)
253
+ print(f" ✓ Saved checkpoint: step {step}")
254
+ return weights_path
255
+
256
+ def upload_checkpoint(weights_path, step, config, include_logs=True):
257
+ """Upload checkpoint to HuggingFace Hub."""
258
+ try:
259
+ # Upload weights
260
+ api.upload_file(
261
+ path_or_fileobj=weights_path,
262
+ path_in_repo=f"checkpoints/step_{step}.safetensors",
263
+ repo_id=HF_REPO,
264
+ commit_message=f"Checkpoint step {step}",
265
+ )
266
+
267
+ # Upload config
268
+ config_path = os.path.join(CHECKPOINT_DIR, "config.json")
269
+ with open(config_path, "w") as f:
270
+ json.dump(config.__dict__, f, indent=2)
271
+ api.upload_file(
272
+ path_or_fileobj=config_path,
273
+ path_in_repo="config.json",
274
+ repo_id=HF_REPO,
275
+ )
276
+
277
+ # Upload tensorboard logs
278
+ if include_logs and os.path.exists(LOG_DIR):
279
+ api.upload_folder(
280
+ folder_path=LOG_DIR,
281
+ path_in_repo="logs",
282
+ repo_id=HF_REPO,
283
+ commit_message=f"Logs at step {step}",
284
+ )
285
+
286
+ # Upload samples
287
+ if os.path.exists(SAMPLE_DIR) and os.listdir(SAMPLE_DIR):
288
+ api.upload_folder(
289
+ folder_path=SAMPLE_DIR,
290
+ path_in_repo="samples",
291
+ repo_id=HF_REPO,
292
+ commit_message=f"Samples at step {step}",
293
+ )
294
+
295
+ print(f" ✓ Uploaded to {HF_REPO}")
296
+ except Exception as e:
297
+ print(f" ⚠ Upload failed: {e}")
298
+
299
+ def load_checkpoint(model, optimizer, scheduler, target):
300
+ """
301
+ Load checkpoint based on target specification.
302
+
303
+ Args:
304
+ target:
305
+ None, "latest" - most recent checkpoint
306
+ "best" - best model
307
+ int (1500) - specific step
308
+ "hub:step_1000" - specific hub checkpoint
309
+ "local:/path/to/file.safetensors" - specific local file
310
+ "none" - skip loading, start fresh
311
+ """
312
+ if target == "none":
313
+ print("Starting fresh (no checkpoint loading)")
314
+ return 0, 0
315
+
316
+ start_step, start_epoch = 0, 0
317
+
318
+ # Parse target
319
+ if target is None or target == "latest":
320
+ load_mode = "latest"
321
+ load_path = None
322
+ elif target == "best":
323
+ load_mode = "best"
324
+ load_path = None
325
+ elif isinstance(target, int):
326
+ load_mode = "step"
327
+ load_path = target
328
+ elif target.startswith("hub:"):
329
+ load_mode = "hub"
330
+ load_path = target[4:] # Remove "hub:" prefix
331
+ elif target.startswith("local:"):
332
+ load_mode = "local"
333
+ load_path = target[6:] # Remove "local:" prefix
334
+ else:
335
+ print(f"Unknown target format: {target}, trying as step number")
336
+ try:
337
+ load_mode = "step"
338
+ load_path = int(target)
339
+ except:
340
+ load_mode = "latest"
341
+ load_path = None
342
+
343
+ # Load based on mode
344
+ if load_mode == "local":
345
+ # Direct local file
346
+ if os.path.exists(load_path):
347
+ weights = load_file(load_path)
348
+ model.load_state_dict(weights)
349
+ # Try to find associated state file
350
+ state_path = load_path.replace(".safetensors", ".pt")
351
+ if os.path.exists(state_path):
352
+ state = torch.load(state_path, weights_only=False)
353
+ optimizer.load_state_dict(state["optimizer"])
354
+ scheduler.load_state_dict(state["scheduler"])
355
+ start_step = state.get("step", 0)
356
+ start_epoch = state.get("epoch", 0)
357
+ print(f"✓ Loaded local: {load_path} (step {start_step})")
358
+ return start_step, start_epoch
359
+ else:
360
+ print(f"⚠ Local file not found: {load_path}")
361
+
362
+ elif load_mode == "hub":
363
+ # Specific hub checkpoint
364
+ try:
365
+ filename = f"checkpoints/{load_path}.safetensors" if not load_path.endswith(".safetensors") else load_path
366
+ local_path = hf_hub_download(repo_id=HF_REPO, filename=filename)
367
+ weights = load_file(local_path)
368
+ model.load_state_dict(weights)
369
+ # Extract step from filename
370
+ if "step_" in load_path:
371
+ start_step = int(load_path.split("step_")[-1].replace(".safetensors", ""))
372
+ print(f"✓ Loaded from Hub: {filename} (step {start_step})")
373
+ return start_step, start_epoch
374
+ except Exception as e:
375
+ print(f"⚠ Hub load failed: {e}")
376
+
377
+ elif load_mode == "best":
378
+ # Try hub best first
379
+ try:
380
+ local_path = hf_hub_download(repo_id=HF_REPO, filename="model.safetensors")
381
+ weights = load_file(local_path)
382
+ model.load_state_dict(weights)
383
+ print(f"✓ Loaded best model from Hub")
384
+ return start_step, start_epoch
385
+ except:
386
+ pass
387
+ # Try local best
388
+ best_path = os.path.join(CHECKPOINT_DIR, "best.safetensors")
389
+ if os.path.exists(best_path):
390
+ weights = load_file(best_path)
391
+ model.load_state_dict(weights)
392
+ state_path = best_path.replace(".safetensors", ".pt")
393
+ if os.path.exists(state_path):
394
+ state = torch.load(state_path, weights_only=False)
395
+ start_step = state.get("step", 0)
396
+ start_epoch = state.get("epoch", 0)
397
+ print(f"✓ Loaded local best (step {start_step})")
398
+ return start_step, start_epoch
399
+
400
+ elif load_mode == "step":
401
+ # Specific step number
402
+ step_num = load_path
403
+ # Try hub
404
+ try:
405
+ filename = f"checkpoints/step_{step_num}.safetensors"
406
+ local_path = hf_hub_download(repo_id=HF_REPO, filename=filename)
407
+ weights = load_file(local_path)
408
+ model.load_state_dict(weights)
409
+ start_step = step_num
410
+ print(f"✓ Loaded step {step_num} from Hub")
411
+ return start_step, start_epoch
412
+ except:
413
+ pass
414
+ # Try local
415
+ local_path = os.path.join(CHECKPOINT_DIR, f"step_{step_num}.safetensors")
416
+ if os.path.exists(local_path):
417
+ weights = load_file(local_path)
418
+ model.load_state_dict(weights)
419
+ state_path = local_path.replace(".safetensors", ".pt")
420
+ if os.path.exists(state_path):
421
+ state = torch.load(state_path, weights_only=False)
422
+ optimizer.load_state_dict(state["optimizer"])
423
+ scheduler.load_state_dict(state["scheduler"])
424
+ start_epoch = state.get("epoch", 0)
425
+ start_step = step_num
426
+ print(f"✓ Loaded local step {step_num}")
427
+ return start_step, start_epoch
428
+ print(f"⚠ Step {step_num} not found")
429
+
430
+ # Default: latest
431
+ # Try Hub first
432
+ try:
433
+ files = api.list_repo_files(repo_id=HF_REPO)
434
+ checkpoints = [f for f in files if f.startswith("checkpoints/step_") and f.endswith(".safetensors")]
435
+ if checkpoints:
436
+ checkpoints.sort(key=lambda x: int(x.split("step_")[-1].replace(".safetensors", "")))
437
+ latest = checkpoints[-1]
438
+ step = int(latest.split("step_")[-1].replace(".safetensors", ""))
439
+ local_path = hf_hub_download(repo_id=HF_REPO, filename=latest)
440
+ weights = load_file(local_path)
441
+ model.load_state_dict(weights)
442
+ start_step = step
443
+ print(f"✓ Loaded latest from Hub: step {step}")
444
+ return start_step, start_epoch
445
+ except Exception as e:
446
+ print(f"Hub check: {e}")
447
+
448
+ # Try local
449
+ if os.path.exists(CHECKPOINT_DIR):
450
+ local_ckpts = [f for f in os.listdir(CHECKPOINT_DIR) if f.startswith("step_") and f.endswith(".safetensors")]
451
+ if local_ckpts:
452
+ local_ckpts.sort(key=lambda x: int(x.split("step_")[-1].replace(".safetensors", "")))
453
+ latest = local_ckpts[-1]
454
+ step = int(latest.split("step_")[-1].replace(".safetensors", ""))
455
+ weights_path = os.path.join(CHECKPOINT_DIR, latest)
456
+ weights = load_file(weights_path)
457
+ model.load_state_dict(weights)
458
+ state_path = weights_path.replace(".safetensors", ".pt")
459
+ if os.path.exists(state_path):
460
+ state = torch.load(state_path, weights_only=False)
461
+ optimizer.load_state_dict(state["optimizer"])
462
+ scheduler.load_state_dict(state["scheduler"])
463
+ start_epoch = state.get("epoch", 0)
464
+ start_step = step
465
+ print(f"✓ Loaded latest local: step {step}")
466
+ return start_step, start_epoch
467
+
468
+ print("No checkpoint found, starting fresh")
469
+ return 0, 0
470
+
471
+ # ============================================================================
472
+ # DATALOADER
473
+ # ============================================================================
474
+ loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate, num_workers=0)
475
+
476
+ # ============================================================================
477
+ # MODEL
478
+ # ============================================================================
479
+ config = TinyFluxConfig()
480
+ model = TinyFlux(config).to(DEVICE).to(DTYPE)
481
+ print(f"\nParams: {sum(p.numel() for p in model.parameters()):,}")
482
+
483
+ # ============================================================================
484
+ # OPTIMIZER & SCHEDULER
485
+ # ============================================================================
486
+ opt = torch.optim.AdamW(model.parameters(), lr=LR, betas=(0.9, 0.99), weight_decay=0.01)
487
+ total_steps = len(loader) * EPOCHS // GRAD_ACCUM
488
+ warmup = min(500, total_steps // 10)
489
+
490
+ def lr_fn(step):
491
+ if step < warmup: return step / warmup
492
+ return 0.5 * (1 + math.cos(math.pi * (step - warmup) / (total_steps - warmup)))
493
+
494
+ sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_fn)
495
+
496
+ # ============================================================================
497
+ # LOAD CHECKPOINT
498
+ # ============================================================================
499
+ print(f"\nLoad target: {LOAD_TARGET}")
500
+ start_step, start_epoch = load_checkpoint(model, opt, sched, LOAD_TARGET)
501
+
502
+ # Log config to tensorboard
503
+ writer.add_text("config", json.dumps(config.__dict__, indent=2), 0)
504
+ writer.add_text("training_config", json.dumps({
505
+ "batch_size": BATCH_SIZE,
506
+ "grad_accum": GRAD_ACCUM,
507
+ "lr": LR,
508
+ "epochs": EPOCHS,
509
+ "min_snr": MIN_SNR,
510
+ "shift": SHIFT,
511
+ }, indent=2), 0)
512
+
513
+ # ============================================================================
514
+ # SAMPLE PROMPTS FOR PERIODIC GENERATION
515
+ # ============================================================================
516
+ SAMPLE_PROMPTS = [
517
+ "a photo of a cat sitting on a windowsill",
518
+ "a beautiful sunset over mountains",
519
+ "a portrait of a woman with red hair",
520
+ "a futuristic cityscape at night",
521
+ ]
522
+
523
+ # ============================================================================
524
+ # TRAINING
525
+ # ============================================================================
526
+ print(f"\nTraining {EPOCHS} epochs, {total_steps} total steps")
527
+ print(f"Resuming from step {start_step}, epoch {start_epoch}")
528
+ print(f"Save: {SAVE_EVERY}, Upload: {UPLOAD_EVERY}, Sample: {SAMPLE_EVERY}, Log: {LOG_EVERY}")
529
+
530
+ model.train()
531
+ step = start_step
532
+ best = float("inf")
533
+
534
+ for ep in range(start_epoch, EPOCHS):
535
+ ep_loss = 0
536
+ ep_batches = 0
537
+ pbar = tqdm(loader, desc=f"E{ep+1}")
538
+
539
+ for i, batch in enumerate(pbar):
540
+ lat = batch["latents"]
541
+ t5 = batch["t5_embeds"]
542
+ clip = batch["clip_pooled"]
543
+
544
+ B, C, H, W = lat.shape
545
+ x1 = lat.permute(0, 2, 3, 1).reshape(B, H*W, C)
546
+ x0 = torch.randn_like(x1)
547
+
548
+ t = torch.sigmoid(torch.randn(B, device=DEVICE))
549
+ t = flux_shift(t).to(DTYPE).clamp(1e-4, 1-1e-4)
550
+
551
+ t_exp = t.view(B, 1, 1)
552
+ x_t = (1 - t_exp) * x0 + t_exp * x1
553
+ v_target = x1 - x0
554
+
555
+ img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE)
556
+ guidance = torch.rand(B, device=DEVICE, dtype=DTYPE) * 4 + 1
557
+
558
+ with torch.autocast("cuda", dtype=DTYPE):
559
+ pred = model(
560
+ hidden_states=x_t,
561
+ encoder_hidden_states=t5,
562
+ pooled_projections=clip,
563
+ timestep=t,
564
+ img_ids=img_ids,
565
+ guidance=guidance,
566
+ )
567
+
568
+ loss_raw = F.mse_loss(pred, v_target, reduction="none").mean(dim=[1,2])
569
+ snr_weights = min_snr_weight(t)
570
+ loss = (loss_raw * snr_weights).mean() / GRAD_ACCUM
571
+ loss.backward()
572
+
573
+ if (i + 1) % GRAD_ACCUM == 0:
574
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
575
+ opt.step()
576
+ sched.step()
577
+ opt.zero_grad()
578
+ step += 1
579
+
580
+ # Tensorboard logging
581
+ if step % LOG_EVERY == 0:
582
+ writer.add_scalar("train/loss", loss.item() * GRAD_ACCUM, step)
583
+ writer.add_scalar("train/lr", sched.get_last_lr()[0], step)
584
+ writer.add_scalar("train/grad_norm", grad_norm.item(), step)
585
+ writer.add_scalar("train/t_mean", t.mean().item(), step)
586
+ writer.add_scalar("train/snr_weight_mean", snr_weights.mean().item(), step)
587
+
588
+ # Generate samples
589
+ if step % SAMPLE_EVERY == 0:
590
+ print(f"\n Generating samples at step {step}...")
591
+ images = generate_samples(model, SAMPLE_PROMPTS, num_steps=20)
592
+ save_samples(images, SAMPLE_PROMPTS, step, SAMPLE_DIR)
593
+
594
+ # Save checkpoint
595
+ if step % SAVE_EVERY == 0:
596
+ ckpt_path = os.path.join(CHECKPOINT_DIR, f"step_{step}.pt")
597
+ weights_path = save_checkpoint(model, opt, sched, step, ep, loss.item(), ckpt_path)
598
+
599
+ # Upload
600
+ if step % UPLOAD_EVERY == 0:
601
+ upload_checkpoint(weights_path, step, config, include_logs=True)
602
+
603
+ ep_loss += loss.item() * GRAD_ACCUM
604
+ ep_batches += 1
605
+ pbar.set_postfix(loss=f"{loss.item()*GRAD_ACCUM:.4f}", lr=f"{sched.get_last_lr()[0]:.1e}", step=step)
606
+
607
+ avg = ep_loss / max(ep_batches, 1)
608
+ print(f"Epoch {ep+1} loss: {avg:.4f}")
609
+ writer.add_scalar("train/epoch_loss", avg, ep + 1)
610
+
611
+ if avg < best:
612
+ best = avg
613
+ best_path = os.path.join(CHECKPOINT_DIR, "best.pt")
614
+ weights_path = save_checkpoint(model, opt, sched, step, ep, avg, best_path)
615
+
616
+ try:
617
+ api.upload_file(
618
+ path_or_fileobj=weights_path,
619
+ path_in_repo="model.safetensors",
620
+ repo_id=HF_REPO,
621
+ commit_message=f"Best model (epoch {ep+1}, loss {avg:.4f})",
622
+ )
623
+ print(f" ✓ Uploaded best to {HF_REPO}")
624
+ except Exception as e:
625
+ print(f" ⚠ Upload failed: {e}")
626
+
627
+ # ============================================================================
628
+ # FINAL
629
+ # ============================================================================
630
+ print("\nSaving final model...")
631
+ final_path = os.path.join(CHECKPOINT_DIR, "final.pt")
632
+ weights_path = save_checkpoint(model, opt, sched, step, EPOCHS, best, final_path)
633
+
634
+ # Final samples
635
+ print("Generating final samples...")
636
+ images = generate_samples(model, SAMPLE_PROMPTS, num_steps=20)
637
+ save_samples(images, SAMPLE_PROMPTS, step, SAMPLE_DIR)
638
+
639
+ # Final upload
640
+ try:
641
+ api.upload_file(path_or_fileobj=weights_path, path_in_repo="model.safetensors", repo_id=HF_REPO)
642
+ config_path = os.path.join(CHECKPOINT_DIR, "config.json")
643
+ with open(config_path, "w") as f:
644
+ json.dump(config.__dict__, f, indent=2)
645
+ api.upload_file(path_or_fileobj=config_path, path_in_repo="config.json", repo_id=HF_REPO)
646
+ api.upload_folder(folder_path=LOG_DIR, path_in_repo="logs", repo_id=HF_REPO)
647
+ api.upload_folder(folder_path=SAMPLE_DIR, path_in_repo="samples", repo_id=HF_REPO)
648
+ print(f"\n✓ Training complete! https://huggingface.co/{HF_REPO}")
649
+ except Exception as e:
650
+ print(f"\n⚠ Final upload failed: {e}")
651
+
652
+ writer.close()
653
+ print(f"Best loss: {best:.4f}")