tiny-flux / trainer_colab.py
AbstractPhil's picture
Update trainer_colab.py
61d1a19 verified
# ============================================================================
# TinyFlux Training Cell - OPTIMIZED
# ============================================================================
# Optimizations:
# - TF32 and cuDNN settings for faster matmuls
# - Fused AdamW optimizer
# - Pre-encoded prompts (encode once at startup, not per batch)
# - Batched prompt encoding
# - DataLoader with num_workers and pin_memory
# - torch.inference_mode() for sampling
# - Cached img_ids in model
# - torch.compile with max-autotune
# ============================================================================
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer
from huggingface_hub import HfApi, hf_hub_download
from safetensors.torch import save_file, load_file
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
import numpy as np
import math
import os
import json
from datetime import datetime
# ============================================================================
# CUDA OPTIMIZATIONS - Set these BEFORE model creation
# ============================================================================
# New PyTorch 2.x API for TF32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('high')
# Suppress the deprecation warning (settings still work)
import warnings
warnings.filterwarnings('ignore', message='.*TF32.*')
# ============================================================================
# CONFIG
# ============================================================================
BATCH_SIZE = 128
GRAD_ACCUM = 1
LR = 1e-4
EPOCHS = 10
MAX_SEQ = 128
MIN_SNR = 5.0
SHIFT = 3.0
DEVICE = "cuda"
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
# HuggingFace Hub
HF_REPO = "AbstractPhil/tiny-flux"
SAVE_EVERY = 1000
UPLOAD_EVERY = 1000
SAMPLE_EVERY = 500
LOG_EVERY = 10
# Checkpoint loading
LOAD_TARGET = "hub:step_24000" # "latest", "best", int, "hub:step_X", "local:path", "none"
RESUME_STEP = None
# Paths
CHECKPOINT_DIR = "./tiny_flux_checkpoints"
LOG_DIR = "./tiny_flux_logs"
SAMPLE_DIR = "./tiny_flux_samples"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)
os.makedirs(SAMPLE_DIR, exist_ok=True)
# ============================================================================
# HF HUB SETUP
# ============================================================================
print("Setting up HuggingFace Hub...")
api = HfApi()
try:
api.create_repo(repo_id=HF_REPO, exist_ok=True, repo_type="model")
print(f"✓ Repo ready: {HF_REPO}")
except Exception as e:
print(f"Note: {e}")
# ============================================================================
# TENSORBOARD
# ============================================================================
run_name = datetime.now().strftime("%Y%m%d_%H%M%S")
writer = SummaryWriter(log_dir=os.path.join(LOG_DIR, run_name))
print(f"✓ Tensorboard: {LOG_DIR}/{run_name}")
# ============================================================================
# LOAD DATASET
# ============================================================================
print("\nLoading dataset...")
ds = load_dataset("AbstractPhil/flux-schnell-teacher-latents", "train_3_512", split="train")
print(f"Samples: {len(ds)}")
# ============================================================================
# LOAD TEXT ENCODERS
# ============================================================================
print("\nLoading flan-t5-base (768 dim)...")
t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
t5_enc = T5EncoderModel.from_pretrained("google/flan-t5-base", torch_dtype=DTYPE).to(DEVICE).eval()
print("Loading CLIP-L...")
clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval()
for p in t5_enc.parameters(): p.requires_grad = False
for p in clip_enc.parameters(): p.requires_grad = False
# ============================================================================
# LOAD VAE FOR SAMPLE GENERATION
# ============================================================================
print("Loading Flux VAE for samples...")
from diffusers import AutoencoderKL
vae = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
subfolder="vae",
torch_dtype=DTYPE
).to(DEVICE).eval()
for p in vae.parameters(): p.requires_grad = False
# ============================================================================
# BATCHED ENCODING - Much faster than one-by-one
# ============================================================================
@torch.inference_mode()
def encode_prompts_batched(prompts: list) -> tuple:
"""Encode multiple prompts at once - MUCH faster than loop."""
# T5 encoding
t5_in = t5_tok(
prompts,
max_length=MAX_SEQ,
padding="max_length",
truncation=True,
return_tensors="pt"
).to(DEVICE)
t5_out = t5_enc(
input_ids=t5_in.input_ids,
attention_mask=t5_in.attention_mask
).last_hidden_state
# CLIP encoding
clip_in = clip_tok(
prompts,
max_length=77,
padding="max_length",
truncation=True,
return_tensors="pt"
).to(DEVICE)
clip_out = clip_enc(
input_ids=clip_in.input_ids,
attention_mask=clip_in.attention_mask
)
return t5_out, clip_out.pooler_output
@torch.inference_mode()
def encode_prompt(prompt: str) -> tuple:
"""Encode single prompt (for compatibility)."""
return encode_prompts_batched([prompt])
# ============================================================================
# PRE-ENCODE ALL PROMPTS (with disk caching)
# ============================================================================
print("\nPre-encoding prompts...")
PRECOMPUTE_ENCODINGS = True
ENCODING_CACHE_DIR = "./encoding_cache"
os.makedirs(ENCODING_CACHE_DIR, exist_ok=True)
# Cache filename based on dataset size and encoder
cache_file = os.path.join(ENCODING_CACHE_DIR, f"encodings_{len(ds)}_t5base_clipL.pt")
if PRECOMPUTE_ENCODINGS:
if os.path.exists(cache_file):
# Load from cache
print(f"Loading cached encodings from {cache_file}...")
cached = torch.load(cache_file, weights_only=True)
all_t5_embeds = cached["t5_embeds"]
all_clip_pooled = cached["clip_pooled"]
print(f"✓ Loaded cached encodings")
else:
# Get all prompts via columnar access (instant, no iteration)
print("Encoding prompts (will cache for future runs)...")
all_prompts = ds["prompt"] # Columnar access - instant!
encode_batch_size = 64
all_t5_embeds = []
all_clip_pooled = []
for i in tqdm(range(0, len(all_prompts), encode_batch_size), desc="Encoding"):
batch_prompts = all_prompts[i:i+encode_batch_size]
t5_out, clip_out = encode_prompts_batched(batch_prompts)
all_t5_embeds.append(t5_out.cpu())
all_clip_pooled.append(clip_out.cpu())
all_t5_embeds = torch.cat(all_t5_embeds, dim=0)
all_clip_pooled = torch.cat(all_clip_pooled, dim=0)
# Save cache (~750MB for 10k samples)
torch.save({
"t5_embeds": all_t5_embeds,
"clip_pooled": all_clip_pooled,
}, cache_file)
print(f"✓ Saved encoding cache to {cache_file}")
print(f" T5 embeds: {all_t5_embeds.shape}")
print(f" CLIP pooled: {all_clip_pooled.shape}")
# ============================================================================
# FLOW MATCHING HELPERS
# ============================================================================
def flux_shift(t, s=SHIFT):
"""Flux timestep shift for training distribution."""
return s * t / (1 + (s - 1) * t)
def min_snr_weight(t, gamma=MIN_SNR):
"""Min-SNR weighting to balance loss across timesteps."""
snr = (t / (1 - t).clamp(min=1e-5)).pow(2)
return torch.clamp(snr, max=gamma) / snr.clamp(min=1e-5)
# ============================================================================
# SAMPLING FUNCTION - Optimized
# ============================================================================
@torch.inference_mode()
def generate_samples(model, prompts, num_steps=20, guidance_scale=3.5, H=64, W=64):
"""Generate sample images using Euler sampling."""
model.eval()
B = len(prompts)
C = 16
# Batch encode prompts
t5_embeds, clip_pooleds = encode_prompts_batched(prompts)
t5_embeds = t5_embeds.to(DTYPE)
clip_pooleds = clip_pooleds.to(DTYPE)
# Start from pure noise
x = torch.randn(B, H * W, C, device=DEVICE, dtype=DTYPE)
# Create image IDs (cached in optimized model)
img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE)
# Timesteps with flux_shift
t_linear = torch.linspace(0, 1, num_steps + 1, device=DEVICE, dtype=DTYPE)
timesteps = flux_shift(t_linear, s=SHIFT)
# Euler sampling
for i in range(num_steps):
t_curr = timesteps[i]
t_next = timesteps[i + 1]
dt = t_next - t_curr
t_batch = t_curr.expand(B).to(DTYPE)
guidance = torch.full((B,), guidance_scale, device=DEVICE, dtype=DTYPE)
v_cond = model(
hidden_states=x,
encoder_hidden_states=t5_embeds,
pooled_projections=clip_pooleds,
timestep=t_batch,
img_ids=img_ids,
guidance=guidance,
)
x = x + v_cond * dt
# Decode
latents = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
latents = latents / vae.config.scaling_factor
images = vae.decode(latents.to(vae.dtype)).sample
images = (images / 2 + 0.5).clamp(0, 1)
model.train()
return images
def save_samples(images, prompts, step, save_dir):
"""Save sample images."""
from torchvision.utils import make_grid, save_image
for i, (img, prompt) in enumerate(zip(images, prompts)):
safe_prompt = prompt[:50].replace(" ", "_").replace("/", "-")
path = os.path.join(save_dir, f"step{step}_{i}_{safe_prompt}.png")
save_image(img, path)
grid = make_grid(images, nrow=2, normalize=False)
writer.add_image("samples", grid, step)
writer.add_text("sample_prompts", "\n".join(prompts), step)
print(f" ✓ Saved {len(images)} samples")
# ============================================================================
# OPTIMIZED COLLATE - Returns CPU tensors (GPU transfer in training loop)
# ============================================================================
def collate_preencoded(batch):
"""Collate using pre-encoded embeddings - returns CPU tensors."""
indices = [b["__index__"] for b in batch]
latents = torch.stack([
torch.tensor(np.array(b["latent"]), dtype=DTYPE)
for b in batch
])
# Return CPU tensors - move to GPU in training loop
return {
"latents": latents,
"t5_embeds": all_t5_embeds[indices].to(DTYPE),
"clip_pooled": all_clip_pooled[indices].to(DTYPE),
}
def collate_online(batch):
"""Collate with online encoding - returns CPU tensors."""
prompts = [b["prompt"] for b in batch]
latents = torch.stack([
torch.tensor(np.array(b["latent"]), dtype=DTYPE)
for b in batch
])
# This still needs CUDA for encoding, so use num_workers=0
t5_embeds, clip_pooled = encode_prompts_batched(prompts)
return {
"latents": latents,
"t5_embeds": t5_embeds.cpu().to(DTYPE),
"clip_pooled": clip_pooled.cpu().to(DTYPE),
}
# Simple wrapper to add index without touching the data
class IndexedDataset:
"""Wraps dataset to add __index__ field without expensive ds.map()"""
def __init__(self, ds):
self.ds = ds
def __len__(self):
return len(self.ds)
def __getitem__(self, idx):
item = dict(self.ds[idx])
item["__index__"] = idx
return item
# Choose collate strategy
if PRECOMPUTE_ENCODINGS:
ds = IndexedDataset(ds) # Instant, no iteration
collate_fn = collate_preencoded
num_workers = 2
else:
collate_fn = collate_online
num_workers = 0
# ============================================================================
# CHECKPOINT FUNCTIONS
# ============================================================================
def load_weights(path):
"""Load weights, handling torch.compile prefix."""
if path.endswith(".safetensors"):
state_dict = load_file(path)
elif path.endswith(".pt"):
ckpt = torch.load(path, map_location=DEVICE, weights_only=False)
if isinstance(ckpt, dict):
state_dict = ckpt.get("model", ckpt.get("state_dict", ckpt))
else:
state_dict = ckpt
else:
try:
state_dict = load_file(path)
except:
state_dict = torch.load(path, map_location=DEVICE, weights_only=False)
# Strip torch.compile prefix
if isinstance(state_dict, dict) and any(k.startswith("_orig_mod.") for k in state_dict.keys()):
print(" Stripping torch.compile prefix...")
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
return state_dict
def save_checkpoint(model, optimizer, scheduler, step, epoch, loss, path):
"""Save checkpoint, stripping torch.compile prefix."""
os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
state_dict = model.state_dict()
if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
weights_path = path.replace(".pt", ".safetensors")
save_file(state_dict, weights_path)
torch.save({
"step": step,
"epoch": epoch,
"loss": loss,
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
}, path)
print(f" ✓ Saved checkpoint: step {step}")
return weights_path
def upload_checkpoint(weights_path, step, config):
"""Upload to HuggingFace Hub."""
try:
api.upload_file(
path_or_fileobj=weights_path,
path_in_repo=f"checkpoints/step_{step}.safetensors",
repo_id=HF_REPO,
commit_message=f"Checkpoint step {step}",
)
config_path = os.path.join(CHECKPOINT_DIR, "config.json")
with open(config_path, "w") as f:
json.dump(config.__dict__, f, indent=2)
api.upload_file(path_or_fileobj=config_path, path_in_repo="config.json", repo_id=HF_REPO)
print(f" ✓ Uploaded step {step} to {HF_REPO}")
except Exception as e:
print(f" ⚠ Upload failed: {e}")
def load_checkpoint(model, optimizer, scheduler, target):
"""Load checkpoint from various sources."""
start_step, start_epoch = 0, 0
if target == "none" or target is None:
print("Starting fresh (no checkpoint)")
return 0, 0
# Hub loading
if target == "hub" or (isinstance(target, str) and target.startswith("hub:")):
try:
if target == "hub":
weights_path = hf_hub_download(repo_id=HF_REPO, filename="model.safetensors")
else:
step_name = target.split(":")[1]
try:
weights_path = hf_hub_download(repo_id=HF_REPO, filename=f"checkpoints/{step_name}.safetensors")
except:
weights_path = hf_hub_download(repo_id=HF_REPO, filename=f"checkpoints/{step_name}.pt")
start_step = int(step_name.split("_")[-1]) if "_" in step_name else 0
weights = load_weights(weights_path)
# strict=False: ignore missing buffers (sin_basis, freqs) - they're precomputed constants
missing, unexpected = model.load_state_dict(weights, strict=False)
if missing:
# Filter out expected missing buffers
expected_missing = {'time_in.sin_basis', 'guidance_in.sin_basis',
'rope.freqs_0', 'rope.freqs_1', 'rope.freqs_2'}
actual_missing = set(missing) - expected_missing
if actual_missing:
print(f" ⚠ Unexpected missing keys: {actual_missing}")
else:
print(f" ✓ Missing only precomputed buffers (OK)")
print(f"✓ Loaded from hub: {target}")
return start_step, start_epoch
except Exception as e:
print(f"Hub load failed: {e}")
return 0, 0
# Local loading
if isinstance(target, str) and target.startswith("local:"):
path = target.split(":", 1)[1]
weights = load_weights(path)
missing, unexpected = model.load_state_dict(weights, strict=False)
if missing:
expected_missing = {'time_in.sin_basis', 'guidance_in.sin_basis',
'rope.freqs_0', 'rope.freqs_1', 'rope.freqs_2'}
actual_missing = set(missing) - expected_missing
if actual_missing:
print(f" ⚠ Unexpected missing keys: {actual_missing}")
print(f"✓ Loaded from local: {path}")
return 0, 0
print("No checkpoint found, starting fresh")
return 0, 0
# ============================================================================
# DATALOADER - Optimized
# ============================================================================
loader = DataLoader(
ds,
batch_size=BATCH_SIZE,
shuffle=True,
collate_fn=collate_fn,
num_workers=num_workers, # 2 for precomputed, 0 for online
pin_memory=True,
persistent_workers=(num_workers > 0),
prefetch_factor=2 if num_workers > 0 else None,
)
# ============================================================================
# MODEL
# ============================================================================
config = TinyFluxConfig()
model = TinyFlux(config).to(DEVICE).to(DTYPE)
print(f"\nParams: {sum(p.numel() for p in model.parameters()):,}")
# ============================================================================
# OPTIMIZER - Fused for speed
# ============================================================================
opt = torch.optim.AdamW(
model.parameters(),
lr=LR,
betas=(0.9, 0.99),
weight_decay=0.01,
fused=True,
)
total_steps = len(loader) * EPOCHS // GRAD_ACCUM
warmup = min(500, total_steps // 10)
def lr_fn(step):
if step < warmup:
return step / warmup
return 0.5 * (1 + math.cos(math.pi * (step - warmup) / (total_steps - warmup)))
sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_fn)
# ============================================================================
# LOAD CHECKPOINT (before compile!)
# ============================================================================
print(f"\nLoad target: {LOAD_TARGET}")
start_step, start_epoch = load_checkpoint(model, opt, sched, LOAD_TARGET)
if RESUME_STEP is not None:
print(f"Overriding start_step: {start_step} -> {RESUME_STEP}")
start_step = RESUME_STEP
# ============================================================================
# COMPILE MODEL (after loading weights)
# ============================================================================
model = torch.compile(model, mode="default")
# Log config
writer.add_text("config", json.dumps(config.__dict__, indent=2), 0)
writer.add_text("training_config", json.dumps({
"batch_size": BATCH_SIZE,
"grad_accum": GRAD_ACCUM,
"lr": LR,
"epochs": EPOCHS,
"min_snr": MIN_SNR,
"shift": SHIFT,
"optimizations": ["TF32", "fused_adamw", "precomputed_encodings", "flash_attention", "torch.compile"]
}, indent=2), 0)
# Sample prompts
SAMPLE_PROMPTS = [
"a photo of a cat sitting on a windowsill",
"a beautiful sunset over mountains",
"a portrait of a woman with red hair",
"a futuristic cityscape at night",
]
# ============================================================================
# TRAINING LOOP
# ============================================================================
print(f"\nTraining {EPOCHS} epochs, {total_steps} total steps")
print(f"Resuming from step {start_step}, epoch {start_epoch}")
print(f"Save: {SAVE_EVERY}, Upload: {UPLOAD_EVERY}, Sample: {SAMPLE_EVERY}, Log: {LOG_EVERY}")
print("Optimizations: TF32, fused AdamW, pre-encoded prompts, Flash Attention, torch.compile")
model.train()
step = start_step
best = float("inf")
# Pre-create img_ids for common resolution (will be cached)
_cached_img_ids = None
for ep in range(start_epoch, EPOCHS):
ep_loss = 0
ep_batches = 0
pbar = tqdm(loader, desc=f"E{ep + 1}")
for i, batch in enumerate(pbar):
# Move to GPU here (not in collate, to support multiprocessing)
latents = batch["latents"].to(DEVICE, non_blocking=True)
t5 = batch["t5_embeds"].to(DEVICE, non_blocking=True)
clip = batch["clip_pooled"].to(DEVICE, non_blocking=True)
B, C, H, W = latents.shape
# Reshape: (B, C, H, W) -> (B, H*W, C)
data = latents.permute(0, 2, 3, 1).reshape(B, H * W, C)
noise = torch.randn_like(data)
# Sample timesteps with logit-normal + flux shift
t = torch.sigmoid(torch.randn(B, device=DEVICE))
t = flux_shift(t, s=SHIFT).to(DTYPE).clamp(1e-4, 1 - 1e-4)
# Linear interpolation
t_expanded = t.view(B, 1, 1)
x_t = (1 - t_expanded) * noise + t_expanded * data
# Velocity target
v_target = data - noise
# Get img_ids (cached in model)
img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE)
# Random guidance
guidance = torch.rand(B, device=DEVICE, dtype=DTYPE) * 4 + 1
# Forward
with torch.autocast("cuda", dtype=DTYPE):
v_pred = model(
hidden_states=x_t,
encoder_hidden_states=t5,
pooled_projections=clip,
timestep=t,
img_ids=img_ids,
guidance=guidance,
)
# Loss with Min-SNR weighting
loss_raw = F.mse_loss(v_pred, v_target, reduction="none").mean(dim=[1, 2])
snr_weights = min_snr_weight(t)
loss = (loss_raw * snr_weights).mean() / GRAD_ACCUM
loss.backward()
if (i + 1) % GRAD_ACCUM == 0:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
sched.step()
opt.zero_grad(set_to_none=True) # Slightly faster than zero_grad()
step += 1
if step % LOG_EVERY == 0:
writer.add_scalar("train/loss", loss.item() * GRAD_ACCUM, step)
writer.add_scalar("train/lr", sched.get_last_lr()[0], step)
writer.add_scalar("train/grad_norm", grad_norm.item(), step)
writer.add_scalar("train/t_mean", t.mean().item(), step)
if step % SAMPLE_EVERY == 0:
print(f"\n Generating samples at step {step}...")
images = generate_samples(model, SAMPLE_PROMPTS, num_steps=20)
save_samples(images, SAMPLE_PROMPTS, step, SAMPLE_DIR)
if step % SAVE_EVERY == 0:
ckpt_path = os.path.join(CHECKPOINT_DIR, f"step_{step}.pt")
weights_path = save_checkpoint(model, opt, sched, step, ep, loss.item(), ckpt_path)
if step % UPLOAD_EVERY == 0:
upload_checkpoint(weights_path, step, config)
ep_loss += loss.item() * GRAD_ACCUM
ep_batches += 1
pbar.set_postfix(loss=f"{loss.item() * GRAD_ACCUM:.4f}", lr=f"{sched.get_last_lr()[0]:.1e}", step=step)
avg = ep_loss / max(ep_batches, 1)
print(f"Epoch {ep + 1} loss: {avg:.4f}")
writer.add_scalar("train/epoch_loss", avg, ep + 1)
if avg < best:
best = avg
best_path = os.path.join(CHECKPOINT_DIR, "best.pt")
weights_path = save_checkpoint(model, opt, sched, step, ep, avg, best_path)
try:
api.upload_file(
path_or_fileobj=weights_path,
path_in_repo="model.safetensors",
repo_id=HF_REPO,
commit_message=f"Best model (epoch {ep + 1}, loss {avg:.4f})",
)
print(f" ✓ Uploaded best to {HF_REPO}")
except Exception as e:
print(f" ⚠ Upload failed: {e}")
# ============================================================================
# FINAL
# ============================================================================
print("\nSaving final model...")
final_path = os.path.join(CHECKPOINT_DIR, "final.pt")
weights_path = save_checkpoint(model, opt, sched, step, EPOCHS, best, final_path)
print("Generating final samples...")
images = generate_samples(model, SAMPLE_PROMPTS, num_steps=20)
save_samples(images, SAMPLE_PROMPTS, step, SAMPLE_DIR)
try:
api.upload_file(path_or_fileobj=weights_path, path_in_repo="model.safetensors", repo_id=HF_REPO)
config_path = os.path.join(CHECKPOINT_DIR, "config.json")
with open(config_path, "w") as f:
json.dump(config.__dict__, f, indent=2)
api.upload_file(path_or_fileobj=config_path, path_in_repo="config.json", repo_id=HF_REPO)
print(f"\n✓ Training complete! https://huggingface.co/{HF_REPO}")
except Exception as e:
print(f"\n⚠ Final upload failed: {e}")
writer.close()
print(f"Best loss: {best:.4f}")