AbstractPhil commited on
Commit
1f24ada
·
verified ·
1 Parent(s): 8b9a727

Create trainer_v4_testing.py

Browse files
Files changed (1) hide show
  1. scripts/trainer_v4_testing.py +2344 -0
scripts/trainer_v4_testing.py ADDED
@@ -0,0 +1,2344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================================
2
+ # TinyFlux-Deep v4.1 Training Cell - Dual Expert Distillation (Lune + Sol)
3
+ # ============================================================================
4
+ # Integrates:
5
+ # - Lune: SD1.5-flow trajectory guidance (mid-block features)
6
+ # - Sol: Geometric attention prior (attention statistics + spatial importance)
7
+ #
8
+ # Both expert features are PRECACHED at 10 timestep buckets for speed.
9
+ # At inference, predictors run standalone - no teachers needed.
10
+ #
11
+ # USAGE: Run model_v4.py cell first, then this cell
12
+ # ============================================================================
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from torch.utils.data import DataLoader, Dataset
18
+ from datasets import load_dataset, concatenate_datasets
19
+ from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer
20
+ from huggingface_hub import HfApi, hf_hub_download
21
+ from safetensors.torch import save_file, load_file
22
+ from torch.utils.tensorboard import SummaryWriter
23
+ from tqdm.auto import tqdm
24
+ import numpy as np
25
+ import math
26
+ import json
27
+ import random
28
+ from typing import Tuple, Optional, Dict, List
29
+ import os
30
+ from datetime import datetime
31
+ from PIL import Image
32
+
33
+ # ============================================================================
34
+ # CUDA OPTIMIZATIONS
35
+ # ============================================================================
36
+ torch.backends.cuda.matmul.allow_tf32 = True
37
+ torch.backends.cudnn.allow_tf32 = True
38
+ torch.backends.cudnn.benchmark = True
39
+ torch.set_float32_matmul_precision('high')
40
+
41
+ import warnings
42
+
43
+ warnings.filterwarnings('ignore', message='.*TF32.*')
44
+
45
+ # ============================================================================
46
+ # CONFIG
47
+ # ============================================================================
48
+ BATCH_SIZE = 16
49
+ GRAD_ACCUM = 2
50
+ LR = 3e-4
51
+ EPOCHS = 10
52
+ MAX_SEQ = 128
53
+ SHIFT = 3.0
54
+ DEVICE = "cuda"
55
+ DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
56
+
57
+ ALLOW_WEIGHT_UPGRADE = True
58
+
59
+ # HuggingFace Hub
60
+ HF_REPO = "AbstractPhil/tiny-flux-deep"
61
+ SAVE_EVERY = 1562
62
+ UPLOAD_EVERY = 1562
63
+ SAMPLE_EVERY = 781
64
+ LOG_EVERY = 200
65
+ LOG_UPLOAD_EVERY = 1562
66
+
67
+ # Checkpoint loading
68
+ # v4.1 init checkpoint (converted from v3 step_401434)
69
+ # Options:
70
+ # "hub:checkpoint_runs/v4_init/lailah_401434_v4_init" - v4.1 init (no EMA, fresh Sol)
71
+ # "hub:step_401434" - v3 checkpoint (will auto-remap expert_predictor -> lune_predictor)
72
+ # "latest" - latest local checkpoint
73
+ # "none" - start fresh
74
+ LOAD_TARGET = "hub:checkpoint_runs/v4_init/lailah_401434_v4_init"
75
+ RESUME_STEP = 401434
76
+
77
+ # ============================================================================
78
+ # EXPERT REPOSITORY (both Lune and Sol)
79
+ # ============================================================================
80
+ EXPERTS_REPO = "AbstractPhil/tinyflux-experts"
81
+
82
+ # ============================================================================
83
+ # LUNE EXPERT DISTILLATION CONFIG (trajectory guidance)
84
+ # ============================================================================
85
+ ENABLE_LUNE_DISTILLATION = True
86
+ LUNE_FILENAME = "sd15-flow-lune-unet.safetensors"
87
+ LUNE_DIM = 1280 # SD1.5 mid-block dimension
88
+ LUNE_HIDDEN_DIM = 512
89
+ LUNE_DROPOUT = 0.1
90
+ LUNE_LOSS_WEIGHT = 0.1
91
+ LUNE_WARMUP_STEPS = 1000
92
+ LUNE_DISTILL_MODE = "cosine" # "hard", "soft", "cosine", "huber"
93
+
94
+ # ============================================================================
95
+ # SOL ATTENTION PRIOR CONFIG (structural guidance)
96
+ # ============================================================================
97
+ ENABLE_SOL_DISTILLATION = True
98
+ SOL_FILENAME = "sd15-flow-sol-unet.safetensors"
99
+ SOL_HIDDEN_DIM = 256
100
+ SOL_SPATIAL_SIZE = 8 # 8x8 spatial importance map
101
+ SOL_GEOMETRIC_WEIGHT = 0.7 # 70% geometric, 30% learned
102
+ SOL_LOSS_WEIGHT = 0.05
103
+ SOL_WARMUP_STEPS = 2000 # Start Sol later than Lune
104
+
105
+ # Timestep buckets for precaching (shared by Lune and Sol)
106
+ EXPERT_T_BUCKETS = torch.linspace(0.05, 0.95, 10)
107
+
108
+ # ============================================================================
109
+ # LOSS CONFIG
110
+ # ============================================================================
111
+ USE_HUBER_LOSS = True
112
+ HUBER_DELTA = 0.1
113
+ USE_SPATIAL_WEIGHTING = False # Weight main loss by Sol spatial importance
114
+
115
+ # ============================================================================
116
+ # DATASET CONFIG
117
+ # ============================================================================
118
+ ENABLE_PORTRAIT = False
119
+ ENABLE_SCHNELL = False
120
+ ENABLE_SPORTFASHION = False
121
+ ENABLE_SYNTHMOCAP = False
122
+ ENABLE_IMAGENET = False
123
+ ENABLE_OBJECT_RELATIONS = True
124
+
125
+ PORTRAIT_REPO = "AbstractPhil/ffhq_flux_latents_repaired"
126
+ PORTRAIT_NUM_SHARDS = 11
127
+ SCHNELL_REPO = "AbstractPhil/flux-schnell-teacher-latents"
128
+ SCHNELL_CONFIGS = ["train_512"]
129
+ SPORTFASHION_REPO = "Pianokill/SportFashion_512x512"
130
+ SYNTHMOCAP_REPO = "toyxyz/SynthMoCap_smpl_512"
131
+ IMAGENET_REPO = "AbstractPhil/synthetic-imagenet-1k"
132
+ IMAGENET_SUBSET = "schnell_512"
133
+ OBJECT_RELATIONS_REPO = "AbstractPhil/synthetic-object-relations"
134
+
135
+ # Confidence threshold for misprediction filtering
136
+ IMAGENET_CONFIDENCE_THRESHOLD = 0.5 # If confident but wrong, remove label
137
+
138
+ FG_LOSS_WEIGHT = 2.0
139
+ BG_LOSS_WEIGHT = 0.5
140
+ USE_MASKED_LOSS = False
141
+ MIN_SNR_GAMMA = 5.0
142
+
143
+ # Paths
144
+ CHECKPOINT_DIR = "./tiny_flux_deep_checkpoints"
145
+ LOG_DIR = "./tiny_flux_deep_logs"
146
+ SAMPLE_DIR = "./tiny_flux_deep_samples"
147
+ ENCODING_CACHE_DIR = "./encoding_cache"
148
+ LATENT_CACHE_DIR = "./latent_cache"
149
+
150
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
151
+ os.makedirs(LOG_DIR, exist_ok=True)
152
+ os.makedirs(SAMPLE_DIR, exist_ok=True)
153
+ os.makedirs(ENCODING_CACHE_DIR, exist_ok=True)
154
+ os.makedirs(LATENT_CACHE_DIR, exist_ok=True)
155
+
156
+ # ============================================================================
157
+ # REGULARIZATION CONFIG
158
+ # ============================================================================
159
+ TEXT_DROPOUT = 0.1
160
+ GUIDANCE_DROPOUT = 0.1
161
+ EMA_DECAY = 0.9999
162
+
163
+
164
+ # ============================================================================
165
+ # LUNE FEATURE CACHE (SD1.5 mid-block features)
166
+ # ============================================================================
167
+ class LuneFeatureCache:
168
+ """
169
+ Precached SD1.5-flow Lune features with timestep interpolation.
170
+ Features extracted at 10 timestep buckets [0.05, 0.15, ..., 0.95].
171
+ """
172
+
173
+ def __init__(self, features: torch.Tensor, t_buckets: torch.Tensor, dtype=torch.float16):
174
+ self.features = features.to(dtype) # [N, 10, 1280]
175
+ self.t_buckets = t_buckets
176
+ self.t_min = t_buckets[0].item()
177
+ self.t_max = t_buckets[-1].item()
178
+ self.t_step = (t_buckets[1] - t_buckets[0]).item()
179
+ self.n_buckets = len(t_buckets)
180
+ self.dtype = dtype
181
+
182
+ def get_features(self, indices: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
183
+ device = timesteps.device
184
+ t_clamped = timesteps.float().clamp(self.t_min, self.t_max)
185
+ t_idx_float = (t_clamped - self.t_min) / self.t_step
186
+ t_idx_low = t_idx_float.long().clamp(0, self.n_buckets - 2)
187
+ t_idx_high = (t_idx_low + 1).clamp(0, self.n_buckets - 1)
188
+ alpha = (t_idx_float - t_idx_low.float()).unsqueeze(-1)
189
+
190
+ idx_cpu = indices.cpu()
191
+ t_low_cpu = t_idx_low.cpu()
192
+ t_high_cpu = t_idx_high.cpu()
193
+
194
+ f_low = self.features[idx_cpu, t_low_cpu]
195
+ f_high = self.features[idx_cpu, t_high_cpu]
196
+
197
+ result = (1 - alpha.cpu()) * f_low + alpha.cpu() * f_high
198
+ return result.to(device=device, dtype=self.dtype)
199
+
200
+
201
+ # ============================================================================
202
+ # SOL FEATURE CACHE (attention statistics + spatial importance)
203
+ # ============================================================================
204
+ class SolFeatureCache:
205
+ """
206
+ Precached Sol attention statistics with timestep interpolation.
207
+
208
+ Statistics per sample per timestep:
209
+ - stats: [N, 10, 4] - locality, entropy, clustering, sparsity
210
+ - spatial: [N, 10, 8, 8] - spatial importance map
211
+ """
212
+
213
+ def __init__(self, stats: torch.Tensor, spatial: torch.Tensor,
214
+ t_buckets: torch.Tensor, dtype=torch.float16):
215
+ self.stats = stats.to(dtype) # [N, 10, 4]
216
+ self.spatial = spatial.to(dtype) # [N, 10, 8, 8]
217
+ self.t_buckets = t_buckets
218
+ self.t_min = t_buckets[0].item()
219
+ self.t_max = t_buckets[-1].item()
220
+ self.t_step = (t_buckets[1] - t_buckets[0]).item()
221
+ self.n_buckets = len(t_buckets)
222
+ self.dtype = dtype
223
+
224
+ def get_features(self, indices: torch.Tensor, timesteps: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
225
+ device = timesteps.device
226
+ t_clamped = timesteps.float().clamp(self.t_min, self.t_max)
227
+ t_idx_float = (t_clamped - self.t_min) / self.t_step
228
+ t_idx_low = t_idx_float.long().clamp(0, self.n_buckets - 2)
229
+ t_idx_high = (t_idx_low + 1).clamp(0, self.n_buckets - 1)
230
+
231
+ alpha_stats = (t_idx_float - t_idx_low.float()).unsqueeze(-1)
232
+ alpha_spatial = alpha_stats.unsqueeze(-1)
233
+
234
+ idx_cpu = indices.cpu()
235
+ t_low_cpu = t_idx_low.cpu()
236
+ t_high_cpu = t_idx_high.cpu()
237
+
238
+ s_low = self.stats[idx_cpu, t_low_cpu]
239
+ s_high = self.stats[idx_cpu, t_high_cpu]
240
+ stats_result = (1 - alpha_stats.cpu()) * s_low + alpha_stats.cpu() * s_high
241
+
242
+ sp_low = self.spatial[idx_cpu, t_low_cpu]
243
+ sp_high = self.spatial[idx_cpu, t_high_cpu]
244
+ spatial_result = (1 - alpha_spatial.cpu()) * sp_low + alpha_spatial.cpu() * sp_high
245
+
246
+ return (
247
+ stats_result.to(device=device, dtype=self.dtype),
248
+ spatial_result.to(device=device, dtype=self.dtype)
249
+ )
250
+
251
+
252
+ def load_or_extract_lune_features(cache_path: str, prompts: List[str], name: str,
253
+ clip_tok, clip_enc, t_buckets: torch.Tensor,
254
+ batch_size: int = 32) -> Optional[LuneFeatureCache]:
255
+ """Load cached Lune features or extract from SD1.5-flow teacher."""
256
+ if not prompts or not ENABLE_LUNE_DISTILLATION:
257
+ return None
258
+
259
+ if os.path.exists(cache_path):
260
+ print(f"Loading cached {name} Lune features...")
261
+ cached = torch.load(cache_path, map_location="cpu")
262
+ cache = LuneFeatureCache(cached["features"], cached["t_buckets"], DTYPE)
263
+ print(f" ✓ Loaded {cache.features.shape[0]} samples × {cache.n_buckets} timesteps")
264
+ return cache
265
+
266
+ print(f"Extracting {name} Lune features ({len(prompts)} × {len(t_buckets)} timesteps)...")
267
+ print(f" This is a one-time operation, will be cached.")
268
+
269
+ checkpoint_path = hf_hub_download(
270
+ repo_id=EXPERTS_REPO,
271
+ filename=LUNE_FILENAME,
272
+ )
273
+ print(f" Loaded Lune from {EXPERTS_REPO}/{LUNE_FILENAME}")
274
+
275
+ from diffusers import UNet2DConditionModel
276
+ unet = UNet2DConditionModel.from_pretrained(
277
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
278
+ subfolder="unet",
279
+ torch_dtype=torch.float16,
280
+ ).to(DEVICE).eval()
281
+
282
+ state_dict = load_file(checkpoint_path)
283
+ unet.load_state_dict(state_dict, strict=False)
284
+
285
+ # Convert to fp16 and compile for speed
286
+ unet = unet.half()
287
+ unet = torch.compile(unet, mode="reduce-overhead")
288
+ print(f" ✓ Lune UNet compiled (fp16)")
289
+
290
+ for p in unet.parameters():
291
+ p.requires_grad = False
292
+
293
+ mid_features = [None]
294
+
295
+ def hook_fn(module, inp, out):
296
+ mid_features[0] = out.mean(dim=[2, 3])
297
+
298
+ unet.mid_block.register_forward_hook(hook_fn)
299
+
300
+ n_prompts = len(prompts)
301
+ n_buckets = len(t_buckets)
302
+ all_features = torch.zeros(n_prompts, n_buckets, LUNE_DIM, dtype=torch.float16)
303
+
304
+ # A100 can handle large batches - 64 prompts × 10 timesteps = 640 UNet forward passes batched
305
+ # SD1.5 UNet at 64x64 latents uses ~2GB for batch of 64, so 640 samples ~10-15GB
306
+ LUNE_BATCH_PROMPTS = 64 # Number of prompts per iteration
307
+
308
+ with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16):
309
+ for start_idx in tqdm(range(0, n_prompts, LUNE_BATCH_PROMPTS), desc=f"Extracting {name} Lune"):
310
+ end_idx = min(start_idx + LUNE_BATCH_PROMPTS, n_prompts)
311
+ batch_prompts = prompts[start_idx:end_idx]
312
+ B = len(batch_prompts)
313
+
314
+ # Encode CLIP once per prompt batch
315
+ clip_inputs = clip_tok(
316
+ batch_prompts, return_tensors="pt", padding="max_length",
317
+ max_length=77, truncation=True
318
+ ).to(DEVICE)
319
+ clip_hidden = clip_enc(**clip_inputs).last_hidden_state # [B, 77, 768]
320
+
321
+ # Expand for all timesteps: [B * n_buckets, 77, 768]
322
+ clip_expanded = clip_hidden.unsqueeze(1).expand(-1, n_buckets, -1, -1)
323
+ clip_expanded = clip_expanded.reshape(B * n_buckets, 77, -1)
324
+
325
+ # Create timesteps for all buckets: [B * n_buckets]
326
+ t_expanded = t_buckets.unsqueeze(0).expand(B, -1).reshape(-1).to(DEVICE)
327
+
328
+ # Random latents: [B * n_buckets, 4, 64, 64]
329
+ latents = torch.randn(B * n_buckets, 4, 64, 64, device=DEVICE, dtype=DTYPE)
330
+
331
+ # Single batched UNet forward pass
332
+ _ = unet(latents, t_expanded * 1000, encoder_hidden_states=clip_expanded.to(DTYPE))
333
+
334
+ # Reshape features back to [B, n_buckets, LUNE_DIM]
335
+ features = mid_features[0].reshape(B, n_buckets, -1)
336
+ all_features[start_idx:end_idx] = features.cpu().to(torch.float16)
337
+
338
+ del unet
339
+ torch.cuda.empty_cache()
340
+
341
+ torch.save({"features": all_features, "t_buckets": t_buckets}, cache_path)
342
+ print(f" ✓ Cached to {cache_path}")
343
+ print(f" Size: {all_features.numel() * 2 / 1e9:.2f} GB")
344
+
345
+ return LuneFeatureCache(all_features, t_buckets, DTYPE)
346
+
347
+
348
+ def load_or_extract_sol_features(cache_path: str, prompts: List[str], name: str,
349
+ clip_tok, clip_enc, t_buckets: torch.Tensor,
350
+ spatial_size: int = 8,
351
+ batch_size: int = 32) -> Optional[SolFeatureCache]:
352
+ """Load cached Sol features or generate geometric heuristics."""
353
+ if not prompts or not ENABLE_SOL_DISTILLATION:
354
+ return None
355
+
356
+ if os.path.exists(cache_path):
357
+ print(f"Loading cached {name} Sol features...")
358
+ cached = torch.load(cache_path, map_location="cpu")
359
+ cache = SolFeatureCache(
360
+ cached["stats"], cached["spatial"], cached["t_buckets"], DTYPE
361
+ )
362
+ print(f" ✓ Loaded {cache.stats.shape[0]} samples × {cache.n_buckets} timesteps")
363
+ return cache
364
+
365
+ print(f"Generating {name} Sol features ({len(prompts)} × {len(t_buckets)} timesteps)...")
366
+ print(f" Using geometric heuristics (no teacher needed)")
367
+
368
+ n_prompts = len(prompts)
369
+ n_buckets = len(t_buckets)
370
+
371
+ # Vectorized generation - no loops needed
372
+ # Stats: [n_buckets, 4] then broadcast to [n_prompts, n_buckets, 4]
373
+ t_vals = t_buckets.float() # [n_buckets]
374
+
375
+ locality = 1 - t_vals # [n_buckets]
376
+ entropy = t_vals
377
+ clustering = 0.5 - 0.3 * (t_vals - 0.5).abs()
378
+ sparsity = 1 - t_vals
379
+
380
+ stats_per_t = torch.stack([locality, entropy, clustering, sparsity], dim=-1) # [n_buckets, 4]
381
+ all_stats = stats_per_t.unsqueeze(0).expand(n_prompts, -1, -1).to(torch.float16) # [n_prompts, n_buckets, 4]
382
+
383
+ # Spatial: [n_buckets, spatial_size, spatial_size] then broadcast
384
+ y, x = torch.meshgrid(
385
+ torch.linspace(-1, 1, spatial_size),
386
+ torch.linspace(-1, 1, spatial_size),
387
+ indexing='ij'
388
+ )
389
+ center_dist = torch.sqrt(x**2 + y**2) # [spatial_size, spatial_size]
390
+
391
+ # Vectorized across timesteps: [n_buckets, spatial_size, spatial_size]
392
+ t_weight = (1 - t_vals).view(-1, 1, 1) # [n_buckets, 1, 1]
393
+ center_bias = 1 - center_dist.unsqueeze(0) * t_weight # [n_buckets, spatial_size, spatial_size]
394
+ center_bias = center_bias / center_bias.sum(dim=[-2, -1], keepdim=True) # Normalize per timestep
395
+
396
+ all_spatial = center_bias.unsqueeze(0).expand(n_prompts, -1, -1, -1).to(torch.float16) # [n_prompts, n_buckets, 8, 8]
397
+
398
+ torch.save({
399
+ "stats": all_stats,
400
+ "spatial": all_spatial,
401
+ "t_buckets": t_buckets
402
+ }, cache_path)
403
+ print(f" ✓ Cached to {cache_path}")
404
+
405
+ return SolFeatureCache(all_stats, all_spatial, t_buckets, DTYPE)
406
+
407
+
408
+
409
+ # ============================================================================
410
+ # EMA
411
+ # ============================================================================
412
+ class EMA:
413
+ def __init__(self, model, decay=0.9999):
414
+ self.decay = decay
415
+ self.shadow = {}
416
+ self._backup = {}
417
+ if hasattr(model, '_orig_mod'):
418
+ state = model._orig_mod.state_dict()
419
+ else:
420
+ state = model.state_dict()
421
+ for k, v in state.items():
422
+ self.shadow[k] = v.clone().detach()
423
+
424
+ @torch.no_grad()
425
+ def update(self, model):
426
+ if hasattr(model, '_orig_mod'):
427
+ state = model._orig_mod.state_dict()
428
+ else:
429
+ state = model.state_dict()
430
+ for k, v in state.items():
431
+ if k in self.shadow:
432
+ self.shadow[k].lerp_(v.to(self.shadow[k].dtype), 1 - self.decay)
433
+
434
+ def apply_shadow_for_eval(self, model):
435
+ if hasattr(model, '_orig_mod'):
436
+ self._backup = {k: v.clone() for k, v in model._orig_mod.state_dict().items()}
437
+ model._orig_mod.load_state_dict(self.shadow)
438
+ else:
439
+ self._backup = {k: v.clone() for k, v in model.state_dict().items()}
440
+ model.load_state_dict(self.shadow)
441
+
442
+ def restore(self, model):
443
+ if hasattr(model, '_orig_mod'):
444
+ model._orig_mod.load_state_dict(self._backup)
445
+ else:
446
+ model.load_state_dict(self._backup)
447
+ self._backup = {}
448
+
449
+ def state_dict(self):
450
+ return {'shadow': self.shadow, 'decay': self.decay}
451
+
452
+ def sync_from_model(self, model, pattern=None):
453
+ if hasattr(model, '_orig_mod'):
454
+ model_state = model._orig_mod.state_dict()
455
+ else:
456
+ model_state = model.state_dict()
457
+
458
+ synced = 0
459
+ for k, v in model_state.items():
460
+ if pattern is None or pattern in k:
461
+ if k in self.shadow:
462
+ self.shadow[k] = v.clone().to(self.shadow[k].device)
463
+ synced += 1
464
+
465
+ print(f" ✓ Synced EMA: {synced} weights" + (f" matching '{pattern}'" if pattern else ""))
466
+
467
+ def load_state_dict(self, state):
468
+ self.shadow = {k: v.clone() for k, v in state['shadow'].items()}
469
+ self.decay = state.get('decay', self.decay)
470
+
471
+ def load_shadow(self, shadow_state, model=None):
472
+ device = next(iter(self.shadow.values())).device if self.shadow else 'cuda'
473
+
474
+ loaded = 0
475
+ skipped_old = 0
476
+ initialized_from_model = 0
477
+
478
+ for k, v in shadow_state.items():
479
+ if k in self.shadow:
480
+ self.shadow[k] = v.clone().to(device)
481
+ loaded += 1
482
+ else:
483
+ skipped_old += 1
484
+
485
+ if model is not None:
486
+ if hasattr(model, '_orig_mod'):
487
+ model_state = model._orig_mod.state_dict()
488
+ else:
489
+ model_state = model.state_dict()
490
+
491
+ for k in self.shadow:
492
+ if k not in shadow_state and k in model_state:
493
+ self.shadow[k] = model_state[k].clone().to(device)
494
+ initialized_from_model += 1
495
+
496
+ print(f" ✓ Restored EMA: {loaded} loaded, {skipped_old} deprecated, {initialized_from_model} new (from model)")
497
+
498
+
499
+ # ============================================================================
500
+ # REGULARIZATION
501
+ # ============================================================================
502
+ def apply_text_dropout(t5_embeds, clip_pooled, dropout_prob=0.1):
503
+ B = t5_embeds.shape[0]
504
+ mask = torch.rand(B, device=t5_embeds.device) < dropout_prob
505
+ t5_embeds = t5_embeds.clone()
506
+ clip_pooled = clip_pooled.clone()
507
+ t5_embeds[mask] = 0
508
+ clip_pooled[mask] = 0
509
+ return t5_embeds, clip_pooled, mask
510
+
511
+
512
+ # ============================================================================
513
+ # MASKING UTILITIES
514
+ # ============================================================================
515
+ def detect_background_color(image: Image.Image, sample_size: int = 100) -> Tuple[int, int, int]:
516
+ img = np.array(image)
517
+ if len(img.shape) == 2:
518
+ img = np.stack([img] * 3, axis=-1)
519
+ h, w = img.shape[:2]
520
+ corners = [
521
+ img[:sample_size, :sample_size],
522
+ img[:sample_size, -sample_size:],
523
+ img[-sample_size:, :sample_size],
524
+ img[-sample_size:, -sample_size:],
525
+ ]
526
+ corner_pixels = np.concatenate([c.reshape(-1, 3) for c in corners], axis=0)
527
+ bg_color = np.median(corner_pixels, axis=0).astype(np.uint8)
528
+ return tuple(bg_color)
529
+
530
+
531
+ def create_product_mask(image: Image.Image, threshold: int = 30) -> np.ndarray:
532
+ img = np.array(image).astype(np.float32)
533
+ if len(img.shape) == 2:
534
+ img = np.stack([img] * 3, axis=-1)
535
+ bg_color = detect_background_color(image)
536
+ bg_color = np.array(bg_color, dtype=np.float32)
537
+ diff = np.sqrt(np.sum((img - bg_color) ** 2, axis=-1))
538
+ mask = (diff > threshold).astype(np.float32)
539
+ return mask
540
+
541
+
542
+ def create_smpl_mask(conditioning_image: Image.Image, threshold: int = 20) -> np.ndarray:
543
+ img = np.array(conditioning_image).astype(np.float32)
544
+ if len(img.shape) == 2:
545
+ return (img > threshold).astype(np.float32)
546
+ r, g, b = img[:, :, 0], img[:, :, 1], img[:, :, 2]
547
+ is_background = (g > r + 20) & (g > b + 20)
548
+ mask = (~is_background).astype(np.float32)
549
+ return mask
550
+
551
+
552
+ def downsample_mask_to_latent(mask: np.ndarray, latent_h: int = 64, latent_w: int = 64) -> torch.Tensor:
553
+ mask_pil = Image.fromarray((mask * 255).astype(np.uint8))
554
+ mask_pil = mask_pil.resize((latent_w, latent_h), Image.Resampling.BILINEAR)
555
+ mask_latent = np.array(mask_pil).astype(np.float32) / 255.0
556
+ return torch.from_numpy(mask_latent)
557
+
558
+
559
+ # ============================================================================
560
+ # HF HUB SETUP
561
+ # ============================================================================
562
+ print("Setting up HuggingFace Hub...")
563
+ api = HfApi()
564
+
565
+
566
+ # ============================================================================
567
+ # FLOW MATCHING HELPERS
568
+ # ============================================================================
569
+ def flux_shift(t, s=SHIFT):
570
+ return s * t / (1 + (s - 1) * t)
571
+
572
+
573
+ def min_snr_weight(t, gamma=MIN_SNR_GAMMA):
574
+ snr = (t / (1 - t).clamp(min=1e-5)).pow(2)
575
+ return torch.clamp(snr, max=gamma) / snr.clamp(min=1e-5)
576
+
577
+
578
+ # ============================================================================
579
+ # LOAD TEXT ENCODERS
580
+ # ============================================================================
581
+ print("Loading text encoders...")
582
+ t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
583
+ t5_enc = T5EncoderModel.from_pretrained("google/flan-t5-base", torch_dtype=DTYPE).to(DEVICE).eval()
584
+ for p in t5_enc.parameters():
585
+ p.requires_grad = False
586
+
587
+ clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
588
+ clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval()
589
+ for p in clip_enc.parameters():
590
+ p.requires_grad = False
591
+ print("✓ Text encoders loaded")
592
+
593
+ # ============================================================================
594
+ # LOAD VAE
595
+ # ============================================================================
596
+ print("Loading VAE...")
597
+ from diffusers import AutoencoderKL
598
+
599
+ vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=DTYPE).to(
600
+ DEVICE).eval()
601
+ for p in vae.parameters():
602
+ p.requires_grad = False
603
+ VAE_SCALE = vae.config.scaling_factor
604
+ print(f"✓ VAE loaded (scale={VAE_SCALE})")
605
+
606
+
607
+ # ============================================================================
608
+ # ENCODING FUNCTIONS
609
+ # ============================================================================
610
+ @torch.no_grad()
611
+ def encode_prompt(prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
612
+ t5_inputs = t5_tok(prompt, return_tensors="pt", padding="max_length",
613
+ max_length=MAX_SEQ, truncation=True).to(DEVICE)
614
+ t5_out = t5_enc(**t5_inputs).last_hidden_state
615
+ clip_inputs = clip_tok(prompt, return_tensors="pt", padding="max_length",
616
+ max_length=77, truncation=True).to(DEVICE)
617
+ clip_out = clip_enc(**clip_inputs).pooler_output
618
+ return t5_out.squeeze(0), clip_out.squeeze(0)
619
+
620
+
621
+ @torch.no_grad()
622
+ @torch.no_grad()
623
+ def encode_prompts_batched(prompts: List[str], batch_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
624
+ """Batch encode prompts with T5 and CLIP."""
625
+ all_t5 = []
626
+ all_clip = []
627
+ for i in tqdm(range(0, len(prompts), batch_size), desc="Encoding prompts", leave=False):
628
+ batch = prompts[i:i + batch_size]
629
+ t5_inputs = t5_tok(batch, return_tensors="pt", padding="max_length",
630
+ max_length=MAX_SEQ, truncation=True).to(DEVICE)
631
+ t5_out = t5_enc(**t5_inputs).last_hidden_state
632
+ all_t5.append(t5_out.cpu())
633
+ clip_inputs = clip_tok(batch, return_tensors="pt", padding="max_length",
634
+ max_length=77, truncation=True).to(DEVICE)
635
+ clip_out = clip_enc(**clip_inputs).pooler_output
636
+ all_clip.append(clip_out.cpu())
637
+ return torch.cat(all_t5, dim=0), torch.cat(all_clip, dim=0)
638
+
639
+
640
+ @torch.no_grad()
641
+ def encode_image_to_latent(image: Image.Image) -> torch.Tensor:
642
+ if image.mode != "RGB":
643
+ image = image.convert("RGB")
644
+ if image.size != (512, 512):
645
+ image = image.resize((512, 512), Image.Resampling.LANCZOS)
646
+ img_tensor = torch.from_numpy(np.array(image)).float() / 255.0
647
+ img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0)
648
+ img_tensor = (img_tensor * 2.0 - 1.0).to(DEVICE, dtype=DTYPE)
649
+ latent = vae.encode(img_tensor).latent_dist.sample()
650
+ latent = latent * VAE_SCALE
651
+ return latent.squeeze(0).cpu()
652
+
653
+
654
+
655
+ # ============================================================================
656
+ # LOAD DATASETS
657
+ # ============================================================================
658
+
659
+ portrait_ds = None
660
+ portrait_indices = []
661
+ portrait_prompts = []
662
+
663
+ if ENABLE_PORTRAIT:
664
+ print(f"\n[1/6] Loading portrait dataset from {PORTRAIT_REPO}...")
665
+ portrait_shards = []
666
+ for i in range(PORTRAIT_NUM_SHARDS):
667
+ split_name = f"train_{i:02d}"
668
+ print(f" Loading {split_name}...")
669
+ shard = load_dataset(PORTRAIT_REPO, split=split_name)
670
+ portrait_shards.append(shard)
671
+ portrait_ds = concatenate_datasets(portrait_shards)
672
+ print(f"✓ Portrait: {len(portrait_ds)} base samples")
673
+ print(" Extracting prompts (columnar)...")
674
+ florence_list = list(portrait_ds["text_florence"])
675
+ llava_list = list(portrait_ds["text_llava"])
676
+ blip_list = list(portrait_ds["text_blip"])
677
+ for i, (f, l, b) in enumerate(zip(florence_list, llava_list, blip_list)):
678
+ if f and f.strip():
679
+ portrait_indices.append(i)
680
+ portrait_prompts.append(f)
681
+ if l and l.strip():
682
+ portrait_indices.append(i)
683
+ portrait_prompts.append(l)
684
+ if b and b.strip():
685
+ portrait_indices.append(i)
686
+ portrait_prompts.append(b)
687
+ print(f" Expanded: {len(portrait_prompts)} samples (3 prompts/image)")
688
+ else:
689
+ print("\n[1/6] Portrait dataset DISABLED")
690
+
691
+ schnell_ds = None
692
+ schnell_prompts = []
693
+
694
+ if ENABLE_SCHNELL:
695
+ print(f"\n[2/6] Loading schnell teacher dataset from {SCHNELL_REPO}...")
696
+ schnell_datasets = []
697
+ for config in SCHNELL_CONFIGS:
698
+ print(f" Loading {config}...")
699
+ ds = load_dataset(SCHNELL_REPO, config, split="train")
700
+ schnell_datasets.append(ds)
701
+ print(f" {len(ds)} samples")
702
+ schnell_ds = concatenate_datasets(schnell_datasets)
703
+ schnell_prompts = list(schnell_ds["prompt"])
704
+ print(f"✓ Schnell: {len(schnell_ds)} samples")
705
+ else:
706
+ print("\n[2/6] Schnell dataset DISABLED")
707
+
708
+ sportfashion_ds = None
709
+ sportfashion_prompts = []
710
+ sportfashion_latents = None
711
+ sportfashion_masks = None
712
+
713
+ if ENABLE_SPORTFASHION:
714
+ print(f"\n[3/6] Loading SportFashion dataset from {SPORTFASHION_REPO}...")
715
+ sportfashion_ds = load_dataset(SPORTFASHION_REPO, split="train")
716
+ sportfashion_prompts = list(sportfashion_ds["text"])
717
+ print(f"✓ SportFashion: {len(sportfashion_ds)} samples")
718
+
719
+ # Precache latents and masks
720
+ sportfashion_latent_cache = os.path.join(LATENT_CACHE_DIR, f"sportfashion_latents_{len(sportfashion_ds)}.pt")
721
+ sportfashion_mask_cache = os.path.join(LATENT_CACHE_DIR, f"sportfashion_masks_{len(sportfashion_ds)}.pt")
722
+
723
+ if os.path.exists(sportfashion_latent_cache):
724
+ print(f" Loading cached SportFashion latents...")
725
+ sportfashion_latents = torch.load(sportfashion_latent_cache)
726
+ print(f" ✓ Loaded {len(sportfashion_latents)} latents")
727
+ if os.path.exists(sportfashion_mask_cache):
728
+ sportfashion_masks = torch.load(sportfashion_mask_cache)
729
+ print(f" ✓ Loaded {len(sportfashion_masks)} masks")
730
+ else:
731
+ print(f" Encoding SportFashion images to latents (one-time)...")
732
+ VAE_BATCH_SIZE = 64 # A100 can handle large batches
733
+ sportfashion_latents = []
734
+ sportfashion_masks = []
735
+ with torch.no_grad():
736
+ for start_idx in tqdm(range(0, len(sportfashion_ds), VAE_BATCH_SIZE), desc="Encoding latents"):
737
+ end_idx = min(start_idx + VAE_BATCH_SIZE, len(sportfashion_ds))
738
+ batch_images = []
739
+ batch_masks = []
740
+ for i in range(start_idx, end_idx):
741
+ image = sportfashion_ds[i]["image"]
742
+ if image.mode != "RGB":
743
+ image = image.convert("RGB")
744
+ if image.size != (512, 512):
745
+ image = image.resize((512, 512), Image.Resampling.LANCZOS)
746
+ img_tensor = torch.from_numpy(np.array(image)).float() / 255.0
747
+ img_tensor = img_tensor.permute(2, 0, 1)
748
+ batch_images.append(img_tensor)
749
+ # Create mask
750
+ pixel_mask = create_product_mask(image)
751
+ mask = downsample_mask_to_latent(pixel_mask, 64, 64)
752
+ batch_masks.append(mask)
753
+ batch_tensor = torch.stack(batch_images)
754
+ batch_tensor = (batch_tensor * 2.0 - 1.0).to(DEVICE, dtype=DTYPE)
755
+ latents = vae.encode(batch_tensor).latent_dist.sample()
756
+ latents = latents * VAE_SCALE
757
+ sportfashion_latents.append(latents.cpu())
758
+ sportfashion_masks.extend(batch_masks)
759
+ sportfashion_latents = torch.cat(sportfashion_latents, dim=0)
760
+ sportfashion_masks = torch.stack(sportfashion_masks)
761
+ torch.save(sportfashion_latents, sportfashion_latent_cache)
762
+ torch.save(sportfashion_masks, sportfashion_mask_cache)
763
+ print(f" ✓ Cached to {sportfashion_latent_cache}")
764
+ else:
765
+ print("\n[3/6] SportFashion dataset DISABLED")
766
+
767
+ synthmocap_ds = None
768
+ synthmocap_prompts = []
769
+ synthmocap_latents = None
770
+ synthmocap_masks = None
771
+
772
+ if ENABLE_SYNTHMOCAP:
773
+ print(f"\n[4/6] Loading SynthMoCap dataset from {SYNTHMOCAP_REPO}...")
774
+ synthmocap_ds = load_dataset(SYNTHMOCAP_REPO, split="train")
775
+ synthmocap_prompts = list(synthmocap_ds["text"])
776
+ print(f"✓ SynthMoCap: {len(synthmocap_ds)} samples")
777
+
778
+ # Precache latents and masks
779
+ synthmocap_latent_cache = os.path.join(LATENT_CACHE_DIR, f"synthmocap_latents_{len(synthmocap_ds)}.pt")
780
+ synthmocap_mask_cache = os.path.join(LATENT_CACHE_DIR, f"synthmocap_masks_{len(synthmocap_ds)}.pt")
781
+
782
+ if os.path.exists(synthmocap_latent_cache):
783
+ print(f" Loading cached SynthMoCap latents...")
784
+ synthmocap_latents = torch.load(synthmocap_latent_cache)
785
+ print(f" ✓ Loaded {len(synthmocap_latents)} latents")
786
+ if os.path.exists(synthmocap_mask_cache):
787
+ synthmocap_masks = torch.load(synthmocap_mask_cache)
788
+ print(f" ✓ Loaded {len(synthmocap_masks)} masks")
789
+ else:
790
+ print(f" Encoding SynthMoCap images to latents (one-time)...")
791
+ VAE_BATCH_SIZE = 64 # A100 can handle large batches
792
+ synthmocap_latents = []
793
+ synthmocap_masks = []
794
+ with torch.no_grad():
795
+ for start_idx in tqdm(range(0, len(synthmocap_ds), VAE_BATCH_SIZE), desc="Encoding latents"):
796
+ end_idx = min(start_idx + VAE_BATCH_SIZE, len(synthmocap_ds))
797
+ batch_images = []
798
+ batch_masks = []
799
+ for i in range(start_idx, end_idx):
800
+ image = synthmocap_ds[i]["image"]
801
+ conditioning = synthmocap_ds[i]["conditioning_image"]
802
+ if image.mode != "RGB":
803
+ image = image.convert("RGB")
804
+ if image.size != (512, 512):
805
+ image = image.resize((512, 512), Image.Resampling.LANCZOS)
806
+ img_tensor = torch.from_numpy(np.array(image)).float() / 255.0
807
+ img_tensor = img_tensor.permute(2, 0, 1)
808
+ batch_images.append(img_tensor)
809
+ # Create mask from conditioning image
810
+ pixel_mask = create_smpl_mask(conditioning)
811
+ mask = downsample_mask_to_latent(pixel_mask, 64, 64)
812
+ batch_masks.append(mask)
813
+ batch_tensor = torch.stack(batch_images)
814
+ batch_tensor = (batch_tensor * 2.0 - 1.0).to(DEVICE, dtype=DTYPE)
815
+ latents = vae.encode(batch_tensor).latent_dist.sample()
816
+ latents = latents * VAE_SCALE
817
+ synthmocap_latents.append(latents.cpu())
818
+ synthmocap_masks.extend(batch_masks)
819
+ synthmocap_latents = torch.cat(synthmocap_latents, dim=0)
820
+ synthmocap_masks = torch.stack(synthmocap_masks)
821
+ torch.save(synthmocap_latents, synthmocap_latent_cache)
822
+ torch.save(synthmocap_masks, synthmocap_mask_cache)
823
+ print(f" ✓ Cached to {synthmocap_latent_cache}")
824
+ else:
825
+ print("\n[4/6] SynthMoCap dataset DISABLED")
826
+
827
+ # ============================================================================
828
+ # IMAGENET DATASET WITH SMART PROMPT FILTERING
829
+ # ============================================================================
830
+ imagenet_ds = None
831
+ imagenet_prompts = []
832
+
833
+
834
+ def build_imagenet_prompt(item):
835
+ semantic_class = item.get("semantic_class", "object")
836
+ semantic_subclass = item.get("semantic_subclass", "")
837
+ label = item.get("label", "").replace("_", " ")
838
+ base_prompt = item.get("prompt", "")
839
+ synset_id = item.get("synset_id", "")
840
+
841
+ pred_confidence = item.get("pred_confidence", 0.0)
842
+ top1_correct = item.get("top1_correct", False)
843
+ top5_correct = item.get("top5_correct", False)
844
+
845
+ confident_but_wrong = (
846
+ pred_confidence >= IMAGENET_CONFIDENCE_THRESHOLD and
847
+ not top1_correct and
848
+ not top5_correct
849
+ )
850
+
851
+ if confident_but_wrong:
852
+ parts = ["subject", semantic_class]
853
+ if semantic_subclass:
854
+ parts.append(semantic_subclass)
855
+ parts.append(base_prompt)
856
+ parts.append(synset_id)
857
+ parts.append("imagenet")
858
+ else:
859
+ parts = ["subject", semantic_class]
860
+ if semantic_subclass:
861
+ parts.append(semantic_subclass)
862
+ if label:
863
+ parts.append(label)
864
+ parts.append(base_prompt)
865
+ parts.append(synset_id)
866
+ parts.append("imagenet")
867
+
868
+ return ", ".join(p for p in parts if p)
869
+
870
+
871
+ if ENABLE_IMAGENET:
872
+ print(f"\n[5/6] Loading Synthetic ImageNet from {IMAGENET_REPO}...")
873
+ imagenet_ds = load_dataset(IMAGENET_REPO, IMAGENET_SUBSET, split="train")
874
+ print(f" Raw samples: {len(imagenet_ds)}")
875
+
876
+ # Use columnar access - MUCH faster than row iteration
877
+ print(f" Building prompts...")
878
+ semantic_classes = imagenet_ds["semantic_class"]
879
+ semantic_subclasses = imagenet_ds.get("semantic_subclass", [""] * len(imagenet_ds)) if "semantic_subclass" in imagenet_ds.features else [""] * len(imagenet_ds)
880
+ labels = imagenet_ds["label"]
881
+ base_prompts = imagenet_ds["prompt"]
882
+ synset_ids = imagenet_ds["synset_id"]
883
+ pred_confidences = imagenet_ds.get("pred_confidence", [0.0] * len(imagenet_ds)) if "pred_confidence" in imagenet_ds.features else [0.0] * len(imagenet_ds)
884
+ top1_corrects = imagenet_ds.get("top1_correct", [False] * len(imagenet_ds)) if "top1_correct" in imagenet_ds.features else [False] * len(imagenet_ds)
885
+ top5_corrects = imagenet_ds.get("top5_correct", [False] * len(imagenet_ds)) if "top5_correct" in imagenet_ds.features else [False] * len(imagenet_ds)
886
+
887
+ # Handle case where columns might not exist
888
+ if not isinstance(semantic_subclasses, list):
889
+ semantic_subclasses = list(semantic_subclasses) if semantic_subclasses else [""] * len(imagenet_ds)
890
+ if not isinstance(pred_confidences, list):
891
+ pred_confidences = list(pred_confidences) if pred_confidences else [0.0] * len(imagenet_ds)
892
+ if not isinstance(top1_corrects, list):
893
+ top1_corrects = list(top1_corrects) if top1_corrects else [False] * len(imagenet_ds)
894
+ if not isinstance(top5_corrects, list):
895
+ top5_corrects = list(top5_corrects) if top5_corrects else [False] * len(imagenet_ds)
896
+
897
+ confident_wrong = 0
898
+ for i in range(len(imagenet_ds)):
899
+ semantic_class = semantic_classes[i] if semantic_classes[i] else "object"
900
+ semantic_subclass = semantic_subclasses[i] if i < len(semantic_subclasses) else ""
901
+ label = labels[i].replace("_", " ") if labels[i] else ""
902
+ base_prompt = base_prompts[i] if base_prompts[i] else ""
903
+ synset_id = synset_ids[i] if synset_ids[i] else ""
904
+ pred_confidence = pred_confidences[i] if i < len(pred_confidences) else 0.0
905
+ top1_correct = top1_corrects[i] if i < len(top1_corrects) else False
906
+ top5_correct = top5_corrects[i] if i < len(top5_corrects) else False
907
+
908
+ confident_but_wrong = (
909
+ pred_confidence >= IMAGENET_CONFIDENCE_THRESHOLD and
910
+ not top1_correct and
911
+ not top5_correct
912
+ )
913
+
914
+ if confident_but_wrong:
915
+ parts = ["subject", semantic_class]
916
+ if semantic_subclass:
917
+ parts.append(semantic_subclass)
918
+ parts.append(base_prompt)
919
+ parts.append(synset_id)
920
+ parts.append("imagenet")
921
+ confident_wrong += 1
922
+ else:
923
+ parts = ["subject", semantic_class]
924
+ if semantic_subclass:
925
+ parts.append(semantic_subclass)
926
+ if label:
927
+ parts.append(label)
928
+ parts.append(base_prompt)
929
+ parts.append(synset_id)
930
+ parts.append("imagenet")
931
+
932
+ imagenet_prompts.append(", ".join(p for p in parts if p))
933
+
934
+ print(f"✓ ImageNet: {len(imagenet_ds)} samples")
935
+ print(f" Confident mispredictions (label removed): {confident_wrong}")
936
+
937
+ imagenet_latent_cache = os.path.join(LATENT_CACHE_DIR, f"imagenet_latents_{len(imagenet_ds)}.pt")
938
+ if os.path.exists(imagenet_latent_cache):
939
+ print(f" Loading cached ImageNet latents...")
940
+ imagenet_latents = torch.load(imagenet_latent_cache)
941
+ print(f" ✓ Loaded {len(imagenet_latents)} latents")
942
+ else:
943
+ print(f" Encoding ImageNet images to latents (one-time)...")
944
+ VAE_BATCH_SIZE = 64 # A100 can handle large batches
945
+ imagenet_latents = []
946
+ with torch.no_grad():
947
+ for start_idx in tqdm(range(0, len(imagenet_ds), VAE_BATCH_SIZE), desc="Encoding latents"):
948
+ end_idx = min(start_idx + VAE_BATCH_SIZE, len(imagenet_ds))
949
+ batch_images = []
950
+ for i in range(start_idx, end_idx):
951
+ image = imagenet_ds[i]["image"]
952
+ if image.mode != "RGB":
953
+ image = image.convert("RGB")
954
+ if image.size != (512, 512):
955
+ image = image.resize((512, 512), Image.Resampling.LANCZOS)
956
+ img_tensor = torch.from_numpy(np.array(image)).float() / 255.0
957
+ img_tensor = img_tensor.permute(2, 0, 1)
958
+ batch_images.append(img_tensor)
959
+ batch_tensor = torch.stack(batch_images)
960
+ batch_tensor = (batch_tensor * 2.0 - 1.0).to(DEVICE, dtype=DTYPE)
961
+ latents = vae.encode(batch_tensor).latent_dist.sample()
962
+ latents = latents * VAE_SCALE
963
+ imagenet_latents.append(latents.cpu())
964
+ imagenet_latents = torch.cat(imagenet_latents, dim=0)
965
+ torch.save(imagenet_latents, imagenet_latent_cache)
966
+ print(f" ✓ Cached to {imagenet_latent_cache}")
967
+ else:
968
+ print("\n[5/6] ImageNet dataset DISABLED")
969
+ imagenet_latents = None
970
+
971
+ # ============================================================================
972
+ # OBJECT RELATIONS DATASET WITH SUBJECT PREFIX
973
+ # ============================================================================
974
+ object_relations_ds = None
975
+ object_relations_prompts = []
976
+ object_relations_latents = None
977
+
978
+
979
+ def build_object_relations_prompt(item):
980
+ prompt = item.get("prompt", "")
981
+ if random.random() < 0.5:
982
+ return f"subject, object, {prompt}"
983
+ else:
984
+ return f"subject, {prompt}"
985
+
986
+
987
+ if ENABLE_OBJECT_RELATIONS:
988
+ print(f"\n[6/6] Loading Object Relations from {OBJECT_RELATIONS_REPO}...")
989
+ object_relations_ds = load_dataset(OBJECT_RELATIONS_REPO, split="train")
990
+ print(f" Raw samples: {len(object_relations_ds)}")
991
+
992
+ # Use columnar access - MUCH faster than row iteration
993
+ print(f" Building prompts...")
994
+ all_prompts = object_relations_ds["prompt"] # Get entire column at once
995
+
996
+ random.seed(42)
997
+ object_relations_prompts = []
998
+ for prompt in all_prompts:
999
+ if random.random() < 0.5:
1000
+ object_relations_prompts.append(f"subject, object, {prompt}")
1001
+ else:
1002
+ object_relations_prompts.append(f"subject, {prompt}")
1003
+ random.seed()
1004
+
1005
+ subject_object_count = sum(1 for p in object_relations_prompts if p.startswith("subject, object,"))
1006
+ subject_only_count = len(object_relations_prompts) - subject_object_count
1007
+ print(f"✓ Object Relations: {len(object_relations_ds)} samples")
1008
+ print(f" 'subject, object, ...' prefix: {subject_object_count}")
1009
+ print(f" 'subject, ...' prefix: {subject_only_count}")
1010
+
1011
+ object_relations_latent_cache = os.path.join(LATENT_CACHE_DIR, f"object_relations_latents_{len(object_relations_ds)}.pt")
1012
+ if os.path.exists(object_relations_latent_cache):
1013
+ print(f" Loading cached Object Relations latents...")
1014
+ object_relations_latents = torch.load(object_relations_latent_cache)
1015
+ print(f" ✓ Loaded {len(object_relations_latents)} latents")
1016
+ else:
1017
+ print(f" Encoding Object Relations images to latents (one-time)...")
1018
+ VAE_BATCH_SIZE = 64 # A100 can handle large batches
1019
+ object_relations_latents = []
1020
+ with torch.no_grad():
1021
+ for start_idx in tqdm(range(0, len(object_relations_ds), VAE_BATCH_SIZE), desc="Encoding latents"):
1022
+ end_idx = min(start_idx + VAE_BATCH_SIZE, len(object_relations_ds))
1023
+ batch_images = []
1024
+ for i in range(start_idx, end_idx):
1025
+ image = object_relations_ds[i]["image"]
1026
+ if image.mode != "RGB":
1027
+ image = image.convert("RGB")
1028
+ if image.size != (512, 512):
1029
+ image = image.resize((512, 512), Image.Resampling.LANCZOS)
1030
+ img_tensor = torch.from_numpy(np.array(image)).float() / 255.0
1031
+ img_tensor = img_tensor.permute(2, 0, 1)
1032
+ batch_images.append(img_tensor)
1033
+ batch_tensor = torch.stack(batch_images)
1034
+ batch_tensor = (batch_tensor * 2.0 - 1.0).to(DEVICE, dtype=DTYPE)
1035
+ latents = vae.encode(batch_tensor).latent_dist.sample()
1036
+ latents = latents * VAE_SCALE
1037
+ object_relations_latents.append(latents.cpu())
1038
+ object_relations_latents = torch.cat(object_relations_latents, dim=0)
1039
+ torch.save(object_relations_latents, object_relations_latent_cache)
1040
+ print(f" ✓ Cached to {object_relations_latent_cache}")
1041
+ else:
1042
+ print("\n[6/6] Object Relations dataset DISABLED")
1043
+
1044
+ # ============================================================================
1045
+ # ENCODE ALL PROMPTS
1046
+ # ============================================================================
1047
+ total_samples = len(portrait_prompts) + len(schnell_prompts) + len(sportfashion_prompts) + len(synthmocap_prompts) + len(imagenet_prompts) + len(object_relations_prompts)
1048
+ print(f"\nTotal combined samples: {total_samples}")
1049
+
1050
+
1051
+ def load_or_encode(cache_path, prompts, name):
1052
+ if not prompts:
1053
+ return None, None
1054
+ if os.path.exists(cache_path):
1055
+ print(f"Loading cached {name} encodings...")
1056
+ cached = torch.load(cache_path)
1057
+ return cached["t5_embeds"], cached["clip_pooled"]
1058
+ else:
1059
+ print(f"Encoding {len(prompts)} {name} prompts...")
1060
+ t5, clip = encode_prompts_batched(prompts, batch_size=64)
1061
+ torch.save({"t5_embeds": t5, "clip_pooled": clip}, cache_path)
1062
+ print(f"✓ Cached to {cache_path}")
1063
+ return t5, clip
1064
+
1065
+
1066
+ portrait_t5, portrait_clip = None, None
1067
+ schnell_t5, schnell_clip = None, None
1068
+ sportfashion_t5, sportfashion_clip = None, None
1069
+ synthmocap_t5, synthmocap_clip = None, None
1070
+
1071
+ if portrait_prompts:
1072
+ portrait_enc_cache = os.path.join(ENCODING_CACHE_DIR, f"portrait_encodings_{len(portrait_prompts)}.pt")
1073
+ portrait_t5, portrait_clip = load_or_encode(portrait_enc_cache, portrait_prompts, "portrait")
1074
+
1075
+ if schnell_prompts:
1076
+ schnell_enc_cache = os.path.join(ENCODING_CACHE_DIR, f"schnell_encodings_{len(schnell_prompts)}.pt")
1077
+ schnell_t5, schnell_clip = load_or_encode(schnell_enc_cache, schnell_prompts, "schnell")
1078
+
1079
+ if sportfashion_prompts:
1080
+ sportfashion_enc_cache = os.path.join(ENCODING_CACHE_DIR, f"sportfashion_encodings_{len(sportfashion_prompts)}.pt")
1081
+ sportfashion_t5, sportfashion_clip = load_or_encode(sportfashion_enc_cache, sportfashion_prompts, "sportfashion")
1082
+
1083
+ if synthmocap_prompts:
1084
+ synthmocap_enc_cache = os.path.join(ENCODING_CACHE_DIR, f"synthmocap_encodings_{len(synthmocap_prompts)}.pt")
1085
+ synthmocap_t5, synthmocap_clip = load_or_encode(synthmocap_enc_cache, synthmocap_prompts, "synthmocap")
1086
+
1087
+ imagenet_t5, imagenet_clip = None, None
1088
+ if imagenet_prompts:
1089
+ imagenet_enc_cache = os.path.join(ENCODING_CACHE_DIR, f"imagenet_encodings_{len(imagenet_prompts)}.pt")
1090
+ imagenet_t5, imagenet_clip = load_or_encode(imagenet_enc_cache, imagenet_prompts, "imagenet")
1091
+
1092
+ object_relations_t5, object_relations_clip = None, None
1093
+ if object_relations_prompts:
1094
+ object_relations_enc_cache = os.path.join(ENCODING_CACHE_DIR, f"object_relations_encodings_{len(object_relations_prompts)}.pt")
1095
+ object_relations_t5, object_relations_clip = load_or_encode(object_relations_enc_cache, object_relations_prompts, "object_relations")
1096
+
1097
+
1098
+
1099
+ # ============================================================================
1100
+ # EXTRACT/LOAD LUNE AND SOL FEATURES (precached)
1101
+ # ============================================================================
1102
+ print("\n" + "=" * 60)
1103
+ print("Expert Feature Caching (Lune + Sol)")
1104
+ print("=" * 60)
1105
+
1106
+ # Lune caches
1107
+ schnell_lune_cache = None
1108
+ portrait_lune_cache = None
1109
+ sportfashion_lune_cache = None
1110
+ synthmocap_lune_cache = None
1111
+ imagenet_lune_cache = None
1112
+ object_relations_lune_cache = None
1113
+
1114
+ # Sol caches
1115
+ schnell_sol_cache = None
1116
+ portrait_sol_cache = None
1117
+ sportfashion_sol_cache = None
1118
+ synthmocap_sol_cache = None
1119
+ imagenet_sol_cache = None
1120
+ object_relations_sol_cache = None
1121
+
1122
+ if schnell_prompts:
1123
+ if ENABLE_LUNE_DISTILLATION:
1124
+ schnell_lune_path = os.path.join(ENCODING_CACHE_DIR, f"schnell_lune_{len(schnell_prompts)}.pt")
1125
+ schnell_lune_cache = load_or_extract_lune_features(
1126
+ schnell_lune_path, schnell_prompts, "schnell",
1127
+ clip_tok, clip_enc, EXPERT_T_BUCKETS
1128
+ )
1129
+ if ENABLE_SOL_DISTILLATION:
1130
+ schnell_sol_path = os.path.join(ENCODING_CACHE_DIR, f"schnell_sol_{len(schnell_prompts)}.pt")
1131
+ schnell_sol_cache = load_or_extract_sol_features(
1132
+ schnell_sol_path, schnell_prompts, "schnell",
1133
+ clip_tok, clip_enc, EXPERT_T_BUCKETS, SOL_SPATIAL_SIZE
1134
+ )
1135
+
1136
+ if portrait_prompts:
1137
+ if ENABLE_LUNE_DISTILLATION:
1138
+ portrait_lune_path = os.path.join(ENCODING_CACHE_DIR, f"portrait_lune_{len(portrait_prompts)}.pt")
1139
+ portrait_lune_cache = load_or_extract_lune_features(
1140
+ portrait_lune_path, portrait_prompts, "portrait",
1141
+ clip_tok, clip_enc, EXPERT_T_BUCKETS
1142
+ )
1143
+ if ENABLE_SOL_DISTILLATION:
1144
+ portrait_sol_path = os.path.join(ENCODING_CACHE_DIR, f"portrait_sol_{len(portrait_prompts)}.pt")
1145
+ portrait_sol_cache = load_or_extract_sol_features(
1146
+ portrait_sol_path, portrait_prompts, "portrait",
1147
+ clip_tok, clip_enc, EXPERT_T_BUCKETS, SOL_SPATIAL_SIZE
1148
+ )
1149
+
1150
+ if sportfashion_prompts:
1151
+ if ENABLE_LUNE_DISTILLATION:
1152
+ sportfashion_lune_path = os.path.join(ENCODING_CACHE_DIR, f"sportfashion_lune_{len(sportfashion_prompts)}.pt")
1153
+ sportfashion_lune_cache = load_or_extract_lune_features(
1154
+ sportfashion_lune_path, sportfashion_prompts, "sportfashion",
1155
+ clip_tok, clip_enc, EXPERT_T_BUCKETS
1156
+ )
1157
+ if ENABLE_SOL_DISTILLATION:
1158
+ sportfashion_sol_path = os.path.join(ENCODING_CACHE_DIR, f"sportfashion_sol_{len(sportfashion_prompts)}.pt")
1159
+ sportfashion_sol_cache = load_or_extract_sol_features(
1160
+ sportfashion_sol_path, sportfashion_prompts, "sportfashion",
1161
+ clip_tok, clip_enc, EXPERT_T_BUCKETS, SOL_SPATIAL_SIZE
1162
+ )
1163
+
1164
+ if synthmocap_prompts:
1165
+ if ENABLE_LUNE_DISTILLATION:
1166
+ synthmocap_lune_path = os.path.join(ENCODING_CACHE_DIR, f"synthmocap_lune_{len(synthmocap_prompts)}.pt")
1167
+ synthmocap_lune_cache = load_or_extract_lune_features(
1168
+ synthmocap_lune_path, synthmocap_prompts, "synthmocap",
1169
+ clip_tok, clip_enc, EXPERT_T_BUCKETS
1170
+ )
1171
+ if ENABLE_SOL_DISTILLATION:
1172
+ synthmocap_sol_path = os.path.join(ENCODING_CACHE_DIR, f"synthmocap_sol_{len(synthmocap_prompts)}.pt")
1173
+ synthmocap_sol_cache = load_or_extract_sol_features(
1174
+ synthmocap_sol_path, synthmocap_prompts, "synthmocap",
1175
+ clip_tok, clip_enc, EXPERT_T_BUCKETS, SOL_SPATIAL_SIZE
1176
+ )
1177
+
1178
+ if imagenet_prompts:
1179
+ if ENABLE_LUNE_DISTILLATION:
1180
+ imagenet_lune_path = os.path.join(ENCODING_CACHE_DIR, f"imagenet_lune_{len(imagenet_prompts)}.pt")
1181
+ imagenet_lune_cache = load_or_extract_lune_features(
1182
+ imagenet_lune_path, imagenet_prompts, "imagenet",
1183
+ clip_tok, clip_enc, EXPERT_T_BUCKETS
1184
+ )
1185
+ if ENABLE_SOL_DISTILLATION:
1186
+ imagenet_sol_path = os.path.join(ENCODING_CACHE_DIR, f"imagenet_sol_{len(imagenet_prompts)}.pt")
1187
+ imagenet_sol_cache = load_or_extract_sol_features(
1188
+ imagenet_sol_path, imagenet_prompts, "imagenet",
1189
+ clip_tok, clip_enc, EXPERT_T_BUCKETS, SOL_SPATIAL_SIZE
1190
+ )
1191
+
1192
+ if object_relations_prompts:
1193
+ if ENABLE_LUNE_DISTILLATION:
1194
+ object_relations_lune_path = os.path.join(ENCODING_CACHE_DIR, f"object_relations_lune_{len(object_relations_prompts)}.pt")
1195
+ object_relations_lune_cache = load_or_extract_lune_features(
1196
+ object_relations_lune_path, object_relations_prompts, "object_relations",
1197
+ clip_tok, clip_enc, EXPERT_T_BUCKETS
1198
+ )
1199
+ if ENABLE_SOL_DISTILLATION:
1200
+ object_relations_sol_path = os.path.join(ENCODING_CACHE_DIR, f"object_relations_sol_{len(object_relations_prompts)}.pt")
1201
+ object_relations_sol_cache = load_or_extract_sol_features(
1202
+ object_relations_sol_path, object_relations_prompts, "object_relations",
1203
+ clip_tok, clip_enc, EXPERT_T_BUCKETS, SOL_SPATIAL_SIZE
1204
+ )
1205
+
1206
+
1207
+ # ============================================================================
1208
+ # COMBINED DATASET CLASS
1209
+ # ============================================================================
1210
+ class CombinedDataset(Dataset):
1211
+ """Combined dataset returning sample index for expert feature lookup."""
1212
+
1213
+ def __init__(
1214
+ self,
1215
+ portrait_ds, portrait_indices, portrait_t5, portrait_clip,
1216
+ schnell_ds, schnell_t5, schnell_clip,
1217
+ sportfashion_ds, sportfashion_latents, sportfashion_masks, sportfashion_t5, sportfashion_clip,
1218
+ synthmocap_ds, synthmocap_latents, synthmocap_masks, synthmocap_t5, synthmocap_clip,
1219
+ imagenet_ds, imagenet_latents, imagenet_t5, imagenet_clip,
1220
+ object_relations_ds, object_relations_latents, object_relations_t5, object_relations_clip,
1221
+ vae, vae_scale, device, dtype,
1222
+ compute_masks=True,
1223
+ ):
1224
+ self.portrait_ds = portrait_ds
1225
+ self.portrait_indices = portrait_indices
1226
+ self.portrait_t5 = portrait_t5
1227
+ self.portrait_clip = portrait_clip
1228
+
1229
+ self.schnell_ds = schnell_ds
1230
+ self.schnell_t5 = schnell_t5
1231
+ self.schnell_clip = schnell_clip
1232
+
1233
+ self.sportfashion_ds = sportfashion_ds
1234
+ self.sportfashion_latents = sportfashion_latents
1235
+ self.sportfashion_masks = sportfashion_masks
1236
+ self.sportfashion_t5 = sportfashion_t5
1237
+ self.sportfashion_clip = sportfashion_clip
1238
+
1239
+ self.synthmocap_ds = synthmocap_ds
1240
+ self.synthmocap_latents = synthmocap_latents
1241
+ self.synthmocap_masks = synthmocap_masks
1242
+ self.synthmocap_t5 = synthmocap_t5
1243
+ self.synthmocap_clip = synthmocap_clip
1244
+
1245
+ self.imagenet_ds = imagenet_ds
1246
+ self.imagenet_latents = imagenet_latents
1247
+ self.imagenet_t5 = imagenet_t5
1248
+ self.imagenet_clip = imagenet_clip
1249
+
1250
+ self.object_relations_ds = object_relations_ds
1251
+ self.object_relations_latents = object_relations_latents
1252
+ self.object_relations_t5 = object_relations_t5
1253
+ self.object_relations_clip = object_relations_clip
1254
+
1255
+ self.vae = vae
1256
+ self.vae_scale = vae_scale
1257
+ self.device = device
1258
+ self.dtype = dtype
1259
+ self.compute_masks = compute_masks
1260
+
1261
+ self.n_portrait = len(portrait_indices) if portrait_indices else 0
1262
+ self.n_schnell = len(schnell_ds) if schnell_ds else 0
1263
+ self.n_sportfashion = len(sportfashion_latents) if sportfashion_latents is not None else 0
1264
+ self.n_synthmocap = len(synthmocap_latents) if synthmocap_latents is not None else 0
1265
+ self.n_imagenet = len(imagenet_latents) if imagenet_latents is not None else 0
1266
+ self.n_object_relations = len(object_relations_latents) if object_relations_latents is not None else 0
1267
+
1268
+ self.c1 = self.n_portrait
1269
+ self.c2 = self.c1 + self.n_schnell
1270
+ self.c3 = self.c2 + self.n_sportfashion
1271
+ self.c4 = self.c3 + self.n_synthmocap
1272
+ self.c5 = self.c4 + self.n_imagenet
1273
+ self.total = self.c5 + self.n_object_relations
1274
+
1275
+ def __len__(self):
1276
+ return self.total
1277
+
1278
+ def _get_latent_from_array(self, latent_data):
1279
+ if isinstance(latent_data, torch.Tensor):
1280
+ return latent_data.to(self.dtype)
1281
+ return torch.tensor(np.array(latent_data), dtype=self.dtype)
1282
+
1283
+ def __getitem__(self, idx):
1284
+ mask = None
1285
+
1286
+ if idx < self.c1:
1287
+ local_idx = idx
1288
+ orig_idx = self.portrait_indices[idx]
1289
+ item = self.portrait_ds[orig_idx]
1290
+ latent = self._get_latent_from_array(item["latent"])
1291
+ t5 = self.portrait_t5[idx]
1292
+ clip = self.portrait_clip[idx]
1293
+ dataset_id = 0
1294
+
1295
+ elif idx < self.c2:
1296
+ local_idx = idx - self.c1
1297
+ item = self.schnell_ds[local_idx]
1298
+ latent = self._get_latent_from_array(item["latent"])
1299
+ t5 = self.schnell_t5[local_idx]
1300
+ clip = self.schnell_clip[local_idx]
1301
+ dataset_id = 1
1302
+
1303
+ elif idx < self.c3:
1304
+ local_idx = idx - self.c2
1305
+ latent = self.sportfashion_latents[local_idx].to(self.dtype)
1306
+ t5 = self.sportfashion_t5[local_idx]
1307
+ clip = self.sportfashion_clip[local_idx]
1308
+ dataset_id = 2
1309
+ if self.compute_masks and self.sportfashion_masks is not None:
1310
+ mask = self.sportfashion_masks[local_idx].to(self.dtype)
1311
+
1312
+ elif idx < self.c4:
1313
+ local_idx = idx - self.c3
1314
+ latent = self.synthmocap_latents[local_idx].to(self.dtype)
1315
+ t5 = self.synthmocap_t5[local_idx]
1316
+ clip = self.synthmocap_clip[local_idx]
1317
+ dataset_id = 3
1318
+ if self.compute_masks and self.synthmocap_masks is not None:
1319
+ mask = self.synthmocap_masks[local_idx].to(self.dtype)
1320
+
1321
+ elif idx < self.c5:
1322
+ local_idx = idx - self.c4
1323
+ latent = self.imagenet_latents[local_idx].to(self.dtype)
1324
+ t5 = self.imagenet_t5[local_idx]
1325
+ clip = self.imagenet_clip[local_idx]
1326
+ dataset_id = 4
1327
+
1328
+ else:
1329
+ local_idx = idx - self.c5
1330
+ latent = self.object_relations_latents[local_idx].to(self.dtype)
1331
+ t5 = self.object_relations_t5[local_idx]
1332
+ clip = self.object_relations_clip[local_idx]
1333
+ dataset_id = 5
1334
+
1335
+ result = {
1336
+ "latent": latent,
1337
+ "t5_embed": t5.to(self.dtype),
1338
+ "clip_pooled": clip.to(self.dtype),
1339
+ "sample_idx": idx,
1340
+ "local_idx": local_idx,
1341
+ "dataset_id": dataset_id,
1342
+ }
1343
+
1344
+ if mask is not None:
1345
+ result["mask"] = mask.to(self.dtype)
1346
+
1347
+ return result
1348
+
1349
+
1350
+ # ============================================================================
1351
+ # COLLATE FUNCTION
1352
+ # ============================================================================
1353
+ def collate_fn(batch):
1354
+ latents = torch.stack([b["latent"] for b in batch])
1355
+ t5_embeds = torch.stack([b["t5_embed"] for b in batch])
1356
+ clip_pooled = torch.stack([b["clip_pooled"] for b in batch])
1357
+ sample_indices = torch.tensor([b["sample_idx"] for b in batch], dtype=torch.long)
1358
+ local_indices = torch.tensor([b["local_idx"] for b in batch], dtype=torch.long)
1359
+ dataset_ids = torch.tensor([b["dataset_id"] for b in batch], dtype=torch.long)
1360
+
1361
+ masks = None
1362
+ if any("mask" in b for b in batch):
1363
+ masks = []
1364
+ for b in batch:
1365
+ if "mask" in b:
1366
+ masks.append(b["mask"])
1367
+ else:
1368
+ masks.append(torch.ones(64, 64, dtype=latents.dtype))
1369
+ masks = torch.stack(masks)
1370
+
1371
+ return {
1372
+ "latents": latents,
1373
+ "t5_embeds": t5_embeds,
1374
+ "clip_pooled": clip_pooled,
1375
+ "sample_indices": sample_indices,
1376
+ "local_indices": local_indices,
1377
+ "dataset_ids": dataset_ids,
1378
+ "masks": masks,
1379
+ }
1380
+
1381
+
1382
+
1383
+ # ============================================================================
1384
+ # EXPERT FEATURE LOOKUP (handles multiple datasets, dual experts)
1385
+ # ============================================================================
1386
+ def get_lune_features_for_batch(
1387
+ local_indices: torch.Tensor,
1388
+ dataset_ids: torch.Tensor,
1389
+ timesteps: torch.Tensor,
1390
+ ) -> Optional[torch.Tensor]:
1391
+ """Get Lune features from the appropriate cache for each sample."""
1392
+ caches = [
1393
+ portrait_lune_cache, schnell_lune_cache, sportfashion_lune_cache,
1394
+ synthmocap_lune_cache, imagenet_lune_cache, object_relations_lune_cache
1395
+ ]
1396
+
1397
+ if not any(c is not None for c in caches):
1398
+ return None
1399
+
1400
+ B = local_indices.shape[0]
1401
+ device = timesteps.device
1402
+ features = torch.zeros(B, LUNE_DIM, device=device, dtype=DTYPE)
1403
+
1404
+ for ds_id, cache in enumerate(caches):
1405
+ if cache is None:
1406
+ continue
1407
+ mask = dataset_ids == ds_id
1408
+ if not mask.any():
1409
+ continue
1410
+ ds_local_indices = local_indices[mask]
1411
+ ds_timesteps = timesteps[mask]
1412
+ ds_features = cache.get_features(ds_local_indices, ds_timesteps)
1413
+ features[mask] = ds_features
1414
+
1415
+ return features
1416
+
1417
+
1418
+ def get_sol_features_for_batch(
1419
+ local_indices: torch.Tensor,
1420
+ dataset_ids: torch.Tensor,
1421
+ timesteps: torch.Tensor,
1422
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
1423
+ """Get Sol features (stats + spatial) from the appropriate cache."""
1424
+ caches = [
1425
+ portrait_sol_cache, schnell_sol_cache, sportfashion_sol_cache,
1426
+ synthmocap_sol_cache, imagenet_sol_cache, object_relations_sol_cache
1427
+ ]
1428
+
1429
+ if not any(c is not None for c in caches):
1430
+ return None, None
1431
+
1432
+ B = local_indices.shape[0]
1433
+ device = timesteps.device
1434
+ stats = torch.zeros(B, 4, device=device, dtype=DTYPE)
1435
+ spatial = torch.zeros(B, SOL_SPATIAL_SIZE, SOL_SPATIAL_SIZE, device=device, dtype=DTYPE)
1436
+
1437
+ for ds_id, cache in enumerate(caches):
1438
+ if cache is None:
1439
+ continue
1440
+ mask = dataset_ids == ds_id
1441
+ if not mask.any():
1442
+ continue
1443
+ ds_local_indices = local_indices[mask]
1444
+ ds_timesteps = timesteps[mask]
1445
+ ds_stats, ds_spatial = cache.get_features(ds_local_indices, ds_timesteps)
1446
+ stats[mask] = ds_stats
1447
+ spatial[mask] = ds_spatial
1448
+
1449
+ return stats, spatial
1450
+
1451
+
1452
+ # ============================================================================
1453
+ # LOSS FUNCTIONS
1454
+ # ============================================================================
1455
+ def huber_loss(pred, target, delta=0.1):
1456
+ """Huber loss - L2 for small errors, L1 for large."""
1457
+ diff = pred - target
1458
+ abs_diff = diff.abs()
1459
+ quadratic = torch.clamp(abs_diff, max=delta)
1460
+ linear = abs_diff - quadratic
1461
+ return 0.5 * quadratic ** 2 + delta * linear
1462
+
1463
+
1464
+ def compute_main_loss(pred, target, mask=None, spatial_weights=None,
1465
+ fg_weight=2.0, bg_weight=0.5, snr_weights=None):
1466
+ """Compute main prediction loss with optional spatial weighting."""
1467
+ B, N, C = pred.shape
1468
+
1469
+ if USE_HUBER_LOSS:
1470
+ loss_per_elem = huber_loss(pred, target, HUBER_DELTA)
1471
+ else:
1472
+ loss_per_elem = (pred - target) ** 2
1473
+
1474
+ # Apply spatial weights from Sol if enabled
1475
+ if spatial_weights is not None and USE_SPATIAL_WEIGHTING:
1476
+ H = W = int(math.sqrt(N))
1477
+ # Upsample spatial weights from 8x8 to HxW
1478
+ spatial_upsampled = F.interpolate(
1479
+ spatial_weights.unsqueeze(1), # [B, 1, 8, 8]
1480
+ size=(H, W),
1481
+ mode='bilinear',
1482
+ align_corners=False
1483
+ ).squeeze(1) # [B, H, W]
1484
+ # Normalize so mean = 1
1485
+ spatial_upsampled = spatial_upsampled / (spatial_upsampled.mean(dim=[1, 2], keepdim=True) + 1e-6)
1486
+ spatial_flat = spatial_upsampled.view(B, N, 1)
1487
+ loss_per_elem = loss_per_elem * spatial_flat
1488
+
1489
+ # Apply foreground/background mask
1490
+ if mask is not None:
1491
+ H = W = int(math.sqrt(N))
1492
+ mask_flat = mask.view(B, H * W, 1).to(pred.device)
1493
+ weights = mask_flat * fg_weight + (1 - mask_flat) * bg_weight
1494
+ loss_per_elem = loss_per_elem * weights
1495
+
1496
+ loss_per_sample = loss_per_elem.mean(dim=[1, 2])
1497
+
1498
+ if snr_weights is not None:
1499
+ loss_per_sample = loss_per_sample * snr_weights
1500
+
1501
+ return loss_per_sample.mean()
1502
+
1503
+
1504
+ def compute_lune_loss(pred, target, mode="cosine"):
1505
+ """Compute Lune distillation loss."""
1506
+ if mode == "cosine":
1507
+ # Cosine similarity loss (1 - cos_sim)
1508
+ pred_norm = F.normalize(pred, dim=-1)
1509
+ target_norm = F.normalize(target, dim=-1)
1510
+ return (1 - (pred_norm * target_norm).sum(dim=-1)).mean()
1511
+ elif mode == "huber":
1512
+ return huber_loss(pred, target, HUBER_DELTA).mean()
1513
+ elif mode == "soft":
1514
+ # Soft L2 with temperature
1515
+ return F.mse_loss(pred / 10.0, target / 10.0)
1516
+ else: # hard
1517
+ return F.mse_loss(pred, target)
1518
+
1519
+
1520
+ def compute_sol_loss(pred_stats, pred_spatial, target_stats, target_spatial):
1521
+ """Compute Sol distillation loss (stats + spatial)."""
1522
+ stats_loss = F.mse_loss(pred_stats, target_stats)
1523
+ spatial_loss = F.mse_loss(pred_spatial, target_spatial)
1524
+ return stats_loss + spatial_loss
1525
+
1526
+
1527
+ # ============================================================================
1528
+ # WEIGHT SCHEDULES
1529
+ # ============================================================================
1530
+ def get_lune_weight(step):
1531
+ if step < LUNE_WARMUP_STEPS:
1532
+ return LUNE_LOSS_WEIGHT * (step / LUNE_WARMUP_STEPS)
1533
+ return LUNE_LOSS_WEIGHT
1534
+
1535
+
1536
+ def get_sol_weight(step):
1537
+ if step < SOL_WARMUP_STEPS:
1538
+ return SOL_LOSS_WEIGHT * (step / SOL_WARMUP_STEPS)
1539
+ return SOL_LOSS_WEIGHT
1540
+
1541
+
1542
+ # ============================================================================
1543
+ # CREATE DATASET
1544
+ # ============================================================================
1545
+ print("\nCreating combined dataset...")
1546
+ combined_ds = CombinedDataset(
1547
+ portrait_ds, portrait_indices, portrait_t5, portrait_clip,
1548
+ schnell_ds, schnell_t5, schnell_clip,
1549
+ sportfashion_ds, sportfashion_latents, sportfashion_masks, sportfashion_t5, sportfashion_clip,
1550
+ synthmocap_ds, synthmocap_latents, synthmocap_masks, synthmocap_t5, synthmocap_clip,
1551
+ imagenet_ds, imagenet_latents, imagenet_t5, imagenet_clip,
1552
+ object_relations_ds, object_relations_latents, object_relations_t5, object_relations_clip,
1553
+ vae, VAE_SCALE, DEVICE, DTYPE,
1554
+ compute_masks=USE_MASKED_LOSS,
1555
+ )
1556
+ print(f"✓ Combined dataset: {len(combined_ds)} samples")
1557
+ print(f" - Portraits (3x): {combined_ds.n_portrait:,}")
1558
+ print(f" - Schnell teacher: {combined_ds.n_schnell:,}")
1559
+ print(f" - SportFashion: {combined_ds.n_sportfashion:,}")
1560
+ print(f" - SynthMoCap: {combined_ds.n_synthmocap:,}")
1561
+ print(f" - ImageNet: {combined_ds.n_imagenet:,}")
1562
+ print(f" - Object Relations: {combined_ds.n_object_relations:,}")
1563
+ print(f" - Lune distillation: {ENABLE_LUNE_DISTILLATION}")
1564
+ print(f" - Sol distillation: {ENABLE_SOL_DISTILLATION}")
1565
+
1566
+ # ============================================================================
1567
+ # DATALOADER
1568
+ # ============================================================================
1569
+ loader = DataLoader(
1570
+ combined_ds,
1571
+ batch_size=BATCH_SIZE,
1572
+ shuffle=True,
1573
+ num_workers=8,
1574
+ pin_memory=True,
1575
+ collate_fn=collate_fn,
1576
+ drop_last=True,
1577
+ )
1578
+ print(f"✓ DataLoader: {len(loader)} batches/epoch")
1579
+
1580
+
1581
+
1582
+ # ============================================================================
1583
+ # SAMPLING FUNCTION
1584
+ # ============================================================================
1585
+ @torch.inference_mode()
1586
+ def generate_samples(model, prompts, num_steps=28, guidance_scale=5.0, H=64, W=64,
1587
+ use_ema=True, seed=None,
1588
+ negative_prompt="blurry, distorted, low quality"):
1589
+ """Generate samples during training with proper CFG support."""
1590
+ was_training = model.training
1591
+ model.eval()
1592
+
1593
+ if seed is not None:
1594
+ torch.manual_seed(seed)
1595
+
1596
+ model_ref = model._orig_mod if hasattr(model, '_orig_mod') else model
1597
+
1598
+ if use_ema and 'ema' in globals() and ema is not None:
1599
+ ema.apply_shadow_for_eval(model)
1600
+
1601
+ B = len(prompts)
1602
+ C = 16
1603
+
1604
+ t5_list, clip_list = [], []
1605
+ for p in prompts:
1606
+ t5, clip = encode_prompt(p)
1607
+ t5_list.append(t5)
1608
+ clip_list.append(clip)
1609
+ t5_cond = torch.stack(t5_list).to(DTYPE)
1610
+ clip_cond = torch.stack(clip_list).to(DTYPE)
1611
+
1612
+ if guidance_scale > 1.0:
1613
+ t5_uncond, clip_uncond = encode_prompt(negative_prompt)
1614
+ t5_uncond = t5_uncond.expand(B, -1, -1)
1615
+ clip_uncond = clip_uncond.expand(B, -1)
1616
+ else:
1617
+ t5_uncond, clip_uncond = None, None
1618
+
1619
+ x = torch.randn(B, H * W, C, device=DEVICE, dtype=DTYPE)
1620
+ img_ids = model_ref.create_img_ids(B, H, W, DEVICE)
1621
+
1622
+ t_linear = torch.linspace(0, 1, num_steps + 1, device=DEVICE, dtype=DTYPE)
1623
+ timesteps = flux_shift(t_linear, s=SHIFT)
1624
+
1625
+ for i in range(num_steps):
1626
+ t_curr = timesteps[i]
1627
+ t_next = timesteps[i + 1]
1628
+ dt = t_next - t_curr
1629
+
1630
+ t_batch = t_curr.expand(B).to(DTYPE)
1631
+
1632
+ with torch.autocast("cuda", dtype=DTYPE):
1633
+ v_cond = model_ref(
1634
+ hidden_states=x,
1635
+ encoder_hidden_states=t5_cond,
1636
+ pooled_projections=clip_cond,
1637
+ timestep=t_batch,
1638
+ img_ids=img_ids,
1639
+ )
1640
+ if isinstance(v_cond, tuple):
1641
+ v_cond = v_cond[0]
1642
+
1643
+ if guidance_scale > 1.0 and t5_uncond is not None:
1644
+ v_uncond = model_ref(
1645
+ hidden_states=x,
1646
+ encoder_hidden_states=t5_uncond,
1647
+ pooled_projections=clip_uncond,
1648
+ timestep=t_batch,
1649
+ img_ids=img_ids,
1650
+ )
1651
+ if isinstance(v_uncond, tuple):
1652
+ v_uncond = v_uncond[0]
1653
+ v = v_uncond + guidance_scale * (v_cond - v_uncond)
1654
+ else:
1655
+ v = v_cond
1656
+
1657
+ x = x + v * dt
1658
+
1659
+ latents = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
1660
+ latents = latents / VAE_SCALE
1661
+
1662
+ with torch.autocast("cuda", dtype=DTYPE):
1663
+ images = vae.decode(latents.to(vae.dtype)).sample
1664
+ images = (images / 2 + 0.5).clamp(0, 1)
1665
+
1666
+ if use_ema and 'ema' in globals() and ema is not None:
1667
+ ema.restore(model)
1668
+
1669
+ if was_training:
1670
+ model.train()
1671
+ return images
1672
+
1673
+
1674
+ def save_samples(images, prompts, step, output_dir):
1675
+ from torchvision.utils import save_image
1676
+ os.makedirs(output_dir, exist_ok=True)
1677
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
1678
+ grid_path = os.path.join(output_dir, f"samples_step_{step}.png")
1679
+ save_image(images, grid_path, nrow=2, padding=2)
1680
+ try:
1681
+ api.upload_file(
1682
+ path_or_fileobj=grid_path,
1683
+ path_in_repo=f"samples/{timestamp}_step_{step}.png",
1684
+ repo_id=HF_REPO,
1685
+ )
1686
+ except:
1687
+ pass
1688
+
1689
+
1690
+ # ============================================================================
1691
+ # CHECKPOINT FUNCTIONS
1692
+ # ============================================================================
1693
+ def save_checkpoint(model, optimizer, scheduler, step, epoch, loss, path, ema=None):
1694
+ os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
1695
+ if hasattr(model, '_orig_mod'):
1696
+ state_dict = model._orig_mod.state_dict()
1697
+ else:
1698
+ state_dict = model.state_dict()
1699
+ state_dict = {k: v.to(DTYPE) if v.is_floating_point() else v for k, v in state_dict.items()}
1700
+ weights_path = path.replace(".pt", ".safetensors")
1701
+ save_file(state_dict, weights_path)
1702
+ if ema is not None:
1703
+ ema_weights = {k: v.to(DTYPE) if v.is_floating_point() else v for k, v in ema.shadow.items()}
1704
+ ema_weights_path = path.replace(".pt", "_ema.safetensors")
1705
+ save_file(ema_weights, ema_weights_path)
1706
+ state = {
1707
+ "step": step,
1708
+ "epoch": epoch,
1709
+ "loss": loss,
1710
+ "optimizer": optimizer.state_dict(),
1711
+ "scheduler": scheduler.state_dict(),
1712
+ }
1713
+ if ema is not None:
1714
+ state["ema_decay"] = ema.decay
1715
+ torch.save(state, path)
1716
+ print(f" ✓ Saved checkpoint: step {step}")
1717
+ return weights_path
1718
+
1719
+
1720
+ def upload_checkpoint(weights_path, step):
1721
+ try:
1722
+ api.upload_file(
1723
+ path_or_fileobj=weights_path,
1724
+ path_in_repo=f"checkpoints/step_{step}.safetensors",
1725
+ repo_id=HF_REPO,
1726
+ )
1727
+ ema_path = weights_path.replace(".safetensors", "_ema.safetensors")
1728
+ if os.path.exists(ema_path):
1729
+ api.upload_file(
1730
+ path_or_fileobj=ema_path,
1731
+ path_in_repo=f"checkpoints/step_{step}_ema.safetensors",
1732
+ repo_id=HF_REPO,
1733
+ )
1734
+ print(f" ✓ Uploaded checkpoint to {HF_REPO}")
1735
+ except Exception as e:
1736
+ print(f" ⚠ Upload failed: {e}")
1737
+
1738
+
1739
+ def upload_logs():
1740
+ try:
1741
+ for root, dirs, files in os.walk(LOG_DIR):
1742
+ for f in files:
1743
+ if f.startswith("events.out.tfevents"):
1744
+ local_path = os.path.join(root, f)
1745
+ rel_path = os.path.relpath(local_path, LOG_DIR)
1746
+ repo_path = f"logs/{rel_path}"
1747
+ api.upload_file(
1748
+ path_or_fileobj=local_path,
1749
+ path_in_repo=repo_path,
1750
+ repo_id=HF_REPO,
1751
+ )
1752
+ print(f" ✓ Uploaded logs to {HF_REPO}")
1753
+ except Exception as e:
1754
+ print(f" ⚠ Log upload failed: {e}")
1755
+
1756
+
1757
+
1758
+ # ============================================================================
1759
+ # WEIGHT UPGRADE LOADING (v3 -> v4.1)
1760
+ # ============================================================================
1761
+ V3_TO_V4_REMAP = {
1762
+ # ExpertPredictor -> LunePredictor
1763
+ 'expert_predictor.t_embed.0.weight': 'lune_predictor.t_embed.0.weight',
1764
+ 'expert_predictor.t_embed.0.bias': 'lune_predictor.t_embed.0.bias',
1765
+ 'expert_predictor.t_embed.2.weight': 'lune_predictor.t_embed.2.weight',
1766
+ 'expert_predictor.t_embed.2.bias': 'lune_predictor.t_embed.2.bias',
1767
+ 'expert_predictor.clip_proj.weight': 'lune_predictor.clip_proj.weight',
1768
+ 'expert_predictor.clip_proj.bias': 'lune_predictor.clip_proj.bias',
1769
+ 'expert_predictor.out_proj.0.weight': 'lune_predictor.out_proj.0.weight',
1770
+ 'expert_predictor.out_proj.0.bias': 'lune_predictor.out_proj.0.bias',
1771
+ 'expert_predictor.out_proj.2.weight': 'lune_predictor.out_proj.2.weight',
1772
+ 'expert_predictor.out_proj.2.bias': 'lune_predictor.out_proj.2.bias',
1773
+ 'expert_predictor.gate': 'lune_predictor.gate',
1774
+ # expert_features -> lune_features
1775
+ 'expert_features': 'lune_features',
1776
+ }
1777
+
1778
+
1779
+ def load_with_weight_upgrade(model, state_dict):
1780
+ """Load state dict with v3 -> v4.1 remapping support."""
1781
+ model_state = model.state_dict()
1782
+
1783
+ # New modules in v4.1
1784
+ NEW_WEIGHT_PATTERNS = [
1785
+ 'lune_predictor.',
1786
+ 'sol_prior.',
1787
+ 't5_vec_proj.',
1788
+ '.norm_q.weight',
1789
+ '.norm_k.weight',
1790
+ '.norm_added_q.weight',
1791
+ '.norm_added_k.weight',
1792
+ ]
1793
+
1794
+ # Deprecated keys from v3
1795
+ DEPRECATED_PATTERNS = [
1796
+ 'guidance_in.',
1797
+ '.sin_basis',
1798
+ 'expert_predictor.', # Renamed to lune_predictor
1799
+ 'expert_features', # Renamed to lune_features
1800
+ ]
1801
+
1802
+ loaded_keys = []
1803
+ missing_keys = []
1804
+ unexpected_keys = []
1805
+ initialized_keys = []
1806
+ remapped_keys = []
1807
+
1808
+ # First pass: remap v3 keys to v4 keys
1809
+ remapped_state = {}
1810
+ for k, v in state_dict.items():
1811
+ if k in V3_TO_V4_REMAP:
1812
+ new_key = V3_TO_V4_REMAP[k]
1813
+ remapped_state[new_key] = v
1814
+ remapped_keys.append(f"{k} -> {new_key}")
1815
+ else:
1816
+ remapped_state[k] = v
1817
+
1818
+ # Second pass: load matching weights
1819
+ for key, v in remapped_state.items():
1820
+ if key in model_state:
1821
+ if v.shape == model_state[key].shape:
1822
+ model_state[key] = v
1823
+ loaded_keys.append(key)
1824
+ else:
1825
+ print(f" ⚠ Shape mismatch for {key}: checkpoint {v.shape} vs model {model_state[key].shape}")
1826
+ unexpected_keys.append(key)
1827
+ else:
1828
+ is_deprecated = any(pat in key for pat in DEPRECATED_PATTERNS)
1829
+ if is_deprecated:
1830
+ unexpected_keys.append(key)
1831
+ else:
1832
+ print(f" ⚠ Unexpected key (not in model): {key}")
1833
+ unexpected_keys.append(key)
1834
+
1835
+ # Third pass: handle missing keys
1836
+ for key in model_state.keys():
1837
+ if key not in loaded_keys:
1838
+ is_new = any(pat in key for pat in NEW_WEIGHT_PATTERNS)
1839
+ if is_new:
1840
+ initialized_keys.append(key)
1841
+ else:
1842
+ missing_keys.append(key)
1843
+ print(f" ⚠ Missing key (not in checkpoint): {key}")
1844
+
1845
+ model.load_state_dict(model_state, strict=False)
1846
+
1847
+ # Report
1848
+ if remapped_keys:
1849
+ print(f" ✓ Remapped v3->v4: {len(remapped_keys)} keys")
1850
+ for rk in remapped_keys[:5]:
1851
+ print(f" {rk}")
1852
+ if len(remapped_keys) > 5:
1853
+ print(f" ... and {len(remapped_keys) - 5} more")
1854
+
1855
+ if initialized_keys:
1856
+ modules = set()
1857
+ for k in initialized_keys:
1858
+ parts = k.split('.')
1859
+ if len(parts) >= 2:
1860
+ modules.add(parts[0])
1861
+ print(f" ✓ Initialized new modules (fresh): {sorted(modules)}")
1862
+
1863
+ if unexpected_keys:
1864
+ deprecated = [k for k in unexpected_keys if any(p in k for p in DEPRECATED_PATTERNS)]
1865
+ if deprecated:
1866
+ print(f" ✓ Ignored deprecated keys: {len(deprecated)}")
1867
+
1868
+ return missing_keys, unexpected_keys
1869
+
1870
+
1871
+ def load_checkpoint(model, optimizer, scheduler, target):
1872
+ """Load checkpoint with weight upgrade support for v4.1."""
1873
+ start_step = 0
1874
+ start_epoch = 0
1875
+ ema_state = None
1876
+
1877
+ if target == "none":
1878
+ print("Starting fresh (no checkpoint)")
1879
+ return start_step, start_epoch, None
1880
+
1881
+ ckpt_path = None
1882
+ weights_path = None
1883
+ ema_weights_path = None
1884
+
1885
+ if target == "latest":
1886
+ if os.path.exists(CHECKPOINT_DIR):
1887
+ ckpts = [f for f in os.listdir(CHECKPOINT_DIR) if f.startswith("step_") and f.endswith(".pt")]
1888
+ if ckpts:
1889
+ steps = [int(f.split("_")[1].split(".")[0]) for f in ckpts]
1890
+ latest_step = max(steps)
1891
+ ckpt_path = os.path.join(CHECKPOINT_DIR, f"step_{latest_step}.pt")
1892
+ weights_path = ckpt_path.replace(".pt", ".safetensors")
1893
+ ema_weights_path = ckpt_path.replace(".pt", "_ema.safetensors")
1894
+
1895
+ elif target == "hub" or target.startswith("hub:"):
1896
+ try:
1897
+ from huggingface_hub import list_repo_files
1898
+
1899
+ if target.startswith("hub:"):
1900
+ path_or_name = target.split(":", 1)[1]
1901
+
1902
+ # Check if it's a full path (contains /) or just a step name
1903
+ if "/" in path_or_name:
1904
+ # Full path like checkpoint_runs/v4_init/lailah_401434_v4_init
1905
+ weights_path = hf_hub_download(HF_REPO, f"{path_or_name}.safetensors")
1906
+ try:
1907
+ ema_weights_path = hf_hub_download(HF_REPO, f"{path_or_name}_ema.safetensors")
1908
+ print(f" Found EMA weights on hub")
1909
+ except:
1910
+ ema_weights_path = None
1911
+ print(f" No EMA weights on hub (will start fresh)")
1912
+ print(f"Downloaded {path_or_name} from hub")
1913
+ else:
1914
+ # Simple step name like step_401434
1915
+ step_name = path_or_name
1916
+ weights_path = hf_hub_download(HF_REPO, f"checkpoints/{step_name}.safetensors")
1917
+ try:
1918
+ ema_weights_path = hf_hub_download(HF_REPO, f"checkpoints/{step_name}_ema.safetensors")
1919
+ print(f" Found EMA weights on hub")
1920
+ except:
1921
+ ema_weights_path = None
1922
+ print(f" No EMA weights on hub (will start fresh)")
1923
+ start_step = int(step_name.split("_")[1]) if "_" in step_name else 0
1924
+ print(f"Downloaded {step_name} from hub")
1925
+ else:
1926
+ files = list_repo_files(HF_REPO)
1927
+ ckpts = [f for f in files if
1928
+ f.startswith("checkpoints/step_") and f.endswith(".safetensors") and "_ema" not in f]
1929
+ if ckpts:
1930
+ steps = [int(f.split("_")[1].split(".")[0]) for f in ckpts]
1931
+ latest = max(steps)
1932
+ weights_path = hf_hub_download(HF_REPO, f"checkpoints/step_{latest}.safetensors")
1933
+ try:
1934
+ ema_weights_path = hf_hub_download(HF_REPO, f"checkpoints/step_{latest}_ema.safetensors")
1935
+ print(f" Found EMA weights on hub")
1936
+ except:
1937
+ ema_weights_path = None
1938
+ print(f" No EMA weights on hub (will start fresh)")
1939
+ start_step = latest
1940
+ print(f"Downloaded step_{latest} from hub")
1941
+ except Exception as e:
1942
+ print(f"Could not download from hub: {e}")
1943
+ return start_step, start_epoch, None
1944
+
1945
+ elif target == "best":
1946
+ ckpt_path = os.path.join(CHECKPOINT_DIR, "best.pt")
1947
+ weights_path = ckpt_path.replace(".pt", ".safetensors")
1948
+ ema_weights_path = ckpt_path.replace(".pt", "_ema.safetensors")
1949
+
1950
+ elif os.path.exists(target):
1951
+ if target.endswith(".safetensors"):
1952
+ weights_path = target
1953
+ ckpt_path = target.replace(".safetensors", ".pt")
1954
+ ema_weights_path = target.replace(".safetensors", "_ema.safetensors")
1955
+ else:
1956
+ ckpt_path = target
1957
+ weights_path = target.replace(".pt", ".safetensors")
1958
+ ema_weights_path = target.replace(".pt", "_ema.safetensors")
1959
+
1960
+ # Load main model weights
1961
+ if weights_path and os.path.exists(weights_path):
1962
+ print(f"Loading weights from {weights_path}")
1963
+ state_dict = load_file(weights_path)
1964
+ state_dict = {k: v.to(DTYPE) if v.is_floating_point() else v for k, v in state_dict.items()}
1965
+
1966
+ model_ref = model._orig_mod if hasattr(model, '_orig_mod') else model
1967
+
1968
+ if ALLOW_WEIGHT_UPGRADE:
1969
+ missing, unexpected = load_with_weight_upgrade(model_ref, state_dict)
1970
+ if missing:
1971
+ print(f" ⚠ {len(missing)} truly missing parameters")
1972
+ else:
1973
+ model_ref.load_state_dict(state_dict, strict=True)
1974
+
1975
+ print(f"✓ Loaded model weights")
1976
+
1977
+ # Load EMA weights
1978
+ if ema_weights_path and os.path.exists(ema_weights_path):
1979
+ ema_state = load_file(ema_weights_path)
1980
+ ema_state = {k: v.to(DTYPE) if v.is_floating_point() else v for k, v in ema_state.items()}
1981
+ print(f"✓ Loaded EMA weights ({len(ema_state)} params)")
1982
+ else:
1983
+ print(f" ℹ No EMA weights found (will initialize fresh)")
1984
+ else:
1985
+ print(f" ⚠ Weights file not found: {weights_path}")
1986
+ print(f" Starting with fresh model")
1987
+ return start_step, start_epoch, None
1988
+
1989
+ # Load optimizer/scheduler state
1990
+ if ckpt_path and os.path.exists(ckpt_path):
1991
+ state = torch.load(ckpt_path, map_location="cpu")
1992
+ start_step = state.get("step", 0)
1993
+ start_epoch = state.get("epoch", 0)
1994
+ try:
1995
+ optimizer.load_state_dict(state["optimizer"])
1996
+ scheduler.load_state_dict(state["scheduler"])
1997
+ print(f"✓ Loaded optimizer/scheduler state")
1998
+ except Exception as e:
1999
+ print(f" ⚠ Could not load optimizer state: {e}")
2000
+ print(f" Will use fresh optimizer")
2001
+ print(f"Resuming from step {start_step}, epoch {start_epoch}")
2002
+
2003
+ return start_step, start_epoch, ema_state
2004
+
2005
+
2006
+
2007
+ # ============================================================================
2008
+ # CREATE MODEL (v4.1 with dual experts)
2009
+ # ============================================================================
2010
+ print("\nCreating TinyFlux v4.1 model with Lune + Sol...")
2011
+
2012
+ # Import model - expects model_v4.py to define TinyFluxConfig and TinyFlux
2013
+ # If running as a notebook cell, ensure model_v4.py cell was run first
2014
+ # If running as a script, uncomment the import below:
2015
+ # from model_v4 import TinyFluxConfig, TinyFlux
2016
+
2017
+ # Check that model classes exist
2018
+ if 'TinyFluxConfig' not in dir() or 'TinyFlux' not in dir():
2019
+ raise RuntimeError(
2020
+ "TinyFluxConfig and TinyFlux not found! "
2021
+ "Run model_v4.py cell first, or add: from model_v4 import TinyFluxConfig, TinyFlux"
2022
+ )
2023
+
2024
+ config = TinyFluxConfig(
2025
+ hidden_size=512,
2026
+ num_attention_heads=4,
2027
+ attention_head_dim=128,
2028
+ num_double_layers=15,
2029
+ num_single_layers=25,
2030
+
2031
+ # Lune expert (trajectory guidance)
2032
+ use_lune_expert=ENABLE_LUNE_DISTILLATION,
2033
+ lune_expert_dim=LUNE_DIM,
2034
+ lune_hidden_dim=LUNE_HIDDEN_DIM,
2035
+ lune_dropout=LUNE_DROPOUT,
2036
+
2037
+ # Sol prior (structural guidance)
2038
+ use_sol_prior=ENABLE_SOL_DISTILLATION,
2039
+ sol_spatial_size=SOL_SPATIAL_SIZE,
2040
+ sol_hidden_dim=SOL_HIDDEN_DIM,
2041
+ sol_geometric_weight=SOL_GEOMETRIC_WEIGHT,
2042
+
2043
+ # Other settings
2044
+ use_t5_vec=True,
2045
+ lune_distill_mode=LUNE_DISTILL_MODE,
2046
+ use_huber_loss=USE_HUBER_LOSS,
2047
+ huber_delta=HUBER_DELTA,
2048
+ guidance_embeds=False,
2049
+ )
2050
+ model = TinyFlux(config).to(device=DEVICE, dtype=DTYPE)
2051
+
2052
+ total_params = sum(p.numel() for p in model.parameters())
2053
+ print(f"Total parameters: {total_params:,}")
2054
+
2055
+ if hasattr(model, 'lune_predictor') and model.lune_predictor is not None:
2056
+ lune_params = sum(p.numel() for p in model.lune_predictor.parameters())
2057
+ print(f"Lune predictor parameters: {lune_params:,}")
2058
+
2059
+ if hasattr(model, 'sol_prior') and model.sol_prior is not None:
2060
+ sol_params = sum(p.numel() for p in model.sol_prior.parameters())
2061
+ print(f"Sol prior parameters: {sol_params:,}")
2062
+
2063
+ trainable_params = [p for p in model.parameters() if p.requires_grad]
2064
+ print(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
2065
+
2066
+ # ============================================================================
2067
+ # OPTIMIZER
2068
+ # ============================================================================
2069
+ opt = torch.optim.AdamW(trainable_params, lr=LR, betas=(0.9, 0.99), weight_decay=0.01, fused=True)
2070
+
2071
+ total_steps = len(loader) * EPOCHS // GRAD_ACCUM
2072
+ warmup = min(1000, total_steps // 10)
2073
+
2074
+
2075
+ def lr_fn(step):
2076
+ if step < warmup:
2077
+ return step / warmup
2078
+ return 0.5 * (1 + math.cos(math.pi * (step - warmup) / (total_steps - warmup)))
2079
+
2080
+
2081
+ sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_fn)
2082
+
2083
+ # ============================================================================
2084
+ # LOAD CHECKPOINT
2085
+ # ============================================================================
2086
+ start_step, start_epoch, ema_state = load_checkpoint(model, opt, sched, LOAD_TARGET)
2087
+
2088
+ if RESUME_STEP is not None:
2089
+ start_step = RESUME_STEP
2090
+
2091
+ # ============================================================================
2092
+ # COMPILE
2093
+ # ============================================================================
2094
+ model = torch.compile(model, mode="default")
2095
+
2096
+ # ============================================================================
2097
+ # EMA
2098
+ # ============================================================================
2099
+ print("Initializing EMA...")
2100
+ ema = EMA(model, decay=EMA_DECAY)
2101
+ if ema_state is not None:
2102
+ # Remap v3 EMA keys to v4
2103
+ remapped_ema = {}
2104
+ for k, v in ema_state.items():
2105
+ if k in V3_TO_V4_REMAP:
2106
+ remapped_ema[V3_TO_V4_REMAP[k]] = v
2107
+ else:
2108
+ remapped_ema[k] = v
2109
+ ema.load_shadow(remapped_ema, model=model)
2110
+
2111
+ # Sync new modules from model
2112
+ has_lune_in_ema = any('lune_predictor' in k for k in ema_state.keys())
2113
+ has_sol_in_ema = any('sol_prior' in k for k in ema_state.keys())
2114
+
2115
+ if ENABLE_LUNE_DISTILLATION and not has_lune_in_ema:
2116
+ # Check if expert_predictor was in the v3 checkpoint (remapped to lune_predictor)
2117
+ has_expert_in_v3 = any('expert_predictor' in k for k in ema_state.keys())
2118
+ if not has_expert_in_v3:
2119
+ ema.sync_from_model(model, pattern='lune_predictor')
2120
+ print(" ✓ Force-synced lune_predictor (new weights)")
2121
+ else:
2122
+ print(" ✓ lune_predictor loaded from remapped v3 checkpoint")
2123
+
2124
+ if ENABLE_SOL_DISTILLATION and not has_sol_in_ema:
2125
+ ema.sync_from_model(model, pattern='sol_prior')
2126
+ print(" ✓ Force-synced sol_prior (new weights)")
2127
+ else:
2128
+ print(" Starting fresh EMA from current weights")
2129
+
2130
+ # ============================================================================
2131
+ # TENSORBOARD
2132
+ # ============================================================================
2133
+ run_name = f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
2134
+ writer = SummaryWriter(os.path.join(LOG_DIR, run_name))
2135
+
2136
+ SAMPLE_PROMPTS = [
2137
+ "a photo of a cat sitting on a windowsill",
2138
+ "a portrait of a woman with red hair",
2139
+ "a black backpack on white background",
2140
+ "a person standing in a t-pose",
2141
+ ]
2142
+
2143
+
2144
+ # ============================================================================
2145
+ # TRAINING LOOP
2146
+ # ============================================================================
2147
+ print(f"\n{'=' * 60}")
2148
+ print(f"Training TinyFlux v4.1 with Dual Expert Distillation")
2149
+ print(f"{'=' * 60}")
2150
+ print(f"Total: {len(combined_ds):,} samples")
2151
+ print(f"Epochs: {EPOCHS}, Steps/epoch: {len(loader)}, Total: {total_steps}")
2152
+ print(f"Batch: {BATCH_SIZE} x {GRAD_ACCUM} = {BATCH_SIZE * GRAD_ACCUM}")
2153
+ print(f"Lune distillation: {ENABLE_LUNE_DISTILLATION}")
2154
+ if ENABLE_LUNE_DISTILLATION:
2155
+ print(f" - Mode: {LUNE_DISTILL_MODE}")
2156
+ print(f" - Weight: {LUNE_LOSS_WEIGHT} (warmup: {LUNE_WARMUP_STEPS} steps)")
2157
+ print(f"Sol distillation: {ENABLE_SOL_DISTILLATION}")
2158
+ if ENABLE_SOL_DISTILLATION:
2159
+ print(f" - Weight: {SOL_LOSS_WEIGHT} (warmup: {SOL_WARMUP_STEPS} steps)")
2160
+ print(f"Huber loss: {USE_HUBER_LOSS} (delta={HUBER_DELTA})")
2161
+ print(f"Spatial weighting: {USE_SPATIAL_WEIGHTING}")
2162
+ print(f"Resume: step {start_step}, epoch {start_epoch}")
2163
+
2164
+ model.train()
2165
+ step = start_step
2166
+ best = float("inf")
2167
+
2168
+ for ep in range(start_epoch, EPOCHS):
2169
+ ep_loss = 0
2170
+ ep_main_loss = 0
2171
+ ep_lune_loss = 0
2172
+ ep_sol_loss = 0
2173
+ ep_batches = 0
2174
+ pbar = tqdm(loader, desc=f"E{ep + 1}")
2175
+
2176
+ for i, batch in enumerate(pbar):
2177
+ latents = batch["latents"].to(DEVICE, non_blocking=True)
2178
+ t5 = batch["t5_embeds"].to(DEVICE, non_blocking=True)
2179
+ clip = batch["clip_pooled"].to(DEVICE, non_blocking=True)
2180
+ local_indices = batch["local_indices"]
2181
+ dataset_ids = batch["dataset_ids"]
2182
+ masks = batch["masks"]
2183
+
2184
+ if masks is not None:
2185
+ masks = masks.to(DEVICE, non_blocking=True)
2186
+
2187
+ B, C, H, W = latents.shape
2188
+ data = latents.permute(0, 2, 3, 1).reshape(B, H * W, C)
2189
+ noise = torch.randn_like(data)
2190
+
2191
+ if TEXT_DROPOUT > 0:
2192
+ t5, clip, _ = apply_text_dropout(t5, clip, TEXT_DROPOUT)
2193
+
2194
+ t = torch.sigmoid(torch.randn(B, device=DEVICE))
2195
+ t = flux_shift(t, s=SHIFT).to(DTYPE).clamp(1e-4, 1 - 1e-4)
2196
+
2197
+ t_expanded = t.view(B, 1, 1)
2198
+ x_t = (1 - t_expanded) * noise + t_expanded * data
2199
+ v_target = data - noise
2200
+
2201
+ img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE)
2202
+
2203
+ # Get expert features from CACHE
2204
+ lune_features = None
2205
+ sol_stats = None
2206
+ sol_spatial = None
2207
+
2208
+ if ENABLE_LUNE_DISTILLATION:
2209
+ lune_features = get_lune_features_for_batch(local_indices, dataset_ids, t)
2210
+ if lune_features is not None and random.random() < LUNE_DROPOUT:
2211
+ lune_features = None
2212
+
2213
+ if ENABLE_SOL_DISTILLATION:
2214
+ sol_stats, sol_spatial = get_sol_features_for_batch(local_indices, dataset_ids, t)
2215
+
2216
+ with torch.autocast("cuda", dtype=DTYPE):
2217
+ result = model(
2218
+ hidden_states=x_t,
2219
+ encoder_hidden_states=t5,
2220
+ pooled_projections=clip,
2221
+ timestep=t,
2222
+ img_ids=img_ids,
2223
+ lune_features=lune_features,
2224
+ sol_stats=sol_stats,
2225
+ sol_spatial=sol_spatial,
2226
+ return_expert_pred=True,
2227
+ )
2228
+
2229
+ if isinstance(result, tuple):
2230
+ v_pred, expert_info = result
2231
+ else:
2232
+ v_pred = result
2233
+ expert_info = {}
2234
+
2235
+ # Compute losses
2236
+ snr_weights = min_snr_weight(t)
2237
+
2238
+ # Main loss with optional spatial weighting from Sol
2239
+ spatial_weights = sol_spatial if USE_SPATIAL_WEIGHTING else None
2240
+ main_loss = compute_main_loss(
2241
+ v_pred, v_target,
2242
+ mask=masks if USE_MASKED_LOSS else None,
2243
+ spatial_weights=spatial_weights,
2244
+ fg_weight=FG_LOSS_WEIGHT,
2245
+ bg_weight=BG_LOSS_WEIGHT,
2246
+ snr_weights=snr_weights
2247
+ )
2248
+
2249
+ # Lune distillation loss
2250
+ lune_loss = torch.tensor(0.0, device=DEVICE)
2251
+ if lune_features is not None and expert_info.get('lune_pred') is not None:
2252
+ lune_loss = compute_lune_loss(
2253
+ expert_info['lune_pred'], lune_features, mode=LUNE_DISTILL_MODE
2254
+ )
2255
+
2256
+ # Sol distillation loss
2257
+ sol_loss = torch.tensor(0.0, device=DEVICE)
2258
+ if sol_stats is not None and expert_info.get('sol_stats_pred') is not None:
2259
+ sol_loss = compute_sol_loss(
2260
+ expert_info['sol_stats_pred'], expert_info.get('sol_spatial_pred'),
2261
+ sol_stats, sol_spatial
2262
+ )
2263
+
2264
+ # Total loss with warmup weights
2265
+ total_loss = main_loss
2266
+ total_loss = total_loss + get_lune_weight(step) * lune_loss
2267
+ total_loss = total_loss + get_sol_weight(step) * sol_loss
2268
+
2269
+ loss = total_loss / GRAD_ACCUM
2270
+ loss.backward()
2271
+
2272
+ if (i + 1) % GRAD_ACCUM == 0:
2273
+ grad_norm = torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
2274
+ opt.step()
2275
+ sched.step()
2276
+ opt.zero_grad(set_to_none=True)
2277
+
2278
+ ema.update(model)
2279
+ step += 1
2280
+
2281
+ if step % LOG_EVERY == 0:
2282
+ writer.add_scalar("train/loss", total_loss.item(), step)
2283
+ writer.add_scalar("train/main_loss", main_loss.item(), step)
2284
+ if ENABLE_LUNE_DISTILLATION:
2285
+ writer.add_scalar("train/lune_loss", lune_loss.item(), step)
2286
+ writer.add_scalar("train/lune_weight", get_lune_weight(step), step)
2287
+ if ENABLE_SOL_DISTILLATION:
2288
+ writer.add_scalar("train/sol_loss", sol_loss.item(), step)
2289
+ writer.add_scalar("train/sol_weight", get_sol_weight(step), step)
2290
+ writer.add_scalar("train/lr", sched.get_last_lr()[0], step)
2291
+ writer.add_scalar("train/grad_norm", grad_norm.item(), step)
2292
+
2293
+ if step % SAMPLE_EVERY == 0:
2294
+ print(f"\n Generating samples at step {step}...")
2295
+ images = generate_samples(
2296
+ model, SAMPLE_PROMPTS,
2297
+ num_steps=28,
2298
+ guidance_scale=5.0,
2299
+ use_ema=True,
2300
+ negative_prompt="blurry, distorted, low quality, deformed",
2301
+ )
2302
+ save_samples(images, SAMPLE_PROMPTS, step, SAMPLE_DIR)
2303
+
2304
+ if step % SAVE_EVERY == 0:
2305
+ ckpt_path = os.path.join(CHECKPOINT_DIR, f"step_{step}.pt")
2306
+ weights_path = save_checkpoint(model, opt, sched, step, ep, total_loss.item(), ckpt_path, ema=ema)
2307
+ if step % UPLOAD_EVERY == 0:
2308
+ upload_checkpoint(weights_path, step)
2309
+ if step % LOG_UPLOAD_EVERY == 0:
2310
+ writer.flush()
2311
+ upload_logs()
2312
+
2313
+ ep_loss += total_loss.item()
2314
+ ep_main_loss += main_loss.item()
2315
+ ep_lune_loss += lune_loss.item()
2316
+ ep_sol_loss += sol_loss.item()
2317
+ ep_batches += 1
2318
+
2319
+ pbar.set_postfix(
2320
+ loss=f"{total_loss.item():.4f}",
2321
+ main=f"{main_loss.item():.4f}",
2322
+ lune=f"{lune_loss.item():.4f}" if ENABLE_LUNE_DISTILLATION else "-",
2323
+ sol=f"{sol_loss.item():.4f}" if ENABLE_SOL_DISTILLATION else "-",
2324
+ step=step
2325
+ )
2326
+
2327
+ avg = ep_loss / max(ep_batches, 1)
2328
+ avg_main = ep_main_loss / max(ep_batches, 1)
2329
+ avg_lune = ep_lune_loss / max(ep_batches, 1)
2330
+ avg_sol = ep_sol_loss / max(ep_batches, 1)
2331
+
2332
+ print(f"Epoch {ep + 1} - total: {avg:.4f}, main: {avg_main:.4f}, lune: {avg_lune:.4f}, sol: {avg_sol:.4f}")
2333
+
2334
+ if avg < best:
2335
+ best = avg
2336
+ weights_path = save_checkpoint(model, opt, sched, step, ep, avg, os.path.join(CHECKPOINT_DIR, "best.pt"),
2337
+ ema=ema)
2338
+ try:
2339
+ api.upload_file(path_or_fileobj=weights_path, path_in_repo="model.safetensors", repo_id=HF_REPO)
2340
+ except:
2341
+ pass
2342
+
2343
+ print(f"\n✓ Training complete! Best loss: {best:.4f}")
2344
+ writer.close()