# ============================================================================ # TinyFlux-Deep Training Cell - With Expert Distillation (Precached) # ============================================================================ # Integrates SD1.5-flow-lune as a frozen timestep expert. # Expert features are PRECACHED at 10 timestep buckets for speed. # The ExpertPredictor learns to emulate expert features from (t, CLIP). # At inference, no expert needed - predictor runs standalone. # # USAGE: Run model.py cell first, then this cell # ============================================================================ import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from datasets import load_dataset, concatenate_datasets from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer from huggingface_hub import HfApi, hf_hub_download from safetensors.torch import save_file, load_file from torch.utils.tensorboard import SummaryWriter from tqdm.auto import tqdm import numpy as np import math import json import random from typing import Tuple, Optional, Dict, List import os from datetime import datetime from PIL import Image # ============================================================================ # CUDA OPTIMIZATIONS # ============================================================================ torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True torch.set_float32_matmul_precision('high') import warnings warnings.filterwarnings('ignore', message='.*TF32.*') # ============================================================================ # CONFIG # ============================================================================ BATCH_SIZE = 16 GRAD_ACCUM = 2 LR = 3e-4 EPOCHS = 40 MAX_SEQ = 128 SHIFT = 3.0 DEVICE = "cuda" DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 ALLOW_WEIGHT_UPGRADE = True # HuggingFace Hub HF_REPO = "AbstractPhil/tiny-flux-deep" SAVE_EVERY = 625 UPLOAD_EVERY = 625 SAMPLE_EVERY = 312 LOG_EVERY = 10 LOG_UPLOAD_EVERY = 625 # Checkpoint loading LOAD_TARGET = "hub:step_305000" RESUME_STEP = None # ============================================================================ # EXPERT DISTILLATION CONFIG # ============================================================================ ENABLE_EXPERT_DISTILLATION = True EXPERT_CHECKPOINT = "AbstractPhil/sd15-flow-lune-flux" EXPERT_CHECKPOINT_PATH = "flux_t2_6_pose_t4_6_port_t1_4/checkpoint-00018765/unet/diffusion_pytorch_model.safetensors" EXPERT_DIM = 1280 EXPERT_HIDDEN_DIM = 512 EXPERT_DROPOUT = 0.1 # Prob of forcing predictor (applied outside model) DISTILL_LOSS_WEIGHT = 0.1 DISTILL_WARMUP_STEPS = 1000 # Timestep buckets for precaching EXPERT_T_BUCKETS = torch.linspace(0.05, 0.95, 10) # ============================================================================ # DATASET CONFIG # ============================================================================ ENABLE_PORTRAIT = False ENABLE_SCHNELL = True ENABLE_SPORTFASHION = False ENABLE_SYNTHMOCAP = False PORTRAIT_REPO = "AbstractPhil/ffhq_flux_latents_repaired" PORTRAIT_NUM_SHARDS = 11 SCHNELL_REPO = "AbstractPhil/flux-schnell-teacher-latents" SCHNELL_CONFIGS = ["train_512"] SPORTFASHION_REPO = "Pianokill/SportFashion_512x512" SYNTHMOCAP_REPO = "toyxyz/SynthMoCap_smpl_512" FG_LOSS_WEIGHT = 2.0 BG_LOSS_WEIGHT = 0.5 USE_MASKED_LOSS = False MIN_SNR_GAMMA = 5.0 # Paths CHECKPOINT_DIR = "./tiny_flux_deep_checkpoints" LOG_DIR = "./tiny_flux_deep_logs" SAMPLE_DIR = "./tiny_flux_deep_samples" ENCODING_CACHE_DIR = "./encoding_cache" LATENT_CACHE_DIR = "./latent_cache" os.makedirs(CHECKPOINT_DIR, exist_ok=True) os.makedirs(LOG_DIR, exist_ok=True) os.makedirs(SAMPLE_DIR, exist_ok=True) os.makedirs(ENCODING_CACHE_DIR, exist_ok=True) os.makedirs(LATENT_CACHE_DIR, exist_ok=True) # ============================================================================ # REGULARIZATION CONFIG # ============================================================================ TEXT_DROPOUT = 0.1 GUIDANCE_DROPOUT = 0.1 EMA_DECAY = 0.9999 # ============================================================================ # EXPERT FEATURE CACHE (precached, fast lookup + interpolation) # ============================================================================ class ExpertFeatureCache: """ Precached SD1.5-flow expert features with timestep interpolation. Features extracted at 10 timestep buckets [0.05, 0.15, ..., 0.95]. At runtime, interpolates between nearest buckets. """ def __init__(self, features: torch.Tensor, t_buckets: torch.Tensor, dtype=torch.float16): self.features = features.to(dtype) # [N, 10, 1280] self.t_buckets = t_buckets self.t_min = t_buckets[0].item() self.t_max = t_buckets[-1].item() self.t_step = (t_buckets[1] - t_buckets[0]).item() self.n_buckets = len(t_buckets) self.dtype = dtype def get_features(self, indices: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: """ Get interpolated expert features. Args: indices: [B] sample indices into dataset timesteps: [B] timesteps in [0, 1] Returns: [B, 1280] interpolated features """ device = timesteps.device # Clamp to valid range t_clamped = timesteps.float().clamp(self.t_min, self.t_max) # Find bucket indices t_idx_float = (t_clamped - self.t_min) / self.t_step t_idx_low = t_idx_float.long().clamp(0, self.n_buckets - 2) t_idx_high = (t_idx_low + 1).clamp(0, self.n_buckets - 1) # Interpolation alpha alpha = (t_idx_float - t_idx_low.float()).unsqueeze(-1) # [B, 1] # Gather (on CPU for large caches) idx_cpu = indices.cpu() t_low_cpu = t_idx_low.cpu() t_high_cpu = t_idx_high.cpu() f_low = self.features[idx_cpu, t_low_cpu] # [B, 1280] f_high = self.features[idx_cpu, t_high_cpu] # [B, 1280] # Interpolate and move to device result = (1 - alpha.cpu()) * f_low + alpha.cpu() * f_high return result.to(device=device, dtype=self.dtype) def load_or_extract_expert_features(cache_path: str, prompts: List[str], name: str, clip_tok, clip_enc, t_buckets: torch.Tensor, batch_size: int = 32) -> Optional[ExpertFeatureCache]: """ Load cached expert features or extract them from SD1.5-flow. Follows same pattern as load_or_encode for text embeddings. """ if not prompts or not ENABLE_EXPERT_DISTILLATION: return None # Check cache if os.path.exists(cache_path): print(f"Loading cached {name} expert features...") cached = torch.load(cache_path, map_location="cpu") cache = ExpertFeatureCache(cached["features"], cached["t_buckets"], DTYPE) print(f" ✓ Loaded {cache.features.shape[0]} samples × {cache.n_buckets} timesteps") return cache # Extract features print(f"Extracting {name} expert features ({len(prompts)} × {len(t_buckets)} timesteps)...") print(f" This is a one-time operation, will be cached for future runs.") # Load expert model temporarily checkpoint_path = hf_hub_download( repo_id=EXPERT_CHECKPOINT, filename=EXPERT_CHECKPOINT_PATH, ) from diffusers import UNet2DConditionModel unet = UNet2DConditionModel.from_pretrained( "stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet", torch_dtype=DTYPE, ).to(DEVICE).eval() state_dict = load_file(checkpoint_path) unet.load_state_dict(state_dict, strict=False) for p in unet.parameters(): p.requires_grad = False # Hook for mid-block features mid_features = [None] def hook_fn(module, inp, out): mid_features[0] = out.mean(dim=[2, 3]) unet.mid_block.register_forward_hook(hook_fn) # Extract n_prompts = len(prompts) n_buckets = len(t_buckets) all_features = torch.zeros(n_prompts, n_buckets, EXPERT_DIM, dtype=torch.float16) with torch.no_grad(): for start_idx in tqdm(range(0, n_prompts, batch_size), desc=f"Extracting {name}"): end_idx = min(start_idx + batch_size, n_prompts) batch_prompts = prompts[start_idx:end_idx] B = len(batch_prompts) # Encode CLIP hidden states clip_inputs = clip_tok( batch_prompts, return_tensors="pt", padding="max_length", max_length=77, truncation=True ).to(DEVICE) clip_hidden = clip_enc(**clip_inputs).last_hidden_state # [B, 77, 768] # Extract at each timestep bucket for t_idx, t_val in enumerate(t_buckets): timesteps = torch.full((B,), t_val.item(), device=DEVICE) latents = torch.randn(B, 4, 64, 64, device=DEVICE, dtype=DTYPE) _ = unet(latents, timesteps * 1000, encoder_hidden_states=clip_hidden.to(DTYPE)) all_features[start_idx:end_idx, t_idx] = mid_features[0].cpu().to(torch.float16) # Cleanup del unet torch.cuda.empty_cache() # Save cache torch.save({"features": all_features, "t_buckets": t_buckets}, cache_path) print(f" ✓ Cached to {cache_path}") print(f" Size: {all_features.numel() * 2 / 1e9:.2f} GB") return ExpertFeatureCache(all_features, t_buckets, DTYPE) # ============================================================================ # EMA # ============================================================================ class EMA: def __init__(self, model, decay=0.9999): self.decay = decay self.shadow = {} self._backup = {} if hasattr(model, '_orig_mod'): state = model._orig_mod.state_dict() else: state = model.state_dict() for k, v in state.items(): self.shadow[k] = v.clone().detach() @torch.no_grad() def update(self, model): if hasattr(model, '_orig_mod'): state = model._orig_mod.state_dict() else: state = model.state_dict() for k, v in state.items(): if k in self.shadow: self.shadow[k].lerp_(v.to(self.shadow[k].dtype), 1 - self.decay) def apply_shadow_for_eval(self, model): if hasattr(model, '_orig_mod'): self._backup = {k: v.clone() for k, v in model._orig_mod.state_dict().items()} model._orig_mod.load_state_dict(self.shadow) else: self._backup = {k: v.clone() for k, v in model.state_dict().items()} model.load_state_dict(self.shadow) def restore(self, model): if hasattr(model, '_orig_mod'): model._orig_mod.load_state_dict(self._backup) else: model.load_state_dict(self._backup) self._backup = {} def state_dict(self): return {'shadow': self.shadow, 'decay': self.decay} def load_state_dict(self, state): self.shadow = {k: v.clone() for k, v in state['shadow'].items()} self.decay = state.get('decay', self.decay) def load_shadow(self, shadow_state): """Load EMA shadow weights, handling architecture changes gracefully.""" device = next(iter(self.shadow.values())).device if self.shadow else 'cuda' loaded = 0 skipped_old = 0 kept_new = 0 for k, v in shadow_state.items(): if k in self.shadow: # Key exists in current model - load it self.shadow[k] = v.clone().to(device) loaded += 1 else: # Key doesn't exist (deprecated like guidance_in) skipped_old += 1 # Count new keys not in checkpoint for k in self.shadow: if k not in shadow_state: kept_new += 1 print(f" ✓ Restored EMA: {loaded} loaded, {skipped_old} deprecated skipped, {kept_new} new (fresh init)") # ============================================================================ # REGULARIZATION # ============================================================================ def apply_text_dropout(t5_embeds, clip_pooled, dropout_prob=0.1): B = t5_embeds.shape[0] mask = torch.rand(B, device=t5_embeds.device) < dropout_prob t5_embeds = t5_embeds.clone() clip_pooled = clip_pooled.clone() t5_embeds[mask] = 0 clip_pooled[mask] = 0 return t5_embeds, clip_pooled, mask # ============================================================================ # MASKING UTILITIES # ============================================================================ def detect_background_color(image: Image.Image, sample_size: int = 100) -> Tuple[int, int, int]: img = np.array(image) if len(img.shape) == 2: img = np.stack([img] * 3, axis=-1) h, w = img.shape[:2] corners = [ img[:sample_size, :sample_size], img[:sample_size, -sample_size:], img[-sample_size:, :sample_size], img[-sample_size:, -sample_size:], ] corner_pixels = np.concatenate([c.reshape(-1, 3) for c in corners], axis=0) bg_color = np.median(corner_pixels, axis=0).astype(np.uint8) return tuple(bg_color) def create_product_mask(image: Image.Image, threshold: int = 30) -> np.ndarray: img = np.array(image).astype(np.float32) if len(img.shape) == 2: img = np.stack([img] * 3, axis=-1) bg_color = detect_background_color(image) bg_color = np.array(bg_color, dtype=np.float32) diff = np.sqrt(np.sum((img - bg_color) ** 2, axis=-1)) mask = (diff > threshold).astype(np.float32) return mask def create_smpl_mask(conditioning_image: Image.Image, threshold: int = 20) -> np.ndarray: img = np.array(conditioning_image).astype(np.float32) if len(img.shape) == 2: return (img > threshold).astype(np.float32) r, g, b = img[:, :, 0], img[:, :, 1], img[:, :, 2] is_background = (g > r + 20) & (g > b + 20) mask = (~is_background).astype(np.float32) return mask def downsample_mask_to_latent(mask: np.ndarray, latent_h: int = 64, latent_w: int = 64) -> torch.Tensor: mask_pil = Image.fromarray((mask * 255).astype(np.uint8)) mask_pil = mask_pil.resize((latent_w, latent_h), Image.Resampling.BILINEAR) mask_latent = np.array(mask_pil).astype(np.float32) / 255.0 return torch.from_numpy(mask_latent) # ============================================================================ # HF HUB SETUP # ============================================================================ print("Setting up HuggingFace Hub...") api = HfApi() # ============================================================================ # FLOW MATCHING HELPERS # ============================================================================ def flux_shift(t, s=SHIFT): return s * t / (1 + (s - 1) * t) def min_snr_weight(t, gamma=MIN_SNR_GAMMA): snr = (t / (1 - t).clamp(min=1e-5)).pow(2) return torch.clamp(snr, max=gamma) / snr.clamp(min=1e-5) # ============================================================================ # LOAD TEXT ENCODERS # ============================================================================ print("Loading text encoders...") t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") t5_enc = T5EncoderModel.from_pretrained("google/flan-t5-base", torch_dtype=DTYPE).to(DEVICE).eval() for p in t5_enc.parameters(): p.requires_grad = False clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval() for p in clip_enc.parameters(): p.requires_grad = False print("✓ Text encoders loaded") # ============================================================================ # LOAD VAE # ============================================================================ print("Loading VAE...") from diffusers import AutoencoderKL vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=DTYPE).to(DEVICE).eval() for p in vae.parameters(): p.requires_grad = False VAE_SCALE = vae.config.scaling_factor print(f"✓ VAE loaded (scale={VAE_SCALE})") # ============================================================================ # ENCODING FUNCTIONS # ============================================================================ @torch.no_grad() def encode_prompt(prompt: str) -> Tuple[torch.Tensor, torch.Tensor]: t5_inputs = t5_tok(prompt, return_tensors="pt", padding="max_length", max_length=MAX_SEQ, truncation=True).to(DEVICE) t5_out = t5_enc(**t5_inputs).last_hidden_state clip_inputs = clip_tok(prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True).to(DEVICE) clip_out = clip_enc(**clip_inputs).pooler_output return t5_out.squeeze(0), clip_out.squeeze(0) @torch.no_grad() def encode_prompts_batched(prompts: List[str], batch_size: int = 64) -> Tuple[torch.Tensor, torch.Tensor]: all_t5 = [] all_clip = [] for i in tqdm(range(0, len(prompts), batch_size), desc="Encoding", leave=False): batch = prompts[i:i+batch_size] t5_inputs = t5_tok(batch, return_tensors="pt", padding="max_length", max_length=MAX_SEQ, truncation=True).to(DEVICE) t5_out = t5_enc(**t5_inputs).last_hidden_state all_t5.append(t5_out.cpu()) clip_inputs = clip_tok(batch, return_tensors="pt", padding="max_length", max_length=77, truncation=True).to(DEVICE) clip_out = clip_enc(**clip_inputs).pooler_output all_clip.append(clip_out.cpu()) return torch.cat(all_t5, dim=0), torch.cat(all_clip, dim=0) @torch.no_grad() def encode_image_to_latent(image: Image.Image) -> torch.Tensor: if image.mode != "RGB": image = image.convert("RGB") if image.size != (512, 512): image = image.resize((512, 512), Image.Resampling.LANCZOS) img_tensor = torch.from_numpy(np.array(image)).float() / 255.0 img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0) img_tensor = (img_tensor * 2.0 - 1.0).to(DEVICE, dtype=DTYPE) latent = vae.encode(img_tensor).latent_dist.sample() latent = latent * VAE_SCALE return latent.squeeze(0).cpu() # ============================================================================ # LOAD DATASETS # ============================================================================ portrait_ds = None portrait_indices = [] portrait_prompts = [] if ENABLE_PORTRAIT: print(f"\n[1/4] Loading portrait dataset from {PORTRAIT_REPO}...") portrait_shards = [] for i in range(PORTRAIT_NUM_SHARDS): split_name = f"train_{i:02d}" print(f" Loading {split_name}...") shard = load_dataset(PORTRAIT_REPO, split=split_name) portrait_shards.append(shard) portrait_ds = concatenate_datasets(portrait_shards) print(f"✓ Portrait: {len(portrait_ds)} base samples") print(" Extracting prompts (columnar)...") florence_list = list(portrait_ds["text_florence"]) llava_list = list(portrait_ds["text_llava"]) blip_list = list(portrait_ds["text_blip"]) for i, (f, l, b) in enumerate(zip(florence_list, llava_list, blip_list)): if f and f.strip(): portrait_indices.append(i) portrait_prompts.append(f) if l and l.strip(): portrait_indices.append(i) portrait_prompts.append(l) if b and b.strip(): portrait_indices.append(i) portrait_prompts.append(b) print(f" Expanded: {len(portrait_prompts)} samples (3 prompts/image)") else: print("\n[1/4] Portrait dataset DISABLED") schnell_ds = None schnell_prompts = [] if ENABLE_SCHNELL: print(f"\n[2/4] Loading schnell teacher dataset from {SCHNELL_REPO}...") schnell_datasets = [] for config in SCHNELL_CONFIGS: print(f" Loading {config}...") ds = load_dataset(SCHNELL_REPO, config, split="train") schnell_datasets.append(ds) print(f" {len(ds)} samples") schnell_ds = concatenate_datasets(schnell_datasets) schnell_prompts = list(schnell_ds["prompt"]) print(f"✓ Schnell: {len(schnell_ds)} samples") else: print("\n[2/4] Schnell dataset DISABLED") sportfashion_ds = None sportfashion_prompts = [] if ENABLE_SPORTFASHION: print(f"\n[3/4] Loading SportFashion dataset from {SPORTFASHION_REPO}...") sportfashion_ds = load_dataset(SPORTFASHION_REPO, split="train") sportfashion_prompts = list(sportfashion_ds["text"]) print(f"✓ SportFashion: {len(sportfashion_ds)} samples") else: print("\n[3/4] SportFashion dataset DISABLED") synthmocap_ds = None synthmocap_prompts = [] if ENABLE_SYNTHMOCAP: print(f"\n[4/4] Loading SynthMoCap dataset from {SYNTHMOCAP_REPO}...") synthmocap_ds = load_dataset(SYNTHMOCAP_REPO, split="train") synthmocap_prompts = list(synthmocap_ds["text"]) print(f"✓ SynthMoCap: {len(synthmocap_ds)} samples") else: print("\n[4/4] SynthMoCap dataset DISABLED") # ============================================================================ # ENCODE ALL PROMPTS # ============================================================================ total_samples = len(portrait_prompts) + len(schnell_prompts) + len(sportfashion_prompts) + len(synthmocap_prompts) print(f"\nTotal combined samples: {total_samples}") def load_or_encode(cache_path, prompts, name): if not prompts: return None, None if os.path.exists(cache_path): print(f"Loading cached {name} encodings...") cached = torch.load(cache_path) return cached["t5_embeds"], cached["clip_pooled"] else: print(f"Encoding {len(prompts)} {name} prompts...") t5, clip = encode_prompts_batched(prompts, batch_size=64) torch.save({"t5_embeds": t5, "clip_pooled": clip}, cache_path) print(f"✓ Cached to {cache_path}") return t5, clip # Standard text encodings portrait_t5, portrait_clip = None, None schnell_t5, schnell_clip = None, None sportfashion_t5, sportfashion_clip = None, None synthmocap_t5, synthmocap_clip = None, None if portrait_prompts: portrait_enc_cache = os.path.join(ENCODING_CACHE_DIR, f"portrait_encodings_{len(portrait_prompts)}.pt") portrait_t5, portrait_clip = load_or_encode(portrait_enc_cache, portrait_prompts, "portrait") if schnell_prompts: schnell_enc_cache = os.path.join(ENCODING_CACHE_DIR, f"schnell_encodings_{len(schnell_prompts)}.pt") schnell_t5, schnell_clip = load_or_encode(schnell_enc_cache, schnell_prompts, "schnell") if sportfashion_prompts: sportfashion_enc_cache = os.path.join(ENCODING_CACHE_DIR, f"sportfashion_encodings_{len(sportfashion_prompts)}.pt") sportfashion_t5, sportfashion_clip = load_or_encode(sportfashion_enc_cache, sportfashion_prompts, "sportfashion") if synthmocap_prompts: synthmocap_enc_cache = os.path.join(ENCODING_CACHE_DIR, f"synthmocap_encodings_{len(synthmocap_prompts)}.pt") synthmocap_t5, synthmocap_clip = load_or_encode(synthmocap_enc_cache, synthmocap_prompts, "synthmocap") # ============================================================================ # EXTRACT/LOAD EXPERT FEATURES (precached) # ============================================================================ print("\n" + "="*60) print("Expert Feature Caching") print("="*60) schnell_expert_cache = None portrait_expert_cache = None sportfashion_expert_cache = None synthmocap_expert_cache = None if schnell_prompts and ENABLE_EXPERT_DISTILLATION: schnell_expert_path = os.path.join(ENCODING_CACHE_DIR, f"schnell_expert_{len(schnell_prompts)}.pt") schnell_expert_cache = load_or_extract_expert_features( schnell_expert_path, schnell_prompts, "schnell", clip_tok, clip_enc, EXPERT_T_BUCKETS ) if portrait_prompts and ENABLE_EXPERT_DISTILLATION: portrait_expert_path = os.path.join(ENCODING_CACHE_DIR, f"portrait_expert_{len(portrait_prompts)}.pt") portrait_expert_cache = load_or_extract_expert_features( portrait_expert_path, portrait_prompts, "portrait", clip_tok, clip_enc, EXPERT_T_BUCKETS ) if sportfashion_prompts and ENABLE_EXPERT_DISTILLATION: sportfashion_expert_path = os.path.join(ENCODING_CACHE_DIR, f"sportfashion_expert_{len(sportfashion_prompts)}.pt") sportfashion_expert_cache = load_or_extract_expert_features( sportfashion_expert_path, sportfashion_prompts, "sportfashion", clip_tok, clip_enc, EXPERT_T_BUCKETS ) if synthmocap_prompts and ENABLE_EXPERT_DISTILLATION: synthmocap_expert_path = os.path.join(ENCODING_CACHE_DIR, f"synthmocap_expert_{len(synthmocap_prompts)}.pt") synthmocap_expert_cache = load_or_extract_expert_features( synthmocap_expert_path, synthmocap_prompts, "synthmocap", clip_tok, clip_enc, EXPERT_T_BUCKETS ) # ============================================================================ # COMBINED DATASET CLASS (with sample_idx for expert lookup) # ============================================================================ class CombinedDataset(Dataset): """Combined dataset returning sample index for expert feature lookup.""" def __init__( self, portrait_ds, portrait_indices, portrait_t5, portrait_clip, schnell_ds, schnell_t5, schnell_clip, sportfashion_ds, sportfashion_t5, sportfashion_clip, synthmocap_ds, synthmocap_t5, synthmocap_clip, vae, vae_scale, device, dtype, compute_masks=True, ): self.portrait_ds = portrait_ds self.portrait_indices = portrait_indices self.portrait_t5 = portrait_t5 self.portrait_clip = portrait_clip self.schnell_ds = schnell_ds self.schnell_t5 = schnell_t5 self.schnell_clip = schnell_clip self.sportfashion_ds = sportfashion_ds self.sportfashion_t5 = sportfashion_t5 self.sportfashion_clip = sportfashion_clip self.synthmocap_ds = synthmocap_ds self.synthmocap_t5 = synthmocap_t5 self.synthmocap_clip = synthmocap_clip self.vae = vae self.vae_scale = vae_scale self.device = device self.dtype = dtype self.compute_masks = compute_masks self.n_portrait = len(portrait_indices) if portrait_indices else 0 self.n_schnell = len(schnell_ds) if schnell_ds else 0 self.n_sportfashion = len(sportfashion_ds) if sportfashion_ds else 0 self.n_synthmocap = len(synthmocap_ds) if synthmocap_ds else 0 self.c1 = self.n_portrait self.c2 = self.c1 + self.n_schnell self.c3 = self.c2 + self.n_sportfashion self.total = self.c3 + self.n_synthmocap def __len__(self): return self.total def _get_latent_from_array(self, latent_data): if isinstance(latent_data, torch.Tensor): return latent_data.to(self.dtype) return torch.tensor(np.array(latent_data), dtype=self.dtype) @torch.no_grad() def _encode_image(self, image): if image.mode != "RGB": image = image.convert("RGB") if image.size != (512, 512): image = image.resize((512, 512), Image.Resampling.LANCZOS) img_tensor = torch.from_numpy(np.array(image)).float() / 255.0 img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0) img_tensor = (img_tensor * 2.0 - 1.0).to(self.device, dtype=self.dtype) latent = self.vae.encode(img_tensor).latent_dist.sample() latent = latent * self.vae_scale return latent.squeeze(0).cpu() def __getitem__(self, idx): mask = None # Determine which dataset and local index if idx < self.c1: # Portrait local_idx = idx orig_idx = self.portrait_indices[idx] item = self.portrait_ds[orig_idx] latent = self._get_latent_from_array(item["latent"]) t5 = self.portrait_t5[idx] clip = self.portrait_clip[idx] dataset_id = 0 elif idx < self.c2: # Schnell local_idx = idx - self.c1 item = self.schnell_ds[local_idx] latent = self._get_latent_from_array(item["latent"]) t5 = self.schnell_t5[local_idx] clip = self.schnell_clip[local_idx] dataset_id = 1 elif idx < self.c3: # SportFashion local_idx = idx - self.c2 item = self.sportfashion_ds[local_idx] image = item["image"] latent = self._encode_image(image) t5 = self.sportfashion_t5[local_idx] clip = self.sportfashion_clip[local_idx] dataset_id = 2 if self.compute_masks: pixel_mask = create_product_mask(image) mask = downsample_mask_to_latent(pixel_mask, 64, 64) else: # SynthMoCap local_idx = idx - self.c3 item = self.synthmocap_ds[local_idx] image = item["image"] conditioning = item["conditioning_image"] latent = self._encode_image(image) t5 = self.synthmocap_t5[local_idx] clip = self.synthmocap_clip[local_idx] dataset_id = 3 if self.compute_masks: pixel_mask = create_smpl_mask(conditioning) mask = downsample_mask_to_latent(pixel_mask, 64, 64) result = { "latent": latent, "t5_embed": t5.to(self.dtype), "clip_pooled": clip.to(self.dtype), "sample_idx": idx, # Global index for expert cache lookup "local_idx": local_idx, # Local index within dataset "dataset_id": dataset_id, # Which dataset (0=portrait, 1=schnell, etc) } if mask is not None: result["mask"] = mask.to(self.dtype) return result # ============================================================================ # COLLATE FUNCTION # ============================================================================ def collate_fn(batch): latents = torch.stack([b["latent"] for b in batch]) t5_embeds = torch.stack([b["t5_embed"] for b in batch]) clip_pooled = torch.stack([b["clip_pooled"] for b in batch]) sample_indices = torch.tensor([b["sample_idx"] for b in batch], dtype=torch.long) local_indices = torch.tensor([b["local_idx"] for b in batch], dtype=torch.long) dataset_ids = torch.tensor([b["dataset_id"] for b in batch], dtype=torch.long) masks = None if any("mask" in b for b in batch): masks = [] for b in batch: if "mask" in b: masks.append(b["mask"]) else: masks.append(torch.ones(64, 64, dtype=latents.dtype)) masks = torch.stack(masks) return { "latents": latents, "t5_embeds": t5_embeds, "clip_pooled": clip_pooled, "sample_indices": sample_indices, "local_indices": local_indices, "dataset_ids": dataset_ids, "masks": masks, } # ============================================================================ # EXPERT FEATURE LOOKUP (handles multiple datasets) # ============================================================================ def get_expert_features_for_batch( local_indices: torch.Tensor, dataset_ids: torch.Tensor, timesteps: torch.Tensor, portrait_cache: Optional[ExpertFeatureCache], schnell_cache: Optional[ExpertFeatureCache], sportfashion_cache: Optional[ExpertFeatureCache], synthmocap_cache: Optional[ExpertFeatureCache], ) -> Optional[torch.Tensor]: """Get expert features from the appropriate cache for each sample.""" caches = [portrait_cache, schnell_cache, sportfashion_cache, synthmocap_cache] # Check if any cache is available if not any(c is not None for c in caches): return None B = local_indices.shape[0] device = timesteps.device features = torch.zeros(B, EXPERT_DIM, device=device, dtype=DTYPE) for ds_id, cache in enumerate(caches): if cache is None: continue # Find samples from this dataset mask = dataset_ids == ds_id if not mask.any(): continue # Get features for these samples ds_local_indices = local_indices[mask] ds_timesteps = timesteps[mask] ds_features = cache.get_features(ds_local_indices, ds_timesteps) features[mask] = ds_features return features # ============================================================================ # MASKED LOSS FUNCTION # ============================================================================ def masked_mse_loss(pred, target, mask=None, fg_weight=2.0, bg_weight=0.5, snr_weights=None): B, N, C = pred.shape if mask is None: loss_per_sample = ((pred - target) ** 2).mean(dim=[1, 2]) else: H = W = int(math.sqrt(N)) mask_flat = mask.view(B, H * W, 1).to(pred.device) sq_error = (pred - target) ** 2 weights = mask_flat * fg_weight + (1 - mask_flat) * bg_weight weighted_error = sq_error * weights loss_per_sample = weighted_error.mean(dim=[1, 2]) if snr_weights is not None: loss_per_sample = loss_per_sample * snr_weights return loss_per_sample.mean() # ============================================================================ # CREATE DATASET # ============================================================================ print("\nCreating combined dataset...") combined_ds = CombinedDataset( portrait_ds, portrait_indices, portrait_t5, portrait_clip, schnell_ds, schnell_t5, schnell_clip, sportfashion_ds, sportfashion_t5, sportfashion_clip, synthmocap_ds, synthmocap_t5, synthmocap_clip, vae, VAE_SCALE, DEVICE, DTYPE, compute_masks=USE_MASKED_LOSS, ) print(f"✓ Combined dataset: {len(combined_ds)} samples") print(f" - Portraits (3x): {combined_ds.n_portrait:,}") print(f" - Schnell teacher: {combined_ds.n_schnell:,}") print(f" - SportFashion: {combined_ds.n_sportfashion:,}") print(f" - SynthMoCap: {combined_ds.n_synthmocap:,}") print(f" - Expert distillation: {ENABLE_EXPERT_DISTILLATION}") # ============================================================================ # DATALOADER # ============================================================================ loader = DataLoader( combined_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True, collate_fn=collate_fn, drop_last=True, ) print(f"✓ DataLoader: {len(loader)} batches/epoch") # ============================================================================ # SAMPLING FUNCTION # ============================================================================ @torch.inference_mode() def generate_samples(model, prompts, num_steps=28, guidance_scale=3.5, H=64, W=64, use_ema=True): was_training = model.training model.eval() if use_ema and 'ema' in globals() and ema is not None: ema.apply_shadow_for_eval(model) B = len(prompts) C = 16 t5_list, clip_list = [], [] for p in prompts: t5, clip = encode_prompt(p) t5_list.append(t5) clip_list.append(clip) t5_embeds = torch.stack(t5_list).to(DTYPE) clip_pooleds = torch.stack(clip_list).to(DTYPE) x = torch.randn(B, H * W, C, device=DEVICE, dtype=DTYPE) img_ids = TinyFluxDeep.create_img_ids(B, H, W, DEVICE) t_linear = torch.linspace(0, 1, num_steps + 1, device=DEVICE, dtype=DTYPE) timesteps = flux_shift(t_linear, s=SHIFT) for i in range(num_steps): t_curr = timesteps[i] t_next = timesteps[i + 1] dt = t_next - t_curr t_batch = t_curr.expand(B).to(DTYPE) with torch.autocast("cuda", dtype=DTYPE): # No expert_features at inference - predictor runs standalone v_cond = model( hidden_states=x, encoder_hidden_states=t5_embeds, pooled_projections=clip_pooleds, timestep=t_batch, img_ids=img_ids, ) x = x + v_cond * dt latents = x.reshape(B, H, W, C).permute(0, 3, 1, 2) latents = latents / VAE_SCALE with torch.autocast("cuda", dtype=DTYPE): images = vae.decode(latents.to(vae.dtype)).sample images = (images / 2 + 0.5).clamp(0, 1) if use_ema and 'ema' in globals() and ema is not None: ema.restore(model) if was_training: model.train() return images def save_samples(images, prompts, step, output_dir): from torchvision.utils import save_image os.makedirs(output_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") grid_path = os.path.join(output_dir, f"samples_step_{step}.png") save_image(images, grid_path, nrow=2, padding=2) try: api.upload_file( path_or_fileobj=grid_path, path_in_repo=f"samples/{timestamp}_step_{step}.png", repo_id=HF_REPO, ) except: pass # ============================================================================ # CHECKPOINT FUNCTIONS # ============================================================================ def save_checkpoint(model, optimizer, scheduler, step, epoch, loss, path, ema=None): os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True) if hasattr(model, '_orig_mod'): state_dict = model._orig_mod.state_dict() else: state_dict = model.state_dict() state_dict = {k: v.to(DTYPE) if v.is_floating_point() else v for k, v in state_dict.items()} weights_path = path.replace(".pt", ".safetensors") save_file(state_dict, weights_path) if ema is not None: ema_weights = {k: v.to(DTYPE) if v.is_floating_point() else v for k, v in ema.shadow.items()} ema_weights_path = path.replace(".pt", "_ema.safetensors") save_file(ema_weights, ema_weights_path) state = { "step": step, "epoch": epoch, "loss": loss, "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), } if ema is not None: state["ema_decay"] = ema.decay torch.save(state, path) print(f" ✓ Saved checkpoint: step {step}") return weights_path def upload_checkpoint(weights_path, step): try: api.upload_file( path_or_fileobj=weights_path, path_in_repo=f"checkpoints/step_{step}.safetensors", repo_id=HF_REPO, ) ema_path = weights_path.replace(".safetensors", "_ema.safetensors") if os.path.exists(ema_path): api.upload_file( path_or_fileobj=ema_path, path_in_repo=f"checkpoints/step_{step}_ema.safetensors", repo_id=HF_REPO, ) print(f" ✓ Uploaded checkpoint to {HF_REPO}") except Exception as e: print(f" ⚠ Upload failed: {e}") def load_with_weight_upgrade(model, state_dict): """ Load state dict with automatic handling of: - Missing ExpertPredictor weights → initialize fresh - Missing Q/K norm weights → initialize to ones (identity) - Unexpected keys → ignore (e.g., old guidance_in, sin_basis caches) """ model_state = model.state_dict() # Patterns for new weights that may not exist in old checkpoints NEW_WEIGHT_PATTERNS = [ 'expert_predictor.', # New ExpertPredictor module '.norm_q.weight', '.norm_k.weight', '.norm_added_q.weight', '.norm_added_k.weight', ] # Keys that may exist in old checkpoints but not new model DEPRECATED_PATTERNS = [ 'guidance_in.', # Replaced by expert_predictor '.sin_basis', # Old cached sin embeddings ] loaded_keys = [] missing_keys = [] unexpected_keys = [] initialized_keys = [] # First pass: load matching weights for key in state_dict.keys(): if key in model_state: if state_dict[key].shape == model_state[key].shape: model_state[key] = state_dict[key] loaded_keys.append(key) else: print(f" ⚠ Shape mismatch for {key}: checkpoint {state_dict[key].shape} vs model {model_state[key].shape}") unexpected_keys.append(key) else: is_deprecated = any(pat in key for pat in DEPRECATED_PATTERNS) if is_deprecated: unexpected_keys.append(key) else: print(f" ⚠ Unexpected key (not in model): {key}") unexpected_keys.append(key) # Second pass: handle missing keys for key in model_state.keys(): if key not in loaded_keys: is_new = any(pat in key for pat in NEW_WEIGHT_PATTERNS) if is_new: # Keep default initialization for new modules initialized_keys.append(key) else: missing_keys.append(key) print(f" ⚠ Missing key (not in checkpoint): {key}") # Load the updated state model.load_state_dict(model_state, strict=False) # Report if initialized_keys: # Group by module for cleaner output modules = set() for k in initialized_keys: parts = k.split('.') if len(parts) >= 2: modules.add(parts[0] + '.' + parts[1] if parts[0] == 'expert_predictor' else parts[0]) print(f" ✓ Initialized new modules (fresh): {sorted(modules)}") if unexpected_keys: deprecated = [k for k in unexpected_keys if any(p in k for p in DEPRECATED_PATTERNS)] if deprecated: print(f" ✓ Ignored deprecated keys: {len(deprecated)} (guidance_in, etc)") return missing_keys, unexpected_keys def load_checkpoint(model, optimizer, scheduler, target): """ Load checkpoint with weight upgrade support for ExpertPredictor. When ALLOW_WEIGHT_UPGRADE=True: - Missing ExpertPredictor weights are initialized fresh - Old guidance_in weights are ignored - Model continues training with new architecture """ start_step = 0 start_epoch = 0 ema_state = None if target == "none": print("Starting fresh (no checkpoint)") return start_step, start_epoch, None ckpt_path = None weights_path = None ema_weights_path = None if target == "latest": if os.path.exists(CHECKPOINT_DIR): ckpts = [f for f in os.listdir(CHECKPOINT_DIR) if f.startswith("step_") and f.endswith(".pt")] if ckpts: steps = [int(f.split("_")[1].split(".")[0]) for f in ckpts] latest_step = max(steps) ckpt_path = os.path.join(CHECKPOINT_DIR, f"step_{latest_step}.pt") weights_path = ckpt_path.replace(".pt", ".safetensors") ema_weights_path = ckpt_path.replace(".pt", "_ema.safetensors") elif target == "hub" or target.startswith("hub:"): try: from huggingface_hub import list_repo_files if target.startswith("hub:"): step_name = target.split(":")[1] weights_path = hf_hub_download(HF_REPO, f"checkpoints/{step_name}.safetensors") try: ema_weights_path = hf_hub_download(HF_REPO, f"checkpoints/{step_name}_ema.safetensors") print(f" Found EMA weights on hub") except: ema_weights_path = None print(f" No EMA weights on hub (will start fresh)") start_step = int(step_name.split("_")[1]) if "_" in step_name else 0 print(f"Downloaded {step_name} from hub") else: files = list_repo_files(HF_REPO) ckpts = [f for f in files if f.startswith("checkpoints/step_") and f.endswith(".safetensors") and "_ema" not in f] if ckpts: steps = [int(f.split("_")[1].split(".")[0]) for f in ckpts] latest = max(steps) weights_path = hf_hub_download(HF_REPO, f"checkpoints/step_{latest}.safetensors") try: ema_weights_path = hf_hub_download(HF_REPO, f"checkpoints/step_{latest}_ema.safetensors") print(f" Found EMA weights on hub") except: ema_weights_path = None print(f" No EMA weights on hub (will start fresh)") start_step = latest print(f"Downloaded step_{latest} from hub") except Exception as e: print(f"Could not download from hub: {e}") return start_step, start_epoch, None elif target == "best": ckpt_path = os.path.join(CHECKPOINT_DIR, "best.pt") weights_path = ckpt_path.replace(".pt", ".safetensors") ema_weights_path = ckpt_path.replace(".pt", "_ema.safetensors") elif os.path.exists(target): if target.endswith(".safetensors"): weights_path = target ckpt_path = target.replace(".safetensors", ".pt") ema_weights_path = target.replace(".safetensors", "_ema.safetensors") else: ckpt_path = target weights_path = target.replace(".pt", ".safetensors") ema_weights_path = target.replace(".pt", "_ema.safetensors") # Load main model weights if weights_path and os.path.exists(weights_path): print(f"Loading weights from {weights_path}") state_dict = load_file(weights_path) state_dict = {k: v.to(DTYPE) if v.is_floating_point() else v for k, v in state_dict.items()} # Get model reference (handle torch.compile wrapper) model_ref = model._orig_mod if hasattr(model, '_orig_mod') else model if ALLOW_WEIGHT_UPGRADE: # Flexible loading with weight upgrade missing, unexpected = load_with_weight_upgrade(model_ref, state_dict) if missing: print(f" ⚠ {len(missing)} truly missing parameters (may need attention)") else: # Strict loading - must match exactly model_ref.load_state_dict(state_dict, strict=True) print(f"✓ Loaded model weights") # Load EMA weights if they exist if ema_weights_path and os.path.exists(ema_weights_path): ema_state = load_file(ema_weights_path) ema_state = {k: v.to(DTYPE) if v.is_floating_point() else v for k, v in ema_state.items()} print(f"✓ Loaded EMA weights ({len(ema_state)} params)") else: print(f" ℹ No EMA weights found (will initialize fresh)") else: print(f" ⚠ Weights file not found: {weights_path}") print(f" Starting with fresh model") return start_step, start_epoch, None # Load optimizer/scheduler state if ckpt_path and os.path.exists(ckpt_path): state = torch.load(ckpt_path, map_location="cpu") start_step = state.get("step", 0) start_epoch = state.get("epoch", 0) try: optimizer.load_state_dict(state["optimizer"]) scheduler.load_state_dict(state["scheduler"]) print(f"✓ Loaded optimizer/scheduler state") except Exception as e: print(f" ⚠ Could not load optimizer state: {e}") print(f" Will use fresh optimizer (this is fine for architecture changes)") print(f"Resuming from step {start_step}, epoch {start_epoch}") return start_step, start_epoch, ema_state # ============================================================================ # CREATE MODEL # ============================================================================ print("\nCreating TinyFluxDeep model with ExpertPredictor...") config = TinyFluxDeepConfig( use_expert_predictor=ENABLE_EXPERT_DISTILLATION, expert_dim=EXPERT_DIM, expert_hidden_dim=EXPERT_HIDDEN_DIM, expert_dropout=EXPERT_DROPOUT, guidance_embeds=False, ) model = TinyFluxDeep(config).to(device=DEVICE, dtype=DTYPE) total_params = sum(p.numel() for p in model.parameters()) print(f"Total parameters: {total_params:,}") if hasattr(model, 'expert_predictor') and model.expert_predictor is not None: expert_params = sum(p.numel() for p in model.expert_predictor.parameters()) print(f"Expert predictor parameters: {expert_params:,}") trainable_params = [p for p in model.parameters() if p.requires_grad] print(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}") # ============================================================================ # OPTIMIZER # ============================================================================ opt = torch.optim.AdamW(trainable_params, lr=LR, betas=(0.9, 0.99), weight_decay=0.01, fused=True) total_steps = len(loader) * EPOCHS // GRAD_ACCUM warmup = min(1000, total_steps // 10) def lr_fn(step): if step < warmup: return step / warmup return 0.5 * (1 + math.cos(math.pi * (step - warmup) / (total_steps - warmup))) sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_fn) # ============================================================================ # LOAD CHECKPOINT # ============================================================================ start_step, start_epoch, ema_state = load_checkpoint(model, opt, sched, LOAD_TARGET) if RESUME_STEP is not None: start_step = RESUME_STEP # ============================================================================ # COMPILE # ============================================================================ model = torch.compile(model, mode="default") # ============================================================================ # EMA # ============================================================================ print("Initializing EMA...") ema = EMA(model, decay=EMA_DECAY) if ema_state is not None: ema.load_shadow(ema_state) else: print(" Starting fresh EMA from current weights") # ============================================================================ # TENSORBOARD # ============================================================================ run_name = f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}" writer = SummaryWriter(os.path.join(LOG_DIR, run_name)) SAMPLE_PROMPTS = [ "a photo of a cat sitting on a windowsill", "a portrait of a woman with red hair", "a black backpack on white background", "a person standing in a t-pose", ] # ============================================================================ # DISTILLATION WEIGHT SCHEDULE # ============================================================================ def get_distill_weight(step): if step < DISTILL_WARMUP_STEPS: return DISTILL_LOSS_WEIGHT * (step / DISTILL_WARMUP_STEPS) return DISTILL_LOSS_WEIGHT # ============================================================================ # TRAINING LOOP # ============================================================================ print(f"\n{'='*60}") print(f"Training TinyFlux-Deep with Expert Distillation (Precached)") print(f"{'='*60}") print(f"Total: {len(combined_ds):,} samples") print(f"Epochs: {EPOCHS}, Steps/epoch: {len(loader)}, Total: {total_steps}") print(f"Batch: {BATCH_SIZE} x {GRAD_ACCUM} = {BATCH_SIZE * GRAD_ACCUM}") print(f"Expert distillation: {ENABLE_EXPERT_DISTILLATION} (PRECACHED)") if ENABLE_EXPERT_DISTILLATION: print(f" - Expert: {EXPERT_CHECKPOINT}") print(f" - Timestep buckets: {len(EXPERT_T_BUCKETS)}") print(f" - Distill weight: {DISTILL_LOSS_WEIGHT} (warmup: {DISTILL_WARMUP_STEPS} steps)") print(f" - Expert dropout: {EXPERT_DROPOUT}") print(f"Masked loss: {USE_MASKED_LOSS}") print(f"Min-SNR gamma: {MIN_SNR_GAMMA}") print(f"Resume: step {start_step}, epoch {start_epoch}") model.train() step = start_step best = float("inf") for ep in range(start_epoch, EPOCHS): ep_loss = 0 ep_main_loss = 0 ep_distill_loss = 0 ep_batches = 0 pbar = tqdm(loader, desc=f"E{ep + 1}") for i, batch in enumerate(pbar): latents = batch["latents"].to(DEVICE, non_blocking=True) t5 = batch["t5_embeds"].to(DEVICE, non_blocking=True) clip = batch["clip_pooled"].to(DEVICE, non_blocking=True) local_indices = batch["local_indices"] dataset_ids = batch["dataset_ids"] masks = batch["masks"] if masks is not None: masks = masks.to(DEVICE, non_blocking=True) B, C, H, W = latents.shape data = latents.permute(0, 2, 3, 1).reshape(B, H * W, C) noise = torch.randn_like(data) if TEXT_DROPOUT > 0: t5, clip, _ = apply_text_dropout(t5, clip, TEXT_DROPOUT) t = torch.sigmoid(torch.randn(B, device=DEVICE)) t = flux_shift(t, s=SHIFT).to(DTYPE).clamp(1e-4, 1 - 1e-4) t_expanded = t.view(B, 1, 1) x_t = (1 - t_expanded) * noise + t_expanded * data v_target = data - noise img_ids = TinyFluxDeep.create_img_ids(B, H, W, DEVICE) # Get expert features from CACHE (fast!) expert_features = None if ENABLE_EXPERT_DISTILLATION: expert_features = get_expert_features_for_batch( local_indices, dataset_ids, t, portrait_expert_cache, schnell_expert_cache, sportfashion_expert_cache, synthmocap_expert_cache, ) # Apply dropout OUTSIDE model (no graph break) if expert_features is not None and random.random() < EXPERT_DROPOUT: expert_features = None with torch.autocast("cuda", dtype=DTYPE): v_pred, expert_info = model( hidden_states=x_t, encoder_hidden_states=t5, pooled_projections=clip, timestep=t, img_ids=img_ids, expert_features=expert_features, return_expert_pred=True, ) # Compute losses snr_weights = min_snr_weight(t) main_loss = masked_mse_loss( v_pred, v_target, mask=masks if USE_MASKED_LOSS else None, fg_weight=FG_LOSS_WEIGHT, bg_weight=BG_LOSS_WEIGHT, snr_weights=snr_weights ) # Distillation loss distill_loss = torch.tensor(0.0, device=DEVICE) if expert_features is not None and expert_info is not None and 'expert_pred' in expert_info: distill_weight = get_distill_weight(step) distill_loss = F.mse_loss(expert_info['expert_pred'], expert_features) total_loss = main_loss + distill_weight * distill_loss else: total_loss = main_loss loss = total_loss / GRAD_ACCUM loss.backward() if (i + 1) % GRAD_ACCUM == 0: grad_norm = torch.nn.utils.clip_grad_norm_(trainable_params, 1.0) opt.step() sched.step() opt.zero_grad(set_to_none=True) ema.update(model) step += 1 if step % LOG_EVERY == 0: writer.add_scalar("train/loss", total_loss.item(), step) writer.add_scalar("train/main_loss", main_loss.item(), step) if ENABLE_EXPERT_DISTILLATION: writer.add_scalar("train/distill_loss", distill_loss.item(), step) writer.add_scalar("train/distill_weight", get_distill_weight(step), step) writer.add_scalar("train/lr", sched.get_last_lr()[0], step) writer.add_scalar("train/grad_norm", grad_norm.item(), step) if step % SAMPLE_EVERY == 0: print(f"\n Generating samples at step {step}...") images = generate_samples(model, SAMPLE_PROMPTS, num_steps=20, use_ema=True) save_samples(images, SAMPLE_PROMPTS, step, SAMPLE_DIR) if step % SAVE_EVERY == 0: ckpt_path = os.path.join(CHECKPOINT_DIR, f"step_{step}.pt") weights_path = save_checkpoint(model, opt, sched, step, ep, total_loss.item(), ckpt_path, ema=ema) if step % UPLOAD_EVERY == 0: upload_checkpoint(weights_path, step) ep_loss += total_loss.item() ep_main_loss += main_loss.item() ep_distill_loss += distill_loss.item() ep_batches += 1 pbar.set_postfix( loss=f"{total_loss.item():.4f}", main=f"{main_loss.item():.4f}", dist=f"{distill_loss.item():.4f}" if ENABLE_EXPERT_DISTILLATION else "off", step=step ) avg = ep_loss / max(ep_batches, 1) avg_main = ep_main_loss / max(ep_batches, 1) avg_distill = ep_distill_loss / max(ep_batches, 1) print(f"Epoch {ep + 1} - total: {avg:.4f}, main: {avg_main:.4f}, distill: {avg_distill:.4f}") if avg < best: best = avg weights_path = save_checkpoint(model, opt, sched, step, ep, avg, os.path.join(CHECKPOINT_DIR, "best.pt"), ema=ema) try: api.upload_file(path_or_fileobj=weights_path, path_in_repo="model.safetensors", repo_id=HF_REPO) except: pass print(f"\n✓ Training complete! Best loss: {best:.4f}") writer.close()