AbstractPhil commited on
Commit
56dce31
·
verified ·
1 Parent(s): 457e2ff

Create trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +600 -0
trainer.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================================
2
+ # TinyFlux-Deep Training Cell
3
+ # ============================================================================
4
+ # Trains the deep variant with frozen ported layers
5
+ # Config: 25 single blocks, 15 double blocks, 4 attention heads
6
+ # hidden_size: 512 (4 heads * 128 head_dim)
7
+ # Repo: AbstractPhil/tiny-flux-deep
8
+ #
9
+ # USAGE: Run model.py cell first, then this cell
10
+ # ============================================================================
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from torch.utils.data import DataLoader
16
+ from datasets import load_dataset
17
+ from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer
18
+ from huggingface_hub import HfApi, hf_hub_download
19
+ from safetensors.torch import save_file, load_file
20
+ from torch.utils.tensorboard import SummaryWriter
21
+ from tqdm.auto import tqdm
22
+ import numpy as np
23
+ import math
24
+ from typing import Tuple, Optional, Dict
25
+ import os
26
+ from datetime import datetime
27
+ from dataclasses import dataclass
28
+
29
+ # ============================================================================
30
+ # CUDA OPTIMIZATIONS
31
+ # ============================================================================
32
+ torch.backends.cuda.matmul.allow_tf32 = True
33
+ torch.backends.cudnn.allow_tf32 = True
34
+ torch.backends.cudnn.benchmark = True
35
+ torch.set_float32_matmul_precision('high')
36
+
37
+ import warnings
38
+ warnings.filterwarnings('ignore', message='.*TF32.*')
39
+
40
+ # ============================================================================
41
+ # CONFIG
42
+ # ============================================================================
43
+ BATCH_SIZE = 16
44
+ GRAD_ACCUM = 2 # Effective batch = 32
45
+ LR = 5e-5 # Lower LR for fine-tuning frozen model
46
+ EPOCHS = 10
47
+ MAX_SEQ = 128
48
+ MIN_SNR = 5.0
49
+ SHIFT = 3.0
50
+ DEVICE = "cuda"
51
+ DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
52
+
53
+ # HuggingFace Hub
54
+ HF_REPO = "AbstractPhil/tiny-flux-deep"
55
+ SAVE_EVERY = 500
56
+ UPLOAD_EVERY = 500
57
+ SAMPLE_EVERY = 250
58
+ LOG_EVERY = 10
59
+
60
+ # Checkpoint loading
61
+ LOAD_TARGET = "hub" # "hub", "latest", "best", "none"
62
+ RESUME_STEP = None
63
+
64
+ # Dataset
65
+ DATASET_REPO = "AbstractPhil/flux-schnell-teacher-latents"
66
+ DATASET_CONFIG = "train_simple_512"
67
+
68
+ # Paths
69
+ CHECKPOINT_DIR = "./tiny_flux_deep_checkpoints"
70
+ LOG_DIR = "./tiny_flux_deep_logs"
71
+ SAMPLE_DIR = "./tiny_flux_deep_samples"
72
+ ENCODING_CACHE_DIR = "./encoding_cache"
73
+
74
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
75
+ os.makedirs(LOG_DIR, exist_ok=True)
76
+ os.makedirs(SAMPLE_DIR, exist_ok=True)
77
+ os.makedirs(ENCODING_CACHE_DIR, exist_ok=True)
78
+
79
+ # ============================================================================
80
+ # FROZEN LAYER POSITIONS (from porting)
81
+ # ============================================================================
82
+ # Single blocks: old 0→0, old 1→{8,12,16}, old 2→24
83
+ FROZEN_SINGLE_POSITIONS = {0, 8, 12, 16, 24}
84
+
85
+ # Double blocks: old 0→0, old 1→{4,7,10}, old 2→14
86
+ FROZEN_DOUBLE_POSITIONS = {0, 4, 7, 10, 14}
87
+
88
+ # ============================================================================
89
+ # MODEL CONFIG
90
+ # ============================================================================
91
+ @dataclass
92
+ class TinyFluxDeepConfig:
93
+ """Deep variant: 512 hidden, 4 heads, 25 single, 15 double."""
94
+ hidden_size: int = 512
95
+ num_attention_heads: int = 4
96
+ attention_head_dim: int = 128
97
+ in_channels: int = 16
98
+ patch_size: int = 1
99
+ joint_attention_dim: int = 768
100
+ pooled_projection_dim: int = 768
101
+ num_double_layers: int = 15
102
+ num_single_layers: int = 25
103
+ mlp_ratio: float = 4.0
104
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)
105
+ guidance_embeds: bool = True
106
+
107
+ # ============================================================================
108
+ # HF HUB SETUP
109
+ # ============================================================================
110
+ print("Setting up HuggingFace Hub...")
111
+ api = HfApi()
112
+ try:
113
+ api.create_repo(repo_id=HF_REPO, exist_ok=True, repo_type="model")
114
+ print(f"✓ Repo ready: {HF_REPO}")
115
+ except Exception as e:
116
+ print(f"Note: {e}")
117
+
118
+ # ============================================================================
119
+ # TENSORBOARD
120
+ # ============================================================================
121
+ run_name = datetime.now().strftime("%Y%m%d_%H%M%S")
122
+ writer = SummaryWriter(log_dir=os.path.join(LOG_DIR, run_name))
123
+ print(f"✓ Tensorboard: {LOG_DIR}/{run_name}")
124
+
125
+ # ============================================================================
126
+ # LOAD DATASET
127
+ # ============================================================================
128
+ print("\nLoading dataset...")
129
+ ds = load_dataset(DATASET_REPO, DATASET_CONFIG, split="train")
130
+ print(f"Samples: {len(ds)} ({DATASET_CONFIG})")
131
+
132
+ # ============================================================================
133
+ # LOAD TEXT ENCODERS
134
+ # ============================================================================
135
+ print("\nLoading flan-t5-base...")
136
+ t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
137
+ t5_enc = T5EncoderModel.from_pretrained("google/flan-t5-base", torch_dtype=DTYPE).to(DEVICE).eval()
138
+
139
+ print("Loading CLIP-L...")
140
+ clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
141
+ clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval()
142
+
143
+ for p in t5_enc.parameters(): p.requires_grad = False
144
+ for p in clip_enc.parameters(): p.requires_grad = False
145
+
146
+ # ============================================================================
147
+ # LOAD VAE
148
+ # ============================================================================
149
+ print("Loading Flux VAE...")
150
+ from diffusers import AutoencoderKL
151
+
152
+ vae = AutoencoderKL.from_pretrained(
153
+ "black-forest-labs/FLUX.1-schnell",
154
+ subfolder="vae",
155
+ torch_dtype=DTYPE
156
+ ).to(DEVICE).eval()
157
+ for p in vae.parameters(): p.requires_grad = False
158
+
159
+ # ============================================================================
160
+ # BATCHED ENCODING
161
+ # ============================================================================
162
+ @torch.inference_mode()
163
+ def encode_prompts_batched(prompts: list) -> tuple:
164
+ t5_in = t5_tok(prompts, max_length=MAX_SEQ, padding="max_length", truncation=True, return_tensors="pt").to(DEVICE)
165
+ t5_out = t5_enc(input_ids=t5_in.input_ids, attention_mask=t5_in.attention_mask).last_hidden_state
166
+
167
+ clip_in = clip_tok(prompts, max_length=77, padding="max_length", truncation=True, return_tensors="pt").to(DEVICE)
168
+ clip_out = clip_enc(input_ids=clip_in.input_ids, attention_mask=clip_in.attention_mask)
169
+
170
+ return t5_out, clip_out.pooler_output
171
+
172
+ # ============================================================================
173
+ # PRE-ENCODE PROMPTS
174
+ # ============================================================================
175
+ print("\nPre-encoding prompts...")
176
+ PRECOMPUTE_ENCODINGS = True
177
+ cache_file = os.path.join(ENCODING_CACHE_DIR, f"encodings_{DATASET_CONFIG}_{len(ds)}.pt")
178
+
179
+ if PRECOMPUTE_ENCODINGS:
180
+ if os.path.exists(cache_file):
181
+ print(f"Loading cached encodings from {cache_file}...")
182
+ cached = torch.load(cache_file, weights_only=True)
183
+ all_t5_embeds = cached["t5_embeds"]
184
+ all_clip_pooled = cached["clip_pooled"]
185
+ print(f"✓ Loaded cached encodings")
186
+ else:
187
+ print("Encoding prompts (will cache)...")
188
+ all_prompts = ds["prompt"]
189
+
190
+ encode_batch_size = 64
191
+ all_t5_embeds = []
192
+ all_clip_pooled = []
193
+
194
+ for i in tqdm(range(0, len(all_prompts), encode_batch_size), desc="Encoding"):
195
+ batch_prompts = all_prompts[i:i+encode_batch_size]
196
+ t5_out, clip_out = encode_prompts_batched(batch_prompts)
197
+ all_t5_embeds.append(t5_out.cpu())
198
+ all_clip_pooled.append(clip_out.cpu())
199
+
200
+ all_t5_embeds = torch.cat(all_t5_embeds, dim=0)
201
+ all_clip_pooled = torch.cat(all_clip_pooled, dim=0)
202
+
203
+ torch.save({"t5_embeds": all_t5_embeds, "clip_pooled": all_clip_pooled}, cache_file)
204
+ print(f"✓ Saved encoding cache")
205
+
206
+ # ============================================================================
207
+ # FLOW MATCHING HELPERS
208
+ # ============================================================================
209
+ def flux_shift(t, s=SHIFT):
210
+ return s * t / (1 + (s - 1) * t)
211
+
212
+ def min_snr_weight(t, gamma=MIN_SNR):
213
+ snr = (t / (1 - t).clamp(min=1e-5)).pow(2)
214
+ return torch.clamp(snr, max=gamma) / snr.clamp(min=1e-5)
215
+
216
+ # ============================================================================
217
+ # SAMPLING FUNCTION
218
+ # ============================================================================
219
+ @torch.inference_mode()
220
+ def generate_samples(model, prompts, num_steps=20, guidance_scale=3.5, H=64, W=64):
221
+ model.eval()
222
+ B = len(prompts)
223
+ C = 16
224
+
225
+ t5_embeds, clip_pooleds = encode_prompts_batched(prompts)
226
+ t5_embeds = t5_embeds.to(DTYPE)
227
+ clip_pooleds = clip_pooleds.to(DTYPE)
228
+
229
+ x = torch.randn(B, H * W, C, device=DEVICE, dtype=DTYPE)
230
+ img_ids = TinyFluxDeep.create_img_ids(B, H, W, DEVICE)
231
+
232
+ t_linear = torch.linspace(0, 1, num_steps + 1, device=DEVICE, dtype=DTYPE)
233
+ timesteps = flux_shift(t_linear, s=SHIFT)
234
+
235
+ for i in range(num_steps):
236
+ t_curr = timesteps[i]
237
+ t_next = timesteps[i + 1]
238
+ dt = t_next - t_curr
239
+
240
+ t_batch = t_curr.expand(B).to(DTYPE)
241
+ guidance = torch.full((B,), guidance_scale, device=DEVICE, dtype=DTYPE)
242
+
243
+ v_cond = model(
244
+ hidden_states=x,
245
+ encoder_hidden_states=t5_embeds,
246
+ pooled_projections=clip_pooleds,
247
+ timestep=t_batch,
248
+ img_ids=img_ids,
249
+ guidance=guidance,
250
+ )
251
+ x = x + v_cond * dt
252
+
253
+ latents = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
254
+ latents = latents / vae.config.scaling_factor
255
+ images = vae.decode(latents.to(vae.dtype)).sample
256
+ images = (images / 2 + 0.5).clamp(0, 1)
257
+
258
+ model.train()
259
+ return images
260
+
261
+
262
+ def save_samples(images, prompts, step, save_dir, upload=True):
263
+ from torchvision.utils import make_grid, save_image
264
+
265
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
266
+
267
+ for i, (img, prompt) in enumerate(zip(images, prompts)):
268
+ safe_prompt = prompt[:50].replace(" ", "_").replace("/", "-")
269
+ path = os.path.join(save_dir, f"step{step}_{i}_{safe_prompt}.png")
270
+ save_image(img, path)
271
+
272
+ grid = make_grid(images, nrow=2, normalize=False)
273
+ grid_path = os.path.join(save_dir, f"step{step}_grid.png")
274
+ save_image(grid, grid_path)
275
+
276
+ writer.add_image("samples", grid, step)
277
+
278
+ if upload:
279
+ try:
280
+ api.upload_file(
281
+ path_or_fileobj=grid_path,
282
+ path_in_repo=f"samples/{timestamp}_step_{step}.png",
283
+ repo_id=HF_REPO,
284
+ )
285
+ print(f" ✓ Saved & uploaded {len(images)} samples")
286
+ except Exception as e:
287
+ print(f" ✓ Saved {len(images)} samples (upload failed: {e})")
288
+
289
+ # ============================================================================
290
+ # COLLATE
291
+ # ============================================================================
292
+ class IndexedDataset:
293
+ def __init__(self, ds):
294
+ self.ds = ds
295
+ def __len__(self):
296
+ return len(self.ds)
297
+ def __getitem__(self, idx):
298
+ item = dict(self.ds[idx])
299
+ item["__index__"] = idx
300
+ return item
301
+
302
+ def collate_preencoded(batch):
303
+ indices = [b["__index__"] for b in batch]
304
+ latents = torch.stack([torch.tensor(np.array(b["latent"]), dtype=DTYPE) for b in batch])
305
+ return {
306
+ "latents": latents,
307
+ "t5_embeds": all_t5_embeds[indices].to(DTYPE),
308
+ "clip_pooled": all_clip_pooled[indices].to(DTYPE),
309
+ }
310
+
311
+ ds = IndexedDataset(ds)
312
+ num_workers = 8
313
+
314
+ # ============================================================================
315
+ # FREEZE PORTED LAYERS
316
+ # ============================================================================
317
+ def freeze_ported_layers(model):
318
+ """Freeze layers that were ported from TinyFlux."""
319
+ frozen_count = 0
320
+ trainable_count = 0
321
+
322
+ for name, param in model.named_parameters():
323
+ should_freeze = False
324
+
325
+ # Check single blocks
326
+ for pos in FROZEN_SINGLE_POSITIONS:
327
+ if f"single_blocks.{pos}." in name:
328
+ should_freeze = True
329
+ break
330
+
331
+ # Check double blocks
332
+ for pos in FROZEN_DOUBLE_POSITIONS:
333
+ if f"double_blocks.{pos}." in name:
334
+ should_freeze = True
335
+ break
336
+
337
+ if should_freeze:
338
+ param.requires_grad = False
339
+ frozen_count += param.numel()
340
+ else:
341
+ param.requires_grad = True
342
+ trainable_count += param.numel()
343
+
344
+ print(f"\nFrozen params: {frozen_count:,}")
345
+ print(f"Trainable params: {trainable_count:,}")
346
+ print(f"Total: {frozen_count + trainable_count:,}")
347
+ print(f"Trainable ratio: {trainable_count / (frozen_count + trainable_count) * 100:.1f}%")
348
+
349
+ return model
350
+
351
+ # ============================================================================
352
+ # CHECKPOINT FUNCTIONS
353
+ # ============================================================================
354
+ EXPECTED_MISSING = {'time_in.sin_basis', 'guidance_in.sin_basis',
355
+ 'rope.freqs_0', 'rope.freqs_1', 'rope.freqs_2'}
356
+
357
+ def load_weights(path):
358
+ if path.endswith(".safetensors"):
359
+ state_dict = load_file(path)
360
+ else:
361
+ ckpt = torch.load(path, map_location=DEVICE, weights_only=False)
362
+ state_dict = ckpt.get("model", ckpt.get("state_dict", ckpt))
363
+
364
+ if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
365
+ state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
366
+
367
+ return state_dict
368
+
369
+
370
+ def save_checkpoint(model, optimizer, scheduler, step, epoch, loss, path):
371
+ os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
372
+
373
+ state_dict = model.state_dict()
374
+ if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
375
+ state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
376
+
377
+ weights_path = path.replace(".pt", ".safetensors")
378
+ save_file(state_dict, weights_path)
379
+
380
+ torch.save({
381
+ "step": step, "epoch": epoch, "loss": loss,
382
+ "optimizer": optimizer.state_dict(),
383
+ "scheduler": scheduler.state_dict(),
384
+ }, path)
385
+ print(f" ✓ Saved checkpoint: step {step}")
386
+ return weights_path
387
+
388
+
389
+ def upload_checkpoint(weights_path, step):
390
+ try:
391
+ api.upload_file(path_or_fileobj=weights_path, path_in_repo=f"checkpoints/step_{step}.safetensors", repo_id=HF_REPO)
392
+ print(f" ✓ Uploaded step {step}")
393
+ except Exception as e:
394
+ print(f" ⚠ Upload failed: {e}")
395
+
396
+
397
+ def load_checkpoint(model, target):
398
+ if target == "none" or target is None:
399
+ print("Starting from scratch (no checkpoint)")
400
+ return 0, 0
401
+
402
+ if target == "hub":
403
+ try:
404
+ weights_path = hf_hub_download(repo_id=HF_REPO, filename="model.safetensors")
405
+ weights = load_weights(weights_path)
406
+ missing, unexpected = model.load_state_dict(weights, strict=False)
407
+ actual_missing = set(missing) - EXPECTED_MISSING
408
+ if actual_missing:
409
+ print(f" ⚠ Missing: {list(actual_missing)[:5]}...")
410
+ else:
411
+ print(f" ✓ Missing only precomputed buffers (OK)")
412
+ if unexpected:
413
+ print(f" ⚠ Unexpected: {unexpected[:5]}...")
414
+ print(f"✓ Loaded from hub: {HF_REPO}")
415
+ return 0, 0
416
+ except Exception as e:
417
+ print(f"Hub load failed: {e}")
418
+ return 0, 0
419
+
420
+ if target == "latest":
421
+ # Find latest local checkpoint
422
+ ckpts = [f for f in os.listdir(CHECKPOINT_DIR) if f.startswith("step_") and f.endswith(".safetensors")]
423
+ if ckpts:
424
+ latest = sorted(ckpts, key=lambda x: int(x.split("_")[1].split(".")[0]))[-1]
425
+ weights_path = os.path.join(CHECKPOINT_DIR, latest)
426
+ weights = load_weights(weights_path)
427
+ model.load_state_dict(weights, strict=False)
428
+ step = int(latest.split("_")[1].split(".")[0])
429
+ print(f"✓ Loaded local: {latest}")
430
+ return step, 0
431
+
432
+ return 0, 0
433
+
434
+ # ============================================================================
435
+ # DATALOADER
436
+ # ============================================================================
437
+ loader = DataLoader(
438
+ ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_preencoded,
439
+ num_workers=num_workers, pin_memory=True,
440
+ persistent_workers=(num_workers > 0),
441
+ prefetch_factor=4 if num_workers > 0 else None,
442
+ )
443
+
444
+ # ============================================================================
445
+ # MODEL (assumes TinyFluxDeep is defined - run model cell first)
446
+ # ============================================================================
447
+ print("\nCreating TinyFlux-Deep model...")
448
+ config = TinyFluxDeepConfig()
449
+ model = TinyFluxDeep(config).to(DEVICE).to(DTYPE)
450
+ print(f"Params: {sum(p.numel() for p in model.parameters()):,}")
451
+
452
+ # ============================================================================
453
+ # LOAD & FREEZE
454
+ # ============================================================================
455
+ print(f"\nLoad target: {LOAD_TARGET}")
456
+ start_step, start_epoch = load_checkpoint(model, LOAD_TARGET)
457
+
458
+ print("\nFreezing ported layers...")
459
+ model = freeze_ported_layers(model)
460
+ print(f"Frozen single blocks: {sorted(FROZEN_SINGLE_POSITIONS)}")
461
+ print(f"Frozen double blocks: {sorted(FROZEN_DOUBLE_POSITIONS)}")
462
+
463
+ # Only optimize trainable params
464
+ trainable_params = [p for p in model.parameters() if p.requires_grad]
465
+ print(f"Optimizing {len(trainable_params)} parameter groups")
466
+
467
+ # ============================================================================
468
+ # OPTIMIZER
469
+ # ============================================================================
470
+ opt = torch.optim.AdamW(trainable_params, lr=LR, betas=(0.9, 0.99), weight_decay=0.01, fused=True)
471
+
472
+ total_steps = len(loader) * EPOCHS // GRAD_ACCUM
473
+ warmup = min(500, total_steps // 10)
474
+
475
+ def lr_fn(step):
476
+ if step < warmup:
477
+ return step / warmup
478
+ return 0.5 * (1 + math.cos(math.pi * (step - warmup) / (total_steps - warmup)))
479
+
480
+ sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_fn)
481
+
482
+ if RESUME_STEP is not None:
483
+ start_step = RESUME_STEP
484
+
485
+ # ============================================================================
486
+ # COMPILE (after freezing)
487
+ # ============================================================================
488
+ model = torch.compile(model, mode="default")
489
+
490
+ # Sample prompts
491
+ SAMPLE_PROMPTS = [
492
+ "a photo of a cat sitting on a windowsill",
493
+ "a beautiful sunset over mountains",
494
+ "a portrait of a woman with red hair",
495
+ "a futuristic cityscape at night",
496
+ ]
497
+
498
+ # ============================================================================
499
+ # TRAINING LOOP
500
+ # ============================================================================
501
+ print(f"\n{'='*60}")
502
+ print(f"Training TinyFlux-Deep")
503
+ print(f"{'='*60}")
504
+ print(f"Epochs: {EPOCHS}, Steps: {total_steps}")
505
+ print(f"Batch: {BATCH_SIZE} x {GRAD_ACCUM} = {BATCH_SIZE * GRAD_ACCUM}")
506
+ print(f"LR: {LR}, Warmup: {warmup}")
507
+
508
+ model.train()
509
+ step = start_step
510
+ best = float("inf")
511
+
512
+ for ep in range(start_epoch, EPOCHS):
513
+ ep_loss = 0
514
+ ep_batches = 0
515
+ pbar = tqdm(loader, desc=f"E{ep + 1}")
516
+
517
+ for i, batch in enumerate(pbar):
518
+ latents = batch["latents"].to(DEVICE, non_blocking=True)
519
+ t5 = batch["t5_embeds"].to(DEVICE, non_blocking=True)
520
+ clip = batch["clip_pooled"].to(DEVICE, non_blocking=True)
521
+
522
+ B, C, H, W = latents.shape
523
+ data = latents.permute(0, 2, 3, 1).reshape(B, H * W, C)
524
+ noise = torch.randn_like(data)
525
+
526
+ # Logit-normal timesteps with flux shift
527
+ t = torch.sigmoid(torch.randn(B, device=DEVICE))
528
+ t = flux_shift(t, s=SHIFT).to(DTYPE).clamp(1e-4, 1 - 1e-4)
529
+
530
+ t_expanded = t.view(B, 1, 1)
531
+ x_t = (1 - t_expanded) * noise + t_expanded * data
532
+ v_target = data - noise
533
+
534
+ img_ids = TinyFluxDeep.create_img_ids(B, H, W, DEVICE)
535
+ guidance = torch.rand(B, device=DEVICE, dtype=DTYPE) * 4 + 1
536
+
537
+ with torch.autocast("cuda", dtype=DTYPE):
538
+ v_pred = model(
539
+ hidden_states=x_t,
540
+ encoder_hidden_states=t5,
541
+ pooled_projections=clip,
542
+ timestep=t,
543
+ img_ids=img_ids,
544
+ guidance=guidance,
545
+ )
546
+
547
+ loss_raw = F.mse_loss(v_pred, v_target, reduction="none").mean(dim=[1, 2])
548
+ snr_weights = min_snr_weight(t)
549
+ loss = (loss_raw * snr_weights).mean() / GRAD_ACCUM
550
+ loss.backward()
551
+
552
+ if (i + 1) % GRAD_ACCUM == 0:
553
+ grad_norm = torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
554
+ opt.step()
555
+ sched.step()
556
+ opt.zero_grad(set_to_none=True)
557
+ step += 1
558
+
559
+ if step % LOG_EVERY == 0:
560
+ writer.add_scalar("train/loss", loss.item() * GRAD_ACCUM, step)
561
+ writer.add_scalar("train/lr", sched.get_last_lr()[0], step)
562
+ writer.add_scalar("train/grad_norm", grad_norm.item(), step)
563
+
564
+ if step % SAMPLE_EVERY == 0:
565
+ print(f"\n Generating samples at step {step}...")
566
+ images = generate_samples(model, SAMPLE_PROMPTS, num_steps=20)
567
+ save_samples(images, SAMPLE_PROMPTS, step, SAMPLE_DIR)
568
+
569
+ if step % SAVE_EVERY == 0:
570
+ ckpt_path = os.path.join(CHECKPOINT_DIR, f"step_{step}.pt")
571
+ weights_path = save_checkpoint(model, opt, sched, step, ep, loss.item(), ckpt_path)
572
+ if step % UPLOAD_EVERY == 0:
573
+ upload_checkpoint(weights_path, step)
574
+
575
+ ep_loss += loss.item() * GRAD_ACCUM
576
+ ep_batches += 1
577
+ pbar.set_postfix(loss=f"{loss.item() * GRAD_ACCUM:.4f}", step=step)
578
+
579
+ avg = ep_loss / max(ep_batches, 1)
580
+ print(f"Epoch {ep + 1} loss: {avg:.4f}")
581
+
582
+ if avg < best:
583
+ best = avg
584
+ weights_path = save_checkpoint(model, opt, sched, step, ep, avg, os.path.join(CHECKPOINT_DIR, "best.pt"))
585
+ try:
586
+ api.upload_file(path_or_fileobj=weights_path, path_in_repo="model.safetensors", repo_id=HF_REPO)
587
+ print(f" ✓ Uploaded best model")
588
+ except:
589
+ pass
590
+
591
+ # ============================================================================
592
+ # FINAL
593
+ # ============================================================================
594
+ print(f"\n✓ Training complete! Best loss: {best:.4f}")
595
+ writer.close()
596
+
597
+ # Final samples
598
+ print("\nGenerating final samples...")
599
+ images = generate_samples(model, SAMPLE_PROMPTS, num_steps=30)
600
+ save_samples(images, SAMPLE_PROMPTS, step, SAMPLE_DIR)