# ============================================================================ # TinyFlux Training Cell - OPTIMIZED # ============================================================================ # Optimizations: # - TF32 and cuDNN settings for faster matmuls # - Fused AdamW optimizer # - Pre-encoded prompts (encode once at startup, not per batch) # - Batched prompt encoding # - DataLoader with num_workers and pin_memory # - torch.inference_mode() for sampling # - Cached img_ids in model # - torch.compile with max-autotune # ============================================================================ import torch import torch.nn.functional as F from torch.utils.data import DataLoader from datasets import load_dataset from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer from huggingface_hub import HfApi, hf_hub_download from safetensors.torch import save_file, load_file from torch.utils.tensorboard import SummaryWriter from tqdm.auto import tqdm import numpy as np import math import os import json from datetime import datetime # ============================================================================ # CUDA OPTIMIZATIONS - Set these BEFORE model creation # ============================================================================ # New PyTorch 2.x API for TF32 torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True torch.set_float32_matmul_precision('high') # Suppress the deprecation warning (settings still work) import warnings warnings.filterwarnings('ignore', message='.*TF32.*') # ============================================================================ # CONFIG # ============================================================================ BATCH_SIZE = 128 GRAD_ACCUM = 1 LR = 1e-4 EPOCHS = 10 MAX_SEQ = 128 MIN_SNR = 5.0 SHIFT = 3.0 DEVICE = "cuda" DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # HuggingFace Hub HF_REPO = "AbstractPhil/tiny-flux" SAVE_EVERY = 1000 UPLOAD_EVERY = 1000 SAMPLE_EVERY = 500 LOG_EVERY = 10 # Checkpoint loading LOAD_TARGET = "hub:step_24000" # "latest", "best", int, "hub:step_X", "local:path", "none" RESUME_STEP = None # Paths CHECKPOINT_DIR = "./tiny_flux_checkpoints" LOG_DIR = "./tiny_flux_logs" SAMPLE_DIR = "./tiny_flux_samples" os.makedirs(CHECKPOINT_DIR, exist_ok=True) os.makedirs(LOG_DIR, exist_ok=True) os.makedirs(SAMPLE_DIR, exist_ok=True) # ============================================================================ # HF HUB SETUP # ============================================================================ print("Setting up HuggingFace Hub...") api = HfApi() try: api.create_repo(repo_id=HF_REPO, exist_ok=True, repo_type="model") print(f"✓ Repo ready: {HF_REPO}") except Exception as e: print(f"Note: {e}") # ============================================================================ # TENSORBOARD # ============================================================================ run_name = datetime.now().strftime("%Y%m%d_%H%M%S") writer = SummaryWriter(log_dir=os.path.join(LOG_DIR, run_name)) print(f"✓ Tensorboard: {LOG_DIR}/{run_name}") # ============================================================================ # LOAD DATASET # ============================================================================ print("\nLoading dataset...") ds = load_dataset("AbstractPhil/flux-schnell-teacher-latents", "train_3_512", split="train") print(f"Samples: {len(ds)}") # ============================================================================ # LOAD TEXT ENCODERS # ============================================================================ print("\nLoading flan-t5-base (768 dim)...") t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") t5_enc = T5EncoderModel.from_pretrained("google/flan-t5-base", torch_dtype=DTYPE).to(DEVICE).eval() print("Loading CLIP-L...") clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval() for p in t5_enc.parameters(): p.requires_grad = False for p in clip_enc.parameters(): p.requires_grad = False # ============================================================================ # LOAD VAE FOR SAMPLE GENERATION # ============================================================================ print("Loading Flux VAE for samples...") from diffusers import AutoencoderKL vae = AutoencoderKL.from_pretrained( "black-forest-labs/FLUX.1-schnell", subfolder="vae", torch_dtype=DTYPE ).to(DEVICE).eval() for p in vae.parameters(): p.requires_grad = False # ============================================================================ # BATCHED ENCODING - Much faster than one-by-one # ============================================================================ @torch.inference_mode() def encode_prompts_batched(prompts: list) -> tuple: """Encode multiple prompts at once - MUCH faster than loop.""" # T5 encoding t5_in = t5_tok( prompts, max_length=MAX_SEQ, padding="max_length", truncation=True, return_tensors="pt" ).to(DEVICE) t5_out = t5_enc( input_ids=t5_in.input_ids, attention_mask=t5_in.attention_mask ).last_hidden_state # CLIP encoding clip_in = clip_tok( prompts, max_length=77, padding="max_length", truncation=True, return_tensors="pt" ).to(DEVICE) clip_out = clip_enc( input_ids=clip_in.input_ids, attention_mask=clip_in.attention_mask ) return t5_out, clip_out.pooler_output @torch.inference_mode() def encode_prompt(prompt: str) -> tuple: """Encode single prompt (for compatibility).""" return encode_prompts_batched([prompt]) # ============================================================================ # PRE-ENCODE ALL PROMPTS (with disk caching) # ============================================================================ print("\nPre-encoding prompts...") PRECOMPUTE_ENCODINGS = True ENCODING_CACHE_DIR = "./encoding_cache" os.makedirs(ENCODING_CACHE_DIR, exist_ok=True) # Cache filename based on dataset size and encoder cache_file = os.path.join(ENCODING_CACHE_DIR, f"encodings_{len(ds)}_t5base_clipL.pt") if PRECOMPUTE_ENCODINGS: if os.path.exists(cache_file): # Load from cache print(f"Loading cached encodings from {cache_file}...") cached = torch.load(cache_file, weights_only=True) all_t5_embeds = cached["t5_embeds"] all_clip_pooled = cached["clip_pooled"] print(f"✓ Loaded cached encodings") else: # Get all prompts via columnar access (instant, no iteration) print("Encoding prompts (will cache for future runs)...") all_prompts = ds["prompt"] # Columnar access - instant! encode_batch_size = 64 all_t5_embeds = [] all_clip_pooled = [] for i in tqdm(range(0, len(all_prompts), encode_batch_size), desc="Encoding"): batch_prompts = all_prompts[i:i+encode_batch_size] t5_out, clip_out = encode_prompts_batched(batch_prompts) all_t5_embeds.append(t5_out.cpu()) all_clip_pooled.append(clip_out.cpu()) all_t5_embeds = torch.cat(all_t5_embeds, dim=0) all_clip_pooled = torch.cat(all_clip_pooled, dim=0) # Save cache (~750MB for 10k samples) torch.save({ "t5_embeds": all_t5_embeds, "clip_pooled": all_clip_pooled, }, cache_file) print(f"✓ Saved encoding cache to {cache_file}") print(f" T5 embeds: {all_t5_embeds.shape}") print(f" CLIP pooled: {all_clip_pooled.shape}") # ============================================================================ # FLOW MATCHING HELPERS # ============================================================================ def flux_shift(t, s=SHIFT): """Flux timestep shift for training distribution.""" return s * t / (1 + (s - 1) * t) def min_snr_weight(t, gamma=MIN_SNR): """Min-SNR weighting to balance loss across timesteps.""" snr = (t / (1 - t).clamp(min=1e-5)).pow(2) return torch.clamp(snr, max=gamma) / snr.clamp(min=1e-5) # ============================================================================ # SAMPLING FUNCTION - Optimized # ============================================================================ @torch.inference_mode() def generate_samples(model, prompts, num_steps=20, guidance_scale=3.5, H=64, W=64): """Generate sample images using Euler sampling.""" model.eval() B = len(prompts) C = 16 # Batch encode prompts t5_embeds, clip_pooleds = encode_prompts_batched(prompts) t5_embeds = t5_embeds.to(DTYPE) clip_pooleds = clip_pooleds.to(DTYPE) # Start from pure noise x = torch.randn(B, H * W, C, device=DEVICE, dtype=DTYPE) # Create image IDs (cached in optimized model) img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE) # Timesteps with flux_shift t_linear = torch.linspace(0, 1, num_steps + 1, device=DEVICE, dtype=DTYPE) timesteps = flux_shift(t_linear, s=SHIFT) # Euler sampling for i in range(num_steps): t_curr = timesteps[i] t_next = timesteps[i + 1] dt = t_next - t_curr t_batch = t_curr.expand(B).to(DTYPE) guidance = torch.full((B,), guidance_scale, device=DEVICE, dtype=DTYPE) v_cond = model( hidden_states=x, encoder_hidden_states=t5_embeds, pooled_projections=clip_pooleds, timestep=t_batch, img_ids=img_ids, guidance=guidance, ) x = x + v_cond * dt # Decode latents = x.reshape(B, H, W, C).permute(0, 3, 1, 2) latents = latents / vae.config.scaling_factor images = vae.decode(latents.to(vae.dtype)).sample images = (images / 2 + 0.5).clamp(0, 1) model.train() return images def save_samples(images, prompts, step, save_dir): """Save sample images.""" from torchvision.utils import make_grid, save_image for i, (img, prompt) in enumerate(zip(images, prompts)): safe_prompt = prompt[:50].replace(" ", "_").replace("/", "-") path = os.path.join(save_dir, f"step{step}_{i}_{safe_prompt}.png") save_image(img, path) grid = make_grid(images, nrow=2, normalize=False) writer.add_image("samples", grid, step) writer.add_text("sample_prompts", "\n".join(prompts), step) print(f" ✓ Saved {len(images)} samples") # ============================================================================ # OPTIMIZED COLLATE - Returns CPU tensors (GPU transfer in training loop) # ============================================================================ def collate_preencoded(batch): """Collate using pre-encoded embeddings - returns CPU tensors.""" indices = [b["__index__"] for b in batch] latents = torch.stack([ torch.tensor(np.array(b["latent"]), dtype=DTYPE) for b in batch ]) # Return CPU tensors - move to GPU in training loop return { "latents": latents, "t5_embeds": all_t5_embeds[indices].to(DTYPE), "clip_pooled": all_clip_pooled[indices].to(DTYPE), } def collate_online(batch): """Collate with online encoding - returns CPU tensors.""" prompts = [b["prompt"] for b in batch] latents = torch.stack([ torch.tensor(np.array(b["latent"]), dtype=DTYPE) for b in batch ]) # This still needs CUDA for encoding, so use num_workers=0 t5_embeds, clip_pooled = encode_prompts_batched(prompts) return { "latents": latents, "t5_embeds": t5_embeds.cpu().to(DTYPE), "clip_pooled": clip_pooled.cpu().to(DTYPE), } # Simple wrapper to add index without touching the data class IndexedDataset: """Wraps dataset to add __index__ field without expensive ds.map()""" def __init__(self, ds): self.ds = ds def __len__(self): return len(self.ds) def __getitem__(self, idx): item = dict(self.ds[idx]) item["__index__"] = idx return item # Choose collate strategy if PRECOMPUTE_ENCODINGS: ds = IndexedDataset(ds) # Instant, no iteration collate_fn = collate_preencoded num_workers = 2 else: collate_fn = collate_online num_workers = 0 # ============================================================================ # CHECKPOINT FUNCTIONS # ============================================================================ def load_weights(path): """Load weights, handling torch.compile prefix.""" if path.endswith(".safetensors"): state_dict = load_file(path) elif path.endswith(".pt"): ckpt = torch.load(path, map_location=DEVICE, weights_only=False) if isinstance(ckpt, dict): state_dict = ckpt.get("model", ckpt.get("state_dict", ckpt)) else: state_dict = ckpt else: try: state_dict = load_file(path) except: state_dict = torch.load(path, map_location=DEVICE, weights_only=False) # Strip torch.compile prefix if isinstance(state_dict, dict) and any(k.startswith("_orig_mod.") for k in state_dict.keys()): print(" Stripping torch.compile prefix...") state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} return state_dict def save_checkpoint(model, optimizer, scheduler, step, epoch, loss, path): """Save checkpoint, stripping torch.compile prefix.""" os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True) state_dict = model.state_dict() if any(k.startswith("_orig_mod.") for k in state_dict.keys()): state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} weights_path = path.replace(".pt", ".safetensors") save_file(state_dict, weights_path) torch.save({ "step": step, "epoch": epoch, "loss": loss, "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), }, path) print(f" ✓ Saved checkpoint: step {step}") return weights_path def upload_checkpoint(weights_path, step, config): """Upload to HuggingFace Hub.""" try: api.upload_file( path_or_fileobj=weights_path, path_in_repo=f"checkpoints/step_{step}.safetensors", repo_id=HF_REPO, commit_message=f"Checkpoint step {step}", ) config_path = os.path.join(CHECKPOINT_DIR, "config.json") with open(config_path, "w") as f: json.dump(config.__dict__, f, indent=2) api.upload_file(path_or_fileobj=config_path, path_in_repo="config.json", repo_id=HF_REPO) print(f" ✓ Uploaded step {step} to {HF_REPO}") except Exception as e: print(f" ⚠ Upload failed: {e}") def load_checkpoint(model, optimizer, scheduler, target): """Load checkpoint from various sources.""" start_step, start_epoch = 0, 0 if target == "none" or target is None: print("Starting fresh (no checkpoint)") return 0, 0 # Hub loading if target == "hub" or (isinstance(target, str) and target.startswith("hub:")): try: if target == "hub": weights_path = hf_hub_download(repo_id=HF_REPO, filename="model.safetensors") else: step_name = target.split(":")[1] try: weights_path = hf_hub_download(repo_id=HF_REPO, filename=f"checkpoints/{step_name}.safetensors") except: weights_path = hf_hub_download(repo_id=HF_REPO, filename=f"checkpoints/{step_name}.pt") start_step = int(step_name.split("_")[-1]) if "_" in step_name else 0 weights = load_weights(weights_path) # strict=False: ignore missing buffers (sin_basis, freqs) - they're precomputed constants missing, unexpected = model.load_state_dict(weights, strict=False) if missing: # Filter out expected missing buffers expected_missing = {'time_in.sin_basis', 'guidance_in.sin_basis', 'rope.freqs_0', 'rope.freqs_1', 'rope.freqs_2'} actual_missing = set(missing) - expected_missing if actual_missing: print(f" ⚠ Unexpected missing keys: {actual_missing}") else: print(f" ✓ Missing only precomputed buffers (OK)") print(f"✓ Loaded from hub: {target}") return start_step, start_epoch except Exception as e: print(f"Hub load failed: {e}") return 0, 0 # Local loading if isinstance(target, str) and target.startswith("local:"): path = target.split(":", 1)[1] weights = load_weights(path) missing, unexpected = model.load_state_dict(weights, strict=False) if missing: expected_missing = {'time_in.sin_basis', 'guidance_in.sin_basis', 'rope.freqs_0', 'rope.freqs_1', 'rope.freqs_2'} actual_missing = set(missing) - expected_missing if actual_missing: print(f" ⚠ Unexpected missing keys: {actual_missing}") print(f"✓ Loaded from local: {path}") return 0, 0 print("No checkpoint found, starting fresh") return 0, 0 # ============================================================================ # DATALOADER - Optimized # ============================================================================ loader = DataLoader( ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=num_workers, # 2 for precomputed, 0 for online pin_memory=True, persistent_workers=(num_workers > 0), prefetch_factor=2 if num_workers > 0 else None, ) # ============================================================================ # MODEL # ============================================================================ config = TinyFluxConfig() model = TinyFlux(config).to(DEVICE).to(DTYPE) print(f"\nParams: {sum(p.numel() for p in model.parameters()):,}") # ============================================================================ # OPTIMIZER - Fused for speed # ============================================================================ opt = torch.optim.AdamW( model.parameters(), lr=LR, betas=(0.9, 0.99), weight_decay=0.01, fused=True, ) total_steps = len(loader) * EPOCHS // GRAD_ACCUM warmup = min(500, total_steps // 10) def lr_fn(step): if step < warmup: return step / warmup return 0.5 * (1 + math.cos(math.pi * (step - warmup) / (total_steps - warmup))) sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_fn) # ============================================================================ # LOAD CHECKPOINT (before compile!) # ============================================================================ print(f"\nLoad target: {LOAD_TARGET}") start_step, start_epoch = load_checkpoint(model, opt, sched, LOAD_TARGET) if RESUME_STEP is not None: print(f"Overriding start_step: {start_step} -> {RESUME_STEP}") start_step = RESUME_STEP # ============================================================================ # COMPILE MODEL (after loading weights) # ============================================================================ model = torch.compile(model, mode="default") # Log config writer.add_text("config", json.dumps(config.__dict__, indent=2), 0) writer.add_text("training_config", json.dumps({ "batch_size": BATCH_SIZE, "grad_accum": GRAD_ACCUM, "lr": LR, "epochs": EPOCHS, "min_snr": MIN_SNR, "shift": SHIFT, "optimizations": ["TF32", "fused_adamw", "precomputed_encodings", "flash_attention", "torch.compile"] }, indent=2), 0) # Sample prompts SAMPLE_PROMPTS = [ "a photo of a cat sitting on a windowsill", "a beautiful sunset over mountains", "a portrait of a woman with red hair", "a futuristic cityscape at night", ] # ============================================================================ # TRAINING LOOP # ============================================================================ print(f"\nTraining {EPOCHS} epochs, {total_steps} total steps") print(f"Resuming from step {start_step}, epoch {start_epoch}") print(f"Save: {SAVE_EVERY}, Upload: {UPLOAD_EVERY}, Sample: {SAMPLE_EVERY}, Log: {LOG_EVERY}") print("Optimizations: TF32, fused AdamW, pre-encoded prompts, Flash Attention, torch.compile") model.train() step = start_step best = float("inf") # Pre-create img_ids for common resolution (will be cached) _cached_img_ids = None for ep in range(start_epoch, EPOCHS): ep_loss = 0 ep_batches = 0 pbar = tqdm(loader, desc=f"E{ep + 1}") for i, batch in enumerate(pbar): # Move to GPU here (not in collate, to support multiprocessing) latents = batch["latents"].to(DEVICE, non_blocking=True) t5 = batch["t5_embeds"].to(DEVICE, non_blocking=True) clip = batch["clip_pooled"].to(DEVICE, non_blocking=True) B, C, H, W = latents.shape # Reshape: (B, C, H, W) -> (B, H*W, C) data = latents.permute(0, 2, 3, 1).reshape(B, H * W, C) noise = torch.randn_like(data) # Sample timesteps with logit-normal + flux shift t = torch.sigmoid(torch.randn(B, device=DEVICE)) t = flux_shift(t, s=SHIFT).to(DTYPE).clamp(1e-4, 1 - 1e-4) # Linear interpolation t_expanded = t.view(B, 1, 1) x_t = (1 - t_expanded) * noise + t_expanded * data # Velocity target v_target = data - noise # Get img_ids (cached in model) img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE) # Random guidance guidance = torch.rand(B, device=DEVICE, dtype=DTYPE) * 4 + 1 # Forward with torch.autocast("cuda", dtype=DTYPE): v_pred = model( hidden_states=x_t, encoder_hidden_states=t5, pooled_projections=clip, timestep=t, img_ids=img_ids, guidance=guidance, ) # Loss with Min-SNR weighting loss_raw = F.mse_loss(v_pred, v_target, reduction="none").mean(dim=[1, 2]) snr_weights = min_snr_weight(t) loss = (loss_raw * snr_weights).mean() / GRAD_ACCUM loss.backward() if (i + 1) % GRAD_ACCUM == 0: grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() sched.step() opt.zero_grad(set_to_none=True) # Slightly faster than zero_grad() step += 1 if step % LOG_EVERY == 0: writer.add_scalar("train/loss", loss.item() * GRAD_ACCUM, step) writer.add_scalar("train/lr", sched.get_last_lr()[0], step) writer.add_scalar("train/grad_norm", grad_norm.item(), step) writer.add_scalar("train/t_mean", t.mean().item(), step) if step % SAMPLE_EVERY == 0: print(f"\n Generating samples at step {step}...") images = generate_samples(model, SAMPLE_PROMPTS, num_steps=20) save_samples(images, SAMPLE_PROMPTS, step, SAMPLE_DIR) if step % SAVE_EVERY == 0: ckpt_path = os.path.join(CHECKPOINT_DIR, f"step_{step}.pt") weights_path = save_checkpoint(model, opt, sched, step, ep, loss.item(), ckpt_path) if step % UPLOAD_EVERY == 0: upload_checkpoint(weights_path, step, config) ep_loss += loss.item() * GRAD_ACCUM ep_batches += 1 pbar.set_postfix(loss=f"{loss.item() * GRAD_ACCUM:.4f}", lr=f"{sched.get_last_lr()[0]:.1e}", step=step) avg = ep_loss / max(ep_batches, 1) print(f"Epoch {ep + 1} loss: {avg:.4f}") writer.add_scalar("train/epoch_loss", avg, ep + 1) if avg < best: best = avg best_path = os.path.join(CHECKPOINT_DIR, "best.pt") weights_path = save_checkpoint(model, opt, sched, step, ep, avg, best_path) try: api.upload_file( path_or_fileobj=weights_path, path_in_repo="model.safetensors", repo_id=HF_REPO, commit_message=f"Best model (epoch {ep + 1}, loss {avg:.4f})", ) print(f" ✓ Uploaded best to {HF_REPO}") except Exception as e: print(f" ⚠ Upload failed: {e}") # ============================================================================ # FINAL # ============================================================================ print("\nSaving final model...") final_path = os.path.join(CHECKPOINT_DIR, "final.pt") weights_path = save_checkpoint(model, opt, sched, step, EPOCHS, best, final_path) print("Generating final samples...") images = generate_samples(model, SAMPLE_PROMPTS, num_steps=20) save_samples(images, SAMPLE_PROMPTS, step, SAMPLE_DIR) try: api.upload_file(path_or_fileobj=weights_path, path_in_repo="model.safetensors", repo_id=HF_REPO) config_path = os.path.join(CHECKPOINT_DIR, "config.json") with open(config_path, "w") as f: json.dump(config.__dict__, f, indent=2) api.upload_file(path_or_fileobj=config_path, path_in_repo="config.json", repo_id=HF_REPO) print(f"\n✓ Training complete! https://huggingface.co/{HF_REPO}") except Exception as e: print(f"\n⚠ Final upload failed: {e}") writer.close() print(f"Best loss: {best:.4f}")