sd15-flow-matching / trainer_v2.py
AbstractPhil's picture
Introduced new trainer with improved systems and an included timestep
4dad82a verified
raw
history blame
37.3 kB
# =====================================================================================
# SD1.5 Flow-Matching Trainer — David-Driven Adaptive Timestep Sampling
# Quartermaster: Mirel
# NEW: David-weighted timesteps + SD3 shift + adaptive chaos
# =====================================================================================
from __future__ import annotations
import os, json, math, random, re, shutil
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
# Diffusers
from diffusers import StableDiffusionPipeline, DDPMScheduler
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
# Repo deps
from geovocab2.train.model.core.geo_david_collective import GeoDavidCollective
from geovocab2.data.prompt.symbolic_tree import SynthesisSystem
# HF / safetensors
from huggingface_hub import snapshot_download, HfApi, create_repo, hf_hub_download
from safetensors.torch import load_file
# =====================================================================================
# 1) CONFIG
# =====================================================================================
@dataclass
class BaseConfig:
run_name: str = "sd15_flowmatch_david_weighted"
out_dir: str = "./runs/sd15_flowmatch_david_weighted"
ckpt_dir: str = "./checkpoints_sd15_flow_david_weighted"
save_every: int = 1
# Data
num_samples: int = 200_000
batch_size: int = 32
num_workers: int = 2
seed: int = 42
# Models / Blocks
model_id: str = "runwayml/stable-diffusion-v1-5"
active_blocks: Tuple[str, ...] = ("down_0","down_1","down_2","down_3","mid","up_0","up_1","up_2","up_3")
pooling: str = "mean"
# Flow training
epochs: int = 10
lr: float = 1e-4
weight_decay: float = 1e-3
grad_clip: float = 1.0
amp: bool = True
global_flow_weight: float = 1.0
block_penalty_weight: float = 0.2
use_local_flow_heads: bool = False
local_flow_weight: float = 1.0
# KD
use_kd: bool = True
kd_weight: float = 0.25
# David
david_repo_id: str = "AbstractPhil/geo-david-collective-sd15-base-e40"
david_cache_dir: str = "./_hf_david_cache"
david_state_key: Optional[str] = None
# Fusion
alpha_timestep: float = 0.5
beta_pattern: float = 0.25
delta_incoherence: float = 0.25
lambda_min: float = 0.5
lambda_max: float = 3.0
block_weights: Dict[str, float] = None
# Timestep Weighting (David-guided adaptive sampling)
use_timestep_weighting: bool = True
use_david_weights: bool = True
timestep_shift: float = 3.0 # SD3-style shift (higher = bias toward clean)
base_jitter: int = 5 # Base ±jitter around bin center
adaptive_chaos: bool = True # Scale jitter by pattern difficulty
profile_samples: int = 500 # Samples to profile David's difficulty
# Scheduler
num_train_timesteps: int = 1000
# Inference
sample_steps: int = 30
guidance_scale: float = 7.5
# HuggingFace
hf_repo_id: Optional[str] = "AbstractPhil/sd15-flow-matching"
upload_every_epoch: bool = True
continue_training: bool = True
def __post_init__(self):
Path(self.out_dir).mkdir(parents=True, exist_ok=True)
Path(self.ckpt_dir).mkdir(parents=True, exist_ok=True)
Path(self.david_cache_dir).mkdir(parents=True, exist_ok=True)
if self.block_weights is None:
self.block_weights = {'down_0':0.7,'down_1':0.9,'down_2':1.0,'down_3':1.1,'mid':1.2,'up_0':1.1,'up_1':1.0,'up_2':0.9,'up_3':0.7}
# =====================================================================================
# 2) DAVID-WEIGHTED TIMESTEP SAMPLER
# =====================================================================================
class DavidWeightedTimestepSampler:
"""
Samples timesteps weighted by David's inherent difficulty + SD3 shift + adaptive chaos.
"""
def __init__(self, num_timesteps=1000, num_bins=100, shift=3.0, base_jitter=5, adaptive_chaos=True):
self.num_timesteps = num_timesteps
self.num_bins = num_bins
self.shift = shift
self.base_jitter = base_jitter
self.adaptive_chaos = adaptive_chaos
self.difficulty_weights = None # Timestep difficulty
self.pattern_difficulty = None # Pattern confusion per bin
def _apply_shift(self, t: float) -> float:
"""Apply SD3-style timestep shift (operates on normalized t ∈ [0,1])."""
if self.shift <= 0:
return t
return self.shift * t / (1.0 + (self.shift - 1.0) * t)
def compute_difficulty_from_david(self, david, teacher, device, num_samples=500):
"""Profile David's confusion patterns to create difficulty map."""
print("🔍 Profiling David's timestep & pattern difficulty...")
david.eval()
teacher.eval()
# Track David's accuracy and pattern entropy per bin
correct_per_bin = torch.zeros(self.num_bins)
total_per_bin = torch.zeros(self.num_bins)
entropy_per_bin = torch.zeros(self.num_bins)
entropy_count_per_bin = torch.zeros(self.num_bins)
with torch.no_grad():
for _ in tqdm(range(num_samples // 32), desc="Profiling David", leave=False):
# Random latents and timesteps
x = torch.randn(32, 4, 64, 64, device=device, dtype=torch.float16)
t = torch.randint(0, self.num_timesteps, (32,), device=device)
t_bins = (t // 10)
# Dummy conditioning
ehs = torch.randn(32, 77, 768, device=device, dtype=torch.float16)
# Get teacher features
teacher.hooks.clear()
_ = teacher.unet(x, t, encoder_hidden_states=ehs)
feats = {k: v.float() for k, v in teacher.hooks.bank.items()}
# Pool features
pooled = {name: f.mean(dim=(2, 3)) for name, f in feats.items()}
# Get David's outputs
outputs = david(pooled, t.float())
# 1. Timestep difficulty (from classification error)
ts_key = None
for key in ["timestep_logits", "logits_timestep", "timestep_head_logits"]:
if key in outputs:
ts_key = key
break
if ts_key:
ts_logits = outputs[ts_key]
if isinstance(ts_logits, dict):
ts_logits = torch.stack(list(ts_logits.values())).mean(0)
preds = ts_logits.argmax(dim=-1)
for pred, true_bin in zip(preds, t_bins):
bin_idx = true_bin.item()
correct_per_bin[bin_idx] += (pred == true_bin).float().item()
total_per_bin[bin_idx] += 1
# 2. Pattern difficulty (from entropy)
pt_key = None
for key in ["pattern_logits", "logits_pattern", "pattern_head_logits"]:
if key in outputs:
pt_key = key
break
if pt_key:
pt_logits = outputs[pt_key]
if isinstance(pt_logits, dict):
pt_logits = torch.stack(list(pt_logits.values())).mean(0)
P = pt_logits.softmax(-1)
ent = -(P * P.clamp_min(1e-9).log()).sum(-1)
norm_ent = ent / math.log(P.shape[-1]) # Normalize by max entropy
for i, true_bin in enumerate(t_bins):
bin_idx = true_bin.item()
entropy_per_bin[bin_idx] += norm_ent[i].item()
entropy_count_per_bin[bin_idx] += 1
# Compute timestep difficulty (inverse of accuracy)
accuracy_per_bin = correct_per_bin / (total_per_bin.clamp(min=1))
timestep_difficulty = (1.0 - accuracy_per_bin) + 0.1 # Higher = harder
self.difficulty_weights = timestep_difficulty / timestep_difficulty.sum()
# Compute pattern difficulty (average entropy per bin)
self.pattern_difficulty = entropy_per_bin / (entropy_count_per_bin.clamp(min=1))
self.pattern_difficulty = self.pattern_difficulty.clamp(min=0.1, max=1.0)
print(f"✓ David difficulty map computed:")
print(f" Avg timestep accuracy: {accuracy_per_bin.mean():.2%}")
print(f" Hardest timestep bin: {accuracy_per_bin.argmin().item()} ({accuracy_per_bin.min():.2%} acc)")
print(f" Easiest timestep bin: {accuracy_per_bin.argmax().item()} ({accuracy_per_bin.max():.2%} acc)")
print(f" Avg pattern entropy: {self.pattern_difficulty.mean():.3f}")
return self.difficulty_weights
def sample(self, batch_size: int) -> List[int]:
"""Sample timesteps with David weighting + shift + adaptive chaos."""
if self.difficulty_weights is None:
# Fallback to uniform
return [random.randint(0, self.num_timesteps - 1) for _ in range(batch_size)]
timesteps = []
for _ in range(batch_size):
# 1. Sample bin weighted by David's difficulty
bin_idx = torch.multinomial(self.difficulty_weights, 1).item()
# 2. Get bin center as normalized t
bin_center_raw = bin_idx * (self.num_timesteps // self.num_bins) + (self.num_timesteps // self.num_bins) // 2
t_normalized = bin_center_raw / self.num_timesteps
# 3. Apply SD3 shift
t_shifted = self._apply_shift(t_normalized)
# 4. Add adaptive chaos (jitter scaled by pattern difficulty)
if self.adaptive_chaos:
chaos_scale = self.pattern_difficulty[bin_idx].item()
jitter = int(self.base_jitter * (0.5 + chaos_scale)) # 0.5-1.5x base jitter
else:
jitter = self.base_jitter
# 5. Convert back to raw timestep with jitter
t_raw = int(t_shifted * self.num_timesteps)
t_raw += random.randint(-jitter, jitter)
t_raw = max(0, min(self.num_timesteps - 1, t_raw))
timesteps.append(t_raw)
return timesteps
# =====================================================================================
# 3) DATA
# =====================================================================================
class SymbolicPromptDataset(Dataset):
def __init__(self, n:int, seed:int=42, timestep_sampler=None):
self.n = n
self.timestep_sampler = timestep_sampler
random.seed(seed)
self.sys = SynthesisSystem(seed=seed)
def __len__(self): return self.n
def __getitem__(self, idx):
r = self.sys.synthesize(complexity=random.choice([1,2,3,4,5]))
prompt = r['text']
if self.timestep_sampler:
t = self.timestep_sampler.sample(1)[0]
else:
t = random.randint(0, 999)
return {"prompt": prompt, "t": t}
def collate(batch: List[dict]):
prompts = [b["prompt"] for b in batch]
t = torch.tensor([b["t"] for b in batch], dtype=torch.long)
t_bins = t // 10
return {"prompts": prompts, "t": t, "t_bins": t_bins}
# =====================================================================================
# 4) HOOKS + POOLING
# =====================================================================================
class HookBank:
def __init__(self, unet: UNet2DConditionModel, active: Tuple[str, ...]):
self.active = set(active)
self.bank: Dict[str, torch.Tensor] = {}
self.hooks: List[torch.utils.hooks.RemovableHandle] = []
self._register(unet)
def _register(self, unet: UNet2DConditionModel):
def mk(name):
def h(m, i, o):
out = o[0] if isinstance(o,(tuple,list)) else o
self.bank[name] = out
return h
for i, blk in enumerate(unet.down_blocks):
nm = f"down_{i}"
if nm in self.active: self.hooks.append(blk.register_forward_hook(mk(nm)))
if "mid" in self.active:
self.hooks.append(unet.mid_block.register_forward_hook(mk("mid")))
for i, blk in enumerate(unet.up_blocks):
nm = f"up_{i}"
if nm in self.active: self.hooks.append(blk.register_forward_hook(mk(nm)))
def clear(self): self.bank.clear()
def close(self):
for h in self.hooks: h.remove()
self.hooks.clear()
def spatial_pool(x: torch.Tensor, name: str, policy: str) -> torch.Tensor:
if policy == "mean": return x.mean(dim=(2,3))
if policy == "max": return x.amax(dim=(2,3))
if policy == "adaptive":
return x.mean(dim=(2,3)) if (name.startswith("down") or name=="mid") else x.amax(dim=(2,3))
raise ValueError(f"Unknown pooling: {policy}")
# =====================================================================================
# 5) TEACHER
# =====================================================================================
class SD15Teacher(nn.Module):
def __init__(self, cfg: BaseConfig, device: str):
super().__init__()
self.pipe = StableDiffusionPipeline.from_pretrained(cfg.model_id, torch_dtype=torch.float16, safety_checker=None).to(device)
self.unet: UNet2DConditionModel = self.pipe.unet
self.text_encoder = self.pipe.text_encoder
self.tokenizer = self.pipe.tokenizer
self.hooks = HookBank(self.unet, cfg.active_blocks)
self.sched = DDPMScheduler(num_train_timesteps=cfg.num_train_timesteps)
self.device = device
for p in self.parameters(): p.requires_grad_(False)
@torch.no_grad()
def encode(self, prompts: List[str]) -> torch.Tensor:
tok = self.tokenizer(prompts, padding="max_length", max_length=self.tokenizer.model_max_length,
truncation=True, return_tensors="pt")
return self.text_encoder(tok.input_ids.to(self.device))[0]
@torch.no_grad()
def forward_eps_and_feats(self, x_t: torch.Tensor, t: torch.LongTensor, ehs: torch.Tensor):
self.hooks.clear()
eps_hat = self.unet(x_t, t, encoder_hidden_states=ehs).sample
feats = {k: v.detach().float() for k, v in self.hooks.bank.items()}
return eps_hat.float(), feats
def alpha_sigma(self, t: torch.LongTensor) -> Tuple[torch.Tensor, torch.Tensor]:
ac = self.sched.alphas_cumprod.to(self.device)[t]
alpha = ac.sqrt().view(-1,1,1,1).float()
sigma = (1.0 - ac).sqrt().view(-1,1,1,1).float()
return alpha, sigma
# =====================================================================================
# 6) STUDENT
# =====================================================================================
class StudentUNet(nn.Module):
def __init__(self, teacher_unet: UNet2DConditionModel, active_blocks: Tuple[str,...], use_local_heads: bool):
super().__init__()
self.unet = UNet2DConditionModel.from_config(teacher_unet.config)
self.unet.load_state_dict(teacher_unet.state_dict(), strict=True)
self.hooks = HookBank(self.unet, active_blocks)
self.use_local_heads = use_local_heads
self.local_heads = nn.ModuleDict()
def _ensure_heads(self, feats: Dict[str, torch.Tensor]):
if not self.use_local_heads: return
if len(self.local_heads) == len(feats): return
target_dtype = next(self.unet.parameters()).dtype
for name, f in feats.items():
c = f.shape[1]
if name not in self.local_heads:
head = nn.Conv2d(c, 4, kernel_size=1)
head = head.to(dtype=target_dtype, device=f.device)
self.local_heads[name] = head
def forward(self, x_t: torch.Tensor, t: torch.LongTensor, ehs: torch.Tensor):
self.hooks.clear()
v_hat = self.unet(x_t, t, encoder_hidden_states=ehs).sample
feats = {k: v for k, v in self.hooks.bank.items()}
self._ensure_heads(feats)
return v_hat, feats
# =====================================================================================
# 7) DAVID + ASSESSOR + FUSION
# =====================================================================================
class DavidLoader:
def __init__(self, cfg: BaseConfig, device: str):
self.cfg = cfg
self.device = device
self.repo_dir = snapshot_download(repo_id=cfg.david_repo_id, local_dir=cfg.david_cache_dir, local_dir_use_symlinks=False)
self.config_path = os.path.join(self.repo_dir, "config.json")
self.weights_path = os.path.join(self.repo_dir, "model.safetensors")
with open(self.config_path, "r") as f:
self.hf_config = json.load(f)
self.gdc = GeoDavidCollective(
block_configs=self.hf_config["block_configs"],
num_timestep_bins=int(self.hf_config["num_timestep_bins"]),
num_patterns_per_bin=int(self.hf_config["num_patterns_per_bin"]),
block_weights=self.hf_config.get("block_weights", {k:1.0 for k in self.hf_config["block_configs"].keys()}),
loss_config=self.hf_config.get("loss_config", {})
).to(device).eval()
state = load_file(self.weights_path)
self.gdc.load_state_dict(state, strict=False)
for p in self.gdc.parameters(): p.requires_grad_(False)
print(f"✓ David loaded from HF: {self.repo_dir}")
print(f" blocks={len(self.hf_config['block_configs'])} bins={self.hf_config['num_timestep_bins']} patterns={self.hf_config['num_patterns_per_bin']}")
if "block_weights" in self.hf_config:
cfg.block_weights = self.hf_config["block_weights"]
class DavidAssessor(nn.Module):
def __init__(self, gdc: GeoDavidCollective, pooling: str):
super().__init__()
self.gdc = gdc
self.pooling = pooling
def _pool(self, feats: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return {k: spatial_pool(v, k, self.pooling) for k, v in feats.items()}
@torch.no_grad()
def forward(self, feats_student: Dict[str, torch.Tensor], t: torch.LongTensor
) -> Tuple[Dict[str,float], Dict[str,float], Dict[str,float]]:
Zs = self._pool(feats_student)
outs = self.gdc(Zs, t.float())
e_t, e_p, coh = {}, {}, {}
ts_key = None
for key in ["timestep_logits", "logits_timestep", "timestep_head_logits"]:
if key in outs: ts_key = key; break
pt_key = None
for key in ["pattern_logits", "logits_pattern", "pattern_head_logits"]:
if key in outs: pt_key = key; break
t_bins = (t // 10).to(next(self.gdc.parameters()).device)
if ts_key is not None:
ts_logits = outs[ts_key]
if isinstance(ts_logits, dict):
for name, L in ts_logits.items():
ce = F.cross_entropy(L, t_bins, reduction="mean")
e_t[name] = float(ce.item())
else:
ce = F.cross_entropy(ts_logits, t_bins, reduction="mean")
for name in Zs.keys():
e_t[name] = float(ce.item())
else:
for name in Zs.keys(): e_t[name] = 0.0
if pt_key is not None:
pt_logits = outs[pt_key]
if isinstance(pt_logits, dict):
for name, L in pt_logits.items():
P = L.softmax(-1)
ent = -(P * (P.clamp_min(1e-9)).log()).sum(-1).mean()
e_p[name] = float(ent.item() / math.log(P.shape[-1]))
else:
P = pt_logits.softmax(-1)
ent = -(P * (P.clamp_min(1e-9)).log()).sum(-1).mean()
for name in Zs.keys():
e_p[name] = float(ent.item() / math.log(P.shape[-1]))
else:
for name in Zs.keys(): e_p[name] = 0.0
alphas = {}
try:
alphas = self.gdc.get_cantor_alphas()
except Exception:
alphas = {}
avg_alpha = float(sum(alphas.values())/max(len(alphas),1)) if alphas else 1.0
for name in Zs.keys():
coh[name] = avg_alpha
return e_t, e_p, coh
class BlockPenaltyFusion:
def __init__(self, cfg: BaseConfig): self.cfg = cfg
def lambdas(self, e_t:Dict[str,float], e_p:Dict[str,float], coh:Dict[str,float]) -> Dict[str,float]:
lam = {}
for name, base in self.cfg.block_weights.items():
val = base * (1.0
+ self.cfg.alpha_timestep * float(e_t.get(name,0.0))
+ self.cfg.beta_pattern * float(e_p.get(name,0.0))
+ self.cfg.delta_incoherence * (1.0 - float(coh.get(name,1.0))))
lam[name] = float(max(self.cfg.lambda_min, min(self.cfg.lambda_max, val)))
return lam
# =====================================================================================
# 8) TRAINER
# =====================================================================================
class FlowMatchDavidTrainer:
def __init__(self, cfg: BaseConfig, device: str = "cuda"):
self.cfg = cfg
self.device = device
self.start_epoch = 0
self.start_gstep = 0
# Initialize David first (needed for timestep sampler)
self.david_loader = DavidLoader(cfg, device)
self.david = self.david_loader.gdc
self.assessor = DavidAssessor(self.david, cfg.pooling)
self.fusion = BlockPenaltyFusion(cfg)
# Initialize teacher (needed for David profiling)
self.teacher = SD15Teacher(cfg, device).eval()
# Initialize timestep sampler
self.timestep_sampler = None
if cfg.use_timestep_weighting:
print("\n" + "="*70)
print("🎯 ADAPTIVE TIMESTEP SAMPLING ENABLED")
print(f" David weighting: {cfg.use_david_weights}")
print(f" SD3 shift: {cfg.timestep_shift}")
print(f" Base jitter: ±{cfg.base_jitter}")
print(f" Adaptive chaos: {cfg.adaptive_chaos}")
self.timestep_sampler = DavidWeightedTimestepSampler(
num_timesteps=cfg.num_train_timesteps,
num_bins=100,
shift=cfg.timestep_shift if cfg.use_david_weights else 0.0,
base_jitter=cfg.base_jitter,
adaptive_chaos=cfg.adaptive_chaos
)
if cfg.use_david_weights:
self.timestep_sampler.compute_difficulty_from_david(
david=self.david,
teacher=self.teacher,
device=device,
num_samples=cfg.profile_samples
)
print("="*70 + "\n")
# Initialize dataset with sampler
self.dataset = SymbolicPromptDataset(cfg.num_samples, cfg.seed, self.timestep_sampler)
self.loader = DataLoader(self.dataset, batch_size=cfg.batch_size, shuffle=True,
num_workers=cfg.num_workers, pin_memory=True, collate_fn=collate)
# Initialize student
self.student = StudentUNet(self.teacher.unet, cfg.active_blocks, cfg.use_local_flow_heads).to(device)
self.opt = torch.optim.AdamW(self.student.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
self.sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=cfg.epochs * len(self.loader))
self.scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)
# Load checkpoints
emergency_path = Path("./EMERGENCY_SAVE_SUCCESS.pt")
if not emergency_path.exists():
print("\n🔍 Emergency checkpoint not found locally, checking HuggingFace...")
emergency_path = self._download_emergency_checkpoint()
if emergency_path and emergency_path.exists():
self._load_emergency_checkpoint(emergency_path)
elif cfg.continue_training:
self._load_latest_from_hf()
self.writer = SummaryWriter(log_dir=os.path.join(cfg.out_dir, cfg.run_name))
def _download_emergency_checkpoint(self) -> Optional[Path]:
"""Download emergency checkpoint from HuggingFace backup repo."""
emergency_repo = "AbstractPhil/sd15-flow-emergency-backup"
emergency_file = "EMERGENCY_SAVE_SUCCESS.pt"
try:
print(f"📥 Downloading emergency checkpoint from {emergency_repo}...")
local_path = hf_hub_download(
repo_id=emergency_repo,
filename=emergency_file,
repo_type="model",
cache_dir="./_emergency_cache"
)
target_path = Path("./EMERGENCY_SAVE_SUCCESS.pt")
shutil.copy(local_path, target_path)
size_mb = target_path.stat().st_size / 1e6
print(f"✅ Downloaded emergency checkpoint ({size_mb:.1f} MB)")
return target_path
except Exception as e:
print(f"⚠️ Could not download emergency checkpoint: {e}")
return None
def _load_emergency_checkpoint(self, path: Path):
"""Load emergency checkpoint with student_unet structure."""
try:
print(f"\n🚨 Found emergency checkpoint: {path}")
checkpoint = torch.load(path, map_location='cpu')
if 'student_unet' in checkpoint:
print("📦 Loading emergency checkpoint format...")
missing, unexpected = self.student.unet.load_state_dict(checkpoint['student_unet'], strict=False)
print(f"✓ Loaded student UNet")
if 'opt' in checkpoint:
self.opt.load_state_dict(checkpoint['opt'])
print("✓ Loaded optimizer state")
if 'sched' in checkpoint:
self.sched.load_state_dict(checkpoint['sched'])
print("✓ Loaded scheduler state")
if 'gstep' in checkpoint:
self.start_gstep = checkpoint['gstep']
self.start_epoch = self.start_gstep // len(self.loader)
print(f"✓ Resuming from global step {self.start_gstep} (epoch ~{self.start_epoch})")
print("✅ Emergency checkpoint loaded successfully!")
except Exception as e:
print(f"⚠️ Failed to load emergency checkpoint: {e}")
def _load_latest_from_hf(self):
if not self.cfg.hf_repo_id:
return
try:
api = HfApi()
print(f"\n🔍 Searching for latest checkpoint in {self.cfg.hf_repo_id}...")
files = api.list_repo_files(repo_id=self.cfg.hf_repo_id, repo_type="model")
epochs = []
for f in files:
if f.endswith('.pt'):
match = re.search(r'_e(\d+)\.pt$', f)
if match:
epochs.append((int(match.group(1)), f))
if not epochs:
return
latest_epoch, latest_file = max(epochs, key=lambda x: x[0])
print(f"📥 Downloading: {latest_file}")
local_path = hf_hub_download(
repo_id=self.cfg.hf_repo_id,
filename=latest_file,
repo_type="model",
cache_dir=self.cfg.ckpt_dir
)
checkpoint = torch.load(local_path, map_location='cpu')
if 'student_unet' in checkpoint:
self.student.unet.load_state_dict(checkpoint['student_unet'], strict=False)
elif 'student' in checkpoint:
self.student.load_state_dict(checkpoint['student'], strict=False)
if 'opt' in checkpoint:
self.opt.load_state_dict(checkpoint['opt'])
if 'sched' in checkpoint:
self.sched.load_state_dict(checkpoint['sched'])
self.start_epoch = latest_epoch
self.start_gstep = latest_epoch * len(self.loader)
print(f"✅ Resuming from epoch {self.start_epoch + 1}")
del checkpoint
torch.cuda.empty_cache()
except Exception as e:
print(f"⚠️ Failed to load from HF: {e}")
def _v_star(self, x_t, t, eps_hat):
alpha, sigma = self.teacher.alpha_sigma(t)
x0_hat = (x_t - sigma * eps_hat) / (alpha + 1e-8)
return alpha * eps_hat - sigma * x0_hat
def _down_like(self, tgt: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
return F.interpolate(tgt, size=ref.shape[-2:], mode="bilinear", align_corners=False)
def _kd_cos(self, s: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
s = F.normalize(s, dim=-1); t = F.normalize(t, dim=-1)
return 1.0 - (s*t).sum(-1).mean()
def train(self):
cfg = self.cfg
gstep = self.start_gstep
for ep in range(self.start_epoch, cfg.epochs):
self.student.train()
pbar = tqdm(self.loader, desc=f"Epoch {ep+1}/{cfg.epochs}",
dynamic_ncols=True, leave=True, position=0)
acc = {"L":0.0, "Lf":0.0, "Lb":0.0}
for it, batch in enumerate(pbar):
prompts = batch["prompts"]
t = batch["t"].to(self.device)
with torch.no_grad():
ehs = self.teacher.encode(prompts)
x_t = torch.randn(len(prompts), 4, 64, 64, device=self.device, dtype=torch.float16)
with torch.no_grad():
eps_hat, t_feats_spatial = self.teacher.forward_eps_and_feats(x_t.half(), t, ehs)
v_star = self._v_star(x_t, t, eps_hat)
with torch.cuda.amp.autocast(enabled=cfg.amp):
v_hat, s_feats_spatial = self.student(x_t, t, ehs)
L_flow = F.mse_loss(v_hat, v_star)
e_t, e_p, coh = self.assessor(s_feats_spatial, t)
lam = self.fusion.lambdas(e_t, e_p, coh)
L_blocks = torch.zeros((), device=self.device)
for name, s_feat in s_feats_spatial.items():
L_kd = torch.zeros((), device=self.device)
if cfg.use_kd:
s_pool = spatial_pool(s_feat, name, cfg.pooling)
t_pool = spatial_pool(t_feats_spatial[name], name, cfg.pooling)
L_kd = self._kd_cos(s_pool, t_pool)
L_lf = torch.zeros((), device=self.device)
if cfg.use_local_flow_heads and name in self.student.local_heads:
v_loc = self.student.local_heads[name](s_feat)
v_ds = self._down_like(v_star, v_loc)
L_lf = F.mse_loss(v_loc, v_ds)
L_blocks = L_blocks + lam.get(name,1.0) * (cfg.kd_weight * L_kd + cfg.local_flow_weight * L_lf)
L_total = cfg.global_flow_weight*L_flow + cfg.block_penalty_weight*L_blocks
self.opt.zero_grad(set_to_none=True)
if cfg.amp:
self.scaler.scale(L_total).backward()
nn.utils.clip_grad_norm_(self.student.parameters(), cfg.grad_clip)
self.scaler.step(self.opt); self.scaler.update()
else:
L_total.backward()
nn.utils.clip_grad_norm_(self.student.parameters(), cfg.grad_clip)
self.opt.step()
self.sched.step(); gstep += 1
acc["L"] += float(L_total.item())
acc["Lf"] += float(L_flow.item())
acc["Lb"] += float(L_blocks.item())
if it % 50 == 0:
self.writer.add_scalar("train/total", float(L_total.item()), gstep)
self.writer.add_scalar("train/flow", float(L_flow.item()), gstep)
self.writer.add_scalar("train/blocks",float(L_blocks.item()), gstep)
for k in list(lam.keys())[:4]:
self.writer.add_scalar(f"lambda/{k}", lam[k], gstep)
if it % 10 == 0 or it == len(self.loader) - 1:
pbar.set_postfix({
"L": f"{float(L_total.item()):.4f}",
"Lf": f"{float(L_flow.item()):.4f}",
"Lb": f"{float(L_blocks.item()):.4f}"
}, refresh=False)
del x_t, eps_hat, v_star, v_hat, s_feats_spatial, t_feats_spatial
pbar.close()
n = len(self.loader)
print(f"\n[Epoch {ep+1}] L={acc['L']/n:.4f} | L_flow={acc['Lf']/n:.4f} | L_blocks={acc['Lb']/n:.4f}")
self.writer.add_scalar("epoch/total", acc['L']/n, ep+1)
self.writer.add_scalar("epoch/flow", acc['Lf']/n, ep+1)
self.writer.add_scalar("epoch/blocks",acc['Lb']/n, ep+1)
if (ep+1) % cfg.save_every == 0:
self._save(ep+1, gstep)
self._save("final", gstep)
self.writer.close()
def _save(self, tag, gstep):
"""Save checkpoint and upload to HuggingFace."""
pt_path = Path(self.cfg.ckpt_dir) / f"{self.cfg.run_name}_e{tag}.pt"
torch.save({
"cfg": asdict(self.cfg),
"student": self.student.state_dict(),
"opt": self.opt.state_dict(),
"sched": self.sched.state_dict(),
"gstep": gstep
}, pt_path)
size_mb = pt_path.stat().st_size / 1e6
print(f"✓ Saved checkpoint: {pt_path.name} ({size_mb:.1f} MB)")
if self.cfg.upload_every_epoch and self.cfg.hf_repo_id:
self._upload_to_hf(pt_path, tag)
def _upload_to_hf(self, path: Path, tag):
"""Upload checkpoint to HuggingFace."""
try:
api = HfApi()
create_repo(self.cfg.hf_repo_id, exist_ok=True, private=False, repo_type="model")
print(f"📤 Uploading {path.name} to {self.cfg.hf_repo_id}...")
api.upload_file(
path_or_fileobj=str(path),
path_in_repo=path.name,
repo_id=self.cfg.hf_repo_id,
repo_type="model",
commit_message=f"Epoch {tag}"
)
print(f"✅ Uploaded: https://huggingface.co/{self.cfg.hf_repo_id}/{path.name}")
except Exception as e:
print(f"⚠️ Upload failed: {e}")
@torch.no_grad()
def sample(self, prompts: List[str], steps: Optional[int]=None, guidance: Optional[float]=None) -> torch.Tensor:
steps = steps or self.cfg.sample_steps
guidance = guidance if guidance is not None else self.cfg.guidance_scale
cond_e = self.teacher.encode(prompts)
uncond_e = self.teacher.encode([""]*len(prompts))
sched = self.teacher.sched
sched.set_timesteps(steps, device=self.device)
x_t = torch.randn(len(prompts), 4, 64, 64, device=self.device)
for t_scalar in sched.timesteps:
t = torch.full((x_t.shape[0],), t_scalar, device=self.device, dtype=torch.long)
v_u, _ = self.student(x_t, t, uncond_e)
v_c, _ = self.student(x_t, t, cond_e)
v_hat = v_u + guidance*(v_c - v_u)
alpha, sigma = self.teacher.alpha_sigma(t)
denom = (alpha**2 + sigma**2)
x0_hat = (alpha * x_t - sigma * v_hat) / (denom + 1e-8)
eps_hat = (x_t - alpha * x0_hat) / (sigma + 1e-8)
step = sched.step(model_output=eps_hat, timestep=t_scalar, sample=x_t)
x_t = step.prev_sample
imgs = self.teacher.pipe.vae.decode(x_t / 0.18215).sample
return imgs.clamp(-1,1)
# =====================================================================================
# 9) MAIN
# =====================================================================================
def main():
cfg = BaseConfig()
print(json.dumps(asdict(cfg), indent=2))
device = "cuda" if torch.cuda.is_available() else "cpu"
if device != "cuda":
print("⚠️ A100 strongly recommended.")
trainer = FlowMatchDavidTrainer(cfg, device=device)
trainer.train()
_ = trainer.sample(["a castle at sunset"], steps=10, guidance=7.0)
print("✓ Training complete.")
if __name__ == "__main__":
main()