AbstractPhil commited on
Commit
55372de
·
verified ·
1 Parent(s): 57cf852

Create trainer_v3_expert_guidance.py

Browse files
Files changed (1) hide show
  1. trainer_v3_expert_guidance.py +1499 -0
trainer_v3_expert_guidance.py ADDED
@@ -0,0 +1,1499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================================
2
+ # TinyFlux-Deep Training Cell - With Expert Distillation (Precached)
3
+ # ============================================================================
4
+ # Integrates SD1.5-flow-lune as a frozen timestep expert.
5
+ # Expert features are PRECACHED at 10 timestep buckets for speed.
6
+ # The ExpertPredictor learns to emulate expert features from (t, CLIP).
7
+ # At inference, no expert needed - predictor runs standalone.
8
+ #
9
+ # USAGE: Run model.py cell first, then this cell
10
+ # ============================================================================
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from torch.utils.data import DataLoader, Dataset
16
+ from datasets import load_dataset, concatenate_datasets
17
+ from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer
18
+ from huggingface_hub import HfApi, hf_hub_download
19
+ from safetensors.torch import save_file, load_file
20
+ from torch.utils.tensorboard import SummaryWriter
21
+ from tqdm.auto import tqdm
22
+ import numpy as np
23
+ import math
24
+ import json
25
+ import random
26
+ from typing import Tuple, Optional, Dict, List
27
+ import os
28
+ from datetime import datetime
29
+ from PIL import Image
30
+
31
+ # ============================================================================
32
+ # CUDA OPTIMIZATIONS
33
+ # ============================================================================
34
+ torch.backends.cuda.matmul.allow_tf32 = True
35
+ torch.backends.cudnn.allow_tf32 = True
36
+ torch.backends.cudnn.benchmark = True
37
+ torch.set_float32_matmul_precision('high')
38
+
39
+ import warnings
40
+ warnings.filterwarnings('ignore', message='.*TF32.*')
41
+
42
+ # ============================================================================
43
+ # CONFIG
44
+ # ============================================================================
45
+ BATCH_SIZE = 16
46
+ GRAD_ACCUM = 2
47
+ LR = 3e-4
48
+ EPOCHS = 40
49
+ MAX_SEQ = 128
50
+ SHIFT = 3.0
51
+ DEVICE = "cuda"
52
+ DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
53
+
54
+ ALLOW_WEIGHT_UPGRADE = True
55
+
56
+ # HuggingFace Hub
57
+ HF_REPO = "AbstractPhil/tiny-flux-deep"
58
+ SAVE_EVERY = 625
59
+ UPLOAD_EVERY = 625
60
+ SAMPLE_EVERY = 312
61
+ LOG_EVERY = 10
62
+ LOG_UPLOAD_EVERY = 625
63
+
64
+ # Checkpoint loading
65
+ LOAD_TARGET = "hub:step_305000"
66
+ RESUME_STEP = None
67
+
68
+ # ============================================================================
69
+ # EXPERT DISTILLATION CONFIG
70
+ # ============================================================================
71
+ ENABLE_EXPERT_DISTILLATION = True
72
+ EXPERT_CHECKPOINT = "AbstractPhil/sd15-flow-lune-flux"
73
+ EXPERT_CHECKPOINT_PATH = "flux_t2_6_pose_t4_6_port_t1_4/checkpoint-00018765/unet/diffusion_pytorch_model.safetensors"
74
+ EXPERT_DIM = 1280
75
+ EXPERT_HIDDEN_DIM = 512
76
+ EXPERT_DROPOUT = 0.1 # Prob of forcing predictor (applied outside model)
77
+ DISTILL_LOSS_WEIGHT = 0.1
78
+ DISTILL_WARMUP_STEPS = 1000
79
+
80
+ # Timestep buckets for precaching
81
+ EXPERT_T_BUCKETS = torch.linspace(0.05, 0.95, 10)
82
+
83
+ # ============================================================================
84
+ # DATASET CONFIG
85
+ # ============================================================================
86
+ ENABLE_PORTRAIT = False
87
+ ENABLE_SCHNELL = True
88
+ ENABLE_SPORTFASHION = False
89
+ ENABLE_SYNTHMOCAP = False
90
+
91
+ PORTRAIT_REPO = "AbstractPhil/ffhq_flux_latents_repaired"
92
+ PORTRAIT_NUM_SHARDS = 11
93
+ SCHNELL_REPO = "AbstractPhil/flux-schnell-teacher-latents"
94
+ SCHNELL_CONFIGS = ["train_512"]
95
+ SPORTFASHION_REPO = "Pianokill/SportFashion_512x512"
96
+ SYNTHMOCAP_REPO = "toyxyz/SynthMoCap_smpl_512"
97
+
98
+ FG_LOSS_WEIGHT = 2.0
99
+ BG_LOSS_WEIGHT = 0.5
100
+ USE_MASKED_LOSS = False
101
+ MIN_SNR_GAMMA = 5.0
102
+
103
+ # Paths
104
+ CHECKPOINT_DIR = "./tiny_flux_deep_checkpoints"
105
+ LOG_DIR = "./tiny_flux_deep_logs"
106
+ SAMPLE_DIR = "./tiny_flux_deep_samples"
107
+ ENCODING_CACHE_DIR = "./encoding_cache"
108
+ LATENT_CACHE_DIR = "./latent_cache"
109
+
110
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
111
+ os.makedirs(LOG_DIR, exist_ok=True)
112
+ os.makedirs(SAMPLE_DIR, exist_ok=True)
113
+ os.makedirs(ENCODING_CACHE_DIR, exist_ok=True)
114
+ os.makedirs(LATENT_CACHE_DIR, exist_ok=True)
115
+
116
+ # ============================================================================
117
+ # REGULARIZATION CONFIG
118
+ # ============================================================================
119
+ TEXT_DROPOUT = 0.1
120
+ GUIDANCE_DROPOUT = 0.1
121
+ EMA_DECAY = 0.9999
122
+
123
+
124
+ # ============================================================================
125
+ # EXPERT FEATURE CACHE (precached, fast lookup + interpolation)
126
+ # ============================================================================
127
+
128
+ class ExpertFeatureCache:
129
+ """
130
+ Precached SD1.5-flow expert features with timestep interpolation.
131
+
132
+ Features extracted at 10 timestep buckets [0.05, 0.15, ..., 0.95].
133
+ At runtime, interpolates between nearest buckets.
134
+ """
135
+
136
+ def __init__(self, features: torch.Tensor, t_buckets: torch.Tensor, dtype=torch.float16):
137
+ self.features = features.to(dtype) # [N, 10, 1280]
138
+ self.t_buckets = t_buckets
139
+ self.t_min = t_buckets[0].item()
140
+ self.t_max = t_buckets[-1].item()
141
+ self.t_step = (t_buckets[1] - t_buckets[0]).item()
142
+ self.n_buckets = len(t_buckets)
143
+ self.dtype = dtype
144
+
145
+ def get_features(self, indices: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
146
+ """
147
+ Get interpolated expert features.
148
+
149
+ Args:
150
+ indices: [B] sample indices into dataset
151
+ timesteps: [B] timesteps in [0, 1]
152
+
153
+ Returns:
154
+ [B, 1280] interpolated features
155
+ """
156
+ device = timesteps.device
157
+
158
+ # Clamp to valid range
159
+ t_clamped = timesteps.float().clamp(self.t_min, self.t_max)
160
+
161
+ # Find bucket indices
162
+ t_idx_float = (t_clamped - self.t_min) / self.t_step
163
+ t_idx_low = t_idx_float.long().clamp(0, self.n_buckets - 2)
164
+ t_idx_high = (t_idx_low + 1).clamp(0, self.n_buckets - 1)
165
+
166
+ # Interpolation alpha
167
+ alpha = (t_idx_float - t_idx_low.float()).unsqueeze(-1) # [B, 1]
168
+
169
+ # Gather (on CPU for large caches)
170
+ idx_cpu = indices.cpu()
171
+ t_low_cpu = t_idx_low.cpu()
172
+ t_high_cpu = t_idx_high.cpu()
173
+
174
+ f_low = self.features[idx_cpu, t_low_cpu] # [B, 1280]
175
+ f_high = self.features[idx_cpu, t_high_cpu] # [B, 1280]
176
+
177
+ # Interpolate and move to device
178
+ result = (1 - alpha.cpu()) * f_low + alpha.cpu() * f_high
179
+ return result.to(device=device, dtype=self.dtype)
180
+
181
+
182
+ def load_or_extract_expert_features(cache_path: str, prompts: List[str], name: str,
183
+ clip_tok, clip_enc, t_buckets: torch.Tensor,
184
+ batch_size: int = 32) -> Optional[ExpertFeatureCache]:
185
+ """
186
+ Load cached expert features or extract them from SD1.5-flow.
187
+ Follows same pattern as load_or_encode for text embeddings.
188
+ """
189
+ if not prompts or not ENABLE_EXPERT_DISTILLATION:
190
+ return None
191
+
192
+ # Check cache
193
+ if os.path.exists(cache_path):
194
+ print(f"Loading cached {name} expert features...")
195
+ cached = torch.load(cache_path, map_location="cpu")
196
+ cache = ExpertFeatureCache(cached["features"], cached["t_buckets"], DTYPE)
197
+ print(f" ✓ Loaded {cache.features.shape[0]} samples × {cache.n_buckets} timesteps")
198
+ return cache
199
+
200
+ # Extract features
201
+ print(f"Extracting {name} expert features ({len(prompts)} × {len(t_buckets)} timesteps)...")
202
+ print(f" This is a one-time operation, will be cached for future runs.")
203
+
204
+ # Load expert model temporarily
205
+ checkpoint_path = hf_hub_download(
206
+ repo_id=EXPERT_CHECKPOINT,
207
+ filename=EXPERT_CHECKPOINT_PATH,
208
+ )
209
+
210
+ from diffusers import UNet2DConditionModel
211
+ unet = UNet2DConditionModel.from_pretrained(
212
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
213
+ subfolder="unet",
214
+ torch_dtype=DTYPE,
215
+ ).to(DEVICE).eval()
216
+
217
+ state_dict = load_file(checkpoint_path)
218
+ unet.load_state_dict(state_dict, strict=False)
219
+
220
+ for p in unet.parameters():
221
+ p.requires_grad = False
222
+
223
+ # Hook for mid-block features
224
+ mid_features = [None]
225
+ def hook_fn(module, inp, out):
226
+ mid_features[0] = out.mean(dim=[2, 3])
227
+ unet.mid_block.register_forward_hook(hook_fn)
228
+
229
+ # Extract
230
+ n_prompts = len(prompts)
231
+ n_buckets = len(t_buckets)
232
+ all_features = torch.zeros(n_prompts, n_buckets, EXPERT_DIM, dtype=torch.float16)
233
+
234
+ with torch.no_grad():
235
+ for start_idx in tqdm(range(0, n_prompts, batch_size), desc=f"Extracting {name}"):
236
+ end_idx = min(start_idx + batch_size, n_prompts)
237
+ batch_prompts = prompts[start_idx:end_idx]
238
+ B = len(batch_prompts)
239
+
240
+ # Encode CLIP hidden states
241
+ clip_inputs = clip_tok(
242
+ batch_prompts, return_tensors="pt", padding="max_length",
243
+ max_length=77, truncation=True
244
+ ).to(DEVICE)
245
+ clip_hidden = clip_enc(**clip_inputs).last_hidden_state # [B, 77, 768]
246
+
247
+ # Extract at each timestep bucket
248
+ for t_idx, t_val in enumerate(t_buckets):
249
+ timesteps = torch.full((B,), t_val.item(), device=DEVICE)
250
+ latents = torch.randn(B, 4, 64, 64, device=DEVICE, dtype=DTYPE)
251
+
252
+ _ = unet(latents, timesteps * 1000, encoder_hidden_states=clip_hidden.to(DTYPE))
253
+
254
+ all_features[start_idx:end_idx, t_idx] = mid_features[0].cpu().to(torch.float16)
255
+
256
+ # Cleanup
257
+ del unet
258
+ torch.cuda.empty_cache()
259
+
260
+ # Save cache
261
+ torch.save({"features": all_features, "t_buckets": t_buckets}, cache_path)
262
+ print(f" ✓ Cached to {cache_path}")
263
+ print(f" Size: {all_features.numel() * 2 / 1e9:.2f} GB")
264
+
265
+ return ExpertFeatureCache(all_features, t_buckets, DTYPE)
266
+
267
+
268
+ # ============================================================================
269
+ # EMA
270
+ # ============================================================================
271
+ class EMA:
272
+ def __init__(self, model, decay=0.9999):
273
+ self.decay = decay
274
+ self.shadow = {}
275
+ self._backup = {}
276
+ if hasattr(model, '_orig_mod'):
277
+ state = model._orig_mod.state_dict()
278
+ else:
279
+ state = model.state_dict()
280
+ for k, v in state.items():
281
+ self.shadow[k] = v.clone().detach()
282
+
283
+ @torch.no_grad()
284
+ def update(self, model):
285
+ if hasattr(model, '_orig_mod'):
286
+ state = model._orig_mod.state_dict()
287
+ else:
288
+ state = model.state_dict()
289
+ for k, v in state.items():
290
+ if k in self.shadow:
291
+ self.shadow[k].lerp_(v.to(self.shadow[k].dtype), 1 - self.decay)
292
+
293
+ def apply_shadow_for_eval(self, model):
294
+ if hasattr(model, '_orig_mod'):
295
+ self._backup = {k: v.clone() for k, v in model._orig_mod.state_dict().items()}
296
+ model._orig_mod.load_state_dict(self.shadow)
297
+ else:
298
+ self._backup = {k: v.clone() for k, v in model.state_dict().items()}
299
+ model.load_state_dict(self.shadow)
300
+
301
+ def restore(self, model):
302
+ if hasattr(model, '_orig_mod'):
303
+ model._orig_mod.load_state_dict(self._backup)
304
+ else:
305
+ model.load_state_dict(self._backup)
306
+ self._backup = {}
307
+
308
+ def state_dict(self):
309
+ return {'shadow': self.shadow, 'decay': self.decay}
310
+
311
+ def load_state_dict(self, state):
312
+ self.shadow = {k: v.clone() for k, v in state['shadow'].items()}
313
+ self.decay = state.get('decay', self.decay)
314
+
315
+ def load_shadow(self, shadow_state):
316
+ """Load EMA shadow weights, handling architecture changes gracefully."""
317
+ device = next(iter(self.shadow.values())).device if self.shadow else 'cuda'
318
+
319
+ loaded = 0
320
+ skipped_old = 0
321
+ kept_new = 0
322
+
323
+ for k, v in shadow_state.items():
324
+ if k in self.shadow:
325
+ # Key exists in current model - load it
326
+ self.shadow[k] = v.clone().to(device)
327
+ loaded += 1
328
+ else:
329
+ # Key doesn't exist (deprecated like guidance_in)
330
+ skipped_old += 1
331
+
332
+ # Count new keys not in checkpoint
333
+ for k in self.shadow:
334
+ if k not in shadow_state:
335
+ kept_new += 1
336
+
337
+ print(f" ✓ Restored EMA: {loaded} loaded, {skipped_old} deprecated skipped, {kept_new} new (fresh init)")
338
+
339
+
340
+ # ============================================================================
341
+ # REGULARIZATION
342
+ # ============================================================================
343
+ def apply_text_dropout(t5_embeds, clip_pooled, dropout_prob=0.1):
344
+ B = t5_embeds.shape[0]
345
+ mask = torch.rand(B, device=t5_embeds.device) < dropout_prob
346
+ t5_embeds = t5_embeds.clone()
347
+ clip_pooled = clip_pooled.clone()
348
+ t5_embeds[mask] = 0
349
+ clip_pooled[mask] = 0
350
+ return t5_embeds, clip_pooled, mask
351
+
352
+
353
+ # ============================================================================
354
+ # MASKING UTILITIES
355
+ # ============================================================================
356
+ def detect_background_color(image: Image.Image, sample_size: int = 100) -> Tuple[int, int, int]:
357
+ img = np.array(image)
358
+ if len(img.shape) == 2:
359
+ img = np.stack([img] * 3, axis=-1)
360
+ h, w = img.shape[:2]
361
+ corners = [
362
+ img[:sample_size, :sample_size],
363
+ img[:sample_size, -sample_size:],
364
+ img[-sample_size:, :sample_size],
365
+ img[-sample_size:, -sample_size:],
366
+ ]
367
+ corner_pixels = np.concatenate([c.reshape(-1, 3) for c in corners], axis=0)
368
+ bg_color = np.median(corner_pixels, axis=0).astype(np.uint8)
369
+ return tuple(bg_color)
370
+
371
+
372
+ def create_product_mask(image: Image.Image, threshold: int = 30) -> np.ndarray:
373
+ img = np.array(image).astype(np.float32)
374
+ if len(img.shape) == 2:
375
+ img = np.stack([img] * 3, axis=-1)
376
+ bg_color = detect_background_color(image)
377
+ bg_color = np.array(bg_color, dtype=np.float32)
378
+ diff = np.sqrt(np.sum((img - bg_color) ** 2, axis=-1))
379
+ mask = (diff > threshold).astype(np.float32)
380
+ return mask
381
+
382
+
383
+ def create_smpl_mask(conditioning_image: Image.Image, threshold: int = 20) -> np.ndarray:
384
+ img = np.array(conditioning_image).astype(np.float32)
385
+ if len(img.shape) == 2:
386
+ return (img > threshold).astype(np.float32)
387
+ r, g, b = img[:, :, 0], img[:, :, 1], img[:, :, 2]
388
+ is_background = (g > r + 20) & (g > b + 20)
389
+ mask = (~is_background).astype(np.float32)
390
+ return mask
391
+
392
+
393
+ def downsample_mask_to_latent(mask: np.ndarray, latent_h: int = 64, latent_w: int = 64) -> torch.Tensor:
394
+ mask_pil = Image.fromarray((mask * 255).astype(np.uint8))
395
+ mask_pil = mask_pil.resize((latent_w, latent_h), Image.Resampling.BILINEAR)
396
+ mask_latent = np.array(mask_pil).astype(np.float32) / 255.0
397
+ return torch.from_numpy(mask_latent)
398
+
399
+
400
+ # ============================================================================
401
+ # HF HUB SETUP
402
+ # ============================================================================
403
+ print("Setting up HuggingFace Hub...")
404
+ api = HfApi()
405
+
406
+
407
+ # ============================================================================
408
+ # FLOW MATCHING HELPERS
409
+ # ============================================================================
410
+ def flux_shift(t, s=SHIFT):
411
+ return s * t / (1 + (s - 1) * t)
412
+
413
+ def min_snr_weight(t, gamma=MIN_SNR_GAMMA):
414
+ snr = (t / (1 - t).clamp(min=1e-5)).pow(2)
415
+ return torch.clamp(snr, max=gamma) / snr.clamp(min=1e-5)
416
+
417
+
418
+ # ============================================================================
419
+ # LOAD TEXT ENCODERS
420
+ # ============================================================================
421
+ print("Loading text encoders...")
422
+ t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
423
+ t5_enc = T5EncoderModel.from_pretrained("google/flan-t5-base", torch_dtype=DTYPE).to(DEVICE).eval()
424
+ for p in t5_enc.parameters():
425
+ p.requires_grad = False
426
+
427
+ clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
428
+ clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval()
429
+ for p in clip_enc.parameters():
430
+ p.requires_grad = False
431
+ print("✓ Text encoders loaded")
432
+
433
+
434
+ # ============================================================================
435
+ # LOAD VAE
436
+ # ============================================================================
437
+ print("Loading VAE...")
438
+ from diffusers import AutoencoderKL
439
+ vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=DTYPE).to(DEVICE).eval()
440
+ for p in vae.parameters():
441
+ p.requires_grad = False
442
+ VAE_SCALE = vae.config.scaling_factor
443
+ print(f"✓ VAE loaded (scale={VAE_SCALE})")
444
+
445
+
446
+ # ============================================================================
447
+ # ENCODING FUNCTIONS
448
+ # ============================================================================
449
+ @torch.no_grad()
450
+ def encode_prompt(prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
451
+ t5_inputs = t5_tok(prompt, return_tensors="pt", padding="max_length",
452
+ max_length=MAX_SEQ, truncation=True).to(DEVICE)
453
+ t5_out = t5_enc(**t5_inputs).last_hidden_state
454
+ clip_inputs = clip_tok(prompt, return_tensors="pt", padding="max_length",
455
+ max_length=77, truncation=True).to(DEVICE)
456
+ clip_out = clip_enc(**clip_inputs).pooler_output
457
+ return t5_out.squeeze(0), clip_out.squeeze(0)
458
+
459
+
460
+ @torch.no_grad()
461
+ def encode_prompts_batched(prompts: List[str], batch_size: int = 64) -> Tuple[torch.Tensor, torch.Tensor]:
462
+ all_t5 = []
463
+ all_clip = []
464
+ for i in tqdm(range(0, len(prompts), batch_size), desc="Encoding", leave=False):
465
+ batch = prompts[i:i+batch_size]
466
+ t5_inputs = t5_tok(batch, return_tensors="pt", padding="max_length",
467
+ max_length=MAX_SEQ, truncation=True).to(DEVICE)
468
+ t5_out = t5_enc(**t5_inputs).last_hidden_state
469
+ all_t5.append(t5_out.cpu())
470
+ clip_inputs = clip_tok(batch, return_tensors="pt", padding="max_length",
471
+ max_length=77, truncation=True).to(DEVICE)
472
+ clip_out = clip_enc(**clip_inputs).pooler_output
473
+ all_clip.append(clip_out.cpu())
474
+ return torch.cat(all_t5, dim=0), torch.cat(all_clip, dim=0)
475
+
476
+
477
+ @torch.no_grad()
478
+ def encode_image_to_latent(image: Image.Image) -> torch.Tensor:
479
+ if image.mode != "RGB":
480
+ image = image.convert("RGB")
481
+ if image.size != (512, 512):
482
+ image = image.resize((512, 512), Image.Resampling.LANCZOS)
483
+ img_tensor = torch.from_numpy(np.array(image)).float() / 255.0
484
+ img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0)
485
+ img_tensor = (img_tensor * 2.0 - 1.0).to(DEVICE, dtype=DTYPE)
486
+ latent = vae.encode(img_tensor).latent_dist.sample()
487
+ latent = latent * VAE_SCALE
488
+ return latent.squeeze(0).cpu()
489
+
490
+
491
+ # ============================================================================
492
+ # LOAD DATASETS
493
+ # ============================================================================
494
+
495
+ portrait_ds = None
496
+ portrait_indices = []
497
+ portrait_prompts = []
498
+
499
+ if ENABLE_PORTRAIT:
500
+ print(f"\n[1/4] Loading portrait dataset from {PORTRAIT_REPO}...")
501
+ portrait_shards = []
502
+ for i in range(PORTRAIT_NUM_SHARDS):
503
+ split_name = f"train_{i:02d}"
504
+ print(f" Loading {split_name}...")
505
+ shard = load_dataset(PORTRAIT_REPO, split=split_name)
506
+ portrait_shards.append(shard)
507
+ portrait_ds = concatenate_datasets(portrait_shards)
508
+ print(f"✓ Portrait: {len(portrait_ds)} base samples")
509
+ print(" Extracting prompts (columnar)...")
510
+ florence_list = list(portrait_ds["text_florence"])
511
+ llava_list = list(portrait_ds["text_llava"])
512
+ blip_list = list(portrait_ds["text_blip"])
513
+ for i, (f, l, b) in enumerate(zip(florence_list, llava_list, blip_list)):
514
+ if f and f.strip():
515
+ portrait_indices.append(i)
516
+ portrait_prompts.append(f)
517
+ if l and l.strip():
518
+ portrait_indices.append(i)
519
+ portrait_prompts.append(l)
520
+ if b and b.strip():
521
+ portrait_indices.append(i)
522
+ portrait_prompts.append(b)
523
+ print(f" Expanded: {len(portrait_prompts)} samples (3 prompts/image)")
524
+ else:
525
+ print("\n[1/4] Portrait dataset DISABLED")
526
+
527
+ schnell_ds = None
528
+ schnell_prompts = []
529
+
530
+ if ENABLE_SCHNELL:
531
+ print(f"\n[2/4] Loading schnell teacher dataset from {SCHNELL_REPO}...")
532
+ schnell_datasets = []
533
+ for config in SCHNELL_CONFIGS:
534
+ print(f" Loading {config}...")
535
+ ds = load_dataset(SCHNELL_REPO, config, split="train")
536
+ schnell_datasets.append(ds)
537
+ print(f" {len(ds)} samples")
538
+ schnell_ds = concatenate_datasets(schnell_datasets)
539
+ schnell_prompts = list(schnell_ds["prompt"])
540
+ print(f"✓ Schnell: {len(schnell_ds)} samples")
541
+ else:
542
+ print("\n[2/4] Schnell dataset DISABLED")
543
+
544
+ sportfashion_ds = None
545
+ sportfashion_prompts = []
546
+
547
+ if ENABLE_SPORTFASHION:
548
+ print(f"\n[3/4] Loading SportFashion dataset from {SPORTFASHION_REPO}...")
549
+ sportfashion_ds = load_dataset(SPORTFASHION_REPO, split="train")
550
+ sportfashion_prompts = list(sportfashion_ds["text"])
551
+ print(f"✓ SportFashion: {len(sportfashion_ds)} samples")
552
+ else:
553
+ print("\n[3/4] SportFashion dataset DISABLED")
554
+
555
+ synthmocap_ds = None
556
+ synthmocap_prompts = []
557
+
558
+ if ENABLE_SYNTHMOCAP:
559
+ print(f"\n[4/4] Loading SynthMoCap dataset from {SYNTHMOCAP_REPO}...")
560
+ synthmocap_ds = load_dataset(SYNTHMOCAP_REPO, split="train")
561
+ synthmocap_prompts = list(synthmocap_ds["text"])
562
+ print(f"✓ SynthMoCap: {len(synthmocap_ds)} samples")
563
+ else:
564
+ print("\n[4/4] SynthMoCap dataset DISABLED")
565
+
566
+
567
+ # ============================================================================
568
+ # ENCODE ALL PROMPTS
569
+ # ============================================================================
570
+ total_samples = len(portrait_prompts) + len(schnell_prompts) + len(sportfashion_prompts) + len(synthmocap_prompts)
571
+ print(f"\nTotal combined samples: {total_samples}")
572
+
573
+ def load_or_encode(cache_path, prompts, name):
574
+ if not prompts:
575
+ return None, None
576
+ if os.path.exists(cache_path):
577
+ print(f"Loading cached {name} encodings...")
578
+ cached = torch.load(cache_path)
579
+ return cached["t5_embeds"], cached["clip_pooled"]
580
+ else:
581
+ print(f"Encoding {len(prompts)} {name} prompts...")
582
+ t5, clip = encode_prompts_batched(prompts, batch_size=64)
583
+ torch.save({"t5_embeds": t5, "clip_pooled": clip}, cache_path)
584
+ print(f"✓ Cached to {cache_path}")
585
+ return t5, clip
586
+
587
+
588
+ # Standard text encodings
589
+ portrait_t5, portrait_clip = None, None
590
+ schnell_t5, schnell_clip = None, None
591
+ sportfashion_t5, sportfashion_clip = None, None
592
+ synthmocap_t5, synthmocap_clip = None, None
593
+
594
+ if portrait_prompts:
595
+ portrait_enc_cache = os.path.join(ENCODING_CACHE_DIR, f"portrait_encodings_{len(portrait_prompts)}.pt")
596
+ portrait_t5, portrait_clip = load_or_encode(portrait_enc_cache, portrait_prompts, "portrait")
597
+
598
+ if schnell_prompts:
599
+ schnell_enc_cache = os.path.join(ENCODING_CACHE_DIR, f"schnell_encodings_{len(schnell_prompts)}.pt")
600
+ schnell_t5, schnell_clip = load_or_encode(schnell_enc_cache, schnell_prompts, "schnell")
601
+
602
+ if sportfashion_prompts:
603
+ sportfashion_enc_cache = os.path.join(ENCODING_CACHE_DIR, f"sportfashion_encodings_{len(sportfashion_prompts)}.pt")
604
+ sportfashion_t5, sportfashion_clip = load_or_encode(sportfashion_enc_cache, sportfashion_prompts, "sportfashion")
605
+
606
+ if synthmocap_prompts:
607
+ synthmocap_enc_cache = os.path.join(ENCODING_CACHE_DIR, f"synthmocap_encodings_{len(synthmocap_prompts)}.pt")
608
+ synthmocap_t5, synthmocap_clip = load_or_encode(synthmocap_enc_cache, synthmocap_prompts, "synthmocap")
609
+
610
+
611
+ # ============================================================================
612
+ # EXTRACT/LOAD EXPERT FEATURES (precached)
613
+ # ============================================================================
614
+ print("\n" + "="*60)
615
+ print("Expert Feature Caching")
616
+ print("="*60)
617
+
618
+ schnell_expert_cache = None
619
+ portrait_expert_cache = None
620
+ sportfashion_expert_cache = None
621
+ synthmocap_expert_cache = None
622
+
623
+ if schnell_prompts and ENABLE_EXPERT_DISTILLATION:
624
+ schnell_expert_path = os.path.join(ENCODING_CACHE_DIR, f"schnell_expert_{len(schnell_prompts)}.pt")
625
+ schnell_expert_cache = load_or_extract_expert_features(
626
+ schnell_expert_path, schnell_prompts, "schnell",
627
+ clip_tok, clip_enc, EXPERT_T_BUCKETS
628
+ )
629
+
630
+ if portrait_prompts and ENABLE_EXPERT_DISTILLATION:
631
+ portrait_expert_path = os.path.join(ENCODING_CACHE_DIR, f"portrait_expert_{len(portrait_prompts)}.pt")
632
+ portrait_expert_cache = load_or_extract_expert_features(
633
+ portrait_expert_path, portrait_prompts, "portrait",
634
+ clip_tok, clip_enc, EXPERT_T_BUCKETS
635
+ )
636
+
637
+ if sportfashion_prompts and ENABLE_EXPERT_DISTILLATION:
638
+ sportfashion_expert_path = os.path.join(ENCODING_CACHE_DIR, f"sportfashion_expert_{len(sportfashion_prompts)}.pt")
639
+ sportfashion_expert_cache = load_or_extract_expert_features(
640
+ sportfashion_expert_path, sportfashion_prompts, "sportfashion",
641
+ clip_tok, clip_enc, EXPERT_T_BUCKETS
642
+ )
643
+
644
+ if synthmocap_prompts and ENABLE_EXPERT_DISTILLATION:
645
+ synthmocap_expert_path = os.path.join(ENCODING_CACHE_DIR, f"synthmocap_expert_{len(synthmocap_prompts)}.pt")
646
+ synthmocap_expert_cache = load_or_extract_expert_features(
647
+ synthmocap_expert_path, synthmocap_prompts, "synthmocap",
648
+ clip_tok, clip_enc, EXPERT_T_BUCKETS
649
+ )
650
+
651
+
652
+ # ============================================================================
653
+ # COMBINED DATASET CLASS (with sample_idx for expert lookup)
654
+ # ============================================================================
655
+ class CombinedDataset(Dataset):
656
+ """Combined dataset returning sample index for expert feature lookup."""
657
+
658
+ def __init__(
659
+ self,
660
+ portrait_ds, portrait_indices, portrait_t5, portrait_clip,
661
+ schnell_ds, schnell_t5, schnell_clip,
662
+ sportfashion_ds, sportfashion_t5, sportfashion_clip,
663
+ synthmocap_ds, synthmocap_t5, synthmocap_clip,
664
+ vae, vae_scale, device, dtype,
665
+ compute_masks=True,
666
+ ):
667
+ self.portrait_ds = portrait_ds
668
+ self.portrait_indices = portrait_indices
669
+ self.portrait_t5 = portrait_t5
670
+ self.portrait_clip = portrait_clip
671
+
672
+ self.schnell_ds = schnell_ds
673
+ self.schnell_t5 = schnell_t5
674
+ self.schnell_clip = schnell_clip
675
+
676
+ self.sportfashion_ds = sportfashion_ds
677
+ self.sportfashion_t5 = sportfashion_t5
678
+ self.sportfashion_clip = sportfashion_clip
679
+
680
+ self.synthmocap_ds = synthmocap_ds
681
+ self.synthmocap_t5 = synthmocap_t5
682
+ self.synthmocap_clip = synthmocap_clip
683
+
684
+ self.vae = vae
685
+ self.vae_scale = vae_scale
686
+ self.device = device
687
+ self.dtype = dtype
688
+ self.compute_masks = compute_masks
689
+
690
+ self.n_portrait = len(portrait_indices) if portrait_indices else 0
691
+ self.n_schnell = len(schnell_ds) if schnell_ds else 0
692
+ self.n_sportfashion = len(sportfashion_ds) if sportfashion_ds else 0
693
+ self.n_synthmocap = len(synthmocap_ds) if synthmocap_ds else 0
694
+
695
+ self.c1 = self.n_portrait
696
+ self.c2 = self.c1 + self.n_schnell
697
+ self.c3 = self.c2 + self.n_sportfashion
698
+ self.total = self.c3 + self.n_synthmocap
699
+
700
+ def __len__(self):
701
+ return self.total
702
+
703
+ def _get_latent_from_array(self, latent_data):
704
+ if isinstance(latent_data, torch.Tensor):
705
+ return latent_data.to(self.dtype)
706
+ return torch.tensor(np.array(latent_data), dtype=self.dtype)
707
+
708
+ @torch.no_grad()
709
+ def _encode_image(self, image):
710
+ if image.mode != "RGB":
711
+ image = image.convert("RGB")
712
+ if image.size != (512, 512):
713
+ image = image.resize((512, 512), Image.Resampling.LANCZOS)
714
+ img_tensor = torch.from_numpy(np.array(image)).float() / 255.0
715
+ img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0)
716
+ img_tensor = (img_tensor * 2.0 - 1.0).to(self.device, dtype=self.dtype)
717
+ latent = self.vae.encode(img_tensor).latent_dist.sample()
718
+ latent = latent * self.vae_scale
719
+ return latent.squeeze(0).cpu()
720
+
721
+ def __getitem__(self, idx):
722
+ mask = None
723
+
724
+ # Determine which dataset and local index
725
+ if idx < self.c1:
726
+ # Portrait
727
+ local_idx = idx
728
+ orig_idx = self.portrait_indices[idx]
729
+ item = self.portrait_ds[orig_idx]
730
+ latent = self._get_latent_from_array(item["latent"])
731
+ t5 = self.portrait_t5[idx]
732
+ clip = self.portrait_clip[idx]
733
+ dataset_id = 0
734
+
735
+ elif idx < self.c2:
736
+ # Schnell
737
+ local_idx = idx - self.c1
738
+ item = self.schnell_ds[local_idx]
739
+ latent = self._get_latent_from_array(item["latent"])
740
+ t5 = self.schnell_t5[local_idx]
741
+ clip = self.schnell_clip[local_idx]
742
+ dataset_id = 1
743
+
744
+ elif idx < self.c3:
745
+ # SportFashion
746
+ local_idx = idx - self.c2
747
+ item = self.sportfashion_ds[local_idx]
748
+ image = item["image"]
749
+ latent = self._encode_image(image)
750
+ t5 = self.sportfashion_t5[local_idx]
751
+ clip = self.sportfashion_clip[local_idx]
752
+ dataset_id = 2
753
+ if self.compute_masks:
754
+ pixel_mask = create_product_mask(image)
755
+ mask = downsample_mask_to_latent(pixel_mask, 64, 64)
756
+
757
+ else:
758
+ # SynthMoCap
759
+ local_idx = idx - self.c3
760
+ item = self.synthmocap_ds[local_idx]
761
+ image = item["image"]
762
+ conditioning = item["conditioning_image"]
763
+ latent = self._encode_image(image)
764
+ t5 = self.synthmocap_t5[local_idx]
765
+ clip = self.synthmocap_clip[local_idx]
766
+ dataset_id = 3
767
+ if self.compute_masks:
768
+ pixel_mask = create_smpl_mask(conditioning)
769
+ mask = downsample_mask_to_latent(pixel_mask, 64, 64)
770
+
771
+ result = {
772
+ "latent": latent,
773
+ "t5_embed": t5.to(self.dtype),
774
+ "clip_pooled": clip.to(self.dtype),
775
+ "sample_idx": idx, # Global index for expert cache lookup
776
+ "local_idx": local_idx, # Local index within dataset
777
+ "dataset_id": dataset_id, # Which dataset (0=portrait, 1=schnell, etc)
778
+ }
779
+
780
+ if mask is not None:
781
+ result["mask"] = mask.to(self.dtype)
782
+
783
+ return result
784
+
785
+
786
+ # ============================================================================
787
+ # COLLATE FUNCTION
788
+ # ============================================================================
789
+ def collate_fn(batch):
790
+ latents = torch.stack([b["latent"] for b in batch])
791
+ t5_embeds = torch.stack([b["t5_embed"] for b in batch])
792
+ clip_pooled = torch.stack([b["clip_pooled"] for b in batch])
793
+ sample_indices = torch.tensor([b["sample_idx"] for b in batch], dtype=torch.long)
794
+ local_indices = torch.tensor([b["local_idx"] for b in batch], dtype=torch.long)
795
+ dataset_ids = torch.tensor([b["dataset_id"] for b in batch], dtype=torch.long)
796
+
797
+ masks = None
798
+ if any("mask" in b for b in batch):
799
+ masks = []
800
+ for b in batch:
801
+ if "mask" in b:
802
+ masks.append(b["mask"])
803
+ else:
804
+ masks.append(torch.ones(64, 64, dtype=latents.dtype))
805
+ masks = torch.stack(masks)
806
+
807
+ return {
808
+ "latents": latents,
809
+ "t5_embeds": t5_embeds,
810
+ "clip_pooled": clip_pooled,
811
+ "sample_indices": sample_indices,
812
+ "local_indices": local_indices,
813
+ "dataset_ids": dataset_ids,
814
+ "masks": masks,
815
+ }
816
+
817
+
818
+ # ============================================================================
819
+ # EXPERT FEATURE LOOKUP (handles multiple datasets)
820
+ # ============================================================================
821
+ def get_expert_features_for_batch(
822
+ local_indices: torch.Tensor,
823
+ dataset_ids: torch.Tensor,
824
+ timesteps: torch.Tensor,
825
+ portrait_cache: Optional[ExpertFeatureCache],
826
+ schnell_cache: Optional[ExpertFeatureCache],
827
+ sportfashion_cache: Optional[ExpertFeatureCache],
828
+ synthmocap_cache: Optional[ExpertFeatureCache],
829
+ ) -> Optional[torch.Tensor]:
830
+ """Get expert features from the appropriate cache for each sample."""
831
+
832
+ caches = [portrait_cache, schnell_cache, sportfashion_cache, synthmocap_cache]
833
+
834
+ # Check if any cache is available
835
+ if not any(c is not None for c in caches):
836
+ return None
837
+
838
+ B = local_indices.shape[0]
839
+ device = timesteps.device
840
+ features = torch.zeros(B, EXPERT_DIM, device=device, dtype=DTYPE)
841
+
842
+ for ds_id, cache in enumerate(caches):
843
+ if cache is None:
844
+ continue
845
+
846
+ # Find samples from this dataset
847
+ mask = dataset_ids == ds_id
848
+ if not mask.any():
849
+ continue
850
+
851
+ # Get features for these samples
852
+ ds_local_indices = local_indices[mask]
853
+ ds_timesteps = timesteps[mask]
854
+ ds_features = cache.get_features(ds_local_indices, ds_timesteps)
855
+ features[mask] = ds_features
856
+
857
+ return features
858
+
859
+
860
+ # ============================================================================
861
+ # MASKED LOSS FUNCTION
862
+ # ============================================================================
863
+ def masked_mse_loss(pred, target, mask=None, fg_weight=2.0, bg_weight=0.5, snr_weights=None):
864
+ B, N, C = pred.shape
865
+ if mask is None:
866
+ loss_per_sample = ((pred - target) ** 2).mean(dim=[1, 2])
867
+ else:
868
+ H = W = int(math.sqrt(N))
869
+ mask_flat = mask.view(B, H * W, 1).to(pred.device)
870
+ sq_error = (pred - target) ** 2
871
+ weights = mask_flat * fg_weight + (1 - mask_flat) * bg_weight
872
+ weighted_error = sq_error * weights
873
+ loss_per_sample = weighted_error.mean(dim=[1, 2])
874
+ if snr_weights is not None:
875
+ loss_per_sample = loss_per_sample * snr_weights
876
+ return loss_per_sample.mean()
877
+
878
+
879
+ # ============================================================================
880
+ # CREATE DATASET
881
+ # ============================================================================
882
+ print("\nCreating combined dataset...")
883
+ combined_ds = CombinedDataset(
884
+ portrait_ds, portrait_indices, portrait_t5, portrait_clip,
885
+ schnell_ds, schnell_t5, schnell_clip,
886
+ sportfashion_ds, sportfashion_t5, sportfashion_clip,
887
+ synthmocap_ds, synthmocap_t5, synthmocap_clip,
888
+ vae, VAE_SCALE, DEVICE, DTYPE,
889
+ compute_masks=USE_MASKED_LOSS,
890
+ )
891
+ print(f"✓ Combined dataset: {len(combined_ds)} samples")
892
+ print(f" - Portraits (3x): {combined_ds.n_portrait:,}")
893
+ print(f" - Schnell teacher: {combined_ds.n_schnell:,}")
894
+ print(f" - SportFashion: {combined_ds.n_sportfashion:,}")
895
+ print(f" - SynthMoCap: {combined_ds.n_synthmocap:,}")
896
+ print(f" - Expert distillation: {ENABLE_EXPERT_DISTILLATION}")
897
+
898
+
899
+ # ============================================================================
900
+ # DATALOADER
901
+ # ============================================================================
902
+ loader = DataLoader(
903
+ combined_ds,
904
+ batch_size=BATCH_SIZE,
905
+ shuffle=True,
906
+ num_workers=8,
907
+ pin_memory=True,
908
+ collate_fn=collate_fn,
909
+ drop_last=True,
910
+ )
911
+ print(f"✓ DataLoader: {len(loader)} batches/epoch")
912
+
913
+
914
+ # ============================================================================
915
+ # SAMPLING FUNCTION
916
+ # ============================================================================
917
+ @torch.inference_mode()
918
+ def generate_samples(model, prompts, num_steps=28, guidance_scale=3.5, H=64, W=64, use_ema=True):
919
+ was_training = model.training
920
+ model.eval()
921
+
922
+ if use_ema and 'ema' in globals() and ema is not None:
923
+ ema.apply_shadow_for_eval(model)
924
+
925
+ B = len(prompts)
926
+ C = 16
927
+
928
+ t5_list, clip_list = [], []
929
+ for p in prompts:
930
+ t5, clip = encode_prompt(p)
931
+ t5_list.append(t5)
932
+ clip_list.append(clip)
933
+ t5_embeds = torch.stack(t5_list).to(DTYPE)
934
+ clip_pooleds = torch.stack(clip_list).to(DTYPE)
935
+
936
+ x = torch.randn(B, H * W, C, device=DEVICE, dtype=DTYPE)
937
+ img_ids = TinyFluxDeep.create_img_ids(B, H, W, DEVICE)
938
+
939
+ t_linear = torch.linspace(0, 1, num_steps + 1, device=DEVICE, dtype=DTYPE)
940
+ timesteps = flux_shift(t_linear, s=SHIFT)
941
+
942
+ for i in range(num_steps):
943
+ t_curr = timesteps[i]
944
+ t_next = timesteps[i + 1]
945
+ dt = t_next - t_curr
946
+
947
+ t_batch = t_curr.expand(B).to(DTYPE)
948
+
949
+ with torch.autocast("cuda", dtype=DTYPE):
950
+ # No expert_features at inference - predictor runs standalone
951
+ v_cond = model(
952
+ hidden_states=x,
953
+ encoder_hidden_states=t5_embeds,
954
+ pooled_projections=clip_pooleds,
955
+ timestep=t_batch,
956
+ img_ids=img_ids,
957
+ )
958
+ x = x + v_cond * dt
959
+
960
+ latents = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
961
+ latents = latents / VAE_SCALE
962
+
963
+ with torch.autocast("cuda", dtype=DTYPE):
964
+ images = vae.decode(latents.to(vae.dtype)).sample
965
+ images = (images / 2 + 0.5).clamp(0, 1)
966
+
967
+ if use_ema and 'ema' in globals() and ema is not None:
968
+ ema.restore(model)
969
+
970
+ if was_training:
971
+ model.train()
972
+ return images
973
+
974
+
975
+ def save_samples(images, prompts, step, output_dir):
976
+ from torchvision.utils import save_image
977
+ os.makedirs(output_dir, exist_ok=True)
978
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
979
+ grid_path = os.path.join(output_dir, f"samples_step_{step}.png")
980
+ save_image(images, grid_path, nrow=2, padding=2)
981
+ try:
982
+ api.upload_file(
983
+ path_or_fileobj=grid_path,
984
+ path_in_repo=f"samples/{timestamp}_step_{step}.png",
985
+ repo_id=HF_REPO,
986
+ )
987
+ except:
988
+ pass
989
+
990
+
991
+ # ============================================================================
992
+ # CHECKPOINT FUNCTIONS
993
+ # ============================================================================
994
+ def save_checkpoint(model, optimizer, scheduler, step, epoch, loss, path, ema=None):
995
+ os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
996
+ if hasattr(model, '_orig_mod'):
997
+ state_dict = model._orig_mod.state_dict()
998
+ else:
999
+ state_dict = model.state_dict()
1000
+ state_dict = {k: v.to(DTYPE) if v.is_floating_point() else v for k, v in state_dict.items()}
1001
+ weights_path = path.replace(".pt", ".safetensors")
1002
+ save_file(state_dict, weights_path)
1003
+ if ema is not None:
1004
+ ema_weights = {k: v.to(DTYPE) if v.is_floating_point() else v for k, v in ema.shadow.items()}
1005
+ ema_weights_path = path.replace(".pt", "_ema.safetensors")
1006
+ save_file(ema_weights, ema_weights_path)
1007
+ state = {
1008
+ "step": step,
1009
+ "epoch": epoch,
1010
+ "loss": loss,
1011
+ "optimizer": optimizer.state_dict(),
1012
+ "scheduler": scheduler.state_dict(),
1013
+ }
1014
+ if ema is not None:
1015
+ state["ema_decay"] = ema.decay
1016
+ torch.save(state, path)
1017
+ print(f" ✓ Saved checkpoint: step {step}")
1018
+ return weights_path
1019
+
1020
+
1021
+ def upload_checkpoint(weights_path, step):
1022
+ try:
1023
+ api.upload_file(
1024
+ path_or_fileobj=weights_path,
1025
+ path_in_repo=f"checkpoints/step_{step}.safetensors",
1026
+ repo_id=HF_REPO,
1027
+ )
1028
+ ema_path = weights_path.replace(".safetensors", "_ema.safetensors")
1029
+ if os.path.exists(ema_path):
1030
+ api.upload_file(
1031
+ path_or_fileobj=ema_path,
1032
+ path_in_repo=f"checkpoints/step_{step}_ema.safetensors",
1033
+ repo_id=HF_REPO,
1034
+ )
1035
+ print(f" ✓ Uploaded checkpoint to {HF_REPO}")
1036
+ except Exception as e:
1037
+ print(f" ⚠ Upload failed: {e}")
1038
+
1039
+
1040
+ def load_with_weight_upgrade(model, state_dict):
1041
+ """
1042
+ Load state dict with automatic handling of:
1043
+ - Missing ExpertPredictor weights → initialize fresh
1044
+ - Missing Q/K norm weights → initialize to ones (identity)
1045
+ - Unexpected keys → ignore (e.g., old guidance_in, sin_basis caches)
1046
+ """
1047
+ model_state = model.state_dict()
1048
+
1049
+ # Patterns for new weights that may not exist in old checkpoints
1050
+ NEW_WEIGHT_PATTERNS = [
1051
+ 'expert_predictor.', # New ExpertPredictor module
1052
+ '.norm_q.weight',
1053
+ '.norm_k.weight',
1054
+ '.norm_added_q.weight',
1055
+ '.norm_added_k.weight',
1056
+ ]
1057
+
1058
+ # Keys that may exist in old checkpoints but not new model
1059
+ DEPRECATED_PATTERNS = [
1060
+ 'guidance_in.', # Replaced by expert_predictor
1061
+ '.sin_basis', # Old cached sin embeddings
1062
+ ]
1063
+
1064
+ loaded_keys = []
1065
+ missing_keys = []
1066
+ unexpected_keys = []
1067
+ initialized_keys = []
1068
+
1069
+ # First pass: load matching weights
1070
+ for key in state_dict.keys():
1071
+ if key in model_state:
1072
+ if state_dict[key].shape == model_state[key].shape:
1073
+ model_state[key] = state_dict[key]
1074
+ loaded_keys.append(key)
1075
+ else:
1076
+ print(f" ⚠ Shape mismatch for {key}: checkpoint {state_dict[key].shape} vs model {model_state[key].shape}")
1077
+ unexpected_keys.append(key)
1078
+ else:
1079
+ is_deprecated = any(pat in key for pat in DEPRECATED_PATTERNS)
1080
+ if is_deprecated:
1081
+ unexpected_keys.append(key)
1082
+ else:
1083
+ print(f" ⚠ Unexpected key (not in model): {key}")
1084
+ unexpected_keys.append(key)
1085
+
1086
+ # Second pass: handle missing keys
1087
+ for key in model_state.keys():
1088
+ if key not in loaded_keys:
1089
+ is_new = any(pat in key for pat in NEW_WEIGHT_PATTERNS)
1090
+
1091
+ if is_new:
1092
+ # Keep default initialization for new modules
1093
+ initialized_keys.append(key)
1094
+ else:
1095
+ missing_keys.append(key)
1096
+ print(f" ⚠ Missing key (not in checkpoint): {key}")
1097
+
1098
+ # Load the updated state
1099
+ model.load_state_dict(model_state, strict=False)
1100
+
1101
+ # Report
1102
+ if initialized_keys:
1103
+ # Group by module for cleaner output
1104
+ modules = set()
1105
+ for k in initialized_keys:
1106
+ parts = k.split('.')
1107
+ if len(parts) >= 2:
1108
+ modules.add(parts[0] + '.' + parts[1] if parts[0] == 'expert_predictor' else parts[0])
1109
+ print(f" ✓ Initialized new modules (fresh): {sorted(modules)}")
1110
+
1111
+ if unexpected_keys:
1112
+ deprecated = [k for k in unexpected_keys if any(p in k for p in DEPRECATED_PATTERNS)]
1113
+ if deprecated:
1114
+ print(f" ✓ Ignored deprecated keys: {len(deprecated)} (guidance_in, etc)")
1115
+
1116
+ return missing_keys, unexpected_keys
1117
+
1118
+
1119
+ def load_checkpoint(model, optimizer, scheduler, target):
1120
+ """
1121
+ Load checkpoint with weight upgrade support for ExpertPredictor.
1122
+
1123
+ When ALLOW_WEIGHT_UPGRADE=True:
1124
+ - Missing ExpertPredictor weights are initialized fresh
1125
+ - Old guidance_in weights are ignored
1126
+ - Model continues training with new architecture
1127
+ """
1128
+ start_step = 0
1129
+ start_epoch = 0
1130
+ ema_state = None
1131
+
1132
+ if target == "none":
1133
+ print("Starting fresh (no checkpoint)")
1134
+ return start_step, start_epoch, None
1135
+
1136
+ ckpt_path = None
1137
+ weights_path = None
1138
+ ema_weights_path = None
1139
+
1140
+ if target == "latest":
1141
+ if os.path.exists(CHECKPOINT_DIR):
1142
+ ckpts = [f for f in os.listdir(CHECKPOINT_DIR) if f.startswith("step_") and f.endswith(".pt")]
1143
+ if ckpts:
1144
+ steps = [int(f.split("_")[1].split(".")[0]) for f in ckpts]
1145
+ latest_step = max(steps)
1146
+ ckpt_path = os.path.join(CHECKPOINT_DIR, f"step_{latest_step}.pt")
1147
+ weights_path = ckpt_path.replace(".pt", ".safetensors")
1148
+ ema_weights_path = ckpt_path.replace(".pt", "_ema.safetensors")
1149
+
1150
+ elif target == "hub" or target.startswith("hub:"):
1151
+ try:
1152
+ from huggingface_hub import list_repo_files
1153
+
1154
+ if target.startswith("hub:"):
1155
+ step_name = target.split(":")[1]
1156
+ weights_path = hf_hub_download(HF_REPO, f"checkpoints/{step_name}.safetensors")
1157
+ try:
1158
+ ema_weights_path = hf_hub_download(HF_REPO, f"checkpoints/{step_name}_ema.safetensors")
1159
+ print(f" Found EMA weights on hub")
1160
+ except:
1161
+ ema_weights_path = None
1162
+ print(f" No EMA weights on hub (will start fresh)")
1163
+ start_step = int(step_name.split("_")[1]) if "_" in step_name else 0
1164
+ print(f"Downloaded {step_name} from hub")
1165
+ else:
1166
+ files = list_repo_files(HF_REPO)
1167
+ ckpts = [f for f in files if f.startswith("checkpoints/step_") and f.endswith(".safetensors") and "_ema" not in f]
1168
+ if ckpts:
1169
+ steps = [int(f.split("_")[1].split(".")[0]) for f in ckpts]
1170
+ latest = max(steps)
1171
+ weights_path = hf_hub_download(HF_REPO, f"checkpoints/step_{latest}.safetensors")
1172
+ try:
1173
+ ema_weights_path = hf_hub_download(HF_REPO, f"checkpoints/step_{latest}_ema.safetensors")
1174
+ print(f" Found EMA weights on hub")
1175
+ except:
1176
+ ema_weights_path = None
1177
+ print(f" No EMA weights on hub (will start fresh)")
1178
+ start_step = latest
1179
+ print(f"Downloaded step_{latest} from hub")
1180
+ except Exception as e:
1181
+ print(f"Could not download from hub: {e}")
1182
+ return start_step, start_epoch, None
1183
+
1184
+ elif target == "best":
1185
+ ckpt_path = os.path.join(CHECKPOINT_DIR, "best.pt")
1186
+ weights_path = ckpt_path.replace(".pt", ".safetensors")
1187
+ ema_weights_path = ckpt_path.replace(".pt", "_ema.safetensors")
1188
+
1189
+ elif os.path.exists(target):
1190
+ if target.endswith(".safetensors"):
1191
+ weights_path = target
1192
+ ckpt_path = target.replace(".safetensors", ".pt")
1193
+ ema_weights_path = target.replace(".safetensors", "_ema.safetensors")
1194
+ else:
1195
+ ckpt_path = target
1196
+ weights_path = target.replace(".pt", ".safetensors")
1197
+ ema_weights_path = target.replace(".pt", "_ema.safetensors")
1198
+
1199
+ # Load main model weights
1200
+ if weights_path and os.path.exists(weights_path):
1201
+ print(f"Loading weights from {weights_path}")
1202
+ state_dict = load_file(weights_path)
1203
+ state_dict = {k: v.to(DTYPE) if v.is_floating_point() else v for k, v in state_dict.items()}
1204
+
1205
+ # Get model reference (handle torch.compile wrapper)
1206
+ model_ref = model._orig_mod if hasattr(model, '_orig_mod') else model
1207
+
1208
+ if ALLOW_WEIGHT_UPGRADE:
1209
+ # Flexible loading with weight upgrade
1210
+ missing, unexpected = load_with_weight_upgrade(model_ref, state_dict)
1211
+
1212
+ if missing:
1213
+ print(f" ⚠ {len(missing)} truly missing parameters (may need attention)")
1214
+ else:
1215
+ # Strict loading - must match exactly
1216
+ model_ref.load_state_dict(state_dict, strict=True)
1217
+
1218
+ print(f"✓ Loaded model weights")
1219
+
1220
+ # Load EMA weights if they exist
1221
+ if ema_weights_path and os.path.exists(ema_weights_path):
1222
+ ema_state = load_file(ema_weights_path)
1223
+ ema_state = {k: v.to(DTYPE) if v.is_floating_point() else v for k, v in ema_state.items()}
1224
+ print(f"✓ Loaded EMA weights ({len(ema_state)} params)")
1225
+ else:
1226
+ print(f" ℹ No EMA weights found (will initialize fresh)")
1227
+ else:
1228
+ print(f" ⚠ Weights file not found: {weights_path}")
1229
+ print(f" Starting with fresh model")
1230
+ return start_step, start_epoch, None
1231
+
1232
+ # Load optimizer/scheduler state
1233
+ if ckpt_path and os.path.exists(ckpt_path):
1234
+ state = torch.load(ckpt_path, map_location="cpu")
1235
+ start_step = state.get("step", 0)
1236
+ start_epoch = state.get("epoch", 0)
1237
+ try:
1238
+ optimizer.load_state_dict(state["optimizer"])
1239
+ scheduler.load_state_dict(state["scheduler"])
1240
+ print(f"✓ Loaded optimizer/scheduler state")
1241
+ except Exception as e:
1242
+ print(f" ⚠ Could not load optimizer state: {e}")
1243
+ print(f" Will use fresh optimizer (this is fine for architecture changes)")
1244
+ print(f"Resuming from step {start_step}, epoch {start_epoch}")
1245
+
1246
+ return start_step, start_epoch, ema_state
1247
+
1248
+
1249
+ # ============================================================================
1250
+ # CREATE MODEL
1251
+ # ============================================================================
1252
+ print("\nCreating TinyFluxDeep model with ExpertPredictor...")
1253
+
1254
+ config = TinyFluxDeepConfig(
1255
+ use_expert_predictor=ENABLE_EXPERT_DISTILLATION,
1256
+ expert_dim=EXPERT_DIM,
1257
+ expert_hidden_dim=EXPERT_HIDDEN_DIM,
1258
+ expert_dropout=EXPERT_DROPOUT,
1259
+ guidance_embeds=False,
1260
+ )
1261
+ model = TinyFluxDeep(config).to(device=DEVICE, dtype=DTYPE)
1262
+
1263
+ total_params = sum(p.numel() for p in model.parameters())
1264
+ print(f"Total parameters: {total_params:,}")
1265
+
1266
+ if hasattr(model, 'expert_predictor') and model.expert_predictor is not None:
1267
+ expert_params = sum(p.numel() for p in model.expert_predictor.parameters())
1268
+ print(f"Expert predictor parameters: {expert_params:,}")
1269
+
1270
+ trainable_params = [p for p in model.parameters() if p.requires_grad]
1271
+ print(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
1272
+
1273
+
1274
+ # ============================================================================
1275
+ # OPTIMIZER
1276
+ # ============================================================================
1277
+ opt = torch.optim.AdamW(trainable_params, lr=LR, betas=(0.9, 0.99), weight_decay=0.01, fused=True)
1278
+
1279
+ total_steps = len(loader) * EPOCHS // GRAD_ACCUM
1280
+ warmup = min(1000, total_steps // 10)
1281
+
1282
+ def lr_fn(step):
1283
+ if step < warmup:
1284
+ return step / warmup
1285
+ return 0.5 * (1 + math.cos(math.pi * (step - warmup) / (total_steps - warmup)))
1286
+
1287
+ sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_fn)
1288
+
1289
+
1290
+ # ============================================================================
1291
+ # LOAD CHECKPOINT
1292
+ # ============================================================================
1293
+ start_step, start_epoch, ema_state = load_checkpoint(model, opt, sched, LOAD_TARGET)
1294
+
1295
+ if RESUME_STEP is not None:
1296
+ start_step = RESUME_STEP
1297
+
1298
+
1299
+ # ============================================================================
1300
+ # COMPILE
1301
+ # ============================================================================
1302
+ model = torch.compile(model, mode="default")
1303
+
1304
+
1305
+ # ============================================================================
1306
+ # EMA
1307
+ # ============================================================================
1308
+ print("Initializing EMA...")
1309
+ ema = EMA(model, decay=EMA_DECAY)
1310
+ if ema_state is not None:
1311
+ ema.load_shadow(ema_state)
1312
+ else:
1313
+ print(" Starting fresh EMA from current weights")
1314
+
1315
+
1316
+ # ============================================================================
1317
+ # TENSORBOARD
1318
+ # ============================================================================
1319
+ run_name = f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
1320
+ writer = SummaryWriter(os.path.join(LOG_DIR, run_name))
1321
+
1322
+ SAMPLE_PROMPTS = [
1323
+ "a photo of a cat sitting on a windowsill",
1324
+ "a portrait of a woman with red hair",
1325
+ "a black backpack on white background",
1326
+ "a person standing in a t-pose",
1327
+ ]
1328
+
1329
+
1330
+ # ============================================================================
1331
+ # DISTILLATION WEIGHT SCHEDULE
1332
+ # ============================================================================
1333
+ def get_distill_weight(step):
1334
+ if step < DISTILL_WARMUP_STEPS:
1335
+ return DISTILL_LOSS_WEIGHT * (step / DISTILL_WARMUP_STEPS)
1336
+ return DISTILL_LOSS_WEIGHT
1337
+
1338
+
1339
+ # ============================================================================
1340
+ # TRAINING LOOP
1341
+ # ============================================================================
1342
+ print(f"\n{'='*60}")
1343
+ print(f"Training TinyFlux-Deep with Expert Distillation (Precached)")
1344
+ print(f"{'='*60}")
1345
+ print(f"Total: {len(combined_ds):,} samples")
1346
+ print(f"Epochs: {EPOCHS}, Steps/epoch: {len(loader)}, Total: {total_steps}")
1347
+ print(f"Batch: {BATCH_SIZE} x {GRAD_ACCUM} = {BATCH_SIZE * GRAD_ACCUM}")
1348
+ print(f"Expert distillation: {ENABLE_EXPERT_DISTILLATION} (PRECACHED)")
1349
+ if ENABLE_EXPERT_DISTILLATION:
1350
+ print(f" - Expert: {EXPERT_CHECKPOINT}")
1351
+ print(f" - Timestep buckets: {len(EXPERT_T_BUCKETS)}")
1352
+ print(f" - Distill weight: {DISTILL_LOSS_WEIGHT} (warmup: {DISTILL_WARMUP_STEPS} steps)")
1353
+ print(f" - Expert dropout: {EXPERT_DROPOUT}")
1354
+ print(f"Masked loss: {USE_MASKED_LOSS}")
1355
+ print(f"Min-SNR gamma: {MIN_SNR_GAMMA}")
1356
+ print(f"Resume: step {start_step}, epoch {start_epoch}")
1357
+
1358
+ model.train()
1359
+ step = start_step
1360
+ best = float("inf")
1361
+
1362
+ for ep in range(start_epoch, EPOCHS):
1363
+ ep_loss = 0
1364
+ ep_main_loss = 0
1365
+ ep_distill_loss = 0
1366
+ ep_batches = 0
1367
+ pbar = tqdm(loader, desc=f"E{ep + 1}")
1368
+
1369
+ for i, batch in enumerate(pbar):
1370
+ latents = batch["latents"].to(DEVICE, non_blocking=True)
1371
+ t5 = batch["t5_embeds"].to(DEVICE, non_blocking=True)
1372
+ clip = batch["clip_pooled"].to(DEVICE, non_blocking=True)
1373
+ local_indices = batch["local_indices"]
1374
+ dataset_ids = batch["dataset_ids"]
1375
+ masks = batch["masks"]
1376
+
1377
+ if masks is not None:
1378
+ masks = masks.to(DEVICE, non_blocking=True)
1379
+
1380
+ B, C, H, W = latents.shape
1381
+ data = latents.permute(0, 2, 3, 1).reshape(B, H * W, C)
1382
+ noise = torch.randn_like(data)
1383
+
1384
+ if TEXT_DROPOUT > 0:
1385
+ t5, clip, _ = apply_text_dropout(t5, clip, TEXT_DROPOUT)
1386
+
1387
+ t = torch.sigmoid(torch.randn(B, device=DEVICE))
1388
+ t = flux_shift(t, s=SHIFT).to(DTYPE).clamp(1e-4, 1 - 1e-4)
1389
+
1390
+ t_expanded = t.view(B, 1, 1)
1391
+ x_t = (1 - t_expanded) * noise + t_expanded * data
1392
+ v_target = data - noise
1393
+
1394
+ img_ids = TinyFluxDeep.create_img_ids(B, H, W, DEVICE)
1395
+
1396
+ # Get expert features from CACHE (fast!)
1397
+ expert_features = None
1398
+ if ENABLE_EXPERT_DISTILLATION:
1399
+ expert_features = get_expert_features_for_batch(
1400
+ local_indices, dataset_ids, t,
1401
+ portrait_expert_cache, schnell_expert_cache,
1402
+ sportfashion_expert_cache, synthmocap_expert_cache,
1403
+ )
1404
+
1405
+ # Apply dropout OUTSIDE model (no graph break)
1406
+ if expert_features is not None and random.random() < EXPERT_DROPOUT:
1407
+ expert_features = None
1408
+
1409
+ with torch.autocast("cuda", dtype=DTYPE):
1410
+ v_pred, expert_info = model(
1411
+ hidden_states=x_t,
1412
+ encoder_hidden_states=t5,
1413
+ pooled_projections=clip,
1414
+ timestep=t,
1415
+ img_ids=img_ids,
1416
+ expert_features=expert_features,
1417
+ return_expert_pred=True,
1418
+ )
1419
+
1420
+ # Compute losses
1421
+ snr_weights = min_snr_weight(t)
1422
+
1423
+ main_loss = masked_mse_loss(
1424
+ v_pred, v_target,
1425
+ mask=masks if USE_MASKED_LOSS else None,
1426
+ fg_weight=FG_LOSS_WEIGHT,
1427
+ bg_weight=BG_LOSS_WEIGHT,
1428
+ snr_weights=snr_weights
1429
+ )
1430
+
1431
+ # Distillation loss
1432
+ distill_loss = torch.tensor(0.0, device=DEVICE)
1433
+ if expert_features is not None and expert_info is not None and 'expert_pred' in expert_info:
1434
+ distill_weight = get_distill_weight(step)
1435
+ distill_loss = F.mse_loss(expert_info['expert_pred'], expert_features)
1436
+ total_loss = main_loss + distill_weight * distill_loss
1437
+ else:
1438
+ total_loss = main_loss
1439
+
1440
+ loss = total_loss / GRAD_ACCUM
1441
+ loss.backward()
1442
+
1443
+ if (i + 1) % GRAD_ACCUM == 0:
1444
+ grad_norm = torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
1445
+ opt.step()
1446
+ sched.step()
1447
+ opt.zero_grad(set_to_none=True)
1448
+
1449
+ ema.update(model)
1450
+ step += 1
1451
+
1452
+ if step % LOG_EVERY == 0:
1453
+ writer.add_scalar("train/loss", total_loss.item(), step)
1454
+ writer.add_scalar("train/main_loss", main_loss.item(), step)
1455
+ if ENABLE_EXPERT_DISTILLATION:
1456
+ writer.add_scalar("train/distill_loss", distill_loss.item(), step)
1457
+ writer.add_scalar("train/distill_weight", get_distill_weight(step), step)
1458
+ writer.add_scalar("train/lr", sched.get_last_lr()[0], step)
1459
+ writer.add_scalar("train/grad_norm", grad_norm.item(), step)
1460
+
1461
+ if step % SAMPLE_EVERY == 0:
1462
+ print(f"\n Generating samples at step {step}...")
1463
+ images = generate_samples(model, SAMPLE_PROMPTS, num_steps=20, use_ema=True)
1464
+ save_samples(images, SAMPLE_PROMPTS, step, SAMPLE_DIR)
1465
+
1466
+ if step % SAVE_EVERY == 0:
1467
+ ckpt_path = os.path.join(CHECKPOINT_DIR, f"step_{step}.pt")
1468
+ weights_path = save_checkpoint(model, opt, sched, step, ep, total_loss.item(), ckpt_path, ema=ema)
1469
+ if step % UPLOAD_EVERY == 0:
1470
+ upload_checkpoint(weights_path, step)
1471
+
1472
+ ep_loss += total_loss.item()
1473
+ ep_main_loss += main_loss.item()
1474
+ ep_distill_loss += distill_loss.item()
1475
+ ep_batches += 1
1476
+
1477
+ pbar.set_postfix(
1478
+ loss=f"{total_loss.item():.4f}",
1479
+ main=f"{main_loss.item():.4f}",
1480
+ dist=f"{distill_loss.item():.4f}" if ENABLE_EXPERT_DISTILLATION else "off",
1481
+ step=step
1482
+ )
1483
+
1484
+ avg = ep_loss / max(ep_batches, 1)
1485
+ avg_main = ep_main_loss / max(ep_batches, 1)
1486
+ avg_distill = ep_distill_loss / max(ep_batches, 1)
1487
+
1488
+ print(f"Epoch {ep + 1} - total: {avg:.4f}, main: {avg_main:.4f}, distill: {avg_distill:.4f}")
1489
+
1490
+ if avg < best:
1491
+ best = avg
1492
+ weights_path = save_checkpoint(model, opt, sched, step, ep, avg, os.path.join(CHECKPOINT_DIR, "best.pt"), ema=ema)
1493
+ try:
1494
+ api.upload_file(path_or_fileobj=weights_path, path_in_repo="model.safetensors", repo_id=HF_REPO)
1495
+ except:
1496
+ pass
1497
+
1498
+ print(f"\n✓ Training complete! Best loss: {best:.4f}")
1499
+ writer.close()