| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import annotations |
| import os, json, math, random, re |
| from dataclasses import dataclass, asdict |
| from pathlib import Path |
| from typing import Dict, List, Tuple, Optional |
| import urllib.request |
| import subprocess |
| import shutil |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from torch.utils.tensorboard import SummaryWriter |
| from tqdm import tqdm |
|
|
| |
| from diffusers import StableDiffusionPipeline, DDPMScheduler |
| from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel |
|
|
| |
| from geovocab2.train.model.core.geo_david_collective import GeoDavidCollective |
| from geovocab2.data.prompt.symbolic_tree import SynthesisSystem |
|
|
| |
| from huggingface_hub import snapshot_download, HfApi, create_repo, hf_hub_download |
| from safetensors.torch import load_file |
|
|
|
|
| |
| |
| |
| @dataclass |
| class BaseConfig: |
| run_name: str = "sd15_flowmatch_david_hf" |
| out_dir: str = "./runs/sd15_flowmatch_david_hf" |
| ckpt_dir: str = "./checkpoints_sd15_flow_david_hf" |
| save_every: int = 1 |
|
|
| |
| num_samples: int = 200_000 |
| batch_size: int = 32 |
| num_workers: int = 2 |
| seed: int = 42 |
|
|
| |
| model_id: str = "runwayml/stable-diffusion-v1-5" |
| active_blocks: Tuple[str, ...] = ("down_0","down_1","down_2","down_3","mid","up_0","up_1","up_2","up_3") |
| pooling: str = "mean" |
|
|
| |
| epochs: int = 10 |
| lr: float = 1e-4 |
| weight_decay: float = 1e-3 |
| grad_clip: float = 1.0 |
| amp: bool = True |
| |
| global_flow_weight: float = 1.0 |
| block_penalty_weight: float = 0.2 |
| use_local_flow_heads: bool = False |
| local_flow_weight: float = 1.0 |
|
|
| |
| use_kd: bool = True |
| kd_weight: float = 0.25 |
|
|
| |
| david_repo_id: str = "AbstractPhil/geo-david-collective-sd15-base-e40" |
| david_cache_dir: str = "./_hf_david_cache" |
| david_state_key: Optional[str] = None |
|
|
| |
| alpha_timestep: float = 0.5 |
| beta_pattern: float = 0.25 |
| delta_incoherence: float = 0.25 |
| lambda_min: float = 0.5 |
| lambda_max: float = 3.0 |
|
|
| |
| block_weights: Dict[str, float] = None |
|
|
| |
| num_train_timesteps: int = 1000 |
|
|
| |
| sample_steps: int = 30 |
| guidance_scale: float = 7.5 |
| |
| |
| hf_repo_id: Optional[str] = "AbstractPhil/sd15-flow-matching" |
| upload_every_epoch: bool = True |
| continue_training: bool = True |
|
|
| def __post_init__(self): |
| Path(self.out_dir).mkdir(parents=True, exist_ok=True) |
| Path(self.ckpt_dir).mkdir(parents=True, exist_ok=True) |
| Path(self.david_cache_dir).mkdir(parents=True, exist_ok=True) |
| if self.block_weights is None: |
| self.block_weights = {'down_0':0.7,'down_1':0.9,'down_2':1.0,'down_3':1.1,'mid':1.2,'up_0':1.1,'up_1':1.0,'up_2':0.9,'up_3':0.7} |
|
|
|
|
| |
| |
| |
| class SymbolicPromptDataset(Dataset): |
| def __init__(self, n:int, seed:int=42): |
| self.n = n |
| random.seed(seed) |
| self.sys = SynthesisSystem(seed=seed) |
|
|
| def __len__(self): return self.n |
|
|
| def __getitem__(self, idx): |
| r = self.sys.synthesize(complexity=random.choice([1,2,3,4,5])) |
| prompt = r['text'] |
| t = random.randint(0, 999) |
| return {"prompt": prompt, "t": t} |
|
|
| def collate(batch: List[dict]): |
| prompts = [b["prompt"] for b in batch] |
| t = torch.tensor([b["t"] for b in batch], dtype=torch.long) |
| t_bins = t // 10 |
| return {"prompts": prompts, "t": t, "t_bins": t_bins} |
|
|
|
|
| |
| |
| |
| class HookBank: |
| def __init__(self, unet: UNet2DConditionModel, active: Tuple[str, ...]): |
| self.active = set(active) |
| self.bank: Dict[str, torch.Tensor] = {} |
| self.hooks: List[torch.utils.hooks.RemovableHandle] = [] |
| self._register(unet) |
|
|
| def _register(self, unet: UNet2DConditionModel): |
| def mk(name): |
| def h(m, i, o): |
| out = o[0] if isinstance(o,(tuple,list)) else o |
| self.bank[name] = out |
| return h |
| for i, blk in enumerate(unet.down_blocks): |
| nm = f"down_{i}" |
| if nm in self.active: self.hooks.append(blk.register_forward_hook(mk(nm))) |
| if "mid" in self.active: |
| self.hooks.append(unet.mid_block.register_forward_hook(mk("mid"))) |
| for i, blk in enumerate(unet.up_blocks): |
| nm = f"up_{i}" |
| if nm in self.active: self.hooks.append(blk.register_forward_hook(mk(nm))) |
|
|
| def clear(self): self.bank.clear() |
| def close(self): |
| for h in self.hooks: h.remove() |
| self.hooks.clear() |
|
|
| def spatial_pool(x: torch.Tensor, name: str, policy: str) -> torch.Tensor: |
| if policy == "mean": return x.mean(dim=(2,3)) |
| if policy == "max": return x.amax(dim=(2,3)) |
| if policy == "adaptive": |
| return x.mean(dim=(2,3)) if (name.startswith("down") or name=="mid") else x.amax(dim=(2,3)) |
| raise ValueError(f"Unknown pooling: {policy}") |
|
|
|
|
| |
| |
| |
| class SD15Teacher(nn.Module): |
| def __init__(self, cfg: BaseConfig, device: str): |
| super().__init__() |
| self.pipe = StableDiffusionPipeline.from_pretrained(cfg.model_id, torch_dtype=torch.float16, safety_checker=None).to(device) |
| self.unet: UNet2DConditionModel = self.pipe.unet |
| self.text_encoder = self.pipe.text_encoder |
| self.tokenizer = self.pipe.tokenizer |
| self.hooks = HookBank(self.unet, cfg.active_blocks) |
| self.sched = DDPMScheduler(num_train_timesteps=cfg.num_train_timesteps) |
| self.device = device |
| for p in self.parameters(): p.requires_grad_(False) |
|
|
| @torch.no_grad() |
| def encode(self, prompts: List[str]) -> torch.Tensor: |
| tok = self.tokenizer(prompts, padding="max_length", max_length=self.tokenizer.model_max_length, |
| truncation=True, return_tensors="pt") |
| return self.text_encoder(tok.input_ids.to(self.device))[0] |
|
|
| @torch.no_grad() |
| def forward_eps_and_feats(self, x_t: torch.Tensor, t: torch.LongTensor, ehs: torch.Tensor): |
| self.hooks.clear() |
| eps_hat = self.unet(x_t, t, encoder_hidden_states=ehs).sample |
| feats = {k: v.detach().float() for k, v in self.hooks.bank.items()} |
| return eps_hat.float(), feats |
|
|
| def alpha_sigma(self, t: torch.LongTensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| ac = self.sched.alphas_cumprod.to(self.device)[t] |
| alpha = ac.sqrt().view(-1,1,1,1).float() |
| sigma = (1.0 - ac).sqrt().view(-1,1,1,1).float() |
| return alpha, sigma |
|
|
|
|
| |
| |
| |
| class StudentUNet(nn.Module): |
| def __init__(self, teacher_unet: UNet2DConditionModel, active_blocks: Tuple[str,...], use_local_heads: bool): |
| super().__init__() |
| self.unet = UNet2DConditionModel.from_config(teacher_unet.config) |
| self.unet.load_state_dict(teacher_unet.state_dict(), strict=True) |
| self.hooks = HookBank(self.unet, active_blocks) |
| self.use_local_heads = use_local_heads |
| self.local_heads = nn.ModuleDict() |
|
|
| def _ensure_heads(self, feats: Dict[str, torch.Tensor]): |
| if not self.use_local_heads: return |
| if len(self.local_heads) == len(feats): return |
| |
| |
| target_dtype = next(self.unet.parameters()).dtype |
| |
| for name, f in feats.items(): |
| c = f.shape[1] |
| if name not in self.local_heads: |
| head = nn.Conv2d(c, 4, kernel_size=1) |
| |
| head = head.to(dtype=target_dtype, device=f.device) |
| self.local_heads[name] = head |
|
|
| def forward(self, x_t: torch.Tensor, t: torch.LongTensor, ehs: torch.Tensor): |
| self.hooks.clear() |
| v_hat = self.unet(x_t, t, encoder_hidden_states=ehs).sample |
| feats = {k: v for k, v in self.hooks.bank.items()} |
| self._ensure_heads(feats) |
| return v_hat, feats |
|
|
|
|
| |
| |
| |
| class DavidLoader: |
| """ |
| Downloads HF repo (config + safetensors), instantiates GeoDavidCollective with HF config, |
| loads weights, returns a frozen model + the parsed HF config. |
| """ |
| def __init__(self, cfg: BaseConfig, device: str): |
| self.cfg = cfg |
| self.device = device |
| self.repo_dir = snapshot_download(repo_id=cfg.david_repo_id, local_dir=cfg.david_cache_dir, local_dir_use_symlinks=False) |
| self.config_path = os.path.join(self.repo_dir, "config.json") |
| self.weights_path = os.path.join(self.repo_dir, "model.safetensors") |
| with open(self.config_path, "r") as f: |
| self.hf_config = json.load(f) |
| |
| self.gdc = GeoDavidCollective( |
| block_configs=self.hf_config["block_configs"], |
| num_timestep_bins=int(self.hf_config["num_timestep_bins"]), |
| num_patterns_per_bin=int(self.hf_config["num_patterns_per_bin"]), |
| block_weights=self.hf_config.get("block_weights", {k:1.0 for k in self.hf_config["block_configs"].keys()}), |
| loss_config=self.hf_config.get("loss_config", {}) |
| ).to(device).eval() |
| |
| state = load_file(self.weights_path) |
| self.gdc.load_state_dict(state, strict=False) |
| for p in self.gdc.parameters(): p.requires_grad_(False) |
| |
| print(f"β David loaded from HF: {self.repo_dir}") |
| print(f" blocks={len(self.hf_config['block_configs'])} bins={self.hf_config['num_timestep_bins']} patterns={self.hf_config['num_patterns_per_bin']}") |
| |
| if "block_weights" in self.hf_config: |
| cfg.block_weights = self.hf_config["block_weights"] |
|
|
| class DavidAssessor(nn.Module): |
| """ |
| Uses David to score STUDENT pooled features (per block) and timesteps. |
| Produces: |
| e_t[name] : timestep CE error proxy (scalar) |
| e_p[name] : pattern CE error proxy if logits present, else 0 |
| coh[name] : coherence proxy (avg Cantor alpha if provided, else 1) |
| """ |
| def __init__(self, gdc: GeoDavidCollective, pooling: str): |
| super().__init__() |
| self.gdc = gdc |
| self.pooling = pooling |
|
|
| def _pool(self, feats: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| return {k: spatial_pool(v, k, self.pooling) for k, v in feats.items()} |
|
|
| @torch.no_grad() |
| def forward(self, feats_student: Dict[str, torch.Tensor], t: torch.LongTensor |
| ) -> Tuple[Dict[str,float], Dict[str,float], Dict[str,float]]: |
| Zs = self._pool(feats_student) |
| outs = self.gdc(Zs, t.float()) |
| e_t, e_p, coh = {}, {}, {} |
|
|
| |
| ts_key = None |
| for key in ["timestep_logits", "logits_timestep", "timestep_head_logits"]: |
| if key in outs: ts_key = key; break |
| |
| pt_key = None |
| for key in ["pattern_logits", "logits_pattern", "pattern_head_logits"]: |
| if key in outs: pt_key = key; break |
|
|
| t_bins = (t // 10).to(next(self.gdc.parameters()).device) |
| if ts_key is not None: |
| |
| ts_logits = outs[ts_key] |
| if isinstance(ts_logits, dict): |
| for name, L in ts_logits.items(): |
| ce = F.cross_entropy(L, t_bins, reduction="mean") |
| e_t[name] = float(ce.item()) |
| else: |
| |
| ce = F.cross_entropy(ts_logits, t_bins, reduction="mean") |
| for name in Zs.keys(): |
| e_t[name] = float(ce.item()) |
| else: |
| for name in Zs.keys(): e_t[name] = 0.0 |
|
|
| if pt_key is not None: |
| pt_logits = outs[pt_key] |
| |
| if isinstance(pt_logits, dict): |
| for name, L in pt_logits.items(): |
| P = L.softmax(-1) |
| ent = -(P * (P.clamp_min(1e-9)).log()).sum(-1).mean() |
| e_p[name] = float(ent.item() / math.log(P.shape[-1])) |
| else: |
| P = pt_logits.softmax(-1) |
| ent = -(P * (P.clamp_min(1e-9)).log()).sum(-1).mean() |
| for name in Zs.keys(): |
| e_p[name] = float(ent.item() / math.log(P.shape[-1])) |
| else: |
| for name in Zs.keys(): e_p[name] = 0.0 |
|
|
| |
| alphas = {} |
| try: |
| alphas = self.gdc.get_cantor_alphas() |
| except Exception: |
| alphas = {} |
| avg_alpha = float(sum(alphas.values())/max(len(alphas),1)) if alphas else 1.0 |
| for name in Zs.keys(): |
| coh[name] = avg_alpha |
|
|
| return e_t, e_p, coh |
|
|
| class BlockPenaltyFusion: |
| def __init__(self, cfg: BaseConfig): self.cfg = cfg |
| def lambdas(self, e_t:Dict[str,float], e_p:Dict[str,float], coh:Dict[str,float]) -> Dict[str,float]: |
| lam = {} |
| for name, base in self.cfg.block_weights.items(): |
| val = base * (1.0 |
| + self.cfg.alpha_timestep * float(e_t.get(name,0.0)) |
| + self.cfg.beta_pattern * float(e_p.get(name,0.0)) |
| + self.cfg.delta_incoherence * (1.0 - float(coh.get(name,1.0)))) |
| lam[name] = float(max(self.cfg.lambda_min, min(self.cfg.lambda_max, val))) |
| return lam |
|
|
|
|
| |
| |
| |
| class FlowMatchDavidTrainer: |
| def __init__(self, cfg: BaseConfig, device: str = "cuda"): |
| self.cfg = cfg |
| self.device = device |
| self.start_epoch = 0 |
| self.start_gstep = 0 |
|
|
| |
| self.dataset = SymbolicPromptDataset(cfg.num_samples, cfg.seed) |
| self.loader = DataLoader(self.dataset, batch_size=cfg.batch_size, shuffle=True, |
| num_workers=cfg.num_workers, pin_memory=True, collate_fn=collate) |
|
|
| |
| self.teacher = SD15Teacher(cfg, device).eval() |
| self.student = StudentUNet(self.teacher.unet, cfg.active_blocks, cfg.use_local_flow_heads).to(device) |
|
|
| |
| self.david_loader = DavidLoader(cfg, device) |
| self.david = self.david_loader.gdc |
| |
| self.assessor = DavidAssessor(self.david, cfg.pooling) |
| self.fusion = BlockPenaltyFusion(cfg) |
|
|
| |
| self.opt = torch.optim.AdamW(self.student.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) |
| self.sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=cfg.epochs * len(self.loader)) |
| self.scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp) |
|
|
| |
| if cfg.continue_training: |
| self._load_latest_from_hf() |
|
|
| |
| self.writer = SummaryWriter(log_dir=os.path.join(cfg.out_dir, cfg.run_name)) |
|
|
| def _load_latest_from_hf(self): |
| """Download and load the latest checkpoint from HuggingFace.""" |
| if not self.cfg.hf_repo_id: |
| print("β οΈ continue_training=True but no hf_repo_id specified") |
| return |
| |
| try: |
| api = HfApi() |
| print(f"\nπ Searching for latest checkpoint in {self.cfg.hf_repo_id}...") |
| |
| |
| try: |
| repo_info = api.repo_info(repo_id=self.cfg.hf_repo_id, repo_type="model") |
| except Exception as e: |
| print(f"β οΈ Could not access repo: {e}") |
| print(" Starting training from scratch") |
| return |
| |
| |
| files = api.list_repo_files(repo_id=self.cfg.hf_repo_id, repo_type="model") |
| |
| if not files: |
| print("βΉοΈ Repo is empty, starting from scratch") |
| return |
| |
| print(f"π Found {len(files)} files in repo:") |
| for f in files: |
| print(f" - {f}") |
| |
| |
| |
| epochs = [] |
| |
| for f in files: |
| if not f.endswith('.safetensors'): |
| continue |
| |
| |
| match = re.search(r'_e(\d+)\.safetensors$', f) |
| if match: |
| epoch_num = int(match.group(1)) |
| epochs.append((epoch_num, f)) |
| print(f"β Found checkpoint: {f} (epoch {epoch_num})") |
| |
| if not epochs: |
| print("βΉοΈ No checkpoint files found (looking for *_e<num>.safetensors)") |
| return |
| |
| |
| latest_epoch, latest_file = max(epochs, key=lambda x: x[0]) |
| print(f"\nπ₯ Downloading latest checkpoint: {latest_file} (epoch {latest_epoch})") |
| |
| |
| local_path = hf_hub_download( |
| repo_id=self.cfg.hf_repo_id, |
| filename=latest_file, |
| repo_type="model", |
| cache_dir=self.cfg.ckpt_dir |
| ) |
| print(f"β Downloaded to: {local_path}") |
| |
| |
| print("π¦ Loading checkpoint into pipeline...") |
| pipe = StableDiffusionPipeline.from_single_file( |
| local_path, |
| torch_dtype=torch.float16, |
| safety_checker=None, |
| load_safety_checker=False |
| ) |
| |
| |
| unet_state = pipe.unet.state_dict() |
| |
| |
| missing, unexpected = self.student.unet.load_state_dict(unet_state, strict=False) |
| print(f"β Loaded student UNet from epoch {latest_epoch}") |
| if missing: |
| print(f" Missing keys: {len(missing)}") |
| if unexpected: |
| print(f" Unexpected keys: {len(unexpected)}") |
| |
| |
| self.start_epoch = latest_epoch |
| self.start_gstep = latest_epoch * len(self.loader) |
| |
| print(f"π― Resuming training from epoch {self.start_epoch + 1}") |
| |
| |
| del pipe |
| torch.cuda.empty_cache() |
| |
| except Exception as e: |
| print(f"β οΈ Failed to load checkpoint from HF: {e}") |
| print(" Starting training from scratch") |
| import traceback |
| traceback.print_exc() |
|
|
| |
| |
| def _v_star(self, x_t, t, eps_hat): |
| alpha, sigma = self.teacher.alpha_sigma(t) |
| x0_hat = (x_t - sigma * eps_hat) / (alpha + 1e-8) |
| return alpha * eps_hat - sigma * x0_hat |
|
|
| def _down_like(self, tgt: torch.Tensor, ref: torch.Tensor) -> torch.Tensor: |
| return F.interpolate(tgt, size=ref.shape[-2:], mode="bilinear", align_corners=False) |
|
|
| def _kd_cos(self, s: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
| s = F.normalize(s, dim=-1); t = F.normalize(t, dim=-1) |
| return 1.0 - (s*t).sum(-1).mean() |
|
|
| |
| def train(self): |
| cfg = self.cfg |
| gstep = self.start_gstep |
| |
| for ep in range(self.start_epoch, cfg.epochs): |
| self.student.train() |
| pbar = tqdm(self.loader, desc=f"Epoch {ep+1}/{cfg.epochs}", |
| dynamic_ncols=True, leave=True, position=0) |
| acc = {"L":0.0, "Lf":0.0, "Lb":0.0} |
|
|
| for it, batch in enumerate(pbar): |
| prompts = batch["prompts"] |
| t = batch["t"].to(self.device) |
|
|
| with torch.no_grad(): |
| ehs = self.teacher.encode(prompts) |
|
|
| |
| x_t = torch.randn(len(prompts), 4, 64, 64, device=self.device, dtype=torch.float16) |
|
|
| |
| with torch.no_grad(): |
| eps_hat, t_feats_spatial = self.teacher.forward_eps_and_feats(x_t.half(), t, ehs) |
| v_star = self._v_star(x_t, t, eps_hat) |
|
|
| with torch.cuda.amp.autocast(enabled=cfg.amp): |
| |
| v_hat, s_feats_spatial = self.student(x_t, t, ehs) |
| L_flow = F.mse_loss(v_hat, v_star) |
|
|
| |
| e_t, e_p, coh = self.assessor(s_feats_spatial, t) |
| lam = self.fusion.lambdas(e_t, e_p, coh) |
|
|
| |
| L_blocks = torch.zeros((), device=self.device) |
| for name, s_feat in s_feats_spatial.items(): |
| |
| L_kd = torch.zeros((), device=self.device) |
| if cfg.use_kd: |
| s_pool = spatial_pool(s_feat, name, cfg.pooling) |
| t_pool = spatial_pool(t_feats_spatial[name], name, cfg.pooling) |
| L_kd = self._kd_cos(s_pool, t_pool) |
| |
| L_lf = torch.zeros((), device=self.device) |
| if cfg.use_local_flow_heads and name in self.student.local_heads: |
| v_loc = self.student.local_heads[name](s_feat) |
| v_ds = self._down_like(v_star, v_loc) |
| L_lf = F.mse_loss(v_loc, v_ds) |
| L_blocks = L_blocks + lam.get(name,1.0) * (cfg.kd_weight * L_kd + cfg.local_flow_weight * L_lf) |
|
|
| L_total = cfg.global_flow_weight*L_flow + cfg.block_penalty_weight*L_blocks |
|
|
| self.opt.zero_grad(set_to_none=True) |
| if cfg.amp: |
| self.scaler.scale(L_total).backward() |
| nn.utils.clip_grad_norm_(self.student.parameters(), cfg.grad_clip) |
| self.scaler.step(self.opt); self.scaler.update() |
| else: |
| L_total.backward() |
| nn.utils.clip_grad_norm_(self.student.parameters(), cfg.grad_clip) |
| self.opt.step() |
| self.sched.step(); gstep += 1 |
|
|
| acc["L"] += float(L_total.item()) |
| acc["Lf"] += float(L_flow.item()) |
| acc["Lb"] += float(L_blocks.item()) |
|
|
| |
| if it % 50 == 0: |
| self.writer.add_scalar("train/total", float(L_total.item()), gstep) |
| self.writer.add_scalar("train/flow", float(L_flow.item()), gstep) |
| self.writer.add_scalar("train/blocks",float(L_blocks.item()), gstep) |
| |
| for k in list(lam.keys())[:4]: |
| self.writer.add_scalar(f"lambda/{k}", lam[k], gstep) |
|
|
| |
| if it % 10 == 0 or it == len(self.loader) - 1: |
| pbar.set_postfix({ |
| "L": f"{float(L_total.item()):.4f}", |
| "Lf": f"{float(L_flow.item()):.4f}", |
| "Lb": f"{float(L_blocks.item()):.4f}" |
| }, refresh=False) |
| |
| del x_t, eps_hat, v_star, v_hat, s_feats_spatial, t_feats_spatial |
|
|
| pbar.close() |
| |
| n = len(self.loader) |
| print(f"\n[Epoch {ep+1}] L={acc['L']/n:.4f} | L_flow={acc['Lf']/n:.4f} | L_blocks={acc['Lb']/n:.4f}") |
| self.writer.add_scalar("epoch/total", acc['L']/n, ep+1) |
| self.writer.add_scalar("epoch/flow", acc['Lf']/n, ep+1) |
| self.writer.add_scalar("epoch/blocks",acc['Lb']/n, ep+1) |
|
|
| if (ep+1) % cfg.save_every == 0: |
| self._save(ep+1, gstep) |
|
|
| self._save("final", gstep) |
| self.writer.close() |
|
|
|
|
| def _save(self, tag, gstep): |
| """Save and convert to ComfyUI format, then upload.""" |
| |
| pt_path = Path(self.cfg.ckpt_dir) / f"{self.cfg.run_name}_e{tag}.pt" |
| torch.save({ |
| "cfg": asdict(self.cfg), |
| "student": self.student.state_dict(), |
| "opt": self.opt.state_dict(), |
| "sched": self.sched.state_dict(), |
| "gstep": gstep |
| }, pt_path) |
| print(f"β Saved temp .pt: {pt_path}") |
| |
| |
| safetensors_path = self._convert_to_comfyui(pt_path, tag) |
| |
| |
| if self.cfg.upload_every_epoch and self.cfg.hf_repo_id and safetensors_path: |
| self._upload_to_hf(safetensors_path, tag) |
| |
| |
| pt_path.unlink() |
| print(f"β Cleaned up temp .pt file") |
|
|
| def _convert_to_comfyui(self, pt_path: Path, tag) -> Optional[Path]: |
| """Convert .pt to ComfyUI-compatible safetensors.""" |
| try: |
| temp_pipeline = Path(self.cfg.ckpt_dir) / f"temp_pipeline_e{tag}" |
| output_safetensors = Path(self.cfg.ckpt_dir) / f"{self.cfg.run_name}_e{tag}.safetensors" |
| |
| |
| converter_path = Path(self.cfg.ckpt_dir) / "convert_diffusers_to_original_stable_diffusion.py" |
| if not converter_path.exists(): |
| print("π₯ Downloading official converter...") |
| url = "https://raw.githubusercontent.com/huggingface/diffusers/main/scripts/convert_diffusers_to_original_stable_diffusion.py" |
| urllib.request.urlretrieve(url, str(converter_path)) |
| print("β Converter downloaded") |
| |
| |
| print(f"π¦ Creating diffusers pipeline from checkpoint...") |
| checkpoint = torch.load(pt_path, map_location='cpu') |
| student_state = checkpoint.get('student', checkpoint) |
| |
| |
| print("π₯ Loading base UNet...") |
| unet = UNet2DConditionModel.from_pretrained( |
| "runwayml/stable-diffusion-v1-5", |
| subfolder="unet", |
| torch_dtype=torch.float16 |
| ) |
| unet.load_state_dict(student_state, strict=False) |
| print("β Loaded student weights into UNet") |
| |
| |
| print("π₯ Loading base SD1.5 pipeline...") |
| pipe = StableDiffusionPipeline.from_pretrained( |
| "runwayml/stable-diffusion-v1-5", |
| torch_dtype=torch.float16, |
| safety_checker=None |
| ) |
| pipe.unet = unet |
| print("β Replaced UNet with student") |
| |
| |
| print(f"πΎ Saving diffusers pipeline...") |
| pipe.save_pretrained(str(temp_pipeline), safe_serialization=True) |
| print(f"β Pipeline saved to {temp_pipeline}") |
| |
| |
| print(f"π Converting to ComfyUI format...") |
| cmd = [ |
| "python", str(converter_path), |
| "--model_path", str(temp_pipeline), |
| "--checkpoint_path", str(output_safetensors), |
| "--half" |
| ] |
| |
| result = subprocess.run(cmd, capture_output=True, text=True) |
| if result.returncode != 0: |
| print(f"β Conversion failed: {result.stderr}") |
| return None |
| |
| |
| if output_safetensors.exists(): |
| size_mb = output_safetensors.stat().st_size / 1e6 |
| print(f"β Converted: {output_safetensors.name} ({size_mb:.1f}MB)") |
| |
| |
| shutil.rmtree(temp_pipeline) |
| print("β Cleaned up temp pipeline") |
| |
| return output_safetensors |
| else: |
| print(f"β Output file not created") |
| return None |
| |
| except Exception as e: |
| print(f"β Conversion failed: {e}") |
| import traceback |
| traceback.print_exc() |
| return None |
|
|
| def _upload_to_hf(self, path: Path, tag): |
| """Upload safetensors to HuggingFace.""" |
| try: |
| api = HfApi() |
| |
| |
| try: |
| create_repo(self.cfg.hf_repo_id, exist_ok=True, private=False, repo_type="model") |
| print(f"β Repo ready: {self.cfg.hf_repo_id}") |
| except Exception: |
| pass |
| |
| |
| print(f"π€ Uploading to {self.cfg.hf_repo_id}...") |
| api.upload_file( |
| path_or_fileobj=str(path), |
| path_in_repo=path.name, |
| repo_id=self.cfg.hf_repo_id, |
| repo_type="model", |
| commit_message=f"Epoch {tag}" |
| ) |
| print(f"β
Uploaded: https://huggingface.co/{self.cfg.hf_repo_id}/{path.name}") |
| |
| except Exception as e: |
| print(f"β οΈ Upload failed: {e}") |
|
|
| |
| @torch.no_grad() |
| def sample(self, prompts: List[str], steps: Optional[int]=None, guidance: Optional[float]=None) -> torch.Tensor: |
| steps = steps or self.cfg.sample_steps |
| guidance = guidance if guidance is not None else self.cfg.guidance_scale |
| cond_e = self.teacher.encode(prompts) |
| uncond_e = self.teacher.encode([""]*len(prompts)) |
| sched = self.teacher.sched |
| sched.set_timesteps(steps, device=self.device) |
| x_t = torch.randn(len(prompts), 4, 64, 64, device=self.device) |
|
|
| for t_scalar in sched.timesteps: |
| t = torch.full((x_t.shape[0],), t_scalar, device=self.device, dtype=torch.long) |
| v_u, _ = self.student(x_t, t, uncond_e) |
| v_c, _ = self.student(x_t, t, cond_e) |
| v_hat = v_u + guidance*(v_c - v_u) |
|
|
| alpha, sigma = self.teacher.alpha_sigma(t) |
| denom = (alpha**2 + sigma**2) |
| x0_hat = (alpha * x_t - sigma * v_hat) / (denom + 1e-8) |
| eps_hat = (x_t - alpha * x0_hat) / (sigma + 1e-8) |
|
|
| step = sched.step(model_output=eps_hat, timestep=t_scalar, sample=x_t) |
| x_t = step.prev_sample |
|
|
| imgs = self.teacher.pipe.vae.decode(x_t / 0.18215).sample |
| return imgs.clamp(-1,1) |
|
|
|
|
| |
| |
| |
| def main(): |
| cfg = BaseConfig() |
| print(json.dumps(asdict(cfg), indent=2)) |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| if device != "cuda": |
| print("β οΈ A100 strongly recommended.") |
| trainer = FlowMatchDavidTrainer(cfg, device=device) |
| trainer.train() |
| |
| _ = trainer.sample(["a castle at sunset"], steps=10, guidance=7.0) |
| print("β Inference sanity done.") |
|
|
| if __name__ == "__main__": |
| main() |