| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.utils.data import DataLoader, Dataset |
| | from datasets import load_dataset, concatenate_datasets |
| | from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer |
| | from huggingface_hub import HfApi, hf_hub_download |
| | from safetensors.torch import save_file, load_file |
| | from torch.utils.tensorboard import SummaryWriter |
| | from tqdm.auto import tqdm |
| | import numpy as np |
| | import math |
| | import json |
| | import random |
| | from typing import Tuple, Optional, Dict, List |
| | import os |
| | from datetime import datetime |
| | from PIL import Image |
| |
|
| | |
| | |
| | |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | torch.backends.cudnn.allow_tf32 = True |
| | torch.backends.cudnn.benchmark = True |
| | torch.set_float32_matmul_precision('high') |
| |
|
| | import warnings |
| | warnings.filterwarnings('ignore', message='.*TF32.*') |
| |
|
| | |
| | |
| | |
| | BATCH_SIZE = 16 |
| | GRAD_ACCUM = 2 |
| | LR = 3e-4 |
| | EPOCHS = 40 |
| | MAX_SEQ = 128 |
| | SHIFT = 3.0 |
| | DEVICE = "cuda" |
| | DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 |
| |
|
| | ALLOW_WEIGHT_UPGRADE = True |
| |
|
| | |
| | HF_REPO = "AbstractPhil/tiny-flux-deep" |
| | SAVE_EVERY = 625 |
| | UPLOAD_EVERY = 625 |
| | SAMPLE_EVERY = 312 |
| | LOG_EVERY = 10 |
| | LOG_UPLOAD_EVERY = 625 |
| |
|
| | |
| | LOAD_TARGET = "hub:step_305000" |
| | RESUME_STEP = None |
| |
|
| | |
| | |
| | |
| | ENABLE_EXPERT_DISTILLATION = True |
| | EXPERT_CHECKPOINT = "AbstractPhil/sd15-flow-lune-flux" |
| | EXPERT_CHECKPOINT_PATH = "flux_t2_6_pose_t4_6_port_t1_4/checkpoint-00018765/unet/diffusion_pytorch_model.safetensors" |
| | EXPERT_DIM = 1280 |
| | EXPERT_HIDDEN_DIM = 512 |
| | EXPERT_DROPOUT = 0.1 |
| | DISTILL_LOSS_WEIGHT = 0.1 |
| | DISTILL_WARMUP_STEPS = 1000 |
| |
|
| | |
| | EXPERT_T_BUCKETS = torch.linspace(0.05, 0.95, 10) |
| |
|
| | |
| | |
| | |
| | ENABLE_PORTRAIT = False |
| | ENABLE_SCHNELL = True |
| | ENABLE_SPORTFASHION = False |
| | ENABLE_SYNTHMOCAP = False |
| |
|
| | PORTRAIT_REPO = "AbstractPhil/ffhq_flux_latents_repaired" |
| | PORTRAIT_NUM_SHARDS = 11 |
| | SCHNELL_REPO = "AbstractPhil/flux-schnell-teacher-latents" |
| | SCHNELL_CONFIGS = ["train_512"] |
| | SPORTFASHION_REPO = "Pianokill/SportFashion_512x512" |
| | SYNTHMOCAP_REPO = "toyxyz/SynthMoCap_smpl_512" |
| |
|
| | FG_LOSS_WEIGHT = 2.0 |
| | BG_LOSS_WEIGHT = 0.5 |
| | USE_MASKED_LOSS = False |
| | MIN_SNR_GAMMA = 5.0 |
| |
|
| | |
| | CHECKPOINT_DIR = "./tiny_flux_deep_checkpoints" |
| | LOG_DIR = "./tiny_flux_deep_logs" |
| | SAMPLE_DIR = "./tiny_flux_deep_samples" |
| | ENCODING_CACHE_DIR = "./encoding_cache" |
| | LATENT_CACHE_DIR = "./latent_cache" |
| |
|
| | os.makedirs(CHECKPOINT_DIR, exist_ok=True) |
| | os.makedirs(LOG_DIR, exist_ok=True) |
| | os.makedirs(SAMPLE_DIR, exist_ok=True) |
| | os.makedirs(ENCODING_CACHE_DIR, exist_ok=True) |
| | os.makedirs(LATENT_CACHE_DIR, exist_ok=True) |
| |
|
| | |
| | |
| | |
| | TEXT_DROPOUT = 0.1 |
| | GUIDANCE_DROPOUT = 0.1 |
| | EMA_DECAY = 0.9999 |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class ExpertFeatureCache: |
| | """ |
| | Precached SD1.5-flow expert features with timestep interpolation. |
| | |
| | Features extracted at 10 timestep buckets [0.05, 0.15, ..., 0.95]. |
| | At runtime, interpolates between nearest buckets. |
| | """ |
| | |
| | def __init__(self, features: torch.Tensor, t_buckets: torch.Tensor, dtype=torch.float16): |
| | self.features = features.to(dtype) |
| | self.t_buckets = t_buckets |
| | self.t_min = t_buckets[0].item() |
| | self.t_max = t_buckets[-1].item() |
| | self.t_step = (t_buckets[1] - t_buckets[0]).item() |
| | self.n_buckets = len(t_buckets) |
| | self.dtype = dtype |
| | |
| | def get_features(self, indices: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Get interpolated expert features. |
| | |
| | Args: |
| | indices: [B] sample indices into dataset |
| | timesteps: [B] timesteps in [0, 1] |
| | |
| | Returns: |
| | [B, 1280] interpolated features |
| | """ |
| | device = timesteps.device |
| | |
| | |
| | t_clamped = timesteps.float().clamp(self.t_min, self.t_max) |
| | |
| | |
| | t_idx_float = (t_clamped - self.t_min) / self.t_step |
| | t_idx_low = t_idx_float.long().clamp(0, self.n_buckets - 2) |
| | t_idx_high = (t_idx_low + 1).clamp(0, self.n_buckets - 1) |
| | |
| | |
| | alpha = (t_idx_float - t_idx_low.float()).unsqueeze(-1) |
| | |
| | |
| | idx_cpu = indices.cpu() |
| | t_low_cpu = t_idx_low.cpu() |
| | t_high_cpu = t_idx_high.cpu() |
| | |
| | f_low = self.features[idx_cpu, t_low_cpu] |
| | f_high = self.features[idx_cpu, t_high_cpu] |
| | |
| | |
| | result = (1 - alpha.cpu()) * f_low + alpha.cpu() * f_high |
| | return result.to(device=device, dtype=self.dtype) |
| |
|
| |
|
| | def load_or_extract_expert_features(cache_path: str, prompts: List[str], name: str, |
| | clip_tok, clip_enc, t_buckets: torch.Tensor, |
| | batch_size: int = 32) -> Optional[ExpertFeatureCache]: |
| | """ |
| | Load cached expert features or extract them from SD1.5-flow. |
| | Follows same pattern as load_or_encode for text embeddings. |
| | """ |
| | if not prompts or not ENABLE_EXPERT_DISTILLATION: |
| | return None |
| | |
| | |
| | if os.path.exists(cache_path): |
| | print(f"Loading cached {name} expert features...") |
| | cached = torch.load(cache_path, map_location="cpu") |
| | cache = ExpertFeatureCache(cached["features"], cached["t_buckets"], DTYPE) |
| | print(f" ✓ Loaded {cache.features.shape[0]} samples × {cache.n_buckets} timesteps") |
| | return cache |
| | |
| | |
| | print(f"Extracting {name} expert features ({len(prompts)} × {len(t_buckets)} timesteps)...") |
| | print(f" This is a one-time operation, will be cached for future runs.") |
| | |
| | |
| | checkpoint_path = hf_hub_download( |
| | repo_id=EXPERT_CHECKPOINT, |
| | filename=EXPERT_CHECKPOINT_PATH, |
| | ) |
| | |
| | from diffusers import UNet2DConditionModel |
| | unet = UNet2DConditionModel.from_pretrained( |
| | "stable-diffusion-v1-5/stable-diffusion-v1-5", |
| | subfolder="unet", |
| | torch_dtype=DTYPE, |
| | ).to(DEVICE).eval() |
| | |
| | state_dict = load_file(checkpoint_path) |
| | unet.load_state_dict(state_dict, strict=False) |
| | |
| | for p in unet.parameters(): |
| | p.requires_grad = False |
| | |
| | |
| | mid_features = [None] |
| | def hook_fn(module, inp, out): |
| | mid_features[0] = out.mean(dim=[2, 3]) |
| | unet.mid_block.register_forward_hook(hook_fn) |
| | |
| | |
| | n_prompts = len(prompts) |
| | n_buckets = len(t_buckets) |
| | all_features = torch.zeros(n_prompts, n_buckets, EXPERT_DIM, dtype=torch.float16) |
| | |
| | with torch.no_grad(): |
| | for start_idx in tqdm(range(0, n_prompts, batch_size), desc=f"Extracting {name}"): |
| | end_idx = min(start_idx + batch_size, n_prompts) |
| | batch_prompts = prompts[start_idx:end_idx] |
| | B = len(batch_prompts) |
| | |
| | |
| | clip_inputs = clip_tok( |
| | batch_prompts, return_tensors="pt", padding="max_length", |
| | max_length=77, truncation=True |
| | ).to(DEVICE) |
| | clip_hidden = clip_enc(**clip_inputs).last_hidden_state |
| | |
| | |
| | for t_idx, t_val in enumerate(t_buckets): |
| | timesteps = torch.full((B,), t_val.item(), device=DEVICE) |
| | latents = torch.randn(B, 4, 64, 64, device=DEVICE, dtype=DTYPE) |
| | |
| | _ = unet(latents, timesteps * 1000, encoder_hidden_states=clip_hidden.to(DTYPE)) |
| | |
| | all_features[start_idx:end_idx, t_idx] = mid_features[0].cpu().to(torch.float16) |
| | |
| | |
| | del unet |
| | torch.cuda.empty_cache() |
| | |
| | |
| | torch.save({"features": all_features, "t_buckets": t_buckets}, cache_path) |
| | print(f" ✓ Cached to {cache_path}") |
| | print(f" Size: {all_features.numel() * 2 / 1e9:.2f} GB") |
| | |
| | return ExpertFeatureCache(all_features, t_buckets, DTYPE) |
| |
|
| |
|
| | |
| | |
| | |
| | class EMA: |
| | def __init__(self, model, decay=0.9999): |
| | self.decay = decay |
| | self.shadow = {} |
| | self._backup = {} |
| | if hasattr(model, '_orig_mod'): |
| | state = model._orig_mod.state_dict() |
| | else: |
| | state = model.state_dict() |
| | for k, v in state.items(): |
| | self.shadow[k] = v.clone().detach() |
| | |
| | @torch.no_grad() |
| | def update(self, model): |
| | if hasattr(model, '_orig_mod'): |
| | state = model._orig_mod.state_dict() |
| | else: |
| | state = model.state_dict() |
| | for k, v in state.items(): |
| | if k in self.shadow: |
| | self.shadow[k].lerp_(v.to(self.shadow[k].dtype), 1 - self.decay) |
| | |
| | def apply_shadow_for_eval(self, model): |
| | if hasattr(model, '_orig_mod'): |
| | self._backup = {k: v.clone() for k, v in model._orig_mod.state_dict().items()} |
| | model._orig_mod.load_state_dict(self.shadow) |
| | else: |
| | self._backup = {k: v.clone() for k, v in model.state_dict().items()} |
| | model.load_state_dict(self.shadow) |
| | |
| | def restore(self, model): |
| | if hasattr(model, '_orig_mod'): |
| | model._orig_mod.load_state_dict(self._backup) |
| | else: |
| | model.load_state_dict(self._backup) |
| | self._backup = {} |
| | |
| | def state_dict(self): |
| | return {'shadow': self.shadow, 'decay': self.decay} |
| | |
| | def load_state_dict(self, state): |
| | self.shadow = {k: v.clone() for k, v in state['shadow'].items()} |
| | self.decay = state.get('decay', self.decay) |
| | |
| | def load_shadow(self, shadow_state): |
| | """Load EMA shadow weights, handling architecture changes gracefully.""" |
| | device = next(iter(self.shadow.values())).device if self.shadow else 'cuda' |
| | |
| | loaded = 0 |
| | skipped_old = 0 |
| | kept_new = 0 |
| | |
| | for k, v in shadow_state.items(): |
| | if k in self.shadow: |
| | |
| | self.shadow[k] = v.clone().to(device) |
| | loaded += 1 |
| | else: |
| | |
| | skipped_old += 1 |
| | |
| | |
| | for k in self.shadow: |
| | if k not in shadow_state: |
| | kept_new += 1 |
| | |
| | print(f" ✓ Restored EMA: {loaded} loaded, {skipped_old} deprecated skipped, {kept_new} new (fresh init)") |
| |
|
| |
|
| | |
| | |
| | |
| | def apply_text_dropout(t5_embeds, clip_pooled, dropout_prob=0.1): |
| | B = t5_embeds.shape[0] |
| | mask = torch.rand(B, device=t5_embeds.device) < dropout_prob |
| | t5_embeds = t5_embeds.clone() |
| | clip_pooled = clip_pooled.clone() |
| | t5_embeds[mask] = 0 |
| | clip_pooled[mask] = 0 |
| | return t5_embeds, clip_pooled, mask |
| |
|
| |
|
| | |
| | |
| | |
| | def detect_background_color(image: Image.Image, sample_size: int = 100) -> Tuple[int, int, int]: |
| | img = np.array(image) |
| | if len(img.shape) == 2: |
| | img = np.stack([img] * 3, axis=-1) |
| | h, w = img.shape[:2] |
| | corners = [ |
| | img[:sample_size, :sample_size], |
| | img[:sample_size, -sample_size:], |
| | img[-sample_size:, :sample_size], |
| | img[-sample_size:, -sample_size:], |
| | ] |
| | corner_pixels = np.concatenate([c.reshape(-1, 3) for c in corners], axis=0) |
| | bg_color = np.median(corner_pixels, axis=0).astype(np.uint8) |
| | return tuple(bg_color) |
| |
|
| |
|
| | def create_product_mask(image: Image.Image, threshold: int = 30) -> np.ndarray: |
| | img = np.array(image).astype(np.float32) |
| | if len(img.shape) == 2: |
| | img = np.stack([img] * 3, axis=-1) |
| | bg_color = detect_background_color(image) |
| | bg_color = np.array(bg_color, dtype=np.float32) |
| | diff = np.sqrt(np.sum((img - bg_color) ** 2, axis=-1)) |
| | mask = (diff > threshold).astype(np.float32) |
| | return mask |
| |
|
| |
|
| | def create_smpl_mask(conditioning_image: Image.Image, threshold: int = 20) -> np.ndarray: |
| | img = np.array(conditioning_image).astype(np.float32) |
| | if len(img.shape) == 2: |
| | return (img > threshold).astype(np.float32) |
| | r, g, b = img[:, :, 0], img[:, :, 1], img[:, :, 2] |
| | is_background = (g > r + 20) & (g > b + 20) |
| | mask = (~is_background).astype(np.float32) |
| | return mask |
| |
|
| |
|
| | def downsample_mask_to_latent(mask: np.ndarray, latent_h: int = 64, latent_w: int = 64) -> torch.Tensor: |
| | mask_pil = Image.fromarray((mask * 255).astype(np.uint8)) |
| | mask_pil = mask_pil.resize((latent_w, latent_h), Image.Resampling.BILINEAR) |
| | mask_latent = np.array(mask_pil).astype(np.float32) / 255.0 |
| | return torch.from_numpy(mask_latent) |
| |
|
| |
|
| | |
| | |
| | |
| | print("Setting up HuggingFace Hub...") |
| | api = HfApi() |
| |
|
| |
|
| | |
| | |
| | |
| | def flux_shift(t, s=SHIFT): |
| | return s * t / (1 + (s - 1) * t) |
| |
|
| | def min_snr_weight(t, gamma=MIN_SNR_GAMMA): |
| | snr = (t / (1 - t).clamp(min=1e-5)).pow(2) |
| | return torch.clamp(snr, max=gamma) / snr.clamp(min=1e-5) |
| |
|
| |
|
| | |
| | |
| | |
| | print("Loading text encoders...") |
| | t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") |
| | t5_enc = T5EncoderModel.from_pretrained("google/flan-t5-base", torch_dtype=DTYPE).to(DEVICE).eval() |
| | for p in t5_enc.parameters(): |
| | p.requires_grad = False |
| |
|
| | clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
| | clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval() |
| | for p in clip_enc.parameters(): |
| | p.requires_grad = False |
| | print("✓ Text encoders loaded") |
| |
|
| |
|
| | |
| | |
| | |
| | print("Loading VAE...") |
| | from diffusers import AutoencoderKL |
| | vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=DTYPE).to(DEVICE).eval() |
| | for p in vae.parameters(): |
| | p.requires_grad = False |
| | VAE_SCALE = vae.config.scaling_factor |
| | print(f"✓ VAE loaded (scale={VAE_SCALE})") |
| |
|
| |
|
| | |
| | |
| | |
| | @torch.no_grad() |
| | def encode_prompt(prompt: str) -> Tuple[torch.Tensor, torch.Tensor]: |
| | t5_inputs = t5_tok(prompt, return_tensors="pt", padding="max_length", |
| | max_length=MAX_SEQ, truncation=True).to(DEVICE) |
| | t5_out = t5_enc(**t5_inputs).last_hidden_state |
| | clip_inputs = clip_tok(prompt, return_tensors="pt", padding="max_length", |
| | max_length=77, truncation=True).to(DEVICE) |
| | clip_out = clip_enc(**clip_inputs).pooler_output |
| | return t5_out.squeeze(0), clip_out.squeeze(0) |
| |
|
| |
|
| | @torch.no_grad() |
| | def encode_prompts_batched(prompts: List[str], batch_size: int = 64) -> Tuple[torch.Tensor, torch.Tensor]: |
| | all_t5 = [] |
| | all_clip = [] |
| | for i in tqdm(range(0, len(prompts), batch_size), desc="Encoding", leave=False): |
| | batch = prompts[i:i+batch_size] |
| | t5_inputs = t5_tok(batch, return_tensors="pt", padding="max_length", |
| | max_length=MAX_SEQ, truncation=True).to(DEVICE) |
| | t5_out = t5_enc(**t5_inputs).last_hidden_state |
| | all_t5.append(t5_out.cpu()) |
| | clip_inputs = clip_tok(batch, return_tensors="pt", padding="max_length", |
| | max_length=77, truncation=True).to(DEVICE) |
| | clip_out = clip_enc(**clip_inputs).pooler_output |
| | all_clip.append(clip_out.cpu()) |
| | return torch.cat(all_t5, dim=0), torch.cat(all_clip, dim=0) |
| |
|
| |
|
| | @torch.no_grad() |
| | def encode_image_to_latent(image: Image.Image) -> torch.Tensor: |
| | if image.mode != "RGB": |
| | image = image.convert("RGB") |
| | if image.size != (512, 512): |
| | image = image.resize((512, 512), Image.Resampling.LANCZOS) |
| | img_tensor = torch.from_numpy(np.array(image)).float() / 255.0 |
| | img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0) |
| | img_tensor = (img_tensor * 2.0 - 1.0).to(DEVICE, dtype=DTYPE) |
| | latent = vae.encode(img_tensor).latent_dist.sample() |
| | latent = latent * VAE_SCALE |
| | return latent.squeeze(0).cpu() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | portrait_ds = None |
| | portrait_indices = [] |
| | portrait_prompts = [] |
| |
|
| | if ENABLE_PORTRAIT: |
| | print(f"\n[1/4] Loading portrait dataset from {PORTRAIT_REPO}...") |
| | portrait_shards = [] |
| | for i in range(PORTRAIT_NUM_SHARDS): |
| | split_name = f"train_{i:02d}" |
| | print(f" Loading {split_name}...") |
| | shard = load_dataset(PORTRAIT_REPO, split=split_name) |
| | portrait_shards.append(shard) |
| | portrait_ds = concatenate_datasets(portrait_shards) |
| | print(f"✓ Portrait: {len(portrait_ds)} base samples") |
| | print(" Extracting prompts (columnar)...") |
| | florence_list = list(portrait_ds["text_florence"]) |
| | llava_list = list(portrait_ds["text_llava"]) |
| | blip_list = list(portrait_ds["text_blip"]) |
| | for i, (f, l, b) in enumerate(zip(florence_list, llava_list, blip_list)): |
| | if f and f.strip(): |
| | portrait_indices.append(i) |
| | portrait_prompts.append(f) |
| | if l and l.strip(): |
| | portrait_indices.append(i) |
| | portrait_prompts.append(l) |
| | if b and b.strip(): |
| | portrait_indices.append(i) |
| | portrait_prompts.append(b) |
| | print(f" Expanded: {len(portrait_prompts)} samples (3 prompts/image)") |
| | else: |
| | print("\n[1/4] Portrait dataset DISABLED") |
| |
|
| | schnell_ds = None |
| | schnell_prompts = [] |
| |
|
| | if ENABLE_SCHNELL: |
| | print(f"\n[2/4] Loading schnell teacher dataset from {SCHNELL_REPO}...") |
| | schnell_datasets = [] |
| | for config in SCHNELL_CONFIGS: |
| | print(f" Loading {config}...") |
| | ds = load_dataset(SCHNELL_REPO, config, split="train") |
| | schnell_datasets.append(ds) |
| | print(f" {len(ds)} samples") |
| | schnell_ds = concatenate_datasets(schnell_datasets) |
| | schnell_prompts = list(schnell_ds["prompt"]) |
| | print(f"✓ Schnell: {len(schnell_ds)} samples") |
| | else: |
| | print("\n[2/4] Schnell dataset DISABLED") |
| |
|
| | sportfashion_ds = None |
| | sportfashion_prompts = [] |
| |
|
| | if ENABLE_SPORTFASHION: |
| | print(f"\n[3/4] Loading SportFashion dataset from {SPORTFASHION_REPO}...") |
| | sportfashion_ds = load_dataset(SPORTFASHION_REPO, split="train") |
| | sportfashion_prompts = list(sportfashion_ds["text"]) |
| | print(f"✓ SportFashion: {len(sportfashion_ds)} samples") |
| | else: |
| | print("\n[3/4] SportFashion dataset DISABLED") |
| |
|
| | synthmocap_ds = None |
| | synthmocap_prompts = [] |
| |
|
| | if ENABLE_SYNTHMOCAP: |
| | print(f"\n[4/4] Loading SynthMoCap dataset from {SYNTHMOCAP_REPO}...") |
| | synthmocap_ds = load_dataset(SYNTHMOCAP_REPO, split="train") |
| | synthmocap_prompts = list(synthmocap_ds["text"]) |
| | print(f"✓ SynthMoCap: {len(synthmocap_ds)} samples") |
| | else: |
| | print("\n[4/4] SynthMoCap dataset DISABLED") |
| |
|
| |
|
| | |
| | |
| | |
| | total_samples = len(portrait_prompts) + len(schnell_prompts) + len(sportfashion_prompts) + len(synthmocap_prompts) |
| | print(f"\nTotal combined samples: {total_samples}") |
| |
|
| | def load_or_encode(cache_path, prompts, name): |
| | if not prompts: |
| | return None, None |
| | if os.path.exists(cache_path): |
| | print(f"Loading cached {name} encodings...") |
| | cached = torch.load(cache_path) |
| | return cached["t5_embeds"], cached["clip_pooled"] |
| | else: |
| | print(f"Encoding {len(prompts)} {name} prompts...") |
| | t5, clip = encode_prompts_batched(prompts, batch_size=64) |
| | torch.save({"t5_embeds": t5, "clip_pooled": clip}, cache_path) |
| | print(f"✓ Cached to {cache_path}") |
| | return t5, clip |
| |
|
| |
|
| | |
| | portrait_t5, portrait_clip = None, None |
| | schnell_t5, schnell_clip = None, None |
| | sportfashion_t5, sportfashion_clip = None, None |
| | synthmocap_t5, synthmocap_clip = None, None |
| |
|
| | if portrait_prompts: |
| | portrait_enc_cache = os.path.join(ENCODING_CACHE_DIR, f"portrait_encodings_{len(portrait_prompts)}.pt") |
| | portrait_t5, portrait_clip = load_or_encode(portrait_enc_cache, portrait_prompts, "portrait") |
| |
|
| | if schnell_prompts: |
| | schnell_enc_cache = os.path.join(ENCODING_CACHE_DIR, f"schnell_encodings_{len(schnell_prompts)}.pt") |
| | schnell_t5, schnell_clip = load_or_encode(schnell_enc_cache, schnell_prompts, "schnell") |
| |
|
| | if sportfashion_prompts: |
| | sportfashion_enc_cache = os.path.join(ENCODING_CACHE_DIR, f"sportfashion_encodings_{len(sportfashion_prompts)}.pt") |
| | sportfashion_t5, sportfashion_clip = load_or_encode(sportfashion_enc_cache, sportfashion_prompts, "sportfashion") |
| |
|
| | if synthmocap_prompts: |
| | synthmocap_enc_cache = os.path.join(ENCODING_CACHE_DIR, f"synthmocap_encodings_{len(synthmocap_prompts)}.pt") |
| | synthmocap_t5, synthmocap_clip = load_or_encode(synthmocap_enc_cache, synthmocap_prompts, "synthmocap") |
| |
|
| |
|
| | |
| | |
| | |
| | print("\n" + "="*60) |
| | print("Expert Feature Caching") |
| | print("="*60) |
| |
|
| | schnell_expert_cache = None |
| | portrait_expert_cache = None |
| | sportfashion_expert_cache = None |
| | synthmocap_expert_cache = None |
| |
|
| | if schnell_prompts and ENABLE_EXPERT_DISTILLATION: |
| | schnell_expert_path = os.path.join(ENCODING_CACHE_DIR, f"schnell_expert_{len(schnell_prompts)}.pt") |
| | schnell_expert_cache = load_or_extract_expert_features( |
| | schnell_expert_path, schnell_prompts, "schnell", |
| | clip_tok, clip_enc, EXPERT_T_BUCKETS |
| | ) |
| |
|
| | if portrait_prompts and ENABLE_EXPERT_DISTILLATION: |
| | portrait_expert_path = os.path.join(ENCODING_CACHE_DIR, f"portrait_expert_{len(portrait_prompts)}.pt") |
| | portrait_expert_cache = load_or_extract_expert_features( |
| | portrait_expert_path, portrait_prompts, "portrait", |
| | clip_tok, clip_enc, EXPERT_T_BUCKETS |
| | ) |
| |
|
| | if sportfashion_prompts and ENABLE_EXPERT_DISTILLATION: |
| | sportfashion_expert_path = os.path.join(ENCODING_CACHE_DIR, f"sportfashion_expert_{len(sportfashion_prompts)}.pt") |
| | sportfashion_expert_cache = load_or_extract_expert_features( |
| | sportfashion_expert_path, sportfashion_prompts, "sportfashion", |
| | clip_tok, clip_enc, EXPERT_T_BUCKETS |
| | ) |
| |
|
| | if synthmocap_prompts and ENABLE_EXPERT_DISTILLATION: |
| | synthmocap_expert_path = os.path.join(ENCODING_CACHE_DIR, f"synthmocap_expert_{len(synthmocap_prompts)}.pt") |
| | synthmocap_expert_cache = load_or_extract_expert_features( |
| | synthmocap_expert_path, synthmocap_prompts, "synthmocap", |
| | clip_tok, clip_enc, EXPERT_T_BUCKETS |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| | class CombinedDataset(Dataset): |
| | """Combined dataset returning sample index for expert feature lookup.""" |
| | |
| | def __init__( |
| | self, |
| | portrait_ds, portrait_indices, portrait_t5, portrait_clip, |
| | schnell_ds, schnell_t5, schnell_clip, |
| | sportfashion_ds, sportfashion_t5, sportfashion_clip, |
| | synthmocap_ds, synthmocap_t5, synthmocap_clip, |
| | vae, vae_scale, device, dtype, |
| | compute_masks=True, |
| | ): |
| | self.portrait_ds = portrait_ds |
| | self.portrait_indices = portrait_indices |
| | self.portrait_t5 = portrait_t5 |
| | self.portrait_clip = portrait_clip |
| | |
| | self.schnell_ds = schnell_ds |
| | self.schnell_t5 = schnell_t5 |
| | self.schnell_clip = schnell_clip |
| | |
| | self.sportfashion_ds = sportfashion_ds |
| | self.sportfashion_t5 = sportfashion_t5 |
| | self.sportfashion_clip = sportfashion_clip |
| | |
| | self.synthmocap_ds = synthmocap_ds |
| | self.synthmocap_t5 = synthmocap_t5 |
| | self.synthmocap_clip = synthmocap_clip |
| | |
| | self.vae = vae |
| | self.vae_scale = vae_scale |
| | self.device = device |
| | self.dtype = dtype |
| | self.compute_masks = compute_masks |
| | |
| | self.n_portrait = len(portrait_indices) if portrait_indices else 0 |
| | self.n_schnell = len(schnell_ds) if schnell_ds else 0 |
| | self.n_sportfashion = len(sportfashion_ds) if sportfashion_ds else 0 |
| | self.n_synthmocap = len(synthmocap_ds) if synthmocap_ds else 0 |
| | |
| | self.c1 = self.n_portrait |
| | self.c2 = self.c1 + self.n_schnell |
| | self.c3 = self.c2 + self.n_sportfashion |
| | self.total = self.c3 + self.n_synthmocap |
| | |
| | def __len__(self): |
| | return self.total |
| | |
| | def _get_latent_from_array(self, latent_data): |
| | if isinstance(latent_data, torch.Tensor): |
| | return latent_data.to(self.dtype) |
| | return torch.tensor(np.array(latent_data), dtype=self.dtype) |
| | |
| | @torch.no_grad() |
| | def _encode_image(self, image): |
| | if image.mode != "RGB": |
| | image = image.convert("RGB") |
| | if image.size != (512, 512): |
| | image = image.resize((512, 512), Image.Resampling.LANCZOS) |
| | img_tensor = torch.from_numpy(np.array(image)).float() / 255.0 |
| | img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0) |
| | img_tensor = (img_tensor * 2.0 - 1.0).to(self.device, dtype=self.dtype) |
| | latent = self.vae.encode(img_tensor).latent_dist.sample() |
| | latent = latent * self.vae_scale |
| | return latent.squeeze(0).cpu() |
| | |
| | def __getitem__(self, idx): |
| | mask = None |
| | |
| | |
| | if idx < self.c1: |
| | |
| | local_idx = idx |
| | orig_idx = self.portrait_indices[idx] |
| | item = self.portrait_ds[orig_idx] |
| | latent = self._get_latent_from_array(item["latent"]) |
| | t5 = self.portrait_t5[idx] |
| | clip = self.portrait_clip[idx] |
| | dataset_id = 0 |
| | |
| | elif idx < self.c2: |
| | |
| | local_idx = idx - self.c1 |
| | item = self.schnell_ds[local_idx] |
| | latent = self._get_latent_from_array(item["latent"]) |
| | t5 = self.schnell_t5[local_idx] |
| | clip = self.schnell_clip[local_idx] |
| | dataset_id = 1 |
| | |
| | elif idx < self.c3: |
| | |
| | local_idx = idx - self.c2 |
| | item = self.sportfashion_ds[local_idx] |
| | image = item["image"] |
| | latent = self._encode_image(image) |
| | t5 = self.sportfashion_t5[local_idx] |
| | clip = self.sportfashion_clip[local_idx] |
| | dataset_id = 2 |
| | if self.compute_masks: |
| | pixel_mask = create_product_mask(image) |
| | mask = downsample_mask_to_latent(pixel_mask, 64, 64) |
| | |
| | else: |
| | |
| | local_idx = idx - self.c3 |
| | item = self.synthmocap_ds[local_idx] |
| | image = item["image"] |
| | conditioning = item["conditioning_image"] |
| | latent = self._encode_image(image) |
| | t5 = self.synthmocap_t5[local_idx] |
| | clip = self.synthmocap_clip[local_idx] |
| | dataset_id = 3 |
| | if self.compute_masks: |
| | pixel_mask = create_smpl_mask(conditioning) |
| | mask = downsample_mask_to_latent(pixel_mask, 64, 64) |
| | |
| | result = { |
| | "latent": latent, |
| | "t5_embed": t5.to(self.dtype), |
| | "clip_pooled": clip.to(self.dtype), |
| | "sample_idx": idx, |
| | "local_idx": local_idx, |
| | "dataset_id": dataset_id, |
| | } |
| | |
| | if mask is not None: |
| | result["mask"] = mask.to(self.dtype) |
| | |
| | return result |
| |
|
| |
|
| | |
| | |
| | |
| | def collate_fn(batch): |
| | latents = torch.stack([b["latent"] for b in batch]) |
| | t5_embeds = torch.stack([b["t5_embed"] for b in batch]) |
| | clip_pooled = torch.stack([b["clip_pooled"] for b in batch]) |
| | sample_indices = torch.tensor([b["sample_idx"] for b in batch], dtype=torch.long) |
| | local_indices = torch.tensor([b["local_idx"] for b in batch], dtype=torch.long) |
| | dataset_ids = torch.tensor([b["dataset_id"] for b in batch], dtype=torch.long) |
| | |
| | masks = None |
| | if any("mask" in b for b in batch): |
| | masks = [] |
| | for b in batch: |
| | if "mask" in b: |
| | masks.append(b["mask"]) |
| | else: |
| | masks.append(torch.ones(64, 64, dtype=latents.dtype)) |
| | masks = torch.stack(masks) |
| | |
| | return { |
| | "latents": latents, |
| | "t5_embeds": t5_embeds, |
| | "clip_pooled": clip_pooled, |
| | "sample_indices": sample_indices, |
| | "local_indices": local_indices, |
| | "dataset_ids": dataset_ids, |
| | "masks": masks, |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| | def get_expert_features_for_batch( |
| | local_indices: torch.Tensor, |
| | dataset_ids: torch.Tensor, |
| | timesteps: torch.Tensor, |
| | portrait_cache: Optional[ExpertFeatureCache], |
| | schnell_cache: Optional[ExpertFeatureCache], |
| | sportfashion_cache: Optional[ExpertFeatureCache], |
| | synthmocap_cache: Optional[ExpertFeatureCache], |
| | ) -> Optional[torch.Tensor]: |
| | """Get expert features from the appropriate cache for each sample.""" |
| | |
| | caches = [portrait_cache, schnell_cache, sportfashion_cache, synthmocap_cache] |
| | |
| | |
| | if not any(c is not None for c in caches): |
| | return None |
| | |
| | B = local_indices.shape[0] |
| | device = timesteps.device |
| | features = torch.zeros(B, EXPERT_DIM, device=device, dtype=DTYPE) |
| | |
| | for ds_id, cache in enumerate(caches): |
| | if cache is None: |
| | continue |
| | |
| | |
| | mask = dataset_ids == ds_id |
| | if not mask.any(): |
| | continue |
| | |
| | |
| | ds_local_indices = local_indices[mask] |
| | ds_timesteps = timesteps[mask] |
| | ds_features = cache.get_features(ds_local_indices, ds_timesteps) |
| | features[mask] = ds_features |
| | |
| | return features |
| |
|
| |
|
| | |
| | |
| | |
| | def masked_mse_loss(pred, target, mask=None, fg_weight=2.0, bg_weight=0.5, snr_weights=None): |
| | B, N, C = pred.shape |
| | if mask is None: |
| | loss_per_sample = ((pred - target) ** 2).mean(dim=[1, 2]) |
| | else: |
| | H = W = int(math.sqrt(N)) |
| | mask_flat = mask.view(B, H * W, 1).to(pred.device) |
| | sq_error = (pred - target) ** 2 |
| | weights = mask_flat * fg_weight + (1 - mask_flat) * bg_weight |
| | weighted_error = sq_error * weights |
| | loss_per_sample = weighted_error.mean(dim=[1, 2]) |
| | if snr_weights is not None: |
| | loss_per_sample = loss_per_sample * snr_weights |
| | return loss_per_sample.mean() |
| |
|
| |
|
| | |
| | |
| | |
| | print("\nCreating combined dataset...") |
| | combined_ds = CombinedDataset( |
| | portrait_ds, portrait_indices, portrait_t5, portrait_clip, |
| | schnell_ds, schnell_t5, schnell_clip, |
| | sportfashion_ds, sportfashion_t5, sportfashion_clip, |
| | synthmocap_ds, synthmocap_t5, synthmocap_clip, |
| | vae, VAE_SCALE, DEVICE, DTYPE, |
| | compute_masks=USE_MASKED_LOSS, |
| | ) |
| | print(f"✓ Combined dataset: {len(combined_ds)} samples") |
| | print(f" - Portraits (3x): {combined_ds.n_portrait:,}") |
| | print(f" - Schnell teacher: {combined_ds.n_schnell:,}") |
| | print(f" - SportFashion: {combined_ds.n_sportfashion:,}") |
| | print(f" - SynthMoCap: {combined_ds.n_synthmocap:,}") |
| | print(f" - Expert distillation: {ENABLE_EXPERT_DISTILLATION}") |
| |
|
| |
|
| | |
| | |
| | |
| | loader = DataLoader( |
| | combined_ds, |
| | batch_size=BATCH_SIZE, |
| | shuffle=True, |
| | num_workers=8, |
| | pin_memory=True, |
| | collate_fn=collate_fn, |
| | drop_last=True, |
| | ) |
| | print(f"✓ DataLoader: {len(loader)} batches/epoch") |
| |
|
| |
|
| | |
| | |
| | |
| | @torch.inference_mode() |
| | def generate_samples(model, prompts, num_steps=28, guidance_scale=3.5, H=64, W=64, use_ema=True): |
| | was_training = model.training |
| | model.eval() |
| | |
| | if use_ema and 'ema' in globals() and ema is not None: |
| | ema.apply_shadow_for_eval(model) |
| | |
| | B = len(prompts) |
| | C = 16 |
| |
|
| | t5_list, clip_list = [], [] |
| | for p in prompts: |
| | t5, clip = encode_prompt(p) |
| | t5_list.append(t5) |
| | clip_list.append(clip) |
| | t5_embeds = torch.stack(t5_list).to(DTYPE) |
| | clip_pooleds = torch.stack(clip_list).to(DTYPE) |
| |
|
| | x = torch.randn(B, H * W, C, device=DEVICE, dtype=DTYPE) |
| | img_ids = TinyFluxDeep.create_img_ids(B, H, W, DEVICE) |
| |
|
| | t_linear = torch.linspace(0, 1, num_steps + 1, device=DEVICE, dtype=DTYPE) |
| | timesteps = flux_shift(t_linear, s=SHIFT) |
| |
|
| | for i in range(num_steps): |
| | t_curr = timesteps[i] |
| | t_next = timesteps[i + 1] |
| | dt = t_next - t_curr |
| |
|
| | t_batch = t_curr.expand(B).to(DTYPE) |
| |
|
| | with torch.autocast("cuda", dtype=DTYPE): |
| | |
| | v_cond = model( |
| | hidden_states=x, |
| | encoder_hidden_states=t5_embeds, |
| | pooled_projections=clip_pooleds, |
| | timestep=t_batch, |
| | img_ids=img_ids, |
| | ) |
| | x = x + v_cond * dt |
| |
|
| | latents = x.reshape(B, H, W, C).permute(0, 3, 1, 2) |
| | latents = latents / VAE_SCALE |
| | |
| | with torch.autocast("cuda", dtype=DTYPE): |
| | images = vae.decode(latents.to(vae.dtype)).sample |
| | images = (images / 2 + 0.5).clamp(0, 1) |
| |
|
| | if use_ema and 'ema' in globals() and ema is not None: |
| | ema.restore(model) |
| | |
| | if was_training: |
| | model.train() |
| | return images |
| |
|
| |
|
| | def save_samples(images, prompts, step, output_dir): |
| | from torchvision.utils import save_image |
| | os.makedirs(output_dir, exist_ok=True) |
| | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| | grid_path = os.path.join(output_dir, f"samples_step_{step}.png") |
| | save_image(images, grid_path, nrow=2, padding=2) |
| | try: |
| | api.upload_file( |
| | path_or_fileobj=grid_path, |
| | path_in_repo=f"samples/{timestamp}_step_{step}.png", |
| | repo_id=HF_REPO, |
| | ) |
| | except: |
| | pass |
| |
|
| |
|
| | |
| | |
| | |
| | def save_checkpoint(model, optimizer, scheduler, step, epoch, loss, path, ema=None): |
| | os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True) |
| | if hasattr(model, '_orig_mod'): |
| | state_dict = model._orig_mod.state_dict() |
| | else: |
| | state_dict = model.state_dict() |
| | state_dict = {k: v.to(DTYPE) if v.is_floating_point() else v for k, v in state_dict.items()} |
| | weights_path = path.replace(".pt", ".safetensors") |
| | save_file(state_dict, weights_path) |
| | if ema is not None: |
| | ema_weights = {k: v.to(DTYPE) if v.is_floating_point() else v for k, v in ema.shadow.items()} |
| | ema_weights_path = path.replace(".pt", "_ema.safetensors") |
| | save_file(ema_weights, ema_weights_path) |
| | state = { |
| | "step": step, |
| | "epoch": epoch, |
| | "loss": loss, |
| | "optimizer": optimizer.state_dict(), |
| | "scheduler": scheduler.state_dict(), |
| | } |
| | if ema is not None: |
| | state["ema_decay"] = ema.decay |
| | torch.save(state, path) |
| | print(f" ✓ Saved checkpoint: step {step}") |
| | return weights_path |
| |
|
| |
|
| | def upload_checkpoint(weights_path, step): |
| | try: |
| | api.upload_file( |
| | path_or_fileobj=weights_path, |
| | path_in_repo=f"checkpoints/step_{step}.safetensors", |
| | repo_id=HF_REPO, |
| | ) |
| | ema_path = weights_path.replace(".safetensors", "_ema.safetensors") |
| | if os.path.exists(ema_path): |
| | api.upload_file( |
| | path_or_fileobj=ema_path, |
| | path_in_repo=f"checkpoints/step_{step}_ema.safetensors", |
| | repo_id=HF_REPO, |
| | ) |
| | print(f" ✓ Uploaded checkpoint to {HF_REPO}") |
| | except Exception as e: |
| | print(f" ⚠ Upload failed: {e}") |
| |
|
| |
|
| | def load_with_weight_upgrade(model, state_dict): |
| | """ |
| | Load state dict with automatic handling of: |
| | - Missing ExpertPredictor weights → initialize fresh |
| | - Missing Q/K norm weights → initialize to ones (identity) |
| | - Unexpected keys → ignore (e.g., old guidance_in, sin_basis caches) |
| | """ |
| | model_state = model.state_dict() |
| | |
| | |
| | NEW_WEIGHT_PATTERNS = [ |
| | 'expert_predictor.', |
| | '.norm_q.weight', |
| | '.norm_k.weight', |
| | '.norm_added_q.weight', |
| | '.norm_added_k.weight', |
| | ] |
| | |
| | |
| | DEPRECATED_PATTERNS = [ |
| | 'guidance_in.', |
| | '.sin_basis', |
| | ] |
| | |
| | loaded_keys = [] |
| | missing_keys = [] |
| | unexpected_keys = [] |
| | initialized_keys = [] |
| | |
| | |
| | for key in state_dict.keys(): |
| | if key in model_state: |
| | if state_dict[key].shape == model_state[key].shape: |
| | model_state[key] = state_dict[key] |
| | loaded_keys.append(key) |
| | else: |
| | print(f" ⚠ Shape mismatch for {key}: checkpoint {state_dict[key].shape} vs model {model_state[key].shape}") |
| | unexpected_keys.append(key) |
| | else: |
| | is_deprecated = any(pat in key for pat in DEPRECATED_PATTERNS) |
| | if is_deprecated: |
| | unexpected_keys.append(key) |
| | else: |
| | print(f" ⚠ Unexpected key (not in model): {key}") |
| | unexpected_keys.append(key) |
| | |
| | |
| | for key in model_state.keys(): |
| | if key not in loaded_keys: |
| | is_new = any(pat in key for pat in NEW_WEIGHT_PATTERNS) |
| | |
| | if is_new: |
| | |
| | initialized_keys.append(key) |
| | else: |
| | missing_keys.append(key) |
| | print(f" ⚠ Missing key (not in checkpoint): {key}") |
| | |
| | |
| | model.load_state_dict(model_state, strict=False) |
| | |
| | |
| | if initialized_keys: |
| | |
| | modules = set() |
| | for k in initialized_keys: |
| | parts = k.split('.') |
| | if len(parts) >= 2: |
| | modules.add(parts[0] + '.' + parts[1] if parts[0] == 'expert_predictor' else parts[0]) |
| | print(f" ✓ Initialized new modules (fresh): {sorted(modules)}") |
| | |
| | if unexpected_keys: |
| | deprecated = [k for k in unexpected_keys if any(p in k for p in DEPRECATED_PATTERNS)] |
| | if deprecated: |
| | print(f" ✓ Ignored deprecated keys: {len(deprecated)} (guidance_in, etc)") |
| | |
| | return missing_keys, unexpected_keys |
| |
|
| |
|
| | def load_checkpoint(model, optimizer, scheduler, target): |
| | """ |
| | Load checkpoint with weight upgrade support for ExpertPredictor. |
| | |
| | When ALLOW_WEIGHT_UPGRADE=True: |
| | - Missing ExpertPredictor weights are initialized fresh |
| | - Old guidance_in weights are ignored |
| | - Model continues training with new architecture |
| | """ |
| | start_step = 0 |
| | start_epoch = 0 |
| | ema_state = None |
| | |
| | if target == "none": |
| | print("Starting fresh (no checkpoint)") |
| | return start_step, start_epoch, None |
| | |
| | ckpt_path = None |
| | weights_path = None |
| | ema_weights_path = None |
| | |
| | if target == "latest": |
| | if os.path.exists(CHECKPOINT_DIR): |
| | ckpts = [f for f in os.listdir(CHECKPOINT_DIR) if f.startswith("step_") and f.endswith(".pt")] |
| | if ckpts: |
| | steps = [int(f.split("_")[1].split(".")[0]) for f in ckpts] |
| | latest_step = max(steps) |
| | ckpt_path = os.path.join(CHECKPOINT_DIR, f"step_{latest_step}.pt") |
| | weights_path = ckpt_path.replace(".pt", ".safetensors") |
| | ema_weights_path = ckpt_path.replace(".pt", "_ema.safetensors") |
| | |
| | elif target == "hub" or target.startswith("hub:"): |
| | try: |
| | from huggingface_hub import list_repo_files |
| | |
| | if target.startswith("hub:"): |
| | step_name = target.split(":")[1] |
| | weights_path = hf_hub_download(HF_REPO, f"checkpoints/{step_name}.safetensors") |
| | try: |
| | ema_weights_path = hf_hub_download(HF_REPO, f"checkpoints/{step_name}_ema.safetensors") |
| | print(f" Found EMA weights on hub") |
| | except: |
| | ema_weights_path = None |
| | print(f" No EMA weights on hub (will start fresh)") |
| | start_step = int(step_name.split("_")[1]) if "_" in step_name else 0 |
| | print(f"Downloaded {step_name} from hub") |
| | else: |
| | files = list_repo_files(HF_REPO) |
| | ckpts = [f for f in files if f.startswith("checkpoints/step_") and f.endswith(".safetensors") and "_ema" not in f] |
| | if ckpts: |
| | steps = [int(f.split("_")[1].split(".")[0]) for f in ckpts] |
| | latest = max(steps) |
| | weights_path = hf_hub_download(HF_REPO, f"checkpoints/step_{latest}.safetensors") |
| | try: |
| | ema_weights_path = hf_hub_download(HF_REPO, f"checkpoints/step_{latest}_ema.safetensors") |
| | print(f" Found EMA weights on hub") |
| | except: |
| | ema_weights_path = None |
| | print(f" No EMA weights on hub (will start fresh)") |
| | start_step = latest |
| | print(f"Downloaded step_{latest} from hub") |
| | except Exception as e: |
| | print(f"Could not download from hub: {e}") |
| | return start_step, start_epoch, None |
| | |
| | elif target == "best": |
| | ckpt_path = os.path.join(CHECKPOINT_DIR, "best.pt") |
| | weights_path = ckpt_path.replace(".pt", ".safetensors") |
| | ema_weights_path = ckpt_path.replace(".pt", "_ema.safetensors") |
| | |
| | elif os.path.exists(target): |
| | if target.endswith(".safetensors"): |
| | weights_path = target |
| | ckpt_path = target.replace(".safetensors", ".pt") |
| | ema_weights_path = target.replace(".safetensors", "_ema.safetensors") |
| | else: |
| | ckpt_path = target |
| | weights_path = target.replace(".pt", ".safetensors") |
| | ema_weights_path = target.replace(".pt", "_ema.safetensors") |
| | |
| | |
| | if weights_path and os.path.exists(weights_path): |
| | print(f"Loading weights from {weights_path}") |
| | state_dict = load_file(weights_path) |
| | state_dict = {k: v.to(DTYPE) if v.is_floating_point() else v for k, v in state_dict.items()} |
| | |
| | |
| | model_ref = model._orig_mod if hasattr(model, '_orig_mod') else model |
| | |
| | if ALLOW_WEIGHT_UPGRADE: |
| | |
| | missing, unexpected = load_with_weight_upgrade(model_ref, state_dict) |
| | |
| | if missing: |
| | print(f" ⚠ {len(missing)} truly missing parameters (may need attention)") |
| | else: |
| | |
| | model_ref.load_state_dict(state_dict, strict=True) |
| | |
| | print(f"✓ Loaded model weights") |
| | |
| | |
| | if ema_weights_path and os.path.exists(ema_weights_path): |
| | ema_state = load_file(ema_weights_path) |
| | ema_state = {k: v.to(DTYPE) if v.is_floating_point() else v for k, v in ema_state.items()} |
| | print(f"✓ Loaded EMA weights ({len(ema_state)} params)") |
| | else: |
| | print(f" ℹ No EMA weights found (will initialize fresh)") |
| | else: |
| | print(f" ⚠ Weights file not found: {weights_path}") |
| | print(f" Starting with fresh model") |
| | return start_step, start_epoch, None |
| | |
| | |
| | if ckpt_path and os.path.exists(ckpt_path): |
| | state = torch.load(ckpt_path, map_location="cpu") |
| | start_step = state.get("step", 0) |
| | start_epoch = state.get("epoch", 0) |
| | try: |
| | optimizer.load_state_dict(state["optimizer"]) |
| | scheduler.load_state_dict(state["scheduler"]) |
| | print(f"✓ Loaded optimizer/scheduler state") |
| | except Exception as e: |
| | print(f" ⚠ Could not load optimizer state: {e}") |
| | print(f" Will use fresh optimizer (this is fine for architecture changes)") |
| | print(f"Resuming from step {start_step}, epoch {start_epoch}") |
| | |
| | return start_step, start_epoch, ema_state |
| |
|
| |
|
| | |
| | |
| | |
| | print("\nCreating TinyFluxDeep model with ExpertPredictor...") |
| |
|
| | config = TinyFluxDeepConfig( |
| | use_expert_predictor=ENABLE_EXPERT_DISTILLATION, |
| | expert_dim=EXPERT_DIM, |
| | expert_hidden_dim=EXPERT_HIDDEN_DIM, |
| | expert_dropout=EXPERT_DROPOUT, |
| | guidance_embeds=False, |
| | ) |
| | model = TinyFluxDeep(config).to(device=DEVICE, dtype=DTYPE) |
| |
|
| | total_params = sum(p.numel() for p in model.parameters()) |
| | print(f"Total parameters: {total_params:,}") |
| |
|
| | if hasattr(model, 'expert_predictor') and model.expert_predictor is not None: |
| | expert_params = sum(p.numel() for p in model.expert_predictor.parameters()) |
| | print(f"Expert predictor parameters: {expert_params:,}") |
| |
|
| | trainable_params = [p for p in model.parameters() if p.requires_grad] |
| | print(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}") |
| |
|
| |
|
| | |
| | |
| | |
| | opt = torch.optim.AdamW(trainable_params, lr=LR, betas=(0.9, 0.99), weight_decay=0.01, fused=True) |
| |
|
| | total_steps = len(loader) * EPOCHS // GRAD_ACCUM |
| | warmup = min(1000, total_steps // 10) |
| |
|
| | def lr_fn(step): |
| | if step < warmup: |
| | return step / warmup |
| | return 0.5 * (1 + math.cos(math.pi * (step - warmup) / (total_steps - warmup))) |
| |
|
| | sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_fn) |
| |
|
| |
|
| | |
| | |
| | |
| | start_step, start_epoch, ema_state = load_checkpoint(model, opt, sched, LOAD_TARGET) |
| |
|
| | if RESUME_STEP is not None: |
| | start_step = RESUME_STEP |
| |
|
| |
|
| | |
| | |
| | |
| | model = torch.compile(model, mode="default") |
| |
|
| |
|
| | |
| | |
| | |
| | print("Initializing EMA...") |
| | ema = EMA(model, decay=EMA_DECAY) |
| | if ema_state is not None: |
| | ema.load_shadow(ema_state) |
| | else: |
| | print(" Starting fresh EMA from current weights") |
| |
|
| |
|
| | |
| | |
| | |
| | run_name = f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
| | writer = SummaryWriter(os.path.join(LOG_DIR, run_name)) |
| |
|
| | SAMPLE_PROMPTS = [ |
| | "a photo of a cat sitting on a windowsill", |
| | "a portrait of a woman with red hair", |
| | "a black backpack on white background", |
| | "a person standing in a t-pose", |
| | ] |
| |
|
| |
|
| | |
| | |
| | |
| | def get_distill_weight(step): |
| | if step < DISTILL_WARMUP_STEPS: |
| | return DISTILL_LOSS_WEIGHT * (step / DISTILL_WARMUP_STEPS) |
| | return DISTILL_LOSS_WEIGHT |
| |
|
| |
|
| | |
| | |
| | |
| | print(f"\n{'='*60}") |
| | print(f"Training TinyFlux-Deep with Expert Distillation (Precached)") |
| | print(f"{'='*60}") |
| | print(f"Total: {len(combined_ds):,} samples") |
| | print(f"Epochs: {EPOCHS}, Steps/epoch: {len(loader)}, Total: {total_steps}") |
| | print(f"Batch: {BATCH_SIZE} x {GRAD_ACCUM} = {BATCH_SIZE * GRAD_ACCUM}") |
| | print(f"Expert distillation: {ENABLE_EXPERT_DISTILLATION} (PRECACHED)") |
| | if ENABLE_EXPERT_DISTILLATION: |
| | print(f" - Expert: {EXPERT_CHECKPOINT}") |
| | print(f" - Timestep buckets: {len(EXPERT_T_BUCKETS)}") |
| | print(f" - Distill weight: {DISTILL_LOSS_WEIGHT} (warmup: {DISTILL_WARMUP_STEPS} steps)") |
| | print(f" - Expert dropout: {EXPERT_DROPOUT}") |
| | print(f"Masked loss: {USE_MASKED_LOSS}") |
| | print(f"Min-SNR gamma: {MIN_SNR_GAMMA}") |
| | print(f"Resume: step {start_step}, epoch {start_epoch}") |
| |
|
| | model.train() |
| | step = start_step |
| | best = float("inf") |
| |
|
| | for ep in range(start_epoch, EPOCHS): |
| | ep_loss = 0 |
| | ep_main_loss = 0 |
| | ep_distill_loss = 0 |
| | ep_batches = 0 |
| | pbar = tqdm(loader, desc=f"E{ep + 1}") |
| |
|
| | for i, batch in enumerate(pbar): |
| | latents = batch["latents"].to(DEVICE, non_blocking=True) |
| | t5 = batch["t5_embeds"].to(DEVICE, non_blocking=True) |
| | clip = batch["clip_pooled"].to(DEVICE, non_blocking=True) |
| | local_indices = batch["local_indices"] |
| | dataset_ids = batch["dataset_ids"] |
| | masks = batch["masks"] |
| | |
| | if masks is not None: |
| | masks = masks.to(DEVICE, non_blocking=True) |
| |
|
| | B, C, H, W = latents.shape |
| | data = latents.permute(0, 2, 3, 1).reshape(B, H * W, C) |
| | noise = torch.randn_like(data) |
| | |
| | if TEXT_DROPOUT > 0: |
| | t5, clip, _ = apply_text_dropout(t5, clip, TEXT_DROPOUT) |
| |
|
| | t = torch.sigmoid(torch.randn(B, device=DEVICE)) |
| | t = flux_shift(t, s=SHIFT).to(DTYPE).clamp(1e-4, 1 - 1e-4) |
| |
|
| | t_expanded = t.view(B, 1, 1) |
| | x_t = (1 - t_expanded) * noise + t_expanded * data |
| | v_target = data - noise |
| |
|
| | img_ids = TinyFluxDeep.create_img_ids(B, H, W, DEVICE) |
| |
|
| | |
| | expert_features = None |
| | if ENABLE_EXPERT_DISTILLATION: |
| | expert_features = get_expert_features_for_batch( |
| | local_indices, dataset_ids, t, |
| | portrait_expert_cache, schnell_expert_cache, |
| | sportfashion_expert_cache, synthmocap_expert_cache, |
| | ) |
| | |
| | |
| | if expert_features is not None and random.random() < EXPERT_DROPOUT: |
| | expert_features = None |
| |
|
| | with torch.autocast("cuda", dtype=DTYPE): |
| | v_pred, expert_info = model( |
| | hidden_states=x_t, |
| | encoder_hidden_states=t5, |
| | pooled_projections=clip, |
| | timestep=t, |
| | img_ids=img_ids, |
| | expert_features=expert_features, |
| | return_expert_pred=True, |
| | ) |
| |
|
| | |
| | snr_weights = min_snr_weight(t) |
| | |
| | main_loss = masked_mse_loss( |
| | v_pred, v_target, |
| | mask=masks if USE_MASKED_LOSS else None, |
| | fg_weight=FG_LOSS_WEIGHT, |
| | bg_weight=BG_LOSS_WEIGHT, |
| | snr_weights=snr_weights |
| | ) |
| | |
| | |
| | distill_loss = torch.tensor(0.0, device=DEVICE) |
| | if expert_features is not None and expert_info is not None and 'expert_pred' in expert_info: |
| | distill_weight = get_distill_weight(step) |
| | distill_loss = F.mse_loss(expert_info['expert_pred'], expert_features) |
| | total_loss = main_loss + distill_weight * distill_loss |
| | else: |
| | total_loss = main_loss |
| | |
| | loss = total_loss / GRAD_ACCUM |
| | loss.backward() |
| |
|
| | if (i + 1) % GRAD_ACCUM == 0: |
| | grad_norm = torch.nn.utils.clip_grad_norm_(trainable_params, 1.0) |
| | opt.step() |
| | sched.step() |
| | opt.zero_grad(set_to_none=True) |
| | |
| | ema.update(model) |
| | step += 1 |
| |
|
| | if step % LOG_EVERY == 0: |
| | writer.add_scalar("train/loss", total_loss.item(), step) |
| | writer.add_scalar("train/main_loss", main_loss.item(), step) |
| | if ENABLE_EXPERT_DISTILLATION: |
| | writer.add_scalar("train/distill_loss", distill_loss.item(), step) |
| | writer.add_scalar("train/distill_weight", get_distill_weight(step), step) |
| | writer.add_scalar("train/lr", sched.get_last_lr()[0], step) |
| | writer.add_scalar("train/grad_norm", grad_norm.item(), step) |
| |
|
| | if step % SAMPLE_EVERY == 0: |
| | print(f"\n Generating samples at step {step}...") |
| | images = generate_samples(model, SAMPLE_PROMPTS, num_steps=20, use_ema=True) |
| | save_samples(images, SAMPLE_PROMPTS, step, SAMPLE_DIR) |
| |
|
| | if step % SAVE_EVERY == 0: |
| | ckpt_path = os.path.join(CHECKPOINT_DIR, f"step_{step}.pt") |
| | weights_path = save_checkpoint(model, opt, sched, step, ep, total_loss.item(), ckpt_path, ema=ema) |
| | if step % UPLOAD_EVERY == 0: |
| | upload_checkpoint(weights_path, step) |
| |
|
| | ep_loss += total_loss.item() |
| | ep_main_loss += main_loss.item() |
| | ep_distill_loss += distill_loss.item() |
| | ep_batches += 1 |
| | |
| | pbar.set_postfix( |
| | loss=f"{total_loss.item():.4f}", |
| | main=f"{main_loss.item():.4f}", |
| | dist=f"{distill_loss.item():.4f}" if ENABLE_EXPERT_DISTILLATION else "off", |
| | step=step |
| | ) |
| |
|
| | avg = ep_loss / max(ep_batches, 1) |
| | avg_main = ep_main_loss / max(ep_batches, 1) |
| | avg_distill = ep_distill_loss / max(ep_batches, 1) |
| | |
| | print(f"Epoch {ep + 1} - total: {avg:.4f}, main: {avg_main:.4f}, distill: {avg_distill:.4f}") |
| |
|
| | if avg < best: |
| | best = avg |
| | weights_path = save_checkpoint(model, opt, sched, step, ep, avg, os.path.join(CHECKPOINT_DIR, "best.pt"), ema=ema) |
| | try: |
| | api.upload_file(path_or_fileobj=weights_path, path_in_repo="model.safetensors", repo_id=HF_REPO) |
| | except: |
| | pass |
| |
|
| | print(f"\n✓ Training complete! Best loss: {best:.4f}") |
| | writer.close() |