Update trainer_colab.py
Browse files- trainer_colab.py +280 -111
trainer_colab.py
CHANGED
|
@@ -36,9 +36,9 @@ DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
|
| 36 |
|
| 37 |
# HuggingFace Hub
|
| 38 |
HF_REPO = "AbstractPhil/tiny-flux"
|
| 39 |
-
SAVE_EVERY =
|
| 40 |
-
UPLOAD_EVERY =
|
| 41 |
-
SAMPLE_EVERY =
|
| 42 |
LOG_EVERY = 10 # steps - tensorboard
|
| 43 |
|
| 44 |
# Checkpoint loading target
|
|
@@ -47,10 +47,14 @@ LOG_EVERY = 10 # steps - tensorboard
|
|
| 47 |
# "best" - load best model
|
| 48 |
# int (e.g. 1500) - load specific step
|
| 49 |
# "hub:step_1000" - load specific checkpoint from hub
|
| 50 |
-
# "local:path/to/checkpoint.safetensors"
|
| 51 |
# "none" - start fresh, ignore existing checkpoints
|
| 52 |
LOAD_TARGET = "latest"
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
# Local paths
|
| 55 |
CHECKPOINT_DIR = "./tiny_flux_checkpoints"
|
| 56 |
LOG_DIR = "./tiny_flux_logs"
|
|
@@ -127,14 +131,37 @@ def encode_prompt(prompt):
|
|
| 127 |
# ============================================================================
|
| 128 |
# FLOW MATCHING HELPERS
|
| 129 |
# ============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
def flux_shift(t, s=SHIFT):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
return s * t / (1 + (s - 1) * t)
|
| 132 |
|
| 133 |
def flux_shift_inverse(t_shifted, s=SHIFT):
|
| 134 |
-
"""Inverse of flux_shift
|
| 135 |
return t_shifted / (s - (s - 1) * t_shifted)
|
| 136 |
|
| 137 |
def min_snr_weight(t, gamma=MIN_SNR):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
snr = (t / (1 - t).clamp(min=1e-5)).pow(2)
|
| 139 |
return torch.clamp(snr, max=gamma) / snr.clamp(min=1e-5)
|
| 140 |
|
|
@@ -143,7 +170,12 @@ def min_snr_weight(t, gamma=MIN_SNR):
|
|
| 143 |
# ============================================================================
|
| 144 |
@torch.no_grad()
|
| 145 |
def generate_samples(model, prompts, num_steps=20, guidance_scale=3.5, H=64, W=64):
|
| 146 |
-
"""Generate sample images using Euler sampling.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
model.eval()
|
| 148 |
B = len(prompts)
|
| 149 |
C = 16 # VAE channels
|
|
@@ -157,18 +189,21 @@ def generate_samples(model, prompts, num_steps=20, guidance_scale=3.5, H=64, W=6
|
|
| 157 |
t5_embeds = torch.stack(t5_embeds)
|
| 158 |
clip_pooleds = torch.stack(clip_pooleds)
|
| 159 |
|
| 160 |
-
# Start from noise
|
| 161 |
x = torch.randn(B, H * W, C, device=DEVICE, dtype=DTYPE)
|
| 162 |
|
| 163 |
# Create image IDs
|
| 164 |
img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE)
|
| 165 |
|
| 166 |
-
# Euler sampling
|
| 167 |
-
timesteps = torch.linspace(
|
| 168 |
|
| 169 |
-
for i
|
| 170 |
-
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
# Conditional prediction
|
| 174 |
guidance = torch.full((B,), guidance_scale, device=DEVICE, dtype=DTYPE)
|
|
@@ -181,15 +216,15 @@ def generate_samples(model, prompts, num_steps=20, guidance_scale=3.5, H=64, W=6
|
|
| 181 |
guidance=guidance,
|
| 182 |
)
|
| 183 |
|
| 184 |
-
# Euler step:
|
| 185 |
x = x + v_cond * dt
|
| 186 |
|
| 187 |
# Reshape to image format: (B, H*W, C) -> (B, C, H, W)
|
| 188 |
latents = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
|
| 189 |
|
| 190 |
-
# Decode with VAE
|
| 191 |
latents = latents / vae.config.scaling_factor
|
| 192 |
-
images = vae.decode(latents.
|
| 193 |
images = (images / 2 + 0.5).clamp(0, 1)
|
| 194 |
|
| 195 |
model.train()
|
|
@@ -235,6 +270,32 @@ def collate(batch):
|
|
| 235 |
# ============================================================================
|
| 236 |
# CHECKPOINT FUNCTIONS
|
| 237 |
# ============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
def save_checkpoint(model, optimizer, scheduler, step, epoch, loss, path):
|
| 239 |
"""Save checkpoint locally."""
|
| 240 |
os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
|
|
@@ -306,7 +367,7 @@ def load_checkpoint(model, optimizer, scheduler, target):
|
|
| 306 |
"best" - best model
|
| 307 |
int (1500) - specific step
|
| 308 |
"hub:step_1000" - specific hub checkpoint
|
| 309 |
-
"local:/path/to/file.safetensors" - specific local file
|
| 310 |
"none" - skip loading, start fresh
|
| 311 |
"""
|
| 312 |
if target == "none":
|
|
@@ -342,102 +403,163 @@ def load_checkpoint(model, optimizer, scheduler, target):
|
|
| 342 |
|
| 343 |
# Load based on mode
|
| 344 |
if load_mode == "local":
|
| 345 |
-
# Direct local file
|
| 346 |
if os.path.exists(load_path):
|
| 347 |
-
weights =
|
| 348 |
model.load_state_dict(weights)
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
if
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
print(f"✓ Loaded local: {load_path} (step {start_step})")
|
| 358 |
return start_step, start_epoch
|
| 359 |
else:
|
| 360 |
print(f"⚠ Local file not found: {load_path}")
|
| 361 |
|
| 362 |
elif load_mode == "hub":
|
| 363 |
-
# Specific hub checkpoint
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
|
| 377 |
elif load_mode == "best":
|
| 378 |
-
# Try hub best first
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
if os.path.exists(
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
|
| 400 |
elif load_mode == "step":
|
| 401 |
# Specific step number
|
| 402 |
step_num = load_path
|
| 403 |
-
# Try hub
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
print(f"⚠ Step {step_num} not found")
|
| 429 |
|
| 430 |
# Default: latest
|
| 431 |
-
# Try Hub first
|
| 432 |
try:
|
| 433 |
files = api.list_repo_files(repo_id=HF_REPO)
|
| 434 |
-
checkpoints = [f for f in files if f.startswith("checkpoints/step_") and f.endswith(".safetensors")]
|
| 435 |
if checkpoints:
|
| 436 |
-
|
|
|
|
|
|
|
|
|
|
| 437 |
latest = checkpoints[-1]
|
| 438 |
-
step =
|
| 439 |
local_path = hf_hub_download(repo_id=HF_REPO, filename=latest)
|
| 440 |
-
weights =
|
| 441 |
model.load_state_dict(weights)
|
| 442 |
start_step = step
|
| 443 |
print(f"✓ Loaded latest from Hub: step {step}")
|
|
@@ -445,22 +567,33 @@ def load_checkpoint(model, optimizer, scheduler, target):
|
|
| 445 |
except Exception as e:
|
| 446 |
print(f"Hub check: {e}")
|
| 447 |
|
| 448 |
-
# Try local
|
| 449 |
if os.path.exists(CHECKPOINT_DIR):
|
| 450 |
-
local_ckpts = [f for f in os.listdir(CHECKPOINT_DIR) if f.startswith("step_") and f.endswith(".safetensors")]
|
|
|
|
|
|
|
| 451 |
if local_ckpts:
|
| 452 |
-
|
|
|
|
|
|
|
| 453 |
latest = local_ckpts[-1]
|
| 454 |
-
step =
|
| 455 |
weights_path = os.path.join(CHECKPOINT_DIR, latest)
|
| 456 |
-
weights =
|
| 457 |
model.load_state_dict(weights)
|
| 458 |
-
|
|
|
|
| 459 |
if os.path.exists(state_path):
|
| 460 |
-
state = torch.load(state_path, weights_only=False)
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
start_step = step
|
| 465 |
print(f"✓ Loaded latest local: step {step}")
|
| 466 |
return start_step, start_epoch
|
|
@@ -479,6 +612,7 @@ loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate,
|
|
| 479 |
config = TinyFluxConfig()
|
| 480 |
model = TinyFlux(config).to(DEVICE).to(DTYPE)
|
| 481 |
print(f"\nParams: {sum(p.numel() for p in model.parameters()):,}")
|
|
|
|
| 482 |
|
| 483 |
# ============================================================================
|
| 484 |
# OPTIMIZER & SCHEDULER
|
|
@@ -499,6 +633,11 @@ sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_fn)
|
|
| 499 |
print(f"\nLoad target: {LOAD_TARGET}")
|
| 500 |
start_step, start_epoch = load_checkpoint(model, opt, sched, LOAD_TARGET)
|
| 501 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
# Log config to tensorboard
|
| 503 |
writer.add_text("config", json.dumps(config.__dict__, indent=2), 0)
|
| 504 |
writer.add_text("training_config", json.dumps({
|
|
@@ -537,26 +676,53 @@ for ep in range(start_epoch, EPOCHS):
|
|
| 537 |
pbar = tqdm(loader, desc=f"E{ep+1}")
|
| 538 |
|
| 539 |
for i, batch in enumerate(pbar):
|
| 540 |
-
|
| 541 |
t5 = batch["t5_embeds"]
|
| 542 |
clip = batch["clip_pooled"]
|
| 543 |
|
| 544 |
-
B, C, H, W =
|
| 545 |
-
x1 = lat.permute(0, 2, 3, 1).reshape(B, H*W, C)
|
| 546 |
-
x0 = torch.randn_like(x1)
|
| 547 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 548 |
t = torch.sigmoid(torch.randn(B, device=DEVICE))
|
| 549 |
-
t = flux_shift(t).to(DTYPE).clamp(1e-4, 1-1e-4)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 550 |
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
v_target = x1 - x0
|
| 554 |
|
|
|
|
| 555 |
img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE)
|
| 556 |
-
guidance = torch.rand(B, device=DEVICE, dtype=DTYPE) * 4 + 1
|
| 557 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 558 |
with torch.autocast("cuda", dtype=DTYPE):
|
| 559 |
-
|
| 560 |
hidden_states=x_t,
|
| 561 |
encoder_hidden_states=t5,
|
| 562 |
pooled_projections=clip,
|
|
@@ -565,7 +731,10 @@ for ep in range(start_epoch, EPOCHS):
|
|
| 565 |
guidance=guidance,
|
| 566 |
)
|
| 567 |
|
| 568 |
-
|
|
|
|
|
|
|
|
|
|
| 569 |
snr_weights = min_snr_weight(t)
|
| 570 |
loss = (loss_raw * snr_weights).mean() / GRAD_ACCUM
|
| 571 |
loss.backward()
|
|
|
|
| 36 |
|
| 37 |
# HuggingFace Hub
|
| 38 |
HF_REPO = "AbstractPhil/tiny-flux"
|
| 39 |
+
SAVE_EVERY = 1000 # steps - local save
|
| 40 |
+
UPLOAD_EVERY = 1000 # steps - hub upload
|
| 41 |
+
SAMPLE_EVERY = 500 # steps - generate samples
|
| 42 |
LOG_EVERY = 10 # steps - tensorboard
|
| 43 |
|
| 44 |
# Checkpoint loading target
|
|
|
|
| 47 |
# "best" - load best model
|
| 48 |
# int (e.g. 1500) - load specific step
|
| 49 |
# "hub:step_1000" - load specific checkpoint from hub
|
| 50 |
+
# "local:path/to/checkpoint.safetensors" or "local:path/to/checkpoint.pt"
|
| 51 |
# "none" - start fresh, ignore existing checkpoints
|
| 52 |
LOAD_TARGET = "latest"
|
| 53 |
|
| 54 |
+
# Manual resume step (set to override step from checkpoint, or None to use checkpoint's step)
|
| 55 |
+
# Useful when checkpoint doesn't contain step info
|
| 56 |
+
RESUME_STEP = None # e.g., 5000 to resume from step 5000
|
| 57 |
+
|
| 58 |
# Local paths
|
| 59 |
CHECKPOINT_DIR = "./tiny_flux_checkpoints"
|
| 60 |
LOG_DIR = "./tiny_flux_logs"
|
|
|
|
| 131 |
# ============================================================================
|
| 132 |
# FLOW MATCHING HELPERS
|
| 133 |
# ============================================================================
|
| 134 |
+
# Rectified Flow / Flow Matching formulation:
|
| 135 |
+
# x_t = (1-t) * x_0 + t * x_1
|
| 136 |
+
# where x_0 = noise, x_1 = data
|
| 137 |
+
# t=0: pure noise, t=1: pure data
|
| 138 |
+
# velocity v = x_1 - x_0 = data - noise
|
| 139 |
+
#
|
| 140 |
+
# Training: model learns to predict v given (x_t, t)
|
| 141 |
+
# Inference: start from noise (t=0), integrate to data (t=1)
|
| 142 |
+
# x_{t+dt} = x_t + v_pred * dt
|
| 143 |
+
# ============================================================================
|
| 144 |
+
|
| 145 |
def flux_shift(t, s=SHIFT):
|
| 146 |
+
"""Flux timestep shift for training distribution.
|
| 147 |
+
|
| 148 |
+
Shifts timesteps towards higher values (closer to data),
|
| 149 |
+
making training focus more on refining details.
|
| 150 |
+
|
| 151 |
+
s=3.0 (default): flux_shift(0.5) ≈ 0.75
|
| 152 |
+
"""
|
| 153 |
return s * t / (1 + (s - 1) * t)
|
| 154 |
|
| 155 |
def flux_shift_inverse(t_shifted, s=SHIFT):
|
| 156 |
+
"""Inverse of flux_shift."""
|
| 157 |
return t_shifted / (s - (s - 1) * t_shifted)
|
| 158 |
|
| 159 |
def min_snr_weight(t, gamma=MIN_SNR):
|
| 160 |
+
"""Min-SNR weighting to balance loss across timesteps.
|
| 161 |
+
|
| 162 |
+
Downweights very easy timesteps (near t=0 or t=1).
|
| 163 |
+
gamma=5.0 is typical.
|
| 164 |
+
"""
|
| 165 |
snr = (t / (1 - t).clamp(min=1e-5)).pow(2)
|
| 166 |
return torch.clamp(snr, max=gamma) / snr.clamp(min=1e-5)
|
| 167 |
|
|
|
|
| 170 |
# ============================================================================
|
| 171 |
@torch.no_grad()
|
| 172 |
def generate_samples(model, prompts, num_steps=20, guidance_scale=3.5, H=64, W=64):
|
| 173 |
+
"""Generate sample images using Euler sampling.
|
| 174 |
+
|
| 175 |
+
Flow matching: x_t = (1-t)*noise + t*data, v = data - noise
|
| 176 |
+
At t=0: pure noise. At t=1: pure data.
|
| 177 |
+
We integrate from t=0 to t=1.
|
| 178 |
+
"""
|
| 179 |
model.eval()
|
| 180 |
B = len(prompts)
|
| 181 |
C = 16 # VAE channels
|
|
|
|
| 189 |
t5_embeds = torch.stack(t5_embeds)
|
| 190 |
clip_pooleds = torch.stack(clip_pooleds)
|
| 191 |
|
| 192 |
+
# Start from pure noise (t=0)
|
| 193 |
x = torch.randn(B, H * W, C, device=DEVICE, dtype=DTYPE)
|
| 194 |
|
| 195 |
# Create image IDs
|
| 196 |
img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE)
|
| 197 |
|
| 198 |
+
# Euler sampling: t goes from 0 (noise) to 1 (data)
|
| 199 |
+
timesteps = torch.linspace(0, 1, num_steps + 1, device=DEVICE, dtype=DTYPE)
|
| 200 |
|
| 201 |
+
for i in range(num_steps):
|
| 202 |
+
t_curr = timesteps[i]
|
| 203 |
+
t_next = timesteps[i + 1]
|
| 204 |
+
dt = t_next - t_curr # positive
|
| 205 |
+
|
| 206 |
+
t_batch = t_curr.expand(B)
|
| 207 |
|
| 208 |
# Conditional prediction
|
| 209 |
guidance = torch.full((B,), guidance_scale, device=DEVICE, dtype=DTYPE)
|
|
|
|
| 216 |
guidance=guidance,
|
| 217 |
)
|
| 218 |
|
| 219 |
+
# Euler step: x_{t+dt} = x_t + v * dt
|
| 220 |
x = x + v_cond * dt
|
| 221 |
|
| 222 |
# Reshape to image format: (B, H*W, C) -> (B, C, H, W)
|
| 223 |
latents = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
|
| 224 |
|
| 225 |
+
# Decode with VAE (match VAE dtype)
|
| 226 |
latents = latents / vae.config.scaling_factor
|
| 227 |
+
images = vae.decode(latents.to(vae.dtype)).sample
|
| 228 |
images = (images / 2 + 0.5).clamp(0, 1)
|
| 229 |
|
| 230 |
model.train()
|
|
|
|
| 270 |
# ============================================================================
|
| 271 |
# CHECKPOINT FUNCTIONS
|
| 272 |
# ============================================================================
|
| 273 |
+
def load_weights(path):
|
| 274 |
+
"""Load weights from .safetensors or .pt file."""
|
| 275 |
+
if path.endswith(".safetensors"):
|
| 276 |
+
return load_file(path)
|
| 277 |
+
elif path.endswith(".pt"):
|
| 278 |
+
ckpt = torch.load(path, map_location=DEVICE, weights_only=False)
|
| 279 |
+
if isinstance(ckpt, dict):
|
| 280 |
+
if "model" in ckpt:
|
| 281 |
+
return ckpt["model"]
|
| 282 |
+
elif "state_dict" in ckpt:
|
| 283 |
+
return ckpt["state_dict"]
|
| 284 |
+
else:
|
| 285 |
+
# Check if it looks like a state dict (has tensor values)
|
| 286 |
+
first_val = next(iter(ckpt.values()), None)
|
| 287 |
+
if isinstance(first_val, torch.Tensor):
|
| 288 |
+
return ckpt
|
| 289 |
+
# Otherwise might have optimizer etc, look for model keys
|
| 290 |
+
return ckpt
|
| 291 |
+
return ckpt
|
| 292 |
+
else:
|
| 293 |
+
# Try safetensors first, then pt
|
| 294 |
+
try:
|
| 295 |
+
return load_file(path)
|
| 296 |
+
except:
|
| 297 |
+
return torch.load(path, map_location=DEVICE, weights_only=False)
|
| 298 |
+
|
| 299 |
def save_checkpoint(model, optimizer, scheduler, step, epoch, loss, path):
|
| 300 |
"""Save checkpoint locally."""
|
| 301 |
os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
|
|
|
|
| 367 |
"best" - best model
|
| 368 |
int (1500) - specific step
|
| 369 |
"hub:step_1000" - specific hub checkpoint
|
| 370 |
+
"local:/path/to/file.safetensors" or "local:/path/to/file.pt" - specific local file
|
| 371 |
"none" - skip loading, start fresh
|
| 372 |
"""
|
| 373 |
if target == "none":
|
|
|
|
| 403 |
|
| 404 |
# Load based on mode
|
| 405 |
if load_mode == "local":
|
| 406 |
+
# Direct local file (.pt or .safetensors)
|
| 407 |
if os.path.exists(load_path):
|
| 408 |
+
weights = load_weights(load_path)
|
| 409 |
model.load_state_dict(weights)
|
| 410 |
+
|
| 411 |
+
# Try to find associated state file for optimizer/scheduler
|
| 412 |
+
if load_path.endswith(".safetensors"):
|
| 413 |
+
state_path = load_path.replace(".safetensors", ".pt")
|
| 414 |
+
elif load_path.endswith(".pt"):
|
| 415 |
+
# The .pt file might contain everything
|
| 416 |
+
ckpt = torch.load(load_path, map_location=DEVICE, weights_only=False)
|
| 417 |
+
if isinstance(ckpt, dict):
|
| 418 |
+
# Debug: show what keys are in the checkpoint
|
| 419 |
+
non_tensor_keys = [k for k in ckpt.keys() if not isinstance(ckpt.get(k), torch.Tensor)]
|
| 420 |
+
if non_tensor_keys:
|
| 421 |
+
print(f" Checkpoint keys: {non_tensor_keys}")
|
| 422 |
+
|
| 423 |
+
# Extract step/epoch - try multiple common key names
|
| 424 |
+
start_step = ckpt.get("step", ckpt.get("global_step", ckpt.get("iteration", 0)))
|
| 425 |
+
start_epoch = ckpt.get("epoch", 0)
|
| 426 |
+
|
| 427 |
+
# Also check for nested state dict
|
| 428 |
+
if "state" in ckpt and isinstance(ckpt["state"], dict):
|
| 429 |
+
start_step = ckpt["state"].get("step", start_step)
|
| 430 |
+
start_epoch = ckpt["state"].get("epoch", start_epoch)
|
| 431 |
+
|
| 432 |
+
# Try to load optimizer/scheduler if present
|
| 433 |
+
if "optimizer" in ckpt:
|
| 434 |
+
try:
|
| 435 |
+
optimizer.load_state_dict(ckpt["optimizer"])
|
| 436 |
+
if "scheduler" in ckpt:
|
| 437 |
+
scheduler.load_state_dict(ckpt["scheduler"])
|
| 438 |
+
except Exception as e:
|
| 439 |
+
print(f" Note: Could not load optimizer state: {e}")
|
| 440 |
+
state_path = None
|
| 441 |
+
else:
|
| 442 |
+
state_path = load_path + ".pt"
|
| 443 |
+
|
| 444 |
+
if state_path and os.path.exists(state_path):
|
| 445 |
+
state = torch.load(state_path, map_location=DEVICE, weights_only=False)
|
| 446 |
+
try:
|
| 447 |
+
start_step = state.get("step", start_step)
|
| 448 |
+
start_epoch = state.get("epoch", start_epoch)
|
| 449 |
+
if "optimizer" in state:
|
| 450 |
+
optimizer.load_state_dict(state["optimizer"])
|
| 451 |
+
if "scheduler" in state:
|
| 452 |
+
scheduler.load_state_dict(state["scheduler"])
|
| 453 |
+
except Exception as e:
|
| 454 |
+
print(f" Note: Could not load optimizer state: {e}")
|
| 455 |
+
|
| 456 |
print(f"✓ Loaded local: {load_path} (step {start_step})")
|
| 457 |
return start_step, start_epoch
|
| 458 |
else:
|
| 459 |
print(f"⚠ Local file not found: {load_path}")
|
| 460 |
|
| 461 |
elif load_mode == "hub":
|
| 462 |
+
# Specific hub checkpoint - try both extensions
|
| 463 |
+
for ext in [".safetensors", ".pt", ""]:
|
| 464 |
+
try:
|
| 465 |
+
if load_path.endswith((".safetensors", ".pt")):
|
| 466 |
+
filename = load_path if "/" in load_path else f"checkpoints/{load_path}"
|
| 467 |
+
else:
|
| 468 |
+
filename = f"checkpoints/{load_path}{ext}"
|
| 469 |
+
local_path = hf_hub_download(repo_id=HF_REPO, filename=filename)
|
| 470 |
+
weights = load_weights(local_path)
|
| 471 |
+
model.load_state_dict(weights)
|
| 472 |
+
# Extract step from filename
|
| 473 |
+
if "step_" in load_path:
|
| 474 |
+
start_step = int(load_path.split("step_")[-1].replace(".safetensors", "").replace(".pt", ""))
|
| 475 |
+
print(f"✓ Loaded from Hub: {filename} (step {start_step})")
|
| 476 |
+
return start_step, start_epoch
|
| 477 |
+
except Exception as e:
|
| 478 |
+
continue
|
| 479 |
+
print(f"⚠ Could not load from hub: {load_path}")
|
| 480 |
|
| 481 |
elif load_mode == "best":
|
| 482 |
+
# Try hub best first (try both extensions)
|
| 483 |
+
for ext in [".safetensors", ".pt"]:
|
| 484 |
+
try:
|
| 485 |
+
filename = f"model{ext}" if ext else "model.safetensors"
|
| 486 |
+
local_path = hf_hub_download(repo_id=HF_REPO, filename=filename)
|
| 487 |
+
weights = load_weights(local_path)
|
| 488 |
+
model.load_state_dict(weights)
|
| 489 |
+
print(f"✓ Loaded best model from Hub")
|
| 490 |
+
return start_step, start_epoch
|
| 491 |
+
except:
|
| 492 |
+
continue
|
| 493 |
+
|
| 494 |
+
# Try local best (both extensions)
|
| 495 |
+
for ext in [".safetensors", ".pt"]:
|
| 496 |
+
best_path = os.path.join(CHECKPOINT_DIR, f"best{ext}")
|
| 497 |
+
if os.path.exists(best_path):
|
| 498 |
+
weights = load_weights(best_path)
|
| 499 |
+
model.load_state_dict(weights)
|
| 500 |
+
# Try to load optimizer state
|
| 501 |
+
state_path = best_path.replace(ext, ".pt") if ext == ".safetensors" else best_path
|
| 502 |
+
if os.path.exists(state_path):
|
| 503 |
+
state = torch.load(state_path, map_location=DEVICE, weights_only=False)
|
| 504 |
+
if isinstance(state, dict) and "step" in state:
|
| 505 |
+
start_step = state.get("step", 0)
|
| 506 |
+
start_epoch = state.get("epoch", 0)
|
| 507 |
+
print(f"✓ Loaded local best (step {start_step})")
|
| 508 |
+
return start_step, start_epoch
|
| 509 |
|
| 510 |
elif load_mode == "step":
|
| 511 |
# Specific step number
|
| 512 |
step_num = load_path
|
| 513 |
+
# Try hub (both extensions)
|
| 514 |
+
for ext in [".safetensors", ".pt"]:
|
| 515 |
+
try:
|
| 516 |
+
filename = f"checkpoints/step_{step_num}{ext}"
|
| 517 |
+
local_path = hf_hub_download(repo_id=HF_REPO, filename=filename)
|
| 518 |
+
weights = load_weights(local_path)
|
| 519 |
+
model.load_state_dict(weights)
|
| 520 |
+
start_step = step_num
|
| 521 |
+
print(f"✓ Loaded step {step_num} from Hub")
|
| 522 |
+
return start_step, start_epoch
|
| 523 |
+
except:
|
| 524 |
+
continue
|
| 525 |
+
|
| 526 |
+
# Try local (both extensions)
|
| 527 |
+
for ext in [".safetensors", ".pt"]:
|
| 528 |
+
local_path = os.path.join(CHECKPOINT_DIR, f"step_{step_num}{ext}")
|
| 529 |
+
if os.path.exists(local_path):
|
| 530 |
+
weights = load_weights(local_path)
|
| 531 |
+
model.load_state_dict(weights)
|
| 532 |
+
state_path = local_path.replace(".safetensors", ".pt") if ext == ".safetensors" else local_path
|
| 533 |
+
if os.path.exists(state_path):
|
| 534 |
+
state = torch.load(state_path, map_location=DEVICE, weights_only=False)
|
| 535 |
+
if isinstance(state, dict):
|
| 536 |
+
try:
|
| 537 |
+
if "optimizer" in state:
|
| 538 |
+
optimizer.load_state_dict(state["optimizer"])
|
| 539 |
+
if "scheduler" in state:
|
| 540 |
+
scheduler.load_state_dict(state["scheduler"])
|
| 541 |
+
start_epoch = state.get("epoch", 0)
|
| 542 |
+
except:
|
| 543 |
+
pass
|
| 544 |
+
start_step = step_num
|
| 545 |
+
print(f"✓ Loaded local step {step_num}")
|
| 546 |
+
return start_step, start_epoch
|
| 547 |
print(f"⚠ Step {step_num} not found")
|
| 548 |
|
| 549 |
# Default: latest
|
| 550 |
+
# Try Hub first (both extensions)
|
| 551 |
try:
|
| 552 |
files = api.list_repo_files(repo_id=HF_REPO)
|
| 553 |
+
checkpoints = [f for f in files if f.startswith("checkpoints/step_") and (f.endswith(".safetensors") or f.endswith(".pt"))]
|
| 554 |
if checkpoints:
|
| 555 |
+
# Sort by step number
|
| 556 |
+
def get_step(f):
|
| 557 |
+
return int(f.split("step_")[-1].replace(".safetensors", "").replace(".pt", ""))
|
| 558 |
+
checkpoints.sort(key=get_step)
|
| 559 |
latest = checkpoints[-1]
|
| 560 |
+
step = get_step(latest)
|
| 561 |
local_path = hf_hub_download(repo_id=HF_REPO, filename=latest)
|
| 562 |
+
weights = load_weights(local_path)
|
| 563 |
model.load_state_dict(weights)
|
| 564 |
start_step = step
|
| 565 |
print(f"✓ Loaded latest from Hub: step {step}")
|
|
|
|
| 567 |
except Exception as e:
|
| 568 |
print(f"Hub check: {e}")
|
| 569 |
|
| 570 |
+
# Try local (both extensions)
|
| 571 |
if os.path.exists(CHECKPOINT_DIR):
|
| 572 |
+
local_ckpts = [f for f in os.listdir(CHECKPOINT_DIR) if f.startswith("step_") and (f.endswith(".safetensors") or f.endswith(".pt"))]
|
| 573 |
+
# Filter to just weights files (not state .pt files that pair with .safetensors)
|
| 574 |
+
local_ckpts = [f for f in local_ckpts if not (f.endswith(".pt") and f.replace(".pt", ".safetensors") in local_ckpts)]
|
| 575 |
if local_ckpts:
|
| 576 |
+
def get_step(f):
|
| 577 |
+
return int(f.split("step_")[-1].replace(".safetensors", "").replace(".pt", ""))
|
| 578 |
+
local_ckpts.sort(key=get_step)
|
| 579 |
latest = local_ckpts[-1]
|
| 580 |
+
step = get_step(latest)
|
| 581 |
weights_path = os.path.join(CHECKPOINT_DIR, latest)
|
| 582 |
+
weights = load_weights(weights_path)
|
| 583 |
model.load_state_dict(weights)
|
| 584 |
+
# Try to load optimizer state
|
| 585 |
+
state_path = weights_path.replace(".safetensors", ".pt") if weights_path.endswith(".safetensors") else weights_path
|
| 586 |
if os.path.exists(state_path):
|
| 587 |
+
state = torch.load(state_path, map_location=DEVICE, weights_only=False)
|
| 588 |
+
if isinstance(state, dict):
|
| 589 |
+
try:
|
| 590 |
+
if "optimizer" in state:
|
| 591 |
+
optimizer.load_state_dict(state["optimizer"])
|
| 592 |
+
if "scheduler" in state:
|
| 593 |
+
scheduler.load_state_dict(state["scheduler"])
|
| 594 |
+
start_epoch = state.get("epoch", 0)
|
| 595 |
+
except:
|
| 596 |
+
pass
|
| 597 |
start_step = step
|
| 598 |
print(f"✓ Loaded latest local: step {step}")
|
| 599 |
return start_step, start_epoch
|
|
|
|
| 612 |
config = TinyFluxConfig()
|
| 613 |
model = TinyFlux(config).to(DEVICE).to(DTYPE)
|
| 614 |
print(f"\nParams: {sum(p.numel() for p in model.parameters()):,}")
|
| 615 |
+
model = torch.compile(model, mode="default")
|
| 616 |
|
| 617 |
# ============================================================================
|
| 618 |
# OPTIMIZER & SCHEDULER
|
|
|
|
| 633 |
print(f"\nLoad target: {LOAD_TARGET}")
|
| 634 |
start_step, start_epoch = load_checkpoint(model, opt, sched, LOAD_TARGET)
|
| 635 |
|
| 636 |
+
# Override start_step if RESUME_STEP is set
|
| 637 |
+
if RESUME_STEP is not None:
|
| 638 |
+
print(f"Overriding start_step: {start_step} -> {RESUME_STEP}")
|
| 639 |
+
start_step = RESUME_STEP
|
| 640 |
+
|
| 641 |
# Log config to tensorboard
|
| 642 |
writer.add_text("config", json.dumps(config.__dict__, indent=2), 0)
|
| 643 |
writer.add_text("training_config", json.dumps({
|
|
|
|
| 676 |
pbar = tqdm(loader, desc=f"E{ep+1}")
|
| 677 |
|
| 678 |
for i, batch in enumerate(pbar):
|
| 679 |
+
latents = batch["latents"] # Ground truth data (VAE encoded images)
|
| 680 |
t5 = batch["t5_embeds"]
|
| 681 |
clip = batch["clip_pooled"]
|
| 682 |
|
| 683 |
+
B, C, H, W = latents.shape
|
|
|
|
|
|
|
| 684 |
|
| 685 |
+
# ================================================================
|
| 686 |
+
# FLOW MATCHING FORMULATION
|
| 687 |
+
# ================================================================
|
| 688 |
+
# x_1 = data (what we want to generate)
|
| 689 |
+
# x_0 = noise (where we start at inference)
|
| 690 |
+
# x_t = (1-t)*x_0 + t*x_1 (linear interpolation)
|
| 691 |
+
#
|
| 692 |
+
# At t=0: x_t = x_0 (pure noise)
|
| 693 |
+
# At t=1: x_t = x_1 (pure data)
|
| 694 |
+
#
|
| 695 |
+
# Velocity field: v = dx/dt = x_1 - x_0
|
| 696 |
+
# Model learns to predict v given (x_t, t)
|
| 697 |
+
#
|
| 698 |
+
# At inference: start from noise, integrate v from t=0 to t=1
|
| 699 |
+
# ================================================================
|
| 700 |
+
|
| 701 |
+
# Reshape data to sequence format: (B, C, H, W) -> (B, H*W, C)
|
| 702 |
+
data = latents.permute(0, 2, 3, 1).reshape(B, H*W, C) # x_1
|
| 703 |
+
noise = torch.randn_like(data) # x_0
|
| 704 |
+
|
| 705 |
+
# Sample timesteps with logit-normal distribution + Flux shift
|
| 706 |
+
# This biases training towards higher t (closer to data)
|
| 707 |
t = torch.sigmoid(torch.randn(B, device=DEVICE))
|
| 708 |
+
t = flux_shift(t, s=SHIFT).to(DTYPE).clamp(1e-4, 1-1e-4)
|
| 709 |
+
|
| 710 |
+
# Create noisy samples via linear interpolation
|
| 711 |
+
t_expanded = t.view(B, 1, 1)
|
| 712 |
+
x_t = (1 - t_expanded) * noise + t_expanded * data # Noisy sample at time t
|
| 713 |
|
| 714 |
+
# Target velocity: direction from noise to data
|
| 715 |
+
v_target = data - noise
|
|
|
|
| 716 |
|
| 717 |
+
# Create position IDs for RoPE
|
| 718 |
img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE)
|
|
|
|
| 719 |
|
| 720 |
+
# Random guidance scale (for CFG training)
|
| 721 |
+
guidance = torch.rand(B, device=DEVICE, dtype=DTYPE) * 4 + 1 # [1, 5]
|
| 722 |
+
|
| 723 |
+
# Forward pass: predict velocity
|
| 724 |
with torch.autocast("cuda", dtype=DTYPE):
|
| 725 |
+
v_pred = model(
|
| 726 |
hidden_states=x_t,
|
| 727 |
encoder_hidden_states=t5,
|
| 728 |
pooled_projections=clip,
|
|
|
|
| 731 |
guidance=guidance,
|
| 732 |
)
|
| 733 |
|
| 734 |
+
# Loss: MSE between predicted and target velocity
|
| 735 |
+
loss_raw = F.mse_loss(v_pred, v_target, reduction="none").mean(dim=[1, 2])
|
| 736 |
+
|
| 737 |
+
# Min-SNR weighting: downweight easy timesteps (near t=0 or t=1)
|
| 738 |
snr_weights = min_snr_weight(t)
|
| 739 |
loss = (loss_raw * snr_weights).mean() / GRAD_ACCUM
|
| 740 |
loss.backward()
|