AbstractPhil commited on
Commit
aae97e7
·
verified ·
1 Parent(s): 11db6fd

Create trainer_v2.py

Browse files
Files changed (1) hide show
  1. trainer_v2.py +1199 -0
trainer_v2.py ADDED
@@ -0,0 +1,1199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================================
2
+ # TinyFlux-Deep Training Cell - Combined Dataset
3
+ # ============================================================================
4
+ # Datasets:
5
+ # - FFHQ portraits (40k × 3 prompts = ~120k)
6
+ # - flux-schnell-teacher-latents (train_simple_512 + train_512 + train_2_512 = ~40k)
7
+ # - SportFashion_512x512 (54.6k) - with background mask
8
+ # - SynthMoCap_smpl_512 (106k) - with SMPL body mask
9
+ # Total: ~320k samples
10
+ #
11
+ # All encoded with flan-t5-base (768 dim)
12
+ # Masked loss for foreground-focused training on product/body datasets
13
+ #
14
+ # USAGE: Run model.py cell first, then this cell
15
+ # This converts tiny-flux-deep into tiny-flux-deep-v2, which is a different variant.
16
+ # WARNING: It will impact performance and weights, so be aware.
17
+ # ============================================================================
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ from torch.utils.data import DataLoader, Dataset
23
+ from datasets import load_dataset, concatenate_datasets
24
+ from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer
25
+ from huggingface_hub import HfApi, hf_hub_download
26
+ from safetensors.torch import save_file, load_file
27
+ from torch.utils.tensorboard import SummaryWriter
28
+ from tqdm.auto import tqdm
29
+ import numpy as np
30
+ import math
31
+ import json
32
+ import random
33
+ from typing import Tuple, Optional, Dict, List
34
+ import os
35
+ from datetime import datetime
36
+ from PIL import Image
37
+
38
+ # ============================================================================
39
+ # CUDA OPTIMIZATIONS
40
+ # ============================================================================
41
+ torch.backends.cuda.matmul.allow_tf32 = True
42
+ torch.backends.cudnn.allow_tf32 = True
43
+ torch.backends.cudnn.benchmark = True
44
+ torch.set_float32_matmul_precision('high')
45
+
46
+ import warnings
47
+ warnings.filterwarnings('ignore', message='.*TF32.*')
48
+
49
+ # ============================================================================
50
+ # CONFIG
51
+ # ============================================================================
52
+ BATCH_SIZE = 16
53
+ GRAD_ACCUM = 2
54
+ LR = 3e-4
55
+ EPOCHS = 20
56
+ MAX_SEQ = 128
57
+ SHIFT = 3.0
58
+ DEVICE = "cuda"
59
+ DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
60
+
61
+ ALLOW_WEIGHT_UPGRADE = True # Set to False to require exact weight match
62
+
63
+ # HuggingFace Hub
64
+ HF_REPO = "AbstractPhil/tiny-flux-deep"
65
+ SAVE_EVERY = 625
66
+ UPLOAD_EVERY = 625
67
+ SAMPLE_EVERY = 312
68
+ LOG_EVERY = 10
69
+ LOG_UPLOAD_EVERY = 625
70
+
71
+ # Checkpoint loading
72
+ LOAD_TARGET = "latest" # "hub", "latest", "best", "none"
73
+ RESUME_STEP = None
74
+
75
+ # ============================================================================
76
+ # DATASET CONFIG - Enable/disable datasets for this run
77
+ # ============================================================================
78
+ ENABLE_PORTRAIT = False
79
+ ENABLE_SCHNELL = True
80
+ ENABLE_SPORTFASHION = False # Disabled for disk space
81
+ ENABLE_SYNTHMOCAP = False # Disabled for disk space
82
+
83
+
84
+
85
+ # Dataset repos
86
+ PORTRAIT_REPO = "AbstractPhil/ffhq_flux_latents_repaired"
87
+ PORTRAIT_NUM_SHARDS = 11
88
+ SCHNELL_REPO = "AbstractPhil/flux-schnell-teacher-latents"
89
+ SCHNELL_CONFIGS = ["train_simple_512"] # Add "train_512", "train_2_512" as disk allows
90
+ SPORTFASHION_REPO = "Pianokill/SportFashion_512x512"
91
+ SYNTHMOCAP_REPO = "toyxyz/SynthMoCap_smpl_512"
92
+
93
+ # Masked loss config
94
+ # Weight foreground higher than background
95
+ FG_LOSS_WEIGHT = 2.0 # Foreground multiplier
96
+ BG_LOSS_WEIGHT = 0.5 # Background multiplier
97
+ USE_MASKED_LOSS = False
98
+
99
+ # Min-SNR weighting for flow matching
100
+ MIN_SNR_GAMMA = 5.0
101
+
102
+ # Paths
103
+ CHECKPOINT_DIR = "./tiny_flux_deep_checkpoints"
104
+ LOG_DIR = "./tiny_flux_deep_logs"
105
+ SAMPLE_DIR = "./tiny_flux_deep_samples"
106
+ ENCODING_CACHE_DIR = "./encoding_cache"
107
+ LATENT_CACHE_DIR = "./latent_cache"
108
+
109
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
110
+ os.makedirs(LOG_DIR, exist_ok=True)
111
+ os.makedirs(SAMPLE_DIR, exist_ok=True)
112
+ os.makedirs(ENCODING_CACHE_DIR, exist_ok=True)
113
+ os.makedirs(LATENT_CACHE_DIR, exist_ok=True)
114
+
115
+ # ============================================================================
116
+ # REGULARIZATION CONFIG
117
+ # ============================================================================
118
+ TEXT_DROPOUT = 0.1
119
+ GUIDANCE_DROPOUT = 0.1
120
+ EMA_DECAY = 0.9999
121
+
122
+ # ============================================================================
123
+ # EMA
124
+ # ============================================================================
125
+ class EMA:
126
+ def __init__(self, model, decay=0.9999):
127
+ self.decay = decay
128
+ self.shadow = {}
129
+ self._backup = {}
130
+ if hasattr(model, '_orig_mod'):
131
+ state = model._orig_mod.state_dict()
132
+ else:
133
+ state = model.state_dict()
134
+ for k, v in state.items():
135
+ self.shadow[k] = v.clone().detach()
136
+
137
+ @torch.no_grad()
138
+ def update(self, model):
139
+ if hasattr(model, '_orig_mod'):
140
+ state = model._orig_mod.state_dict()
141
+ else:
142
+ state = model.state_dict()
143
+ for k, v in state.items():
144
+ if k in self.shadow:
145
+ self.shadow[k].lerp_(v.to(self.shadow[k].dtype), 1 - self.decay)
146
+
147
+ def apply_shadow_for_eval(self, model):
148
+ if hasattr(model, '_orig_mod'):
149
+ self._backup = {k: v.clone() for k, v in model._orig_mod.state_dict().items()}
150
+ model._orig_mod.load_state_dict(self.shadow)
151
+ else:
152
+ self._backup = {k: v.clone() for k, v in model.state_dict().items()}
153
+ model.load_state_dict(self.shadow)
154
+
155
+ def restore(self, model):
156
+ if hasattr(model, '_orig_mod'):
157
+ model._orig_mod.load_state_dict(self._backup)
158
+ else:
159
+ model.load_state_dict(self._backup)
160
+ self._backup = {}
161
+
162
+ def state_dict(self):
163
+ return {'shadow': self.shadow, 'decay': self.decay}
164
+
165
+ def load_state_dict(self, state):
166
+ self.shadow = {k: v.clone() for k, v in state['shadow'].items()}
167
+ self.decay = state.get('decay', self.decay)
168
+
169
+ # ============================================================================
170
+ # REGULARIZATION
171
+ # ============================================================================
172
+ def apply_text_dropout(t5_embeds, clip_pooled, dropout_prob=0.1):
173
+ B = t5_embeds.shape[0]
174
+ mask = torch.rand(B, device=t5_embeds.device) < dropout_prob
175
+ t5_embeds = t5_embeds.clone()
176
+ clip_pooled = clip_pooled.clone()
177
+ t5_embeds[mask] = 0
178
+ clip_pooled[mask] = 0
179
+ return t5_embeds, clip_pooled, mask
180
+
181
+ # ============================================================================
182
+ # MASKING UTILITIES
183
+ # ============================================================================
184
+ def detect_background_color(image: Image.Image, sample_size: int = 100) -> Tuple[int, int, int]:
185
+ """Detect dominant background color by sampling corners."""
186
+ img = np.array(image)
187
+ if len(img.shape) == 2:
188
+ img = np.stack([img] * 3, axis=-1)
189
+
190
+ h, w = img.shape[:2]
191
+ corners = [
192
+ img[:sample_size, :sample_size], # Top-left
193
+ img[:sample_size, -sample_size:], # Top-right
194
+ img[-sample_size:, :sample_size], # Bottom-left
195
+ img[-sample_size:, -sample_size:], # Bottom-right
196
+ ]
197
+
198
+ # Compute median color across corners
199
+ corner_pixels = np.concatenate([c.reshape(-1, 3) for c in corners], axis=0)
200
+ bg_color = np.median(corner_pixels, axis=0).astype(np.uint8)
201
+ return tuple(bg_color)
202
+
203
+
204
+ def create_product_mask(image: Image.Image, threshold: int = 30) -> np.ndarray:
205
+ """Create foreground mask for product images (non-background pixels)."""
206
+ img = np.array(image).astype(np.float32)
207
+ if len(img.shape) == 2:
208
+ img = np.stack([img] * 3, axis=-1)
209
+
210
+ bg_color = detect_background_color(image)
211
+ bg_color = np.array(bg_color, dtype=np.float32)
212
+
213
+ # Distance from background color
214
+ diff = np.sqrt(np.sum((img - bg_color) ** 2, axis=-1))
215
+ mask = (diff > threshold).astype(np.float32)
216
+
217
+ return mask
218
+
219
+
220
+ def create_smpl_mask(conditioning_image: Image.Image, threshold: int = 20) -> np.ndarray:
221
+ """Create body mask from SMPL conditioning render.
222
+
223
+ SynthMoCap uses green/teal background. Body is rendered as mesh.
224
+ Non-green pixels = body.
225
+ """
226
+ img = np.array(conditioning_image).astype(np.float32)
227
+ if len(img.shape) == 2:
228
+ return (img > threshold).astype(np.float32)
229
+
230
+ # Green background detection (high G, low R and B relative to G)
231
+ r, g, b = img[:, :, 0], img[:, :, 1], img[:, :, 2]
232
+
233
+ # Background is typically green/teal
234
+ # Body pixels have different color distribution
235
+ is_background = (g > r + 20) & (g > b + 20)
236
+ mask = (~is_background).astype(np.float32)
237
+
238
+ return mask
239
+
240
+
241
+ def downsample_mask_to_latent(mask: np.ndarray, latent_h: int = 64, latent_w: int = 64) -> torch.Tensor:
242
+ """Downsample pixel mask to latent space dimensions."""
243
+ # Use area averaging for downsampling
244
+ mask_pil = Image.fromarray((mask * 255).astype(np.uint8))
245
+ mask_pil = mask_pil.resize((latent_w, latent_h), Image.Resampling.BILINEAR)
246
+ mask_latent = np.array(mask_pil).astype(np.float32) / 255.0
247
+ return torch.from_numpy(mask_latent)
248
+
249
+ # ============================================================================
250
+ # HF HUB SETUP
251
+ # ============================================================================
252
+ print("Setting up HuggingFace Hub...")
253
+ api = HfApi()
254
+
255
+ # ============================================================================
256
+ # FLOW MATCHING HELPERS
257
+ # ============================================================================
258
+ def flux_shift(t, s=SHIFT):
259
+ return s * t / (1 + (s - 1) * t)
260
+
261
+ def min_snr_weight(t, gamma=MIN_SNR_GAMMA):
262
+ """Min-SNR weighting for flow matching to balance loss across timesteps."""
263
+ snr = (t / (1 - t).clamp(min=1e-5)).pow(2)
264
+ return torch.clamp(snr, max=gamma) / snr.clamp(min=1e-5)
265
+
266
+ # ============================================================================
267
+ # LOAD TEXT ENCODERS
268
+ # ============================================================================
269
+ print("Loading text encoders...")
270
+ t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
271
+ t5_enc = T5EncoderModel.from_pretrained("google/flan-t5-base", torch_dtype=DTYPE).to(DEVICE).eval()
272
+ for p in t5_enc.parameters():
273
+ p.requires_grad = False
274
+
275
+ clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
276
+ clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval()
277
+ for p in clip_enc.parameters():
278
+ p.requires_grad = False
279
+ print("✓ Text encoders loaded")
280
+
281
+ # ============================================================================
282
+ # LOAD VAE
283
+ # ============================================================================
284
+ print("Loading VAE...")
285
+ from diffusers import AutoencoderKL
286
+ vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=DTYPE).to(DEVICE).eval()
287
+ for p in vae.parameters():
288
+ p.requires_grad = False
289
+ VAE_SCALE = vae.config.scaling_factor
290
+ print(f"✓ VAE loaded (scale={VAE_SCALE})")
291
+
292
+ # ============================================================================
293
+ # ENCODING FUNCTIONS
294
+ # ============================================================================
295
+ @torch.no_grad()
296
+ def encode_prompt(prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
297
+ t5_inputs = t5_tok(prompt, return_tensors="pt", padding="max_length",
298
+ max_length=MAX_SEQ, truncation=True).to(DEVICE)
299
+ t5_out = t5_enc(**t5_inputs).last_hidden_state
300
+
301
+ clip_inputs = clip_tok(prompt, return_tensors="pt", padding="max_length",
302
+ max_length=77, truncation=True).to(DEVICE)
303
+ clip_out = clip_enc(**clip_inputs).pooler_output
304
+
305
+ return t5_out.squeeze(0), clip_out.squeeze(0)
306
+
307
+ @torch.no_grad()
308
+ def encode_prompts_batched(prompts: List[str], batch_size: int = 64) -> Tuple[torch.Tensor, torch.Tensor]:
309
+ all_t5 = []
310
+ all_clip = []
311
+
312
+ for i in tqdm(range(0, len(prompts), batch_size), desc="Encoding", leave=False):
313
+ batch = prompts[i:i+batch_size]
314
+
315
+ t5_inputs = t5_tok(batch, return_tensors="pt", padding="max_length",
316
+ max_length=MAX_SEQ, truncation=True).to(DEVICE)
317
+ t5_out = t5_enc(**t5_inputs).last_hidden_state
318
+ all_t5.append(t5_out.cpu())
319
+
320
+ clip_inputs = clip_tok(batch, return_tensors="pt", padding="max_length",
321
+ max_length=77, truncation=True).to(DEVICE)
322
+ clip_out = clip_enc(**clip_inputs).pooler_output
323
+ all_clip.append(clip_out.cpu())
324
+
325
+ return torch.cat(all_t5, dim=0), torch.cat(all_clip, dim=0)
326
+
327
+ @torch.no_grad()
328
+ def encode_image_to_latent(image: Image.Image) -> torch.Tensor:
329
+ """Encode PIL image to VAE latent."""
330
+ if image.mode != "RGB":
331
+ image = image.convert("RGB")
332
+
333
+ # Resize to 512x512 if needed
334
+ if image.size != (512, 512):
335
+ image = image.resize((512, 512), Image.Resampling.LANCZOS)
336
+
337
+ # To tensor and normalize
338
+ img_tensor = torch.from_numpy(np.array(image)).float() / 255.0
339
+ img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
340
+ img_tensor = (img_tensor * 2.0 - 1.0).to(DEVICE, dtype=DTYPE)
341
+
342
+ # Encode
343
+ latent = vae.encode(img_tensor).latent_dist.sample()
344
+ latent = latent * VAE_SCALE
345
+
346
+ return latent.squeeze(0).cpu() # [16, 64, 64]
347
+
348
+ # ============================================================================
349
+ # LOAD DATASETS
350
+ # ============================================================================
351
+
352
+ # --- 1. Portrait Dataset (FFHQ) ---
353
+ portrait_ds = None
354
+ portrait_indices = []
355
+ portrait_prompts = []
356
+
357
+ if ENABLE_PORTRAIT:
358
+ print(f"\n[1/4] Loading portrait dataset from {PORTRAIT_REPO}...")
359
+ portrait_shards = []
360
+ for i in range(PORTRAIT_NUM_SHARDS):
361
+ split_name = f"train_{i:02d}"
362
+ print(f" Loading {split_name}...")
363
+ shard = load_dataset(PORTRAIT_REPO, split=split_name)
364
+ portrait_shards.append(shard)
365
+ portrait_ds = concatenate_datasets(portrait_shards)
366
+ print(f"✓ Portrait: {len(portrait_ds)} base samples")
367
+
368
+ # Extract triplicated prompts - batch read columns then iterate
369
+ print(" Extracting prompts (columnar)...")
370
+
371
+ # Batch read all three columns at once (fast Arrow read)
372
+ florence_list = list(portrait_ds["text_florence"])
373
+ llava_list = list(portrait_ds["text_llava"])
374
+ blip_list = list(portrait_ds["text_blip"])
375
+
376
+ # Build from Python lists (instant)
377
+ for i, (f, l, b) in enumerate(zip(florence_list, llava_list, blip_list)):
378
+ if f and f.strip():
379
+ portrait_indices.append(i)
380
+ portrait_prompts.append(f)
381
+ if l and l.strip():
382
+ portrait_indices.append(i)
383
+ portrait_prompts.append(l)
384
+ if b and b.strip():
385
+ portrait_indices.append(i)
386
+ portrait_prompts.append(b)
387
+ print(f" Expanded: {len(portrait_prompts)} samples (3 prompts/image)")
388
+ else:
389
+ print("\n[1/4] Portrait dataset DISABLED")
390
+
391
+ # --- 2. Schnell Teacher Dataset ---
392
+ schnell_ds = None
393
+ schnell_prompts = []
394
+
395
+ if ENABLE_SCHNELL:
396
+ print(f"\n[2/4] Loading schnell teacher dataset from {SCHNELL_REPO}...")
397
+ schnell_datasets = []
398
+ for config in SCHNELL_CONFIGS:
399
+ print(f" Loading {config}...")
400
+ ds = load_dataset(SCHNELL_REPO, config, split="train")
401
+ schnell_datasets.append(ds)
402
+ print(f" {len(ds)} samples")
403
+ schnell_ds = concatenate_datasets(schnell_datasets)
404
+ schnell_prompts = list(schnell_ds["prompt"])
405
+ print(f"✓ Schnell: {len(schnell_ds)} samples")
406
+ else:
407
+ print("\n[2/4] Schnell dataset DISABLED")
408
+
409
+ # --- 3. SportFashion Dataset ---
410
+ sportfashion_ds = None
411
+ sportfashion_prompts = []
412
+
413
+ if ENABLE_SPORTFASHION:
414
+ print(f"\n[3/4] Loading SportFashion dataset from {SPORTFASHION_REPO}...")
415
+ sportfashion_ds = load_dataset(SPORTFASHION_REPO, split="train")
416
+ sportfashion_prompts = list(sportfashion_ds["text"])
417
+ print(f"✓ SportFashion: {len(sportfashion_ds)} samples")
418
+ else:
419
+ print("\n[3/4] SportFashion dataset DISABLED")
420
+
421
+ # --- 4. SynthMoCap Dataset ---
422
+ synthmocap_ds = None
423
+ synthmocap_prompts = []
424
+
425
+ if ENABLE_SYNTHMOCAP:
426
+ print(f"\n[4/4] Loading SynthMoCap dataset from {SYNTHMOCAP_REPO}...")
427
+ synthmocap_ds = load_dataset(SYNTHMOCAP_REPO, split="train")
428
+ synthmocap_prompts = list(synthmocap_ds["text"])
429
+ print(f"✓ SynthMoCap: {len(synthmocap_ds)} samples")
430
+ else:
431
+ print("\n[4/4] SynthMoCap dataset DISABLED")
432
+
433
+ # ============================================================================
434
+ # ENCODE ALL PROMPTS
435
+ # ============================================================================
436
+ total_samples = len(portrait_prompts) + len(schnell_prompts) + len(sportfashion_prompts) + len(synthmocap_prompts)
437
+ print(f"\nTotal combined samples: {total_samples}")
438
+
439
+ def load_or_encode(cache_path, prompts, name):
440
+ if not prompts:
441
+ return None, None
442
+ if os.path.exists(cache_path):
443
+ print(f"Loading cached {name} encodings...")
444
+ cached = torch.load(cache_path)
445
+ return cached["t5_embeds"], cached["clip_pooled"]
446
+ else:
447
+ print(f"Encoding {len(prompts)} {name} prompts...")
448
+ t5, clip = encode_prompts_batched(prompts, batch_size=64)
449
+ torch.save({"t5_embeds": t5, "clip_pooled": clip}, cache_path)
450
+ print(f"✓ Cached to {cache_path}")
451
+ return t5, clip
452
+
453
+ # Cache paths and encoding
454
+ portrait_t5, portrait_clip = None, None
455
+ schnell_t5, schnell_clip = None, None
456
+ sportfashion_t5, sportfashion_clip = None, None
457
+ synthmocap_t5, synthmocap_clip = None, None
458
+
459
+ if portrait_prompts:
460
+ portrait_enc_cache = os.path.join(ENCODING_CACHE_DIR, f"portrait_encodings_{len(portrait_prompts)}.pt")
461
+ portrait_t5, portrait_clip = load_or_encode(portrait_enc_cache, portrait_prompts, "portrait")
462
+
463
+ if schnell_prompts:
464
+ schnell_enc_cache = os.path.join(ENCODING_CACHE_DIR, f"schnell_encodings_{len(schnell_prompts)}.pt")
465
+ schnell_t5, schnell_clip = load_or_encode(schnell_enc_cache, schnell_prompts, "schnell")
466
+
467
+ if sportfashion_prompts:
468
+ sportfashion_enc_cache = os.path.join(ENCODING_CACHE_DIR, f"sportfashion_encodings_{len(sportfashion_prompts)}.pt")
469
+ sportfashion_t5, sportfashion_clip = load_or_encode(sportfashion_enc_cache, sportfashion_prompts, "sportfashion")
470
+
471
+ if synthmocap_prompts:
472
+ synthmocap_enc_cache = os.path.join(ENCODING_CACHE_DIR, f"synthmocap_encodings_{len(synthmocap_prompts)}.pt")
473
+ synthmocap_t5, synthmocap_clip = load_or_encode(synthmocap_enc_cache, synthmocap_prompts, "synthmocap")
474
+
475
+ # ============================================================================
476
+ # COMBINED DATASET CLASS WITH MASK SUPPORT
477
+ # ============================================================================
478
+ class CombinedDataset(Dataset):
479
+ """Combined dataset with mask support for weighted loss."""
480
+
481
+ def __init__(
482
+ self,
483
+ portrait_ds, portrait_indices, portrait_t5, portrait_clip,
484
+ schnell_ds, schnell_t5, schnell_clip,
485
+ sportfashion_ds, sportfashion_t5, sportfashion_clip,
486
+ synthmocap_ds, synthmocap_t5, synthmocap_clip,
487
+ vae, vae_scale, device, dtype,
488
+ compute_masks=True
489
+ ):
490
+ self.portrait_ds = portrait_ds
491
+ self.portrait_indices = portrait_indices
492
+ self.portrait_t5 = portrait_t5
493
+ self.portrait_clip = portrait_clip
494
+
495
+ self.schnell_ds = schnell_ds
496
+ self.schnell_t5 = schnell_t5
497
+ self.schnell_clip = schnell_clip
498
+
499
+ self.sportfashion_ds = sportfashion_ds
500
+ self.sportfashion_t5 = sportfashion_t5
501
+ self.sportfashion_clip = sportfashion_clip
502
+
503
+ self.synthmocap_ds = synthmocap_ds
504
+ self.synthmocap_t5 = synthmocap_t5
505
+ self.synthmocap_clip = synthmocap_clip
506
+
507
+ self.vae = vae
508
+ self.vae_scale = vae_scale
509
+ self.device = device
510
+ self.dtype = dtype
511
+ self.compute_masks = compute_masks
512
+
513
+ # Dataset sizes (0 if disabled)
514
+ self.n_portrait = len(portrait_indices) if portrait_indices else 0
515
+ self.n_schnell = len(schnell_ds) if schnell_ds else 0
516
+ self.n_sportfashion = len(sportfashion_ds) if sportfashion_ds else 0
517
+ self.n_synthmocap = len(synthmocap_ds) if synthmocap_ds else 0
518
+
519
+ # Cumulative indices for fast lookup
520
+ self.c1 = self.n_portrait
521
+ self.c2 = self.c1 + self.n_schnell
522
+ self.c3 = self.c2 + self.n_sportfashion
523
+ self.total = self.c3 + self.n_synthmocap
524
+
525
+ def __len__(self):
526
+ return self.total
527
+
528
+ def _get_latent_from_array(self, latent_data):
529
+ """Convert latent data to tensor."""
530
+ if isinstance(latent_data, torch.Tensor):
531
+ return latent_data.to(self.dtype)
532
+ return torch.tensor(np.array(latent_data), dtype=self.dtype)
533
+
534
+ @torch.no_grad()
535
+ def _encode_image(self, image):
536
+ """Encode PIL image to VAE latent."""
537
+ if image.mode != "RGB":
538
+ image = image.convert("RGB")
539
+ if image.size != (512, 512):
540
+ image = image.resize((512, 512), Image.Resampling.LANCZOS)
541
+
542
+ img_tensor = torch.from_numpy(np.array(image)).float() / 255.0
543
+ img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0)
544
+ img_tensor = (img_tensor * 2.0 - 1.0).to(self.device, dtype=self.dtype)
545
+
546
+ latent = self.vae.encode(img_tensor).latent_dist.sample()
547
+ latent = latent * self.vae_scale
548
+ return latent.squeeze(0).cpu()
549
+
550
+ def __getitem__(self, idx):
551
+ mask = None # Default: no mask (uniform loss)
552
+
553
+ if idx < self.c1:
554
+ # Portrait sample (has pre-computed latent, no mask needed)
555
+ orig_idx = self.portrait_indices[idx]
556
+ item = self.portrait_ds[orig_idx]
557
+ latent = self._get_latent_from_array(item["latent"])
558
+ t5 = self.portrait_t5[idx]
559
+ clip = self.portrait_clip[idx]
560
+
561
+ elif idx < self.c2:
562
+ # Schnell sample (has pre-computed latent, no mask needed)
563
+ schnell_idx = idx - self.c1
564
+ item = self.schnell_ds[schnell_idx]
565
+ latent = self._get_latent_from_array(item["latent"])
566
+ t5 = self.schnell_t5[schnell_idx]
567
+ clip = self.schnell_clip[schnell_idx]
568
+
569
+ elif idx < self.c3:
570
+ # SportFashion (needs VAE encoding + product mask)
571
+ sf_idx = idx - self.c2
572
+ item = self.sportfashion_ds[sf_idx]
573
+ image = item["image"]
574
+
575
+ latent = self._encode_image(image)
576
+ t5 = self.sportfashion_t5[sf_idx]
577
+ clip = self.sportfashion_clip[sf_idx]
578
+
579
+ if self.compute_masks:
580
+ pixel_mask = create_product_mask(image)
581
+ mask = downsample_mask_to_latent(pixel_mask, 64, 64)
582
+
583
+ else:
584
+ # SynthMoCap (needs VAE encoding + SMPL body mask)
585
+ sm_idx = idx - self.c3
586
+ item = self.synthmocap_ds[sm_idx]
587
+ image = item["image"]
588
+ conditioning = item["conditioning_image"]
589
+
590
+ latent = self._encode_image(image)
591
+ t5 = self.synthmocap_t5[sm_idx]
592
+ clip = self.synthmocap_clip[sm_idx]
593
+
594
+ if self.compute_masks:
595
+ pixel_mask = create_smpl_mask(conditioning)
596
+ mask = downsample_mask_to_latent(pixel_mask, 64, 64)
597
+
598
+ result = {
599
+ "latent": latent,
600
+ "t5_embed": t5.to(self.dtype),
601
+ "clip_pooled": clip.to(self.dtype),
602
+ }
603
+
604
+ if mask is not None:
605
+ result["mask"] = mask.to(self.dtype)
606
+
607
+ return result
608
+
609
+ # ============================================================================
610
+ # COLLATE FUNCTION
611
+ # ============================================================================
612
+ def collate_fn(batch):
613
+ latents = torch.stack([b["latent"] for b in batch])
614
+ t5_embeds = torch.stack([b["t5_embed"] for b in batch])
615
+ clip_pooled = torch.stack([b["clip_pooled"] for b in batch])
616
+
617
+ # Handle masks (some samples may not have masks)
618
+ masks = None
619
+ if any("mask" in b for b in batch):
620
+ masks = []
621
+ for b in batch:
622
+ if "mask" in b:
623
+ masks.append(b["mask"])
624
+ else:
625
+ # No mask = uniform weight (all 1s)
626
+ masks.append(torch.ones(64, 64, dtype=latents.dtype))
627
+ masks = torch.stack(masks)
628
+
629
+ return {
630
+ "latents": latents,
631
+ "t5_embeds": t5_embeds,
632
+ "clip_pooled": clip_pooled,
633
+ "masks": masks,
634
+ }
635
+
636
+ # ============================================================================
637
+ # MASKED LOSS FUNCTION
638
+ # ============================================================================
639
+ def masked_mse_loss(pred, target, mask=None, fg_weight=2.0, bg_weight=0.5, snr_weights=None):
640
+ """
641
+ Compute MSE loss with optional foreground/background weighting and min-SNR.
642
+
643
+ Args:
644
+ pred: [B, H*W, C] predicted velocity
645
+ target: [B, H*W, C] target velocity
646
+ mask: [B, H, W] foreground mask (1=foreground, 0=background) or None
647
+ fg_weight: Weight for foreground pixels
648
+ bg_weight: Weight for background pixels
649
+ snr_weights: [B] min-SNR weights per sample or None
650
+
651
+ Returns:
652
+ Scalar loss value
653
+ """
654
+ B, N, C = pred.shape
655
+
656
+ if mask is None:
657
+ # No spatial mask - compute per-sample loss
658
+ loss_per_sample = ((pred - target) ** 2).mean(dim=[1, 2]) # [B]
659
+ else:
660
+ H = W = int(math.sqrt(N))
661
+ mask_flat = mask.view(B, H * W, 1).to(pred.device)
662
+ sq_error = (pred - target) ** 2
663
+ weights = mask_flat * fg_weight + (1 - mask_flat) * bg_weight
664
+ weighted_error = sq_error * weights
665
+ loss_per_sample = weighted_error.mean(dim=[1, 2]) # [B]
666
+
667
+ # Apply min-SNR weighting if provided
668
+ if snr_weights is not None:
669
+ loss_per_sample = loss_per_sample * snr_weights
670
+
671
+ return loss_per_sample.mean()
672
+
673
+ # ============================================================================
674
+ # CREATE DATASET
675
+ # ============================================================================
676
+ print("\nCreating combined dataset...")
677
+ combined_ds = CombinedDataset(
678
+ portrait_ds, portrait_indices, portrait_t5, portrait_clip,
679
+ schnell_ds, schnell_t5, schnell_clip,
680
+ sportfashion_ds, sportfashion_t5, sportfashion_clip,
681
+ synthmocap_ds, synthmocap_t5, synthmocap_clip,
682
+ vae, VAE_SCALE, DEVICE, DTYPE,
683
+ compute_masks=USE_MASKED_LOSS
684
+ )
685
+ print(f"✓ Combined dataset: {len(combined_ds)} samples")
686
+ print(f" - Portraits (3x): {combined_ds.n_portrait:,}")
687
+ print(f" - Schnell teacher: {combined_ds.n_schnell:,}")
688
+ print(f" - SportFashion: {combined_ds.n_sportfashion:,}")
689
+ print(f" - SynthMoCap: {combined_ds.n_synthmocap:,}")
690
+
691
+ # ============================================================================
692
+ # DATALOADER
693
+ # ============================================================================
694
+ loader = DataLoader(
695
+ combined_ds,
696
+ batch_size=BATCH_SIZE,
697
+ shuffle=True,
698
+ num_workers=8,
699
+ pin_memory=True,
700
+ collate_fn=collate_fn,
701
+ drop_last=True,
702
+ )
703
+ print(f"✓ DataLoader: {len(loader)} batches/epoch")
704
+
705
+ # ============================================================================
706
+ # SAMPLING FUNCTION
707
+ # ============================================================================
708
+ @torch.inference_mode()
709
+ def generate_samples(model, prompts, num_steps=28, guidance_scale=3.5, H=64, W=64, use_ema=True):
710
+ was_training = model.training
711
+ model.eval()
712
+
713
+ if use_ema and 'ema' in globals() and ema is not None:
714
+ ema.apply_shadow_for_eval(model)
715
+
716
+ B = len(prompts)
717
+ C = 16
718
+
719
+ t5_list, clip_list = [], []
720
+ for p in prompts:
721
+ t5, clip = encode_prompt(p)
722
+ t5_list.append(t5)
723
+ clip_list.append(clip)
724
+ t5_embeds = torch.stack(t5_list).to(DTYPE)
725
+ clip_pooleds = torch.stack(clip_list).to(DTYPE)
726
+
727
+ x = torch.randn(B, H * W, C, device=DEVICE, dtype=DTYPE)
728
+ img_ids = TinyFluxDeep.create_img_ids(B, H, W, DEVICE)
729
+
730
+ t_linear = torch.linspace(0, 1, num_steps + 1, device=DEVICE, dtype=DTYPE)
731
+ timesteps = flux_shift(t_linear, s=SHIFT)
732
+
733
+ for i in range(num_steps):
734
+ t_curr = timesteps[i]
735
+ t_next = timesteps[i + 1]
736
+ dt = t_next - t_curr
737
+
738
+ t_batch = t_curr.expand(B).to(DTYPE)
739
+ guidance = torch.full((B,), guidance_scale, device=DEVICE, dtype=DTYPE)
740
+
741
+ with torch.autocast("cuda", dtype=DTYPE):
742
+ v_cond = model(
743
+ hidden_states=x,
744
+ encoder_hidden_states=t5_embeds,
745
+ pooled_projections=clip_pooleds,
746
+ timestep=t_batch,
747
+ img_ids=img_ids,
748
+ guidance=guidance,
749
+ )
750
+ x = x + v_cond * dt
751
+
752
+ latents = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
753
+ latents = latents / VAE_SCALE
754
+
755
+ with torch.autocast("cuda", dtype=DTYPE):
756
+ images = vae.decode(latents.to(vae.dtype)).sample
757
+ images = (images / 2 + 0.5).clamp(0, 1)
758
+
759
+ if use_ema and 'ema' in globals() and ema is not None:
760
+ ema.restore(model)
761
+
762
+ if was_training:
763
+ model.train()
764
+ return images
765
+
766
+ def save_samples(images, prompts, step, output_dir):
767
+ from torchvision.utils import save_image
768
+ os.makedirs(output_dir, exist_ok=True)
769
+
770
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
771
+ grid_path = os.path.join(output_dir, f"samples_step_{step}.png")
772
+ save_image(images, grid_path, nrow=2, padding=2)
773
+
774
+ try:
775
+ api.upload_file(
776
+ path_or_fileobj=grid_path,
777
+ path_in_repo=f"samples/{timestamp}_step_{step}.png",
778
+ repo_id=HF_REPO,
779
+ )
780
+ except:
781
+ pass
782
+
783
+
784
+ # ============================================================================
785
+ # CHECKPOINT LOADING WITH WEIGHT UPGRADE SUPPORT
786
+ # ============================================================================
787
+ # Add this config flag near your other CONFIG section:
788
+ #
789
+ # ALLOW_WEIGHT_UPGRADE = True # Allow loading old checkpoints into new model
790
+ # ============================================================================
791
+
792
+ def load_checkpoint(model, optimizer, scheduler, target):
793
+ """
794
+ Load checkpoint with optional weight upgrade support.
795
+
796
+ When ALLOW_WEIGHT_UPGRADE=True:
797
+ - Missing Q/K norm weights are initialized to ones (identity transform)
798
+ - Unexpected keys (e.g., old sin_basis caches) are ignored
799
+ - Model behavior is identical to old weights at load time
800
+
801
+ When ALLOW_WEIGHT_UPGRADE=False:
802
+ - Requires exact weight match (strict=True)
803
+ """
804
+ start_step = 0
805
+ start_epoch = 0
806
+
807
+ if target == "none":
808
+ print("Starting fresh (no checkpoint)")
809
+ return start_step, start_epoch
810
+
811
+ ckpt_path = None
812
+ weights_path = None
813
+
814
+ if target == "latest":
815
+ if os.path.exists(CHECKPOINT_DIR):
816
+ ckpts = [f for f in os.listdir(CHECKPOINT_DIR) if f.startswith("step_") and f.endswith(".pt")]
817
+ if ckpts:
818
+ steps = [int(f.split("_")[1].split(".")[0]) for f in ckpts]
819
+ latest_step = max(steps)
820
+ ckpt_path = os.path.join(CHECKPOINT_DIR, f"step_{latest_step}.pt")
821
+ weights_path = ckpt_path.replace(".pt", ".safetensors")
822
+ elif target == "hub" or target.startswith("hub:"):
823
+ try:
824
+ from huggingface_hub import list_repo_files
825
+
826
+ if target.startswith("hub:"):
827
+ step_name = target.split(":")[1]
828
+ weights_path = hf_hub_download(HF_REPO, f"checkpoints/{step_name}.safetensors")
829
+ start_step = int(step_name.split("_")[1]) if "_" in step_name else 0
830
+ print(f"Downloaded {step_name} from hub")
831
+ else:
832
+ files = list_repo_files(HF_REPO)
833
+ ckpts = [f for f in files if f.startswith("checkpoints/step_") and f.endswith(".safetensors") and "_ema" not in f]
834
+ if ckpts:
835
+ steps = [int(f.split("_")[1].split(".")[0]) for f in ckpts]
836
+ latest = max(steps)
837
+ weights_path = hf_hub_download(HF_REPO, f"checkpoints/step_{latest}.safetensors")
838
+ start_step = latest
839
+ print(f"Downloaded step_{latest} from hub")
840
+ except Exception as e:
841
+ print(f"Could not download from hub: {e}")
842
+ return start_step, start_epoch
843
+ elif target == "best":
844
+ ckpt_path = os.path.join(CHECKPOINT_DIR, "best.pt")
845
+ weights_path = ckpt_path.replace(".pt", ".safetensors")
846
+ elif os.path.exists(target):
847
+ # Direct path provided
848
+ if target.endswith(".safetensors"):
849
+ weights_path = target
850
+ ckpt_path = target.replace(".safetensors", ".pt")
851
+ else:
852
+ ckpt_path = target
853
+ weights_path = target.replace(".pt", ".safetensors")
854
+
855
+ if weights_path and os.path.exists(weights_path):
856
+ print(f"Loading weights from {weights_path}")
857
+ state_dict = load_file(weights_path)
858
+ state_dict = {k: v.to(DTYPE) if v.is_floating_point() else v for k, v in state_dict.items()}
859
+
860
+ # Get model reference (handle torch.compile wrapper)
861
+ model_ref = model._orig_mod if hasattr(model, '_orig_mod') else model
862
+
863
+ if ALLOW_WEIGHT_UPGRADE:
864
+ # Flexible loading with weight upgrade
865
+ missing, unexpected = load_with_weight_upgrade(model_ref, state_dict)
866
+
867
+ if missing:
868
+ print(f" ℹ Initialized {len(missing)} new parameters (identity)")
869
+ if unexpected:
870
+ print(f" ℹ Ignored {len(unexpected)} deprecated parameters")
871
+ else:
872
+ # Strict loading - must match exactly
873
+ model_ref.load_state_dict(state_dict, strict=True)
874
+
875
+ print(f"✓ Loaded model weights")
876
+
877
+ if ckpt_path and os.path.exists(ckpt_path):
878
+ state = torch.load(ckpt_path, map_location="cpu")
879
+ start_step = state.get("step", 0)
880
+ start_epoch = state.get("epoch", 0)
881
+ try:
882
+ optimizer.load_state_dict(state["optimizer"])
883
+ scheduler.load_state_dict(state["scheduler"])
884
+ print(f"✓ Loaded optimizer/scheduler state")
885
+ except:
886
+ print(" ⚠ Could not load optimizer state (will use fresh optimizer)")
887
+ print(f"Resuming from step {start_step}, epoch {start_epoch}")
888
+
889
+ return start_step, start_epoch
890
+
891
+
892
+ def load_with_weight_upgrade(model, state_dict):
893
+ """
894
+ Load state dict with automatic handling of:
895
+ - Missing Q/K norm weights → initialize to ones (identity)
896
+ - Unexpected keys → ignore (e.g., old sin_basis caches)
897
+
898
+ Returns:
899
+ (missing_keys, unexpected_keys) - lists of handled keys
900
+ """
901
+ model_state = model.state_dict()
902
+
903
+ # Keys that are new in the repaired model (Q/K norms)
904
+ QK_NORM_PATTERNS = [
905
+ '.norm_q.weight',
906
+ '.norm_k.weight',
907
+ '.norm_added_q.weight',
908
+ '.norm_added_k.weight',
909
+ ]
910
+
911
+ # Keys that may exist in old checkpoints but not new model
912
+ DEPRECATED_PATTERNS = [
913
+ '.sin_basis', # Old cached sin embeddings
914
+ ]
915
+
916
+ loaded_keys = []
917
+ missing_keys = []
918
+ unexpected_keys = []
919
+ initialized_keys = []
920
+
921
+ # First pass: load matching weights
922
+ for key in state_dict.keys():
923
+ if key in model_state:
924
+ if state_dict[key].shape == model_state[key].shape:
925
+ model_state[key] = state_dict[key]
926
+ loaded_keys.append(key)
927
+ else:
928
+ print(f" ⚠ Shape mismatch for {key}: checkpoint {state_dict[key].shape} vs model {model_state[key].shape}")
929
+ unexpected_keys.append(key)
930
+ else:
931
+ # Check if it's a known deprecated key
932
+ is_deprecated = any(pat in key for pat in DEPRECATED_PATTERNS)
933
+ if is_deprecated:
934
+ unexpected_keys.append(key)
935
+ else:
936
+ print(f" ⚠ Unexpected key (not in model): {key}")
937
+ unexpected_keys.append(key)
938
+
939
+ # Second pass: handle missing keys
940
+ for key in model_state.keys():
941
+ if key not in loaded_keys:
942
+ # Check if it's a Q/K norm that needs identity initialization
943
+ is_qk_norm = any(pat in key for pat in QK_NORM_PATTERNS)
944
+
945
+ if is_qk_norm:
946
+ # Initialize to ones (identity transform for RMSNorm)
947
+ model_state[key] = torch.ones_like(model_state[key])
948
+ initialized_keys.append(key)
949
+ else:
950
+ missing_keys.append(key)
951
+ print(f" ⚠ Missing key (not in checkpoint): {key}")
952
+
953
+ # Load the updated state
954
+ model.load_state_dict(model_state, strict=False)
955
+
956
+ # Report
957
+ if initialized_keys:
958
+ print(f" ✓ Initialized Q/K norms to identity ({len(initialized_keys)} params):")
959
+ # Group by block for cleaner output
960
+ blocks = set()
961
+ for k in initialized_keys:
962
+ if 'double_blocks' in k:
963
+ block_num = k.split('.')[1]
964
+ blocks.add(f"double_blocks.{block_num}")
965
+ elif 'single_blocks' in k:
966
+ block_num = k.split('.')[1]
967
+ blocks.add(f"single_blocks.{block_num}")
968
+ for block in sorted(blocks):
969
+ print(f" - {block}.attn.norm_[q,k,added_q,added_k]")
970
+
971
+ if unexpected_keys:
972
+ deprecated = [k for k in unexpected_keys if any(p in k for p in DEPRECATED_PATTERNS)]
973
+ if deprecated:
974
+ print(f" ✓ Ignored deprecated keys: {deprecated}")
975
+
976
+ return missing_keys, unexpected_keys
977
+
978
+
979
+ # ============================================================================
980
+ # ALSO UPDATE save_checkpoint TO STRIP _orig_mod PREFIX
981
+ # ============================================================================
982
+ def save_checkpoint(model, optimizer, scheduler, step, epoch, loss, path, ema_state=None):
983
+ """Save checkpoint with proper handling of torch.compile wrapper."""
984
+ os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
985
+
986
+ # Get state dict, handling torch.compile wrapper
987
+ if hasattr(model, '_orig_mod'):
988
+ state_dict = model._orig_mod.state_dict()
989
+ else:
990
+ state_dict = model.state_dict()
991
+
992
+ # Ensure proper dtype for storage
993
+ state_dict = {k: v.to(DTYPE) if v.is_floating_point() else v for k, v in state_dict.items()}
994
+
995
+ # Save weights
996
+ weights_path = path.replace(".pt", ".safetensors")
997
+ save_file(state_dict, weights_path)
998
+
999
+ # Save EMA weights if provided
1000
+ if ema_state is not None:
1001
+ ema_weights = {k: v.to(DTYPE) if v.is_floating_point() else v for k, v in ema_state['shadow'].items()}
1002
+ ema_weights_path = path.replace(".pt", "_ema.safetensors")
1003
+ save_file(ema_weights, ema_weights_path)
1004
+
1005
+ # Save optimizer/scheduler state
1006
+ state = {
1007
+ "step": step,
1008
+ "epoch": epoch,
1009
+ "loss": loss,
1010
+ "optimizer": optimizer.state_dict(),
1011
+ "scheduler": scheduler.state_dict(),
1012
+ }
1013
+ if ema_state is not None:
1014
+ state["ema_decay"] = ema_state.get('decay', EMA_DECAY)
1015
+
1016
+ torch.save(state, path)
1017
+ print(f" ✓ Saved checkpoint: step {step}")
1018
+ return weights_path
1019
+
1020
+
1021
+ # ============================================================================
1022
+ # CREATE MODEL
1023
+ # ============================================================================
1024
+ print("\nCreating TinyFluxDeep model...")
1025
+ config = TinyFluxDeepConfig()
1026
+ model = TinyFluxDeep(config).to(device=DEVICE, dtype=DTYPE)
1027
+
1028
+ total_params = sum(p.numel() for p in model.parameters())
1029
+ print(f"Total parameters: {total_params:,}")
1030
+
1031
+ trainable_params = [p for p in model.parameters() if p.requires_grad]
1032
+ print(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
1033
+
1034
+ # ============================================================================
1035
+ # OPTIMIZER
1036
+ # ============================================================================
1037
+ opt = torch.optim.AdamW(trainable_params, lr=LR, betas=(0.9, 0.99), weight_decay=0.01, fused=True)
1038
+
1039
+ total_steps = len(loader) * EPOCHS // GRAD_ACCUM
1040
+ warmup = min(1000, total_steps // 10)
1041
+
1042
+ def lr_fn(step):
1043
+ if step < warmup:
1044
+ return step / warmup
1045
+ return 0.5 * (1 + math.cos(math.pi * (step - warmup) / (total_steps - warmup)))
1046
+
1047
+ sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_fn)
1048
+
1049
+ # ============================================================================
1050
+ # LOAD CHECKPOINT
1051
+ # ============================================================================
1052
+ start_step, start_epoch = load_checkpoint(model, opt, sched, LOAD_TARGET)
1053
+
1054
+ if RESUME_STEP is not None:
1055
+ start_step = RESUME_STEP
1056
+
1057
+ # ============================================================================
1058
+ # COMPILE
1059
+ # ============================================================================
1060
+ model = torch.compile(model, mode="default")
1061
+
1062
+ # ============================================================================
1063
+ # EMA
1064
+ # ============================================================================
1065
+ print("Initializing EMA...")
1066
+ ema = EMA(model, decay=EMA_DECAY)
1067
+
1068
+ # ============================================================================
1069
+ # TENSORBOARD
1070
+ # ============================================================================
1071
+ run_name = f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
1072
+ writer = SummaryWriter(os.path.join(LOG_DIR, run_name))
1073
+
1074
+ # Sample prompts
1075
+ SAMPLE_PROMPTS = [
1076
+ "a photo of a cat sitting on a windowsill",
1077
+ "a portrait of a woman with red hair",
1078
+ "a black backpack on white background",
1079
+ "a person standing in a t-pose",
1080
+ ]
1081
+
1082
+ # ============================================================================
1083
+ # TRAINING LOOP
1084
+ # ============================================================================
1085
+ print(f"\n{'='*60}")
1086
+ print(f"Training TinyFlux-Deep")
1087
+ print(f"{'='*60}")
1088
+ print(f"Total: {len(combined_ds):,} samples")
1089
+ print(f"Epochs: {EPOCHS}, Steps/epoch: {len(loader)}, Total: {total_steps}")
1090
+ print(f"Batch: {BATCH_SIZE} x {GRAD_ACCUM} = {BATCH_SIZE * GRAD_ACCUM}")
1091
+ print(f"Masked loss: {USE_MASKED_LOSS} (fg={FG_LOSS_WEIGHT}, bg={BG_LOSS_WEIGHT})")
1092
+ print(f"Min-SNR gamma: {MIN_SNR_GAMMA}")
1093
+ print(f"Resume: step {start_step}, epoch {start_epoch}")
1094
+
1095
+ model.train()
1096
+ step = start_step
1097
+ best = float("inf")
1098
+
1099
+ for ep in range(start_epoch, EPOCHS):
1100
+ ep_loss = 0
1101
+ ep_batches = 0
1102
+ pbar = tqdm(loader, desc=f"E{ep + 1}")
1103
+
1104
+ for i, batch in enumerate(pbar):
1105
+ latents = batch["latents"].to(DEVICE, non_blocking=True)
1106
+ t5 = batch["t5_embeds"].to(DEVICE, non_blocking=True)
1107
+ clip = batch["clip_pooled"].to(DEVICE, non_blocking=True)
1108
+ masks = batch["masks"]
1109
+ if masks is not None:
1110
+ masks = masks.to(DEVICE, non_blocking=True)
1111
+
1112
+ B, C, H, W = latents.shape
1113
+ data = latents.permute(0, 2, 3, 1).reshape(B, H * W, C)
1114
+
1115
+ noise = torch.randn_like(data)
1116
+
1117
+ if TEXT_DROPOUT > 0:
1118
+ t5, clip, _ = apply_text_dropout(t5, clip, TEXT_DROPOUT)
1119
+
1120
+ t = torch.sigmoid(torch.randn(B, device=DEVICE))
1121
+ t = flux_shift(t, s=SHIFT).to(DTYPE).clamp(1e-4, 1 - 1e-4)
1122
+
1123
+ t_expanded = t.view(B, 1, 1)
1124
+ x_t = (1 - t_expanded) * noise + t_expanded * data
1125
+ v_target = data - noise
1126
+
1127
+ img_ids = TinyFluxDeep.create_img_ids(B, H, W, DEVICE)
1128
+
1129
+ guidance = torch.rand(B, device=DEVICE, dtype=DTYPE) * 4 + 1
1130
+ if GUIDANCE_DROPOUT > 0:
1131
+ guide_mask = torch.rand(B, device=DEVICE) < GUIDANCE_DROPOUT
1132
+ guidance[guide_mask] = 1.0
1133
+
1134
+ with torch.autocast("cuda", dtype=DTYPE):
1135
+ v_pred = model(
1136
+ hidden_states=x_t,
1137
+ encoder_hidden_states=t5,
1138
+ pooled_projections=clip,
1139
+ timestep=t,
1140
+ img_ids=img_ids,
1141
+ guidance=guidance,
1142
+ )
1143
+
1144
+ # Compute loss with min-SNR weighting
1145
+ snr_weights = min_snr_weight(t) # [B]
1146
+
1147
+ # Unified loss: handles mask + SNR weighting
1148
+ loss = masked_mse_loss(
1149
+ v_pred, v_target,
1150
+ mask=masks if USE_MASKED_LOSS else None,
1151
+ fg_weight=FG_LOSS_WEIGHT,
1152
+ bg_weight=BG_LOSS_WEIGHT,
1153
+ snr_weights=snr_weights
1154
+ ) / GRAD_ACCUM
1155
+
1156
+ loss.backward()
1157
+
1158
+ if (i + 1) % GRAD_ACCUM == 0:
1159
+ grad_norm = torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
1160
+ opt.step()
1161
+ sched.step()
1162
+ opt.zero_grad(set_to_none=True)
1163
+
1164
+ ema.update(model)
1165
+ step += 1
1166
+
1167
+ if step % LOG_EVERY == 0:
1168
+ writer.add_scalar("train/loss", loss.item() * GRAD_ACCUM, step)
1169
+ writer.add_scalar("train/lr", sched.get_last_lr()[0], step)
1170
+ writer.add_scalar("train/grad_norm", grad_norm.item(), step)
1171
+
1172
+ if step % SAMPLE_EVERY == 0:
1173
+ print(f"\n Generating samples at step {step}...")
1174
+ images = generate_samples(model, SAMPLE_PROMPTS, num_steps=20, use_ema=True)
1175
+ save_samples(images, SAMPLE_PROMPTS, step, SAMPLE_DIR)
1176
+
1177
+ if step % SAVE_EVERY == 0:
1178
+ ckpt_path = os.path.join(CHECKPOINT_DIR, f"step_{step}.pt")
1179
+ weights_path = save_checkpoint(model, opt, sched, step, ep, loss.item(), ckpt_path, ema_state=ema.state_dict())
1180
+ if step % UPLOAD_EVERY == 0:
1181
+ upload_checkpoint(weights_path, step)
1182
+
1183
+ ep_loss += loss.item() * GRAD_ACCUM
1184
+ ep_batches += 1
1185
+ pbar.set_postfix(loss=f"{loss.item() * GRAD_ACCUM:.4f}", step=step)
1186
+
1187
+ avg = ep_loss / max(ep_batches, 1)
1188
+ print(f"Epoch {ep + 1} loss: {avg:.4f}")
1189
+
1190
+ if avg < best:
1191
+ best = avg
1192
+ weights_path = save_checkpoint(model, opt, sched, step, ep, avg, os.path.join(CHECKPOINT_DIR, "best.pt"), ema_state=ema.state_dict())
1193
+ try:
1194
+ api.upload_file(path_or_fileobj=weights_path, path_in_repo="model.safetensors", repo_id=HF_REPO)
1195
+ except:
1196
+ pass
1197
+
1198
+ print(f"\n✓ Training complete! Best loss: {best:.4f}")
1199
+ writer.close()