# ===================================================================================== # SD1.5 Flow-Matching Trainer — David-Driven Block Penalties (HF-loaded) # Author: AbstractPhil # Assistant: Claude Sonnet 4.5 + GPT 4o # - BaseConfig at top # - Functionality (teacher/student/david/assessor/fusion/trainer) # - Activations at bottom # ===================================================================================== # try: # !pip uninstall -qy geometricvocab # except: # pass # # !pip install -q git+https://github.com/AbstractEyes/lattice_vocabulary.git # # ===================================================================================== from __future__ import annotations import os, json, math, random, re from dataclasses import dataclass, asdict from pathlib import Path from typing import Dict, List, Tuple, Optional import urllib.request import subprocess import shutil import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm # Diffusers from diffusers import StableDiffusionPipeline, DDPMScheduler from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel # Repo deps (present in your repo) from geovocab2.train.model.core.geo_david_collective import GeoDavidCollective from geovocab2.data.prompt.symbolic_tree import SynthesisSystem # HF / safetensors from huggingface_hub import snapshot_download, HfApi, create_repo, hf_hub_download from safetensors.torch import load_file # ===================================================================================== # 1) CONFIG (BaseConfig) # ===================================================================================== @dataclass class BaseConfig: run_name: str = "sd15_flowmatch_david_hf" out_dir: str = "./runs/sd15_flowmatch_david_hf" ckpt_dir: str = "./checkpoints_sd15_flow_david_hf" save_every: int = 1 # Data num_samples: int = 200_000 batch_size: int = 32 num_workers: int = 2 seed: int = 42 # Models / Blocks model_id: str = "runwayml/stable-diffusion-v1-5" active_blocks: Tuple[str, ...] = ("down_0","down_1","down_2","down_3","mid","up_0","up_1","up_2","up_3") pooling: str = "mean" # mean | max | adaptive # Flow training epochs: int = 10 lr: float = 1e-4 weight_decay: float = 1e-3 grad_clip: float = 1.0 amp: bool = True global_flow_weight: float = 1.0 block_penalty_weight: float = 0.2 # ← NEW: Start very low! use_local_flow_heads: bool = False local_flow_weight: float = 1.0 # KD (optional) use_kd: bool = True kd_weight: float = 0.25 # David (ALWAYS used, HF) david_repo_id: str = "AbstractPhil/geo-david-collective-sd15-base-e40" david_cache_dir: str = "./_hf_david_cache" david_state_key: Optional[str] = None # None→raw state; or "model_state_dict" if ckpt-style # Fusion: λ_b = w_b * (1 + α·e_t + β·e_p + δ·(1−coh)) alpha_timestep: float = 0.5 beta_pattern: float = 0.25 delta_incoherence: float = 0.25 lambda_min: float = 0.5 lambda_max: float = 3.0 # Block weights (overridden by HF config if present) block_weights: Dict[str, float] = None # Scheduler num_train_timesteps: int = 1000 # Inference sample_steps: int = 30 guidance_scale: float = 7.5 # HuggingFace upload & resume hf_repo_id: Optional[str] = "AbstractPhil/sd15-flow-matching" upload_every_epoch: bool = True continue_training: bool = True # Download latest checkpoint and resume def __post_init__(self): Path(self.out_dir).mkdir(parents=True, exist_ok=True) Path(self.ckpt_dir).mkdir(parents=True, exist_ok=True) Path(self.david_cache_dir).mkdir(parents=True, exist_ok=True) if self.block_weights is None: self.block_weights = {'down_0':0.7,'down_1':0.9,'down_2':1.0,'down_3':1.1,'mid':1.2,'up_0':1.1,'up_1':1.0,'up_2':0.9,'up_3':0.7} # ===================================================================================== # 2) DATA # ===================================================================================== class SymbolicPromptDataset(Dataset): def __init__(self, n:int, seed:int=42): self.n = n random.seed(seed) self.sys = SynthesisSystem(seed=seed) def __len__(self): return self.n def __getitem__(self, idx): r = self.sys.synthesize(complexity=random.choice([1,2,3,4,5])) prompt = r['text'] t = random.randint(0, 999) return {"prompt": prompt, "t": t} def collate(batch: List[dict]): prompts = [b["prompt"] for b in batch] t = torch.tensor([b["t"] for b in batch], dtype=torch.long) t_bins = t // 10 return {"prompts": prompts, "t": t, "t_bins": t_bins} # ===================================================================================== # 3) HOOKS + POOLING # ===================================================================================== class HookBank: def __init__(self, unet: UNet2DConditionModel, active: Tuple[str, ...]): self.active = set(active) self.bank: Dict[str, torch.Tensor] = {} self.hooks: List[torch.utils.hooks.RemovableHandle] = [] self._register(unet) def _register(self, unet: UNet2DConditionModel): def mk(name): def h(m, i, o): out = o[0] if isinstance(o,(tuple,list)) else o self.bank[name] = out return h for i, blk in enumerate(unet.down_blocks): nm = f"down_{i}" if nm in self.active: self.hooks.append(blk.register_forward_hook(mk(nm))) if "mid" in self.active: self.hooks.append(unet.mid_block.register_forward_hook(mk("mid"))) for i, blk in enumerate(unet.up_blocks): nm = f"up_{i}" if nm in self.active: self.hooks.append(blk.register_forward_hook(mk(nm))) def clear(self): self.bank.clear() def close(self): for h in self.hooks: h.remove() self.hooks.clear() def spatial_pool(x: torch.Tensor, name: str, policy: str) -> torch.Tensor: if policy == "mean": return x.mean(dim=(2,3)) if policy == "max": return x.amax(dim=(2,3)) if policy == "adaptive": return x.mean(dim=(2,3)) if (name.startswith("down") or name=="mid") else x.amax(dim=(2,3)) raise ValueError(f"Unknown pooling: {policy}") # ===================================================================================== # 4) TEACHER (SD1.5) # ===================================================================================== class SD15Teacher(nn.Module): def __init__(self, cfg: BaseConfig, device: str): super().__init__() self.pipe = StableDiffusionPipeline.from_pretrained(cfg.model_id, torch_dtype=torch.float16, safety_checker=None).to(device) self.unet: UNet2DConditionModel = self.pipe.unet self.text_encoder = self.pipe.text_encoder self.tokenizer = self.pipe.tokenizer self.hooks = HookBank(self.unet, cfg.active_blocks) self.sched = DDPMScheduler(num_train_timesteps=cfg.num_train_timesteps) self.device = device for p in self.parameters(): p.requires_grad_(False) @torch.no_grad() def encode(self, prompts: List[str]) -> torch.Tensor: tok = self.tokenizer(prompts, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") return self.text_encoder(tok.input_ids.to(self.device))[0] @torch.no_grad() def forward_eps_and_feats(self, x_t: torch.Tensor, t: torch.LongTensor, ehs: torch.Tensor): self.hooks.clear() eps_hat = self.unet(x_t, t, encoder_hidden_states=ehs).sample feats = {k: v.detach().float() for k, v in self.hooks.bank.items()} return eps_hat.float(), feats def alpha_sigma(self, t: torch.LongTensor) -> Tuple[torch.Tensor, torch.Tensor]: ac = self.sched.alphas_cumprod.to(self.device)[t] alpha = ac.sqrt().view(-1,1,1,1).float() sigma = (1.0 - ac).sqrt().view(-1,1,1,1).float() return alpha, sigma # ===================================================================================== # 5) STUDENT (v-pred) + LOCAL FLOW HEADS # ===================================================================================== class StudentUNet(nn.Module): def __init__(self, teacher_unet: UNet2DConditionModel, active_blocks: Tuple[str,...], use_local_heads: bool): super().__init__() self.unet = UNet2DConditionModel.from_config(teacher_unet.config) self.unet.load_state_dict(teacher_unet.state_dict(), strict=True) self.hooks = HookBank(self.unet, active_blocks) self.use_local_heads = use_local_heads self.local_heads = nn.ModuleDict() def _ensure_heads(self, feats: Dict[str, torch.Tensor]): if not self.use_local_heads: return if len(self.local_heads) == len(feats): return # Get dtype from main UNet target_dtype = next(self.unet.parameters()).dtype for name, f in feats.items(): c = f.shape[1] if name not in self.local_heads: head = nn.Conv2d(c, 4, kernel_size=1) # Convert head to match UNet dtype head = head.to(dtype=target_dtype, device=f.device) self.local_heads[name] = head def forward(self, x_t: torch.Tensor, t: torch.LongTensor, ehs: torch.Tensor): self.hooks.clear() v_hat = self.unet(x_t, t, encoder_hidden_states=ehs).sample feats = {k: v for k, v in self.hooks.bank.items()} # Keep original dtype self._ensure_heads(feats) return v_hat, feats # ===================================================================================== # 6) DAVID LOADER (HF) + ASSESSOR + FUSION # ===================================================================================== class DavidLoader: """ Downloads HF repo (config + safetensors), instantiates GeoDavidCollective with HF config, loads weights, returns a frozen model + the parsed HF config. """ def __init__(self, cfg: BaseConfig, device: str): self.cfg = cfg self.device = device self.repo_dir = snapshot_download(repo_id=cfg.david_repo_id, local_dir=cfg.david_cache_dir, local_dir_use_symlinks=False) self.config_path = os.path.join(self.repo_dir, "config.json") self.weights_path = os.path.join(self.repo_dir, "model.safetensors") with open(self.config_path, "r") as f: self.hf_config = json.load(f) # Instantiate GeoDavidCollective from HF config self.gdc = GeoDavidCollective( block_configs=self.hf_config["block_configs"], num_timestep_bins=int(self.hf_config["num_timestep_bins"]), num_patterns_per_bin=int(self.hf_config["num_patterns_per_bin"]), block_weights=self.hf_config.get("block_weights", {k:1.0 for k in self.hf_config["block_configs"].keys()}), loss_config=self.hf_config.get("loss_config", {}) ).to(device).eval() # Load weights state = load_file(self.weights_path) self.gdc.load_state_dict(state, strict=False) for p in self.gdc.parameters(): p.requires_grad_(False) # Report print(f"✓ David loaded from HF: {self.repo_dir}") print(f" blocks={len(self.hf_config['block_configs'])} bins={self.hf_config['num_timestep_bins']} patterns={self.hf_config['num_patterns_per_bin']}") # Override block weights in main cfg if provided if "block_weights" in self.hf_config: cfg.block_weights = self.hf_config["block_weights"] class DavidAssessor(nn.Module): """ Uses David to score STUDENT pooled features (per block) and timesteps. Produces: e_t[name] : timestep CE error proxy (scalar) e_p[name] : pattern CE error proxy if logits present, else 0 coh[name] : coherence proxy (avg Cantor alpha if provided, else 1) """ def __init__(self, gdc: GeoDavidCollective, pooling: str): super().__init__() self.gdc = gdc self.pooling = pooling def _pool(self, feats: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return {k: spatial_pool(v, k, self.pooling) for k, v in feats.items()} @torch.no_grad() def forward(self, feats_student: Dict[str, torch.Tensor], t: torch.LongTensor ) -> Tuple[Dict[str,float], Dict[str,float], Dict[str,float]]: Zs = self._pool(feats_student) # [B,C] per block outs = self.gdc(Zs, t.float()) # forward for predictions/logits e_t, e_p, coh = {}, {}, {} # timestep logits ts_key = None for key in ["timestep_logits", "logits_timestep", "timestep_head_logits"]: if key in outs: ts_key = key; break # pattern logits (optional) pt_key = None for key in ["pattern_logits", "logits_pattern", "pattern_head_logits"]: if key in outs: pt_key = key; break t_bins = (t // 10).to(next(self.gdc.parameters()).device) if ts_key is not None: # Expect dict per block or a tensor across blocks; support both ts_logits = outs[ts_key] if isinstance(ts_logits, dict): for name, L in ts_logits.items(): ce = F.cross_entropy(L, t_bins, reduction="mean") e_t[name] = float(ce.item()) else: # single head: broadcast same CE to all blocks ce = F.cross_entropy(ts_logits, t_bins, reduction="mean") for name in Zs.keys(): e_t[name] = float(ce.item()) else: for name in Zs.keys(): e_t[name] = 0.0 if pt_key is not None: pt_logits = outs[pt_key] # If no labels for pattern, use entropy as "error" proxy if isinstance(pt_logits, dict): for name, L in pt_logits.items(): P = L.softmax(-1) ent = -(P * (P.clamp_min(1e-9)).log()).sum(-1).mean() e_p[name] = float(ent.item() / math.log(P.shape[-1])) else: P = pt_logits.softmax(-1) ent = -(P * (P.clamp_min(1e-9)).log()).sum(-1).mean() for name in Zs.keys(): e_p[name] = float(ent.item() / math.log(P.shape[-1])) else: for name in Zs.keys(): e_p[name] = 0.0 # Cantor alphas / coherence alphas = {} try: alphas = self.gdc.get_cantor_alphas() # dict of scalars except Exception: alphas = {} avg_alpha = float(sum(alphas.values())/max(len(alphas),1)) if alphas else 1.0 for name in Zs.keys(): coh[name] = avg_alpha # higher=more coherent return e_t, e_p, coh class BlockPenaltyFusion: def __init__(self, cfg: BaseConfig): self.cfg = cfg def lambdas(self, e_t:Dict[str,float], e_p:Dict[str,float], coh:Dict[str,float]) -> Dict[str,float]: lam = {} for name, base in self.cfg.block_weights.items(): val = base * (1.0 + self.cfg.alpha_timestep * float(e_t.get(name,0.0)) + self.cfg.beta_pattern * float(e_p.get(name,0.0)) + self.cfg.delta_incoherence * (1.0 - float(coh.get(name,1.0)))) lam[name] = float(max(self.cfg.lambda_min, min(self.cfg.lambda_max, val))) return lam # ===================================================================================== # 7) TRAINER + INFERENCE # ===================================================================================== class FlowMatchDavidTrainer: def __init__(self, cfg: BaseConfig, device: str = "cuda"): self.cfg = cfg self.device = device self.start_epoch = 0 self.start_gstep = 0 # Data self.dataset = SymbolicPromptDataset(cfg.num_samples, cfg.seed) self.loader = DataLoader(self.dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=True, collate_fn=collate) # Teacher & Student self.teacher = SD15Teacher(cfg, device).eval() self.student = StudentUNet(self.teacher.unet, cfg.active_blocks, cfg.use_local_flow_heads).to(device) # David self.david_loader = DavidLoader(cfg, device) self.david = self.david_loader.gdc # Assessor + Fusion self.assessor = DavidAssessor(self.david, cfg.pooling) self.fusion = BlockPenaltyFusion(cfg) # Opt/Sched/AMP self.opt = torch.optim.AdamW(self.student.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) self.sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=cfg.epochs * len(self.loader)) self.scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp) # Try to resume from HF if enabled if cfg.continue_training: self._load_latest_from_hf() # Logs self.writer = SummaryWriter(log_dir=os.path.join(cfg.out_dir, cfg.run_name)) def _load_latest_from_hf(self): """Download and load the latest checkpoint from HuggingFace.""" if not self.cfg.hf_repo_id: print("⚠️ continue_training=True but no hf_repo_id specified") return try: api = HfApi() print(f"\n🔍 Searching for latest checkpoint in {self.cfg.hf_repo_id}...") # Check if repo exists try: repo_info = api.repo_info(repo_id=self.cfg.hf_repo_id, repo_type="model") except Exception as e: print(f"⚠️ Could not access repo: {e}") print(" Starting training from scratch") return # List all files in repo files = api.list_repo_files(repo_id=self.cfg.hf_repo_id, repo_type="model") if not files: print("ℹ️ Repo is empty, starting from scratch") return print(f"📂 Found {len(files)} files in repo:") for f in files: print(f" - {f}") # Find all .safetensors files with epoch numbers # Try multiple patterns epochs = [] for f in files: if not f.endswith('.safetensors'): continue # Look for _e pattern anywhere in filename match = re.search(r'_e(\d+)\.safetensors$', f) if match: epoch_num = int(match.group(1)) epochs.append((epoch_num, f)) print(f"✓ Found checkpoint: {f} (epoch {epoch_num})") if not epochs: print("ℹ️ No checkpoint files found (looking for *_e.safetensors)") return # Get latest epoch latest_epoch, latest_file = max(epochs, key=lambda x: x[0]) print(f"\n📥 Downloading latest checkpoint: {latest_file} (epoch {latest_epoch})") # Download the safetensors file local_path = hf_hub_download( repo_id=self.cfg.hf_repo_id, filename=latest_file, repo_type="model", cache_dir=self.cfg.ckpt_dir ) print(f"✓ Downloaded to: {local_path}") # Load the checkpoint using from_single_file print("📦 Loading checkpoint into pipeline...") pipe = StableDiffusionPipeline.from_single_file( local_path, torch_dtype=torch.float16, safety_checker=None, load_safety_checker=False ) # Extract UNet state dict unet_state = pipe.unet.state_dict() # Load into student missing, unexpected = self.student.unet.load_state_dict(unet_state, strict=False) print(f"✓ Loaded student UNet from epoch {latest_epoch}") if missing: print(f" Missing keys: {len(missing)}") if unexpected: print(f" Unexpected keys: {len(unexpected)}") # Set starting epoch (resume from next epoch) self.start_epoch = latest_epoch self.start_gstep = latest_epoch * len(self.loader) print(f"🎯 Resuming training from epoch {self.start_epoch + 1}") # Clean up del pipe torch.cuda.empty_cache() except Exception as e: print(f"⚠️ Failed to load checkpoint from HF: {e}") print(" Starting training from scratch") import traceback traceback.print_exc() # math helpers def _v_star(self, x_t, t, eps_hat): alpha, sigma = self.teacher.alpha_sigma(t) x0_hat = (x_t - sigma * eps_hat) / (alpha + 1e-8) return alpha * eps_hat - sigma * x0_hat def _down_like(self, tgt: torch.Tensor, ref: torch.Tensor) -> torch.Tensor: return F.interpolate(tgt, size=ref.shape[-2:], mode="bilinear", align_corners=False) def _kd_cos(self, s: torch.Tensor, t: torch.Tensor) -> torch.Tensor: s = F.normalize(s, dim=-1); t = F.normalize(t, dim=-1) return 1.0 - (s*t).sum(-1).mean() # training def train(self): cfg = self.cfg gstep = self.start_gstep for ep in range(self.start_epoch, cfg.epochs): self.student.train() pbar = tqdm(self.loader, desc=f"Epoch {ep+1}/{cfg.epochs}", dynamic_ncols=True, leave=True, position=0) # Add these params acc = {"L":0.0, "Lf":0.0, "Lb":0.0} for it, batch in enumerate(pbar): prompts = batch["prompts"] t = batch["t"].to(self.device) with torch.no_grad(): ehs = self.teacher.encode(prompts) # Latents x_t = torch.randn(len(prompts), 4, 64, 64, device=self.device, dtype=torch.float16) # Teacher targets with torch.no_grad(): eps_hat, t_feats_spatial = self.teacher.forward_eps_and_feats(x_t.half(), t, ehs) v_star = self._v_star(x_t, t, eps_hat) with torch.cuda.amp.autocast(enabled=cfg.amp): # Student v_hat, s_feats_spatial = self.student(x_t, t, ehs) L_flow = F.mse_loss(v_hat, v_star) # David assessor on STUDENT pooled features e_t, e_p, coh = self.assessor(s_feats_spatial, t) lam = self.fusion.lambdas(e_t, e_p, coh) # Per-block KD + Local flow L_blocks = torch.zeros((), device=self.device) for name, s_feat in s_feats_spatial.items(): # KD (pooled) L_kd = torch.zeros((), device=self.device) if cfg.use_kd: s_pool = spatial_pool(s_feat, name, cfg.pooling) t_pool = spatial_pool(t_feats_spatial[name], name, cfg.pooling) L_kd = self._kd_cos(s_pool, t_pool) # Local flow L_lf = torch.zeros((), device=self.device) if cfg.use_local_flow_heads and name in self.student.local_heads: v_loc = self.student.local_heads[name](s_feat) v_ds = self._down_like(v_star, v_loc) L_lf = F.mse_loss(v_loc, v_ds) L_blocks = L_blocks + lam.get(name,1.0) * (cfg.kd_weight * L_kd + cfg.local_flow_weight * L_lf) L_total = cfg.global_flow_weight*L_flow + cfg.block_penalty_weight*L_blocks self.opt.zero_grad(set_to_none=True) if cfg.amp: self.scaler.scale(L_total).backward() nn.utils.clip_grad_norm_(self.student.parameters(), cfg.grad_clip) self.scaler.step(self.opt); self.scaler.update() else: L_total.backward() nn.utils.clip_grad_norm_(self.student.parameters(), cfg.grad_clip) self.opt.step() self.sched.step(); gstep += 1 acc["L"] += float(L_total.item()) acc["Lf"] += float(L_flow.item()) acc["Lb"] += float(L_blocks.item()) # Only log to tensorboard every 50 iterations if it % 50 == 0: self.writer.add_scalar("train/total", float(L_total.item()), gstep) self.writer.add_scalar("train/flow", float(L_flow.item()), gstep) self.writer.add_scalar("train/blocks",float(L_blocks.item()), gstep) # log a few lambdas for k in list(lam.keys())[:4]: self.writer.add_scalar(f"lambda/{k}", lam[k], gstep) # Update progress bar less frequently to avoid double display if it % 10 == 0 or it == len(self.loader) - 1: # Update every 10 iterations pbar.set_postfix({ "L": f"{float(L_total.item()):.4f}", "Lf": f"{float(L_flow.item()):.4f}", "Lb": f"{float(L_blocks.item()):.4f}" }, refresh=False) # Add refresh=False del x_t, eps_hat, v_star, v_hat, s_feats_spatial, t_feats_spatial pbar.close() # Explicitly close the progress bar n = len(self.loader) print(f"\n[Epoch {ep+1}] L={acc['L']/n:.4f} | L_flow={acc['Lf']/n:.4f} | L_blocks={acc['Lb']/n:.4f}") self.writer.add_scalar("epoch/total", acc['L']/n, ep+1) self.writer.add_scalar("epoch/flow", acc['Lf']/n, ep+1) self.writer.add_scalar("epoch/blocks",acc['Lb']/n, ep+1) if (ep+1) % cfg.save_every == 0: self._save(ep+1, gstep) self._save("final", gstep) self.writer.close() def _save(self, tag, gstep): """Save and convert to ComfyUI format, then upload.""" # 1. Save .pt first (for resuming training if needed) pt_path = Path(self.cfg.ckpt_dir) / f"{self.cfg.run_name}_e{tag}.pt" torch.save({ "cfg": asdict(self.cfg), "student": self.student.state_dict(), "opt": self.opt.state_dict(), "sched": self.sched.state_dict(), "gstep": gstep }, pt_path) print(f"✓ Saved temp .pt: {pt_path}") # 2. Convert to ComfyUI safetensors safetensors_path = self._convert_to_comfyui(pt_path, tag) # 3. Upload to HF if self.cfg.upload_every_epoch and self.cfg.hf_repo_id and safetensors_path: self._upload_to_hf(safetensors_path, tag) # 4. Clean up large .pt file pt_path.unlink() print(f"✓ Cleaned up temp .pt file") def _convert_to_comfyui(self, pt_path: Path, tag) -> Optional[Path]: """Convert .pt to ComfyUI-compatible safetensors.""" try: temp_pipeline = Path(self.cfg.ckpt_dir) / f"temp_pipeline_e{tag}" output_safetensors = Path(self.cfg.ckpt_dir) / f"{self.cfg.run_name}_e{tag}.safetensors" # Download converter if needed converter_path = Path(self.cfg.ckpt_dir) / "convert_diffusers_to_original_stable_diffusion.py" if not converter_path.exists(): print("📥 Downloading official converter...") url = "https://raw.githubusercontent.com/huggingface/diffusers/main/scripts/convert_diffusers_to_original_stable_diffusion.py" urllib.request.urlretrieve(url, str(converter_path)) print("✓ Converter downloaded") # Load checkpoint print(f"📦 Creating diffusers pipeline from checkpoint...") checkpoint = torch.load(pt_path, map_location='cpu') student_state = checkpoint.get('student', checkpoint) # Load base UNet and replace with student weights print("📥 Loading base UNet...") unet = UNet2DConditionModel.from_pretrained( "runwayml/stable-diffusion-v1-5", subfolder="unet", torch_dtype=torch.float16 ) unet.load_state_dict(student_state, strict=False) print("✓ Loaded student weights into UNet") # Load full pipeline and replace UNet print("📥 Loading base SD1.5 pipeline...") pipe = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker=None ) pipe.unet = unet print("✓ Replaced UNet with student") # Save as pipeline print(f"💾 Saving diffusers pipeline...") pipe.save_pretrained(str(temp_pipeline), safe_serialization=True) print(f"✓ Pipeline saved to {temp_pipeline}") # Convert to checkpoint print(f"🔄 Converting to ComfyUI format...") cmd = [ "python", str(converter_path), "--model_path", str(temp_pipeline), "--checkpoint_path", str(output_safetensors), "--half" ] result = subprocess.run(cmd, capture_output=True, text=True) if result.returncode != 0: print(f"❌ Conversion failed: {result.stderr}") return None # Verify output if output_safetensors.exists(): size_mb = output_safetensors.stat().st_size / 1e6 print(f"✓ Converted: {output_safetensors.name} ({size_mb:.1f}MB)") # Clean up temp pipeline shutil.rmtree(temp_pipeline) print("✓ Cleaned up temp pipeline") return output_safetensors else: print(f"❌ Output file not created") return None except Exception as e: print(f"❌ Conversion failed: {e}") import traceback traceback.print_exc() return None def _upload_to_hf(self, path: Path, tag): """Upload safetensors to HuggingFace.""" try: api = HfApi() # Create repo if doesn't exist try: create_repo(self.cfg.hf_repo_id, exist_ok=True, private=False, repo_type="model") print(f"✓ Repo ready: {self.cfg.hf_repo_id}") except Exception: pass # Upload print(f"📤 Uploading to {self.cfg.hf_repo_id}...") api.upload_file( path_or_fileobj=str(path), path_in_repo=path.name, repo_id=self.cfg.hf_repo_id, repo_type="model", commit_message=f"Epoch {tag}" ) print(f"✅ Uploaded: https://huggingface.co/{self.cfg.hf_repo_id}/{path.name}") except Exception as e: print(f"⚠️ Upload failed: {e}") # ---------- Inference (v-pred sampling; use teacher VAE for decode) ---------- @torch.no_grad() def sample(self, prompts: List[str], steps: Optional[int]=None, guidance: Optional[float]=None) -> torch.Tensor: steps = steps or self.cfg.sample_steps guidance = guidance if guidance is not None else self.cfg.guidance_scale cond_e = self.teacher.encode(prompts) uncond_e = self.teacher.encode([""]*len(prompts)) sched = self.teacher.sched sched.set_timesteps(steps, device=self.device) x_t = torch.randn(len(prompts), 4, 64, 64, device=self.device) for t_scalar in sched.timesteps: t = torch.full((x_t.shape[0],), t_scalar, device=self.device, dtype=torch.long) v_u, _ = self.student(x_t, t, uncond_e) v_c, _ = self.student(x_t, t, cond_e) v_hat = v_u + guidance*(v_c - v_u) alpha, sigma = self.teacher.alpha_sigma(t) denom = (alpha**2 + sigma**2) x0_hat = (alpha * x_t - sigma * v_hat) / (denom + 1e-8) eps_hat = (x_t - alpha * x0_hat) / (sigma + 1e-8) step = sched.step(model_output=eps_hat, timestep=t_scalar, sample=x_t) x_t = step.prev_sample imgs = self.teacher.pipe.vae.decode(x_t / 0.18215).sample return imgs.clamp(-1,1) # ===================================================================================== # 8) ACTIVATION # ===================================================================================== def main(): cfg = BaseConfig() print(json.dumps(asdict(cfg), indent=2)) device = "cuda" if torch.cuda.is_available() else "cpu" if device != "cuda": print("⚠️ A100 strongly recommended.") trainer = FlowMatchDavidTrainer(cfg, device=device) trainer.train() # quick sanity _ = trainer.sample(["a castle at sunset"], steps=10, guidance=7.0) print("✓ Inference sanity done.") if __name__ == "__main__": main()