| |
| |
| |
| |
| |
| from __future__ import annotations |
| import os, json, math, random, re, shutil |
| from dataclasses import dataclass, asdict |
| from pathlib import Path |
| from typing import Dict, List, Tuple, Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from torch.utils.tensorboard import SummaryWriter |
| from tqdm import tqdm |
|
|
| |
| from diffusers import StableDiffusionPipeline, DDPMScheduler |
| from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel |
|
|
| |
| from geovocab2.train.model.core.geo_david_collective import GeoDavidCollective |
| from geovocab2.data.prompt.symbolic_tree import SynthesisSystem |
|
|
| |
| from huggingface_hub import snapshot_download, HfApi, create_repo, hf_hub_download |
| from safetensors.torch import load_file |
|
|
|
|
| |
| |
| |
| @dataclass |
| class BaseConfig: |
| run_name: str = "sd15_flowmatch_david_weighted" |
| out_dir: str = "./runs/sd15_flowmatch_david_weighted" |
| ckpt_dir: str = "./checkpoints_sd15_flow_david_weighted" |
| save_every: int = 1 |
|
|
| |
| num_samples: int = 200_000 |
| batch_size: int = 32 |
| num_workers: int = 2 |
| seed: int = 42 |
|
|
| |
| model_id: str = "runwayml/stable-diffusion-v1-5" |
| active_blocks: Tuple[str, ...] = ("down_0","down_1","down_2","down_3","mid","up_0","up_1","up_2","up_3") |
| pooling: str = "mean" |
|
|
| |
| epochs: int = 10 |
| lr: float = 1e-4 |
| weight_decay: float = 1e-3 |
| grad_clip: float = 1.0 |
| amp: bool = True |
| |
| global_flow_weight: float = 1.0 |
| block_penalty_weight: float = 0.2 |
| use_local_flow_heads: bool = False |
| local_flow_weight: float = 1.0 |
|
|
| |
| use_kd: bool = True |
| kd_weight: float = 0.25 |
|
|
| |
| david_repo_id: str = "AbstractPhil/geo-david-collective-sd15-base-e40" |
| david_cache_dir: str = "./_hf_david_cache" |
| david_state_key: Optional[str] = None |
|
|
| |
| alpha_timestep: float = 0.5 |
| beta_pattern: float = 0.25 |
| delta_incoherence: float = 0.25 |
| lambda_min: float = 0.5 |
| lambda_max: float = 3.0 |
|
|
| block_weights: Dict[str, float] = None |
| |
| |
| use_timestep_weighting: bool = True |
| use_david_weights: bool = True |
| timestep_shift: float = 3.0 |
| base_jitter: int = 5 |
| adaptive_chaos: bool = True |
| profile_samples: int = 500 |
|
|
| |
| num_train_timesteps: int = 1000 |
|
|
| |
| sample_steps: int = 30 |
| guidance_scale: float = 7.5 |
| |
| |
| hf_repo_id: Optional[str] = "AbstractPhil/sd15-flow-matching" |
| upload_every_epoch: bool = True |
| continue_training: bool = True |
|
|
| def __post_init__(self): |
| Path(self.out_dir).mkdir(parents=True, exist_ok=True) |
| Path(self.ckpt_dir).mkdir(parents=True, exist_ok=True) |
| Path(self.david_cache_dir).mkdir(parents=True, exist_ok=True) |
| if self.block_weights is None: |
| self.block_weights = {'down_0':0.7,'down_1':0.9,'down_2':1.0,'down_3':1.1,'mid':1.2,'up_0':1.1,'up_1':1.0,'up_2':0.9,'up_3':0.7} |
|
|
|
|
| |
| |
| |
| class DavidWeightedTimestepSampler: |
| """ |
| Samples timesteps weighted by David's inherent difficulty + SD3 shift + adaptive chaos. |
| """ |
| def __init__(self, num_timesteps=1000, num_bins=100, shift=3.0, base_jitter=5, adaptive_chaos=True): |
| self.num_timesteps = num_timesteps |
| self.num_bins = num_bins |
| self.shift = shift |
| self.base_jitter = base_jitter |
| self.adaptive_chaos = adaptive_chaos |
| |
| self.difficulty_weights = None |
| self.pattern_difficulty = None |
| |
| def _apply_shift(self, t: float) -> float: |
| """Apply SD3-style timestep shift (operates on normalized t ∈ [0,1]).""" |
| if self.shift <= 0: |
| return t |
| return self.shift * t / (1.0 + (self.shift - 1.0) * t) |
| |
| def compute_difficulty_from_david(self, david, teacher, device, num_samples=500): |
| """Profile David's confusion patterns to create difficulty map.""" |
| print("🔍 Profiling David's timestep & pattern difficulty...") |
| |
| david.eval() |
| teacher.eval() |
| |
| |
| correct_per_bin = torch.zeros(self.num_bins) |
| total_per_bin = torch.zeros(self.num_bins) |
| entropy_per_bin = torch.zeros(self.num_bins) |
| entropy_count_per_bin = torch.zeros(self.num_bins) |
| |
| with torch.no_grad(): |
| for _ in tqdm(range(num_samples // 32), desc="Profiling David", leave=False): |
| |
| x = torch.randn(32, 4, 64, 64, device=device, dtype=torch.float16) |
| t = torch.randint(0, self.num_timesteps, (32,), device=device) |
| t_bins = (t // 10) |
| |
| |
| ehs = torch.randn(32, 77, 768, device=device, dtype=torch.float16) |
| |
| |
| teacher.hooks.clear() |
| _ = teacher.unet(x, t, encoder_hidden_states=ehs) |
| feats = {k: v.float() for k, v in teacher.hooks.bank.items()} |
| |
| |
| pooled = {name: f.mean(dim=(2, 3)) for name, f in feats.items()} |
| |
| |
| outputs = david(pooled, t.float()) |
| |
| |
| ts_key = None |
| for key in ["timestep_logits", "logits_timestep", "timestep_head_logits"]: |
| if key in outputs: |
| ts_key = key |
| break |
| |
| if ts_key: |
| ts_logits = outputs[ts_key] |
| if isinstance(ts_logits, dict): |
| ts_logits = torch.stack(list(ts_logits.values())).mean(0) |
| |
| preds = ts_logits.argmax(dim=-1) |
| for pred, true_bin in zip(preds, t_bins): |
| bin_idx = true_bin.item() |
| correct_per_bin[bin_idx] += (pred == true_bin).float().item() |
| total_per_bin[bin_idx] += 1 |
| |
| |
| pt_key = None |
| for key in ["pattern_logits", "logits_pattern", "pattern_head_logits"]: |
| if key in outputs: |
| pt_key = key |
| break |
| |
| if pt_key: |
| pt_logits = outputs[pt_key] |
| if isinstance(pt_logits, dict): |
| pt_logits = torch.stack(list(pt_logits.values())).mean(0) |
| |
| P = pt_logits.softmax(-1) |
| ent = -(P * P.clamp_min(1e-9).log()).sum(-1) |
| norm_ent = ent / math.log(P.shape[-1]) |
| |
| for i, true_bin in enumerate(t_bins): |
| bin_idx = true_bin.item() |
| entropy_per_bin[bin_idx] += norm_ent[i].item() |
| entropy_count_per_bin[bin_idx] += 1 |
| |
| |
| accuracy_per_bin = correct_per_bin / (total_per_bin.clamp(min=1)) |
| timestep_difficulty = (1.0 - accuracy_per_bin) + 0.1 |
| self.difficulty_weights = timestep_difficulty / timestep_difficulty.sum() |
| |
| |
| self.pattern_difficulty = entropy_per_bin / (entropy_count_per_bin.clamp(min=1)) |
| self.pattern_difficulty = self.pattern_difficulty.clamp(min=0.1, max=1.0) |
| |
| print(f"✓ David difficulty map computed:") |
| print(f" Avg timestep accuracy: {accuracy_per_bin.mean():.2%}") |
| print(f" Hardest timestep bin: {accuracy_per_bin.argmin().item()} ({accuracy_per_bin.min():.2%} acc)") |
| print(f" Easiest timestep bin: {accuracy_per_bin.argmax().item()} ({accuracy_per_bin.max():.2%} acc)") |
| print(f" Avg pattern entropy: {self.pattern_difficulty.mean():.3f}") |
| |
| return self.difficulty_weights |
| |
| def sample(self, batch_size: int) -> List[int]: |
| """Sample timesteps with David weighting + shift + adaptive chaos.""" |
| if self.difficulty_weights is None: |
| |
| return [random.randint(0, self.num_timesteps - 1) for _ in range(batch_size)] |
| |
| timesteps = [] |
| for _ in range(batch_size): |
| |
| bin_idx = torch.multinomial(self.difficulty_weights, 1).item() |
| |
| |
| bin_center_raw = bin_idx * (self.num_timesteps // self.num_bins) + (self.num_timesteps // self.num_bins) // 2 |
| t_normalized = bin_center_raw / self.num_timesteps |
| |
| |
| t_shifted = self._apply_shift(t_normalized) |
| |
| |
| if self.adaptive_chaos: |
| chaos_scale = self.pattern_difficulty[bin_idx].item() |
| jitter = int(self.base_jitter * (0.5 + chaos_scale)) |
| else: |
| jitter = self.base_jitter |
| |
| |
| t_raw = int(t_shifted * self.num_timesteps) |
| t_raw += random.randint(-jitter, jitter) |
| t_raw = max(0, min(self.num_timesteps - 1, t_raw)) |
| |
| timesteps.append(t_raw) |
| |
| return timesteps |
|
|
|
|
| |
| |
| |
| class SymbolicPromptDataset(Dataset): |
| def __init__(self, n:int, seed:int=42, timestep_sampler=None): |
| self.n = n |
| self.timestep_sampler = timestep_sampler |
| random.seed(seed) |
| self.sys = SynthesisSystem(seed=seed) |
|
|
| def __len__(self): return self.n |
|
|
| def __getitem__(self, idx): |
| r = self.sys.synthesize(complexity=random.choice([1,2,3,4,5])) |
| prompt = r['text'] |
| |
| if self.timestep_sampler: |
| t = self.timestep_sampler.sample(1)[0] |
| else: |
| t = random.randint(0, 999) |
| |
| return {"prompt": prompt, "t": t} |
|
|
| def collate(batch: List[dict]): |
| prompts = [b["prompt"] for b in batch] |
| t = torch.tensor([b["t"] for b in batch], dtype=torch.long) |
| t_bins = t // 10 |
| return {"prompts": prompts, "t": t, "t_bins": t_bins} |
|
|
|
|
| |
| |
| |
| class HookBank: |
| def __init__(self, unet: UNet2DConditionModel, active: Tuple[str, ...]): |
| self.active = set(active) |
| self.bank: Dict[str, torch.Tensor] = {} |
| self.hooks: List[torch.utils.hooks.RemovableHandle] = [] |
| self._register(unet) |
|
|
| def _register(self, unet: UNet2DConditionModel): |
| def mk(name): |
| def h(m, i, o): |
| out = o[0] if isinstance(o,(tuple,list)) else o |
| self.bank[name] = out |
| return h |
| for i, blk in enumerate(unet.down_blocks): |
| nm = f"down_{i}" |
| if nm in self.active: self.hooks.append(blk.register_forward_hook(mk(nm))) |
| if "mid" in self.active: |
| self.hooks.append(unet.mid_block.register_forward_hook(mk("mid"))) |
| for i, blk in enumerate(unet.up_blocks): |
| nm = f"up_{i}" |
| if nm in self.active: self.hooks.append(blk.register_forward_hook(mk(nm))) |
|
|
| def clear(self): self.bank.clear() |
| def close(self): |
| for h in self.hooks: h.remove() |
| self.hooks.clear() |
|
|
| def spatial_pool(x: torch.Tensor, name: str, policy: str) -> torch.Tensor: |
| if policy == "mean": return x.mean(dim=(2,3)) |
| if policy == "max": return x.amax(dim=(2,3)) |
| if policy == "adaptive": |
| return x.mean(dim=(2,3)) if (name.startswith("down") or name=="mid") else x.amax(dim=(2,3)) |
| raise ValueError(f"Unknown pooling: {policy}") |
|
|
|
|
| |
| |
| |
| class SD15Teacher(nn.Module): |
| def __init__(self, cfg: BaseConfig, device: str): |
| super().__init__() |
| self.pipe = StableDiffusionPipeline.from_pretrained(cfg.model_id, torch_dtype=torch.float16, safety_checker=None).to(device) |
| self.unet: UNet2DConditionModel = self.pipe.unet |
| self.text_encoder = self.pipe.text_encoder |
| self.tokenizer = self.pipe.tokenizer |
| self.hooks = HookBank(self.unet, cfg.active_blocks) |
| self.sched = DDPMScheduler(num_train_timesteps=cfg.num_train_timesteps) |
| self.device = device |
| for p in self.parameters(): p.requires_grad_(False) |
|
|
| @torch.no_grad() |
| def encode(self, prompts: List[str]) -> torch.Tensor: |
| tok = self.tokenizer(prompts, padding="max_length", max_length=self.tokenizer.model_max_length, |
| truncation=True, return_tensors="pt") |
| return self.text_encoder(tok.input_ids.to(self.device))[0] |
|
|
| @torch.no_grad() |
| def forward_eps_and_feats(self, x_t: torch.Tensor, t: torch.LongTensor, ehs: torch.Tensor): |
| self.hooks.clear() |
| eps_hat = self.unet(x_t, t, encoder_hidden_states=ehs).sample |
| feats = {k: v.detach().float() for k, v in self.hooks.bank.items()} |
| return eps_hat.float(), feats |
|
|
| def alpha_sigma(self, t: torch.LongTensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| ac = self.sched.alphas_cumprod.to(self.device)[t] |
| alpha = ac.sqrt().view(-1,1,1,1).float() |
| sigma = (1.0 - ac).sqrt().view(-1,1,1,1).float() |
| return alpha, sigma |
|
|
|
|
| |
| |
| |
| class StudentUNet(nn.Module): |
| def __init__(self, teacher_unet: UNet2DConditionModel, active_blocks: Tuple[str,...], use_local_heads: bool): |
| super().__init__() |
| self.unet = UNet2DConditionModel.from_config(teacher_unet.config) |
| self.unet.load_state_dict(teacher_unet.state_dict(), strict=True) |
| self.hooks = HookBank(self.unet, active_blocks) |
| self.use_local_heads = use_local_heads |
| self.local_heads = nn.ModuleDict() |
|
|
| def _ensure_heads(self, feats: Dict[str, torch.Tensor]): |
| if not self.use_local_heads: return |
| if len(self.local_heads) == len(feats): return |
| |
| target_dtype = next(self.unet.parameters()).dtype |
| |
| for name, f in feats.items(): |
| c = f.shape[1] |
| if name not in self.local_heads: |
| head = nn.Conv2d(c, 4, kernel_size=1) |
| head = head.to(dtype=target_dtype, device=f.device) |
| self.local_heads[name] = head |
|
|
| def forward(self, x_t: torch.Tensor, t: torch.LongTensor, ehs: torch.Tensor): |
| self.hooks.clear() |
| v_hat = self.unet(x_t, t, encoder_hidden_states=ehs).sample |
| feats = {k: v for k, v in self.hooks.bank.items()} |
| self._ensure_heads(feats) |
| return v_hat, feats |
|
|
|
|
| |
| |
| |
| class DavidLoader: |
| def __init__(self, cfg: BaseConfig, device: str): |
| self.cfg = cfg |
| self.device = device |
| self.repo_dir = snapshot_download(repo_id=cfg.david_repo_id, local_dir=cfg.david_cache_dir, local_dir_use_symlinks=False) |
| self.config_path = os.path.join(self.repo_dir, "config.json") |
| self.weights_path = os.path.join(self.repo_dir, "model.safetensors") |
| with open(self.config_path, "r") as f: |
| self.hf_config = json.load(f) |
| |
| self.gdc = GeoDavidCollective( |
| block_configs=self.hf_config["block_configs"], |
| num_timestep_bins=int(self.hf_config["num_timestep_bins"]), |
| num_patterns_per_bin=int(self.hf_config["num_patterns_per_bin"]), |
| block_weights=self.hf_config.get("block_weights", {k:1.0 for k in self.hf_config["block_configs"].keys()}), |
| loss_config=self.hf_config.get("loss_config", {}) |
| ).to(device).eval() |
| |
| state = load_file(self.weights_path) |
| self.gdc.load_state_dict(state, strict=False) |
| for p in self.gdc.parameters(): p.requires_grad_(False) |
| |
| print(f"✓ David loaded from HF: {self.repo_dir}") |
| print(f" blocks={len(self.hf_config['block_configs'])} bins={self.hf_config['num_timestep_bins']} patterns={self.hf_config['num_patterns_per_bin']}") |
| |
| if "block_weights" in self.hf_config: |
| cfg.block_weights = self.hf_config["block_weights"] |
|
|
| class DavidAssessor(nn.Module): |
| def __init__(self, gdc: GeoDavidCollective, pooling: str): |
| super().__init__() |
| self.gdc = gdc |
| self.pooling = pooling |
|
|
| def _pool(self, feats: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| return {k: spatial_pool(v, k, self.pooling) for k, v in feats.items()} |
|
|
| @torch.no_grad() |
| def forward(self, feats_student: Dict[str, torch.Tensor], t: torch.LongTensor |
| ) -> Tuple[Dict[str,float], Dict[str,float], Dict[str,float]]: |
| Zs = self._pool(feats_student) |
| outs = self.gdc(Zs, t.float()) |
| e_t, e_p, coh = {}, {}, {} |
|
|
| ts_key = None |
| for key in ["timestep_logits", "logits_timestep", "timestep_head_logits"]: |
| if key in outs: ts_key = key; break |
| |
| pt_key = None |
| for key in ["pattern_logits", "logits_pattern", "pattern_head_logits"]: |
| if key in outs: pt_key = key; break |
|
|
| t_bins = (t // 10).to(next(self.gdc.parameters()).device) |
| if ts_key is not None: |
| ts_logits = outs[ts_key] |
| if isinstance(ts_logits, dict): |
| for name, L in ts_logits.items(): |
| ce = F.cross_entropy(L, t_bins, reduction="mean") |
| e_t[name] = float(ce.item()) |
| else: |
| ce = F.cross_entropy(ts_logits, t_bins, reduction="mean") |
| for name in Zs.keys(): |
| e_t[name] = float(ce.item()) |
| else: |
| for name in Zs.keys(): e_t[name] = 0.0 |
|
|
| if pt_key is not None: |
| pt_logits = outs[pt_key] |
| if isinstance(pt_logits, dict): |
| for name, L in pt_logits.items(): |
| P = L.softmax(-1) |
| ent = -(P * (P.clamp_min(1e-9)).log()).sum(-1).mean() |
| e_p[name] = float(ent.item() / math.log(P.shape[-1])) |
| else: |
| P = pt_logits.softmax(-1) |
| ent = -(P * (P.clamp_min(1e-9)).log()).sum(-1).mean() |
| for name in Zs.keys(): |
| e_p[name] = float(ent.item() / math.log(P.shape[-1])) |
| else: |
| for name in Zs.keys(): e_p[name] = 0.0 |
|
|
| alphas = {} |
| try: |
| alphas = self.gdc.get_cantor_alphas() |
| except Exception: |
| alphas = {} |
| avg_alpha = float(sum(alphas.values())/max(len(alphas),1)) if alphas else 1.0 |
| for name in Zs.keys(): |
| coh[name] = avg_alpha |
|
|
| return e_t, e_p, coh |
|
|
| class BlockPenaltyFusion: |
| def __init__(self, cfg: BaseConfig): self.cfg = cfg |
| def lambdas(self, e_t:Dict[str,float], e_p:Dict[str,float], coh:Dict[str,float]) -> Dict[str,float]: |
| lam = {} |
| for name, base in self.cfg.block_weights.items(): |
| val = base * (1.0 |
| + self.cfg.alpha_timestep * float(e_t.get(name,0.0)) |
| + self.cfg.beta_pattern * float(e_p.get(name,0.0)) |
| + self.cfg.delta_incoherence * (1.0 - float(coh.get(name,1.0)))) |
| lam[name] = float(max(self.cfg.lambda_min, min(self.cfg.lambda_max, val))) |
| return lam |
|
|
|
|
| |
| |
| |
| class FlowMatchDavidTrainer: |
| def __init__(self, cfg: BaseConfig, device: str = "cuda"): |
| self.cfg = cfg |
| self.device = device |
| self.start_epoch = 0 |
| self.start_gstep = 0 |
|
|
| |
| self.david_loader = DavidLoader(cfg, device) |
| self.david = self.david_loader.gdc |
| self.assessor = DavidAssessor(self.david, cfg.pooling) |
| self.fusion = BlockPenaltyFusion(cfg) |
| |
| |
| self.teacher = SD15Teacher(cfg, device).eval() |
| |
| |
| self.timestep_sampler = None |
| if cfg.use_timestep_weighting: |
| print("\n" + "="*70) |
| print("🎯 ADAPTIVE TIMESTEP SAMPLING ENABLED") |
| print(f" David weighting: {cfg.use_david_weights}") |
| print(f" SD3 shift: {cfg.timestep_shift}") |
| print(f" Base jitter: ±{cfg.base_jitter}") |
| print(f" Adaptive chaos: {cfg.adaptive_chaos}") |
| |
| self.timestep_sampler = DavidWeightedTimestepSampler( |
| num_timesteps=cfg.num_train_timesteps, |
| num_bins=100, |
| shift=cfg.timestep_shift if cfg.use_david_weights else 0.0, |
| base_jitter=cfg.base_jitter, |
| adaptive_chaos=cfg.adaptive_chaos |
| ) |
| |
| if cfg.use_david_weights: |
| self.timestep_sampler.compute_difficulty_from_david( |
| david=self.david, |
| teacher=self.teacher, |
| device=device, |
| num_samples=cfg.profile_samples |
| ) |
| print("="*70 + "\n") |
| |
| |
| self.dataset = SymbolicPromptDataset(cfg.num_samples, cfg.seed, self.timestep_sampler) |
| self.loader = DataLoader(self.dataset, batch_size=cfg.batch_size, shuffle=True, |
| num_workers=cfg.num_workers, pin_memory=True, collate_fn=collate) |
|
|
| |
| self.student = StudentUNet(self.teacher.unet, cfg.active_blocks, cfg.use_local_flow_heads).to(device) |
|
|
| self.opt = torch.optim.AdamW(self.student.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) |
| self.sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=cfg.epochs * len(self.loader)) |
| self.scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp) |
|
|
| |
| emergency_path = Path("./EMERGENCY_SAVE_SUCCESS.pt") |
| if not emergency_path.exists(): |
| print("\n🔍 Emergency checkpoint not found locally, checking HuggingFace...") |
| emergency_path = self._download_emergency_checkpoint() |
| |
| if emergency_path and emergency_path.exists(): |
| self._load_emergency_checkpoint(emergency_path) |
| elif cfg.continue_training: |
| self._load_latest_from_hf() |
|
|
| self.writer = SummaryWriter(log_dir=os.path.join(cfg.out_dir, cfg.run_name)) |
|
|
| def _download_emergency_checkpoint(self) -> Optional[Path]: |
| """Download emergency checkpoint from HuggingFace backup repo.""" |
| emergency_repo = "AbstractPhil/sd15-flow-emergency-backup" |
| emergency_file = "EMERGENCY_SAVE_SUCCESS.pt" |
| |
| try: |
| print(f"📥 Downloading emergency checkpoint from {emergency_repo}...") |
| local_path = hf_hub_download( |
| repo_id=emergency_repo, |
| filename=emergency_file, |
| repo_type="model", |
| cache_dir="./_emergency_cache" |
| ) |
| |
| target_path = Path("./EMERGENCY_SAVE_SUCCESS.pt") |
| shutil.copy(local_path, target_path) |
| |
| size_mb = target_path.stat().st_size / 1e6 |
| print(f"✅ Downloaded emergency checkpoint ({size_mb:.1f} MB)") |
| return target_path |
| |
| except Exception as e: |
| print(f"⚠️ Could not download emergency checkpoint: {e}") |
| return None |
|
|
| def _load_emergency_checkpoint(self, path: Path): |
| """Load emergency checkpoint with student_unet structure.""" |
| try: |
| print(f"\n🚨 Found emergency checkpoint: {path}") |
| checkpoint = torch.load(path, map_location='cpu') |
| |
| if 'student_unet' in checkpoint: |
| print("📦 Loading emergency checkpoint format...") |
| missing, unexpected = self.student.unet.load_state_dict(checkpoint['student_unet'], strict=False) |
| print(f"✓ Loaded student UNet") |
| |
| if 'opt' in checkpoint: |
| self.opt.load_state_dict(checkpoint['opt']) |
| print("✓ Loaded optimizer state") |
| |
| if 'sched' in checkpoint: |
| self.sched.load_state_dict(checkpoint['sched']) |
| print("✓ Loaded scheduler state") |
| |
| if 'gstep' in checkpoint: |
| self.start_gstep = checkpoint['gstep'] |
| self.start_epoch = self.start_gstep // len(self.loader) |
| print(f"✓ Resuming from global step {self.start_gstep} (epoch ~{self.start_epoch})") |
| |
| print("✅ Emergency checkpoint loaded successfully!") |
| |
| except Exception as e: |
| print(f"⚠️ Failed to load emergency checkpoint: {e}") |
|
|
| def _load_latest_from_hf(self): |
| if not self.cfg.hf_repo_id: |
| return |
| |
| try: |
| api = HfApi() |
| print(f"\n🔍 Searching for latest checkpoint in {self.cfg.hf_repo_id}...") |
| |
| files = api.list_repo_files(repo_id=self.cfg.hf_repo_id, repo_type="model") |
| epochs = [] |
| for f in files: |
| if f.endswith('.pt'): |
| match = re.search(r'_e(\d+)\.pt$', f) |
| if match: |
| epochs.append((int(match.group(1)), f)) |
| |
| if not epochs: |
| return |
| |
| latest_epoch, latest_file = max(epochs, key=lambda x: x[0]) |
| print(f"📥 Downloading: {latest_file}") |
| |
| local_path = hf_hub_download( |
| repo_id=self.cfg.hf_repo_id, |
| filename=latest_file, |
| repo_type="model", |
| cache_dir=self.cfg.ckpt_dir |
| ) |
| |
| checkpoint = torch.load(local_path, map_location='cpu') |
| |
| if 'student_unet' in checkpoint: |
| self.student.unet.load_state_dict(checkpoint['student_unet'], strict=False) |
| elif 'student' in checkpoint: |
| self.student.load_state_dict(checkpoint['student'], strict=False) |
| |
| if 'opt' in checkpoint: |
| self.opt.load_state_dict(checkpoint['opt']) |
| if 'sched' in checkpoint: |
| self.sched.load_state_dict(checkpoint['sched']) |
| |
| self.start_epoch = latest_epoch |
| self.start_gstep = latest_epoch * len(self.loader) |
| |
| print(f"✅ Resuming from epoch {self.start_epoch + 1}") |
| del checkpoint |
| torch.cuda.empty_cache() |
| |
| except Exception as e: |
| print(f"⚠️ Failed to load from HF: {e}") |
|
|
| def _v_star(self, x_t, t, eps_hat): |
| alpha, sigma = self.teacher.alpha_sigma(t) |
| x0_hat = (x_t - sigma * eps_hat) / (alpha + 1e-8) |
| return alpha * eps_hat - sigma * x0_hat |
|
|
| def _down_like(self, tgt: torch.Tensor, ref: torch.Tensor) -> torch.Tensor: |
| return F.interpolate(tgt, size=ref.shape[-2:], mode="bilinear", align_corners=False) |
|
|
| def _kd_cos(self, s: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
| s = F.normalize(s, dim=-1); t = F.normalize(t, dim=-1) |
| return 1.0 - (s*t).sum(-1).mean() |
|
|
| def train(self): |
| cfg = self.cfg |
| gstep = self.start_gstep |
| |
| for ep in range(self.start_epoch, cfg.epochs): |
| self.student.train() |
| pbar = tqdm(self.loader, desc=f"Epoch {ep+1}/{cfg.epochs}", |
| dynamic_ncols=True, leave=True, position=0) |
| acc = {"L":0.0, "Lf":0.0, "Lb":0.0} |
|
|
| for it, batch in enumerate(pbar): |
| prompts = batch["prompts"] |
| t = batch["t"].to(self.device) |
|
|
| with torch.no_grad(): |
| ehs = self.teacher.encode(prompts) |
|
|
| x_t = torch.randn(len(prompts), 4, 64, 64, device=self.device, dtype=torch.float16) |
|
|
| with torch.no_grad(): |
| eps_hat, t_feats_spatial = self.teacher.forward_eps_and_feats(x_t.half(), t, ehs) |
| v_star = self._v_star(x_t, t, eps_hat) |
|
|
| with torch.cuda.amp.autocast(enabled=cfg.amp): |
| v_hat, s_feats_spatial = self.student(x_t, t, ehs) |
| L_flow = F.mse_loss(v_hat, v_star) |
|
|
| e_t, e_p, coh = self.assessor(s_feats_spatial, t) |
| lam = self.fusion.lambdas(e_t, e_p, coh) |
|
|
| L_blocks = torch.zeros((), device=self.device) |
| for name, s_feat in s_feats_spatial.items(): |
| L_kd = torch.zeros((), device=self.device) |
| if cfg.use_kd: |
| s_pool = spatial_pool(s_feat, name, cfg.pooling) |
| t_pool = spatial_pool(t_feats_spatial[name], name, cfg.pooling) |
| L_kd = self._kd_cos(s_pool, t_pool) |
| |
| L_lf = torch.zeros((), device=self.device) |
| if cfg.use_local_flow_heads and name in self.student.local_heads: |
| v_loc = self.student.local_heads[name](s_feat) |
| v_ds = self._down_like(v_star, v_loc) |
| L_lf = F.mse_loss(v_loc, v_ds) |
| L_blocks = L_blocks + lam.get(name,1.0) * (cfg.kd_weight * L_kd + cfg.local_flow_weight * L_lf) |
|
|
| L_total = cfg.global_flow_weight*L_flow + cfg.block_penalty_weight*L_blocks |
|
|
| self.opt.zero_grad(set_to_none=True) |
| if cfg.amp: |
| self.scaler.scale(L_total).backward() |
| nn.utils.clip_grad_norm_(self.student.parameters(), cfg.grad_clip) |
| self.scaler.step(self.opt); self.scaler.update() |
| else: |
| L_total.backward() |
| nn.utils.clip_grad_norm_(self.student.parameters(), cfg.grad_clip) |
| self.opt.step() |
| self.sched.step(); gstep += 1 |
|
|
| acc["L"] += float(L_total.item()) |
| acc["Lf"] += float(L_flow.item()) |
| acc["Lb"] += float(L_blocks.item()) |
|
|
| if it % 50 == 0: |
| self.writer.add_scalar("train/total", float(L_total.item()), gstep) |
| self.writer.add_scalar("train/flow", float(L_flow.item()), gstep) |
| self.writer.add_scalar("train/blocks",float(L_blocks.item()), gstep) |
| for k in list(lam.keys())[:4]: |
| self.writer.add_scalar(f"lambda/{k}", lam[k], gstep) |
|
|
| if it % 10 == 0 or it == len(self.loader) - 1: |
| pbar.set_postfix({ |
| "L": f"{float(L_total.item()):.4f}", |
| "Lf": f"{float(L_flow.item()):.4f}", |
| "Lb": f"{float(L_blocks.item()):.4f}" |
| }, refresh=False) |
| |
| del x_t, eps_hat, v_star, v_hat, s_feats_spatial, t_feats_spatial |
|
|
| pbar.close() |
| |
| n = len(self.loader) |
| print(f"\n[Epoch {ep+1}] L={acc['L']/n:.4f} | L_flow={acc['Lf']/n:.4f} | L_blocks={acc['Lb']/n:.4f}") |
| self.writer.add_scalar("epoch/total", acc['L']/n, ep+1) |
| self.writer.add_scalar("epoch/flow", acc['Lf']/n, ep+1) |
| self.writer.add_scalar("epoch/blocks",acc['Lb']/n, ep+1) |
|
|
| if (ep+1) % cfg.save_every == 0: |
| self._save(ep+1, gstep) |
|
|
| self._save("final", gstep) |
| self.writer.close() |
|
|
| def _save(self, tag, gstep): |
| """Save checkpoint and upload to HuggingFace.""" |
| pt_path = Path(self.cfg.ckpt_dir) / f"{self.cfg.run_name}_e{tag}.pt" |
| torch.save({ |
| "cfg": asdict(self.cfg), |
| "student": self.student.state_dict(), |
| "opt": self.opt.state_dict(), |
| "sched": self.sched.state_dict(), |
| "gstep": gstep |
| }, pt_path) |
| |
| size_mb = pt_path.stat().st_size / 1e6 |
| print(f"✓ Saved checkpoint: {pt_path.name} ({size_mb:.1f} MB)") |
| |
| if self.cfg.upload_every_epoch and self.cfg.hf_repo_id: |
| self._upload_to_hf(pt_path, tag) |
|
|
| def _upload_to_hf(self, path: Path, tag): |
| """Upload checkpoint to HuggingFace.""" |
| try: |
| api = HfApi() |
| create_repo(self.cfg.hf_repo_id, exist_ok=True, private=False, repo_type="model") |
| |
| print(f"📤 Uploading {path.name} to {self.cfg.hf_repo_id}...") |
| api.upload_file( |
| path_or_fileobj=str(path), |
| path_in_repo=path.name, |
| repo_id=self.cfg.hf_repo_id, |
| repo_type="model", |
| commit_message=f"Epoch {tag}" |
| ) |
| print(f"✅ Uploaded: https://huggingface.co/{self.cfg.hf_repo_id}/{path.name}") |
| |
| except Exception as e: |
| print(f"⚠️ Upload failed: {e}") |
|
|
| @torch.no_grad() |
| def sample(self, prompts: List[str], steps: Optional[int]=None, guidance: Optional[float]=None) -> torch.Tensor: |
| steps = steps or self.cfg.sample_steps |
| guidance = guidance if guidance is not None else self.cfg.guidance_scale |
| cond_e = self.teacher.encode(prompts) |
| uncond_e = self.teacher.encode([""]*len(prompts)) |
| sched = self.teacher.sched |
| sched.set_timesteps(steps, device=self.device) |
| x_t = torch.randn(len(prompts), 4, 64, 64, device=self.device) |
|
|
| for t_scalar in sched.timesteps: |
| t = torch.full((x_t.shape[0],), t_scalar, device=self.device, dtype=torch.long) |
| v_u, _ = self.student(x_t, t, uncond_e) |
| v_c, _ = self.student(x_t, t, cond_e) |
| v_hat = v_u + guidance*(v_c - v_u) |
|
|
| alpha, sigma = self.teacher.alpha_sigma(t) |
| denom = (alpha**2 + sigma**2) |
| x0_hat = (alpha * x_t - sigma * v_hat) / (denom + 1e-8) |
| eps_hat = (x_t - alpha * x0_hat) / (sigma + 1e-8) |
|
|
| step = sched.step(model_output=eps_hat, timestep=t_scalar, sample=x_t) |
| x_t = step.prev_sample |
|
|
| imgs = self.teacher.pipe.vae.decode(x_t / 0.18215).sample |
| return imgs.clamp(-1,1) |
|
|
|
|
| |
| |
| |
| def main(): |
| cfg = BaseConfig() |
| print(json.dumps(asdict(cfg), indent=2)) |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| if device != "cuda": |
| print("⚠️ A100 strongly recommended.") |
| trainer = FlowMatchDavidTrainer(cfg, device=device) |
| trainer.train() |
| _ = trainer.sample(["a castle at sunset"], steps=10, guidance=7.0) |
| print("✓ Training complete.") |
|
|
| if __name__ == "__main__": |
| main() |