tiny-flux / trainer_colab.py
AbstractPhil's picture
Update trainer_colab.py
4adc3d8 verified
raw
history blame
33.9 kB
# ============================================================================
# TinyFlux Training Cell - Full Featured
# ============================================================================
# Run the model cell before this one (defines TinyFlux, TinyFluxConfig)
# Dataset: AbstractPhil/flux-schnell-teacher-latents
# Uploads checkpoints to: AbstractPhil/tiny-flux
# ============================================================================
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer
from huggingface_hub import HfApi, hf_hub_download
from safetensors.torch import save_file, load_file
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
import numpy as np
import math
import os
import json
from datetime import datetime
# ============================================================================
# CONFIG
# ============================================================================
BATCH_SIZE = 4
GRAD_ACCUM = 2
LR = 1e-4
EPOCHS = 10
MAX_SEQ = 128
MIN_SNR = 5.0
SHIFT = 3.0
DEVICE = "cuda"
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
# HuggingFace Hub
HF_REPO = "AbstractPhil/tiny-flux"
SAVE_EVERY = 1000 # steps - local save
UPLOAD_EVERY = 1000 # steps - hub upload
SAMPLE_EVERY = 500 # steps - generate samples
LOG_EVERY = 10 # steps - tensorboard
# Checkpoint loading target
# Options:
# None or "latest" - load most recent checkpoint
# "best" - load best model
# int (e.g. 1500) - load specific step
# "hub:step_1000" - load specific checkpoint from hub
# "local:path/to/checkpoint.safetensors" or "local:path/to/checkpoint.pt"
# "none" - start fresh, ignore existing checkpoints
LOAD_TARGET = "latest"
# Manual resume step (set to override step from checkpoint, or None to use checkpoint's step)
# Useful when checkpoint doesn't contain step info
RESUME_STEP = None # e.g., 5000 to resume from step 5000
# Local paths
CHECKPOINT_DIR = "./tiny_flux_checkpoints"
LOG_DIR = "./tiny_flux_logs"
SAMPLE_DIR = "./tiny_flux_samples"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)
os.makedirs(SAMPLE_DIR, exist_ok=True)
# ============================================================================
# HF HUB SETUP
# ============================================================================
print("Setting up HuggingFace Hub...")
api = HfApi()
try:
api.create_repo(repo_id=HF_REPO, exist_ok=True, repo_type="model")
print(f"✓ Repo ready: {HF_REPO}")
except Exception as e:
print(f"Note: {e}")
# ============================================================================
# TENSORBOARD
# ============================================================================
run_name = datetime.now().strftime("%Y%m%d_%H%M%S")
writer = SummaryWriter(log_dir=os.path.join(LOG_DIR, run_name))
print(f"✓ Tensorboard: {LOG_DIR}/{run_name}")
# ============================================================================
# LOAD DATASET
# ============================================================================
print("\nLoading dataset...")
ds = load_dataset("AbstractPhil/flux-schnell-teacher-latents", split="train")
print(f"Samples: {len(ds)}")
# ============================================================================
# LOAD TEXT ENCODERS
# ============================================================================
print("\nLoading flan-t5-base (768 dim)...")
t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
t5_enc = T5EncoderModel.from_pretrained("google/flan-t5-base", torch_dtype=DTYPE).to(DEVICE).eval()
print("Loading CLIP-L...")
clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval()
for p in t5_enc.parameters(): p.requires_grad = False
for p in clip_enc.parameters(): p.requires_grad = False
# ============================================================================
# LOAD VAE FOR SAMPLE GENERATION
# ============================================================================
print("Loading Flux VAE for samples...")
from diffusers import AutoencoderKL
vae = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
subfolder="vae",
torch_dtype=DTYPE
).to(DEVICE).eval()
for p in vae.parameters(): p.requires_grad = False
# ============================================================================
# ENCODING HELPERS
# ============================================================================
@torch.no_grad()
def encode_prompt(prompt):
t5_in = t5_tok(prompt, max_length=MAX_SEQ, padding="max_length", truncation=True, return_tensors="pt").to(DEVICE)
t5_out = t5_enc(input_ids=t5_in.input_ids, attention_mask=t5_in.attention_mask).last_hidden_state
clip_in = clip_tok(prompt, max_length=77, padding="max_length", truncation=True, return_tensors="pt").to(DEVICE)
clip_out = clip_enc(input_ids=clip_in.input_ids, attention_mask=clip_in.attention_mask)
return t5_out, clip_out.pooler_output
# ============================================================================
# FLOW MATCHING HELPERS
# ============================================================================
# Rectified Flow / Flow Matching formulation:
# x_t = (1-t) * x_0 + t * x_1
# where x_0 = noise, x_1 = data
# t=0: pure noise, t=1: pure data
# velocity v = x_1 - x_0 = data - noise
#
# Training: model learns to predict v given (x_t, t)
# Inference: start from noise (t=0), integrate to data (t=1)
# x_{t+dt} = x_t + v_pred * dt
# ============================================================================
def flux_shift(t, s=SHIFT):
"""Flux timestep shift for training distribution.
Shifts timesteps towards higher values (closer to data),
making training focus more on refining details.
s=3.0 (default): flux_shift(0.5) ≈ 0.75
"""
return s * t / (1 + (s - 1) * t)
def flux_shift_inverse(t_shifted, s=SHIFT):
"""Inverse of flux_shift."""
return t_shifted / (s - (s - 1) * t_shifted)
def min_snr_weight(t, gamma=MIN_SNR):
"""Min-SNR weighting to balance loss across timesteps.
Downweights very easy timesteps (near t=0 or t=1).
gamma=5.0 is typical.
"""
snr = (t / (1 - t).clamp(min=1e-5)).pow(2)
return torch.clamp(snr, max=gamma) / snr.clamp(min=1e-5)
# ============================================================================
# SAMPLING FUNCTION
# ============================================================================
@torch.no_grad()
def generate_samples(model, prompts, num_steps=20, guidance_scale=3.5, H=64, W=64):
"""Generate sample images using Euler sampling.
Flow matching: x_t = (1-t)*noise + t*data, v = data - noise
At t=0: pure noise. At t=1: pure data.
We integrate from t=0 to t=1.
"""
model.eval()
B = len(prompts)
C = 16 # VAE channels
# Encode prompts
t5_embeds, clip_pooleds = [], []
for p in prompts:
t5_out, clip_pooled = encode_prompt(p)
t5_embeds.append(t5_out.squeeze(0))
clip_pooleds.append(clip_pooled.squeeze(0))
t5_embeds = torch.stack(t5_embeds)
clip_pooleds = torch.stack(clip_pooleds)
# Start from pure noise (t=0)
x = torch.randn(B, H * W, C, device=DEVICE, dtype=DTYPE)
# Create image IDs
img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE)
# Euler sampling: t goes from 0 (noise) to 1 (data)
timesteps = torch.linspace(0, 1, num_steps + 1, device=DEVICE, dtype=DTYPE)
for i in range(num_steps):
t_curr = timesteps[i]
t_next = timesteps[i + 1]
dt = t_next - t_curr # positive
t_batch = t_curr.expand(B)
# Conditional prediction
guidance = torch.full((B,), guidance_scale, device=DEVICE, dtype=DTYPE)
v_cond = model(
hidden_states=x,
encoder_hidden_states=t5_embeds,
pooled_projections=clip_pooleds,
timestep=t_batch,
img_ids=img_ids,
guidance=guidance,
)
# Euler step: x_{t+dt} = x_t + v * dt
x = x + v_cond * dt
# Reshape to image format: (B, H*W, C) -> (B, C, H, W)
latents = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
# Decode with VAE (match VAE dtype)
latents = latents / vae.config.scaling_factor
images = vae.decode(latents.to(vae.dtype)).sample
images = (images / 2 + 0.5).clamp(0, 1)
model.train()
return images
def save_samples(images, prompts, step, save_dir):
"""Save sample images and log to tensorboard."""
from torchvision.utils import make_grid, save_image
# Save individual images
for i, (img, prompt) in enumerate(zip(images, prompts)):
safe_prompt = prompt[:50].replace(" ", "_").replace("/", "-")
path = os.path.join(save_dir, f"step{step}_{i}_{safe_prompt}.png")
save_image(img, path)
# Log grid to tensorboard
grid = make_grid(images, nrow=2, normalize=False)
writer.add_image("samples", grid, step)
# Log prompts
writer.add_text("sample_prompts", "\n".join(prompts), step)
print(f" ✓ Saved {len(images)} samples")
# ============================================================================
# COLLATE
# ============================================================================
def collate(batch):
latents, t5_embeds, clip_embeds, prompts = [], [], [], []
for b in batch:
latents.append(torch.tensor(np.array(b["latent"]), dtype=DTYPE))
t5_out, clip_pooled = encode_prompt(b["prompt"])
t5_embeds.append(t5_out.squeeze(0))
clip_embeds.append(clip_pooled.squeeze(0))
prompts.append(b["prompt"])
return {
"latents": torch.stack(latents).to(DEVICE),
"t5_embeds": torch.stack(t5_embeds),
"clip_pooled": torch.stack(clip_embeds),
"prompts": prompts,
}
# ============================================================================
# CHECKPOINT FUNCTIONS
# ============================================================================
def load_weights(path):
"""Load weights from .safetensors or .pt file."""
if path.endswith(".safetensors"):
return load_file(path)
elif path.endswith(".pt"):
ckpt = torch.load(path, map_location=DEVICE, weights_only=False)
if isinstance(ckpt, dict):
if "model" in ckpt:
return ckpt["model"]
elif "state_dict" in ckpt:
return ckpt["state_dict"]
else:
# Check if it looks like a state dict (has tensor values)
first_val = next(iter(ckpt.values()), None)
if isinstance(first_val, torch.Tensor):
return ckpt
# Otherwise might have optimizer etc, look for model keys
return ckpt
return ckpt
else:
# Try safetensors first, then pt
try:
return load_file(path)
except:
return torch.load(path, map_location=DEVICE, weights_only=False)
def save_checkpoint(model, optimizer, scheduler, step, epoch, loss, path):
"""Save checkpoint locally."""
os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
weights_path = path.replace(".pt", ".safetensors")
save_file(model.state_dict(), weights_path)
state = {
"step": step,
"epoch": epoch,
"loss": loss,
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
}
torch.save(state, path)
print(f" ✓ Saved checkpoint: step {step}")
return weights_path
def upload_checkpoint(weights_path, step, config, include_logs=True):
"""Upload checkpoint to HuggingFace Hub."""
try:
# Upload weights
api.upload_file(
path_or_fileobj=weights_path,
path_in_repo=f"checkpoints/step_{step}.safetensors",
repo_id=HF_REPO,
commit_message=f"Checkpoint step {step}",
)
# Upload config
config_path = os.path.join(CHECKPOINT_DIR, "config.json")
with open(config_path, "w") as f:
json.dump(config.__dict__, f, indent=2)
api.upload_file(
path_or_fileobj=config_path,
path_in_repo="config.json",
repo_id=HF_REPO,
)
# Upload tensorboard logs
if include_logs and os.path.exists(LOG_DIR):
api.upload_folder(
folder_path=LOG_DIR,
path_in_repo="logs",
repo_id=HF_REPO,
commit_message=f"Logs at step {step}",
)
# Upload samples
if os.path.exists(SAMPLE_DIR) and os.listdir(SAMPLE_DIR):
api.upload_folder(
folder_path=SAMPLE_DIR,
path_in_repo="samples",
repo_id=HF_REPO,
commit_message=f"Samples at step {step}",
)
print(f" ✓ Uploaded to {HF_REPO}")
except Exception as e:
print(f" ⚠ Upload failed: {e}")
def load_checkpoint(model, optimizer, scheduler, target):
"""
Load checkpoint based on target specification.
Args:
target:
None, "latest" - most recent checkpoint
"best" - best model
int (1500) - specific step
"hub:step_1000" - specific hub checkpoint
"local:/path/to/file.safetensors" or "local:/path/to/file.pt" - specific local file
"none" - skip loading, start fresh
"""
if target == "none":
print("Starting fresh (no checkpoint loading)")
return 0, 0
start_step, start_epoch = 0, 0
# Parse target
if target is None or target == "latest":
load_mode = "latest"
load_path = None
elif target == "best":
load_mode = "best"
load_path = None
elif isinstance(target, int):
load_mode = "step"
load_path = target
elif target.startswith("hub:"):
load_mode = "hub"
load_path = target[4:] # Remove "hub:" prefix
elif target.startswith("local:"):
load_mode = "local"
load_path = target[6:] # Remove "local:" prefix
else:
print(f"Unknown target format: {target}, trying as step number")
try:
load_mode = "step"
load_path = int(target)
except:
load_mode = "latest"
load_path = None
# Load based on mode
if load_mode == "local":
# Direct local file (.pt or .safetensors)
if os.path.exists(load_path):
weights = load_weights(load_path)
model.load_state_dict(weights)
# Try to find associated state file for optimizer/scheduler
if load_path.endswith(".safetensors"):
state_path = load_path.replace(".safetensors", ".pt")
elif load_path.endswith(".pt"):
# The .pt file might contain everything
ckpt = torch.load(load_path, map_location=DEVICE, weights_only=False)
if isinstance(ckpt, dict):
# Debug: show what keys are in the checkpoint
non_tensor_keys = [k for k in ckpt.keys() if not isinstance(ckpt.get(k), torch.Tensor)]
if non_tensor_keys:
print(f" Checkpoint keys: {non_tensor_keys}")
# Extract step/epoch - try multiple common key names
start_step = ckpt.get("step", ckpt.get("global_step", ckpt.get("iteration", 0)))
start_epoch = ckpt.get("epoch", 0)
# Also check for nested state dict
if "state" in ckpt and isinstance(ckpt["state"], dict):
start_step = ckpt["state"].get("step", start_step)
start_epoch = ckpt["state"].get("epoch", start_epoch)
# Try to load optimizer/scheduler if present
if "optimizer" in ckpt:
try:
optimizer.load_state_dict(ckpt["optimizer"])
if "scheduler" in ckpt:
scheduler.load_state_dict(ckpt["scheduler"])
except Exception as e:
print(f" Note: Could not load optimizer state: {e}")
state_path = None
else:
state_path = load_path + ".pt"
if state_path and os.path.exists(state_path):
state = torch.load(state_path, map_location=DEVICE, weights_only=False)
try:
start_step = state.get("step", start_step)
start_epoch = state.get("epoch", start_epoch)
if "optimizer" in state:
optimizer.load_state_dict(state["optimizer"])
if "scheduler" in state:
scheduler.load_state_dict(state["scheduler"])
except Exception as e:
print(f" Note: Could not load optimizer state: {e}")
print(f"✓ Loaded local: {load_path} (step {start_step})")
return start_step, start_epoch
else:
print(f"⚠ Local file not found: {load_path}")
elif load_mode == "hub":
# Specific hub checkpoint - try both extensions
for ext in [".safetensors", ".pt", ""]:
try:
if load_path.endswith((".safetensors", ".pt")):
filename = load_path if "/" in load_path else f"checkpoints/{load_path}"
else:
filename = f"checkpoints/{load_path}{ext}"
local_path = hf_hub_download(repo_id=HF_REPO, filename=filename)
weights = load_weights(local_path)
model.load_state_dict(weights)
# Extract step from filename
if "step_" in load_path:
start_step = int(load_path.split("step_")[-1].replace(".safetensors", "").replace(".pt", ""))
print(f"✓ Loaded from Hub: {filename} (step {start_step})")
return start_step, start_epoch
except Exception as e:
continue
print(f"⚠ Could not load from hub: {load_path}")
elif load_mode == "best":
# Try hub best first (try both extensions)
for ext in [".safetensors", ".pt"]:
try:
filename = f"model{ext}" if ext else "model.safetensors"
local_path = hf_hub_download(repo_id=HF_REPO, filename=filename)
weights = load_weights(local_path)
model.load_state_dict(weights)
print(f"✓ Loaded best model from Hub")
return start_step, start_epoch
except:
continue
# Try local best (both extensions)
for ext in [".safetensors", ".pt"]:
best_path = os.path.join(CHECKPOINT_DIR, f"best{ext}")
if os.path.exists(best_path):
weights = load_weights(best_path)
model.load_state_dict(weights)
# Try to load optimizer state
state_path = best_path.replace(ext, ".pt") if ext == ".safetensors" else best_path
if os.path.exists(state_path):
state = torch.load(state_path, map_location=DEVICE, weights_only=False)
if isinstance(state, dict) and "step" in state:
start_step = state.get("step", 0)
start_epoch = state.get("epoch", 0)
print(f"✓ Loaded local best (step {start_step})")
return start_step, start_epoch
elif load_mode == "step":
# Specific step number
step_num = load_path
# Try hub (both extensions)
for ext in [".safetensors", ".pt"]:
try:
filename = f"checkpoints/step_{step_num}{ext}"
local_path = hf_hub_download(repo_id=HF_REPO, filename=filename)
weights = load_weights(local_path)
model.load_state_dict(weights)
start_step = step_num
print(f"✓ Loaded step {step_num} from Hub")
return start_step, start_epoch
except:
continue
# Try local (both extensions)
for ext in [".safetensors", ".pt"]:
local_path = os.path.join(CHECKPOINT_DIR, f"step_{step_num}{ext}")
if os.path.exists(local_path):
weights = load_weights(local_path)
model.load_state_dict(weights)
state_path = local_path.replace(".safetensors", ".pt") if ext == ".safetensors" else local_path
if os.path.exists(state_path):
state = torch.load(state_path, map_location=DEVICE, weights_only=False)
if isinstance(state, dict):
try:
if "optimizer" in state:
optimizer.load_state_dict(state["optimizer"])
if "scheduler" in state:
scheduler.load_state_dict(state["scheduler"])
start_epoch = state.get("epoch", 0)
except:
pass
start_step = step_num
print(f"✓ Loaded local step {step_num}")
return start_step, start_epoch
print(f"⚠ Step {step_num} not found")
# Default: latest
# Try Hub first (both extensions)
try:
files = api.list_repo_files(repo_id=HF_REPO)
checkpoints = [f for f in files if f.startswith("checkpoints/step_") and (f.endswith(".safetensors") or f.endswith(".pt"))]
if checkpoints:
# Sort by step number
def get_step(f):
return int(f.split("step_")[-1].replace(".safetensors", "").replace(".pt", ""))
checkpoints.sort(key=get_step)
latest = checkpoints[-1]
step = get_step(latest)
local_path = hf_hub_download(repo_id=HF_REPO, filename=latest)
weights = load_weights(local_path)
model.load_state_dict(weights)
start_step = step
print(f"✓ Loaded latest from Hub: step {step}")
return start_step, start_epoch
except Exception as e:
print(f"Hub check: {e}")
# Try local (both extensions)
if os.path.exists(CHECKPOINT_DIR):
local_ckpts = [f for f in os.listdir(CHECKPOINT_DIR) if f.startswith("step_") and (f.endswith(".safetensors") or f.endswith(".pt"))]
# Filter to just weights files (not state .pt files that pair with .safetensors)
local_ckpts = [f for f in local_ckpts if not (f.endswith(".pt") and f.replace(".pt", ".safetensors") in local_ckpts)]
if local_ckpts:
def get_step(f):
return int(f.split("step_")[-1].replace(".safetensors", "").replace(".pt", ""))
local_ckpts.sort(key=get_step)
latest = local_ckpts[-1]
step = get_step(latest)
weights_path = os.path.join(CHECKPOINT_DIR, latest)
weights = load_weights(weights_path)
model.load_state_dict(weights)
# Try to load optimizer state
state_path = weights_path.replace(".safetensors", ".pt") if weights_path.endswith(".safetensors") else weights_path
if os.path.exists(state_path):
state = torch.load(state_path, map_location=DEVICE, weights_only=False)
if isinstance(state, dict):
try:
if "optimizer" in state:
optimizer.load_state_dict(state["optimizer"])
if "scheduler" in state:
scheduler.load_state_dict(state["scheduler"])
start_epoch = state.get("epoch", 0)
except:
pass
start_step = step
print(f"✓ Loaded latest local: step {step}")
return start_step, start_epoch
print("No checkpoint found, starting fresh")
return 0, 0
# ============================================================================
# DATALOADER
# ============================================================================
loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate, num_workers=0)
# ============================================================================
# MODEL
# ============================================================================
config = TinyFluxConfig()
model = TinyFlux(config).to(DEVICE).to(DTYPE)
print(f"\nParams: {sum(p.numel() for p in model.parameters()):,}")
model = torch.compile(model, mode="default")
# ============================================================================
# OPTIMIZER & SCHEDULER
# ============================================================================
opt = torch.optim.AdamW(model.parameters(), lr=LR, betas=(0.9, 0.99), weight_decay=0.01)
total_steps = len(loader) * EPOCHS // GRAD_ACCUM
warmup = min(500, total_steps // 10)
def lr_fn(step):
if step < warmup: return step / warmup
return 0.5 * (1 + math.cos(math.pi * (step - warmup) / (total_steps - warmup)))
sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_fn)
# ============================================================================
# LOAD CHECKPOINT
# ============================================================================
print(f"\nLoad target: {LOAD_TARGET}")
start_step, start_epoch = load_checkpoint(model, opt, sched, LOAD_TARGET)
# Override start_step if RESUME_STEP is set
if RESUME_STEP is not None:
print(f"Overriding start_step: {start_step} -> {RESUME_STEP}")
start_step = RESUME_STEP
# Log config to tensorboard
writer.add_text("config", json.dumps(config.__dict__, indent=2), 0)
writer.add_text("training_config", json.dumps({
"batch_size": BATCH_SIZE,
"grad_accum": GRAD_ACCUM,
"lr": LR,
"epochs": EPOCHS,
"min_snr": MIN_SNR,
"shift": SHIFT,
}, indent=2), 0)
# ============================================================================
# SAMPLE PROMPTS FOR PERIODIC GENERATION
# ============================================================================
SAMPLE_PROMPTS = [
"a photo of a cat sitting on a windowsill",
"a beautiful sunset over mountains",
"a portrait of a woman with red hair",
"a futuristic cityscape at night",
]
# ============================================================================
# TRAINING
# ============================================================================
print(f"\nTraining {EPOCHS} epochs, {total_steps} total steps")
print(f"Resuming from step {start_step}, epoch {start_epoch}")
print(f"Save: {SAVE_EVERY}, Upload: {UPLOAD_EVERY}, Sample: {SAMPLE_EVERY}, Log: {LOG_EVERY}")
model.train()
step = start_step
best = float("inf")
for ep in range(start_epoch, EPOCHS):
ep_loss = 0
ep_batches = 0
pbar = tqdm(loader, desc=f"E{ep+1}")
for i, batch in enumerate(pbar):
latents = batch["latents"] # Ground truth data (VAE encoded images)
t5 = batch["t5_embeds"]
clip = batch["clip_pooled"]
B, C, H, W = latents.shape
# ================================================================
# FLOW MATCHING FORMULATION
# ================================================================
# x_1 = data (what we want to generate)
# x_0 = noise (where we start at inference)
# x_t = (1-t)*x_0 + t*x_1 (linear interpolation)
#
# At t=0: x_t = x_0 (pure noise)
# At t=1: x_t = x_1 (pure data)
#
# Velocity field: v = dx/dt = x_1 - x_0
# Model learns to predict v given (x_t, t)
#
# At inference: start from noise, integrate v from t=0 to t=1
# ================================================================
# Reshape data to sequence format: (B, C, H, W) -> (B, H*W, C)
data = latents.permute(0, 2, 3, 1).reshape(B, H*W, C) # x_1
noise = torch.randn_like(data) # x_0
# Sample timesteps with logit-normal distribution + Flux shift
# This biases training towards higher t (closer to data)
t = torch.sigmoid(torch.randn(B, device=DEVICE))
t = flux_shift(t, s=SHIFT).to(DTYPE).clamp(1e-4, 1-1e-4)
# Create noisy samples via linear interpolation
t_expanded = t.view(B, 1, 1)
x_t = (1 - t_expanded) * noise + t_expanded * data # Noisy sample at time t
# Target velocity: direction from noise to data
v_target = data - noise
# Create position IDs for RoPE
img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE)
# Random guidance scale (for CFG training)
guidance = torch.rand(B, device=DEVICE, dtype=DTYPE) * 4 + 1 # [1, 5]
# Forward pass: predict velocity
with torch.autocast("cuda", dtype=DTYPE):
v_pred = model(
hidden_states=x_t,
encoder_hidden_states=t5,
pooled_projections=clip,
timestep=t,
img_ids=img_ids,
guidance=guidance,
)
# Loss: MSE between predicted and target velocity
loss_raw = F.mse_loss(v_pred, v_target, reduction="none").mean(dim=[1, 2])
# Min-SNR weighting: downweight easy timesteps (near t=0 or t=1)
snr_weights = min_snr_weight(t)
loss = (loss_raw * snr_weights).mean() / GRAD_ACCUM
loss.backward()
if (i + 1) % GRAD_ACCUM == 0:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
sched.step()
opt.zero_grad()
step += 1
# Tensorboard logging
if step % LOG_EVERY == 0:
writer.add_scalar("train/loss", loss.item() * GRAD_ACCUM, step)
writer.add_scalar("train/lr", sched.get_last_lr()[0], step)
writer.add_scalar("train/grad_norm", grad_norm.item(), step)
writer.add_scalar("train/t_mean", t.mean().item(), step)
writer.add_scalar("train/snr_weight_mean", snr_weights.mean().item(), step)
# Generate samples
if step % SAMPLE_EVERY == 0:
print(f"\n Generating samples at step {step}...")
images = generate_samples(model, SAMPLE_PROMPTS, num_steps=20)
save_samples(images, SAMPLE_PROMPTS, step, SAMPLE_DIR)
# Save checkpoint
if step % SAVE_EVERY == 0:
ckpt_path = os.path.join(CHECKPOINT_DIR, f"step_{step}.pt")
weights_path = save_checkpoint(model, opt, sched, step, ep, loss.item(), ckpt_path)
# Upload
if step % UPLOAD_EVERY == 0:
upload_checkpoint(weights_path, step, config, include_logs=True)
ep_loss += loss.item() * GRAD_ACCUM
ep_batches += 1
pbar.set_postfix(loss=f"{loss.item()*GRAD_ACCUM:.4f}", lr=f"{sched.get_last_lr()[0]:.1e}", step=step)
avg = ep_loss / max(ep_batches, 1)
print(f"Epoch {ep+1} loss: {avg:.4f}")
writer.add_scalar("train/epoch_loss", avg, ep + 1)
if avg < best:
best = avg
best_path = os.path.join(CHECKPOINT_DIR, "best.pt")
weights_path = save_checkpoint(model, opt, sched, step, ep, avg, best_path)
try:
api.upload_file(
path_or_fileobj=weights_path,
path_in_repo="model.safetensors",
repo_id=HF_REPO,
commit_message=f"Best model (epoch {ep+1}, loss {avg:.4f})",
)
print(f" ✓ Uploaded best to {HF_REPO}")
except Exception as e:
print(f" ⚠ Upload failed: {e}")
# ============================================================================
# FINAL
# ============================================================================
print("\nSaving final model...")
final_path = os.path.join(CHECKPOINT_DIR, "final.pt")
weights_path = save_checkpoint(model, opt, sched, step, EPOCHS, best, final_path)
# Final samples
print("Generating final samples...")
images = generate_samples(model, SAMPLE_PROMPTS, num_steps=20)
save_samples(images, SAMPLE_PROMPTS, step, SAMPLE_DIR)
# Final upload
try:
api.upload_file(path_or_fileobj=weights_path, path_in_repo="model.safetensors", repo_id=HF_REPO)
config_path = os.path.join(CHECKPOINT_DIR, "config.json")
with open(config_path, "w") as f:
json.dump(config.__dict__, f, indent=2)
api.upload_file(path_or_fileobj=config_path, path_in_repo="config.json", repo_id=HF_REPO)
api.upload_folder(folder_path=LOG_DIR, path_in_repo="logs", repo_id=HF_REPO)
api.upload_folder(folder_path=SAMPLE_DIR, path_in_repo="samples", repo_id=HF_REPO)
print(f"\n✓ Training complete! https://huggingface.co/{HF_REPO}")
except Exception as e:
print(f"\n⚠ Final upload failed: {e}")
writer.close()
print(f"Best loss: {best:.4f}")