|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
import os, json, math, random, re |
|
|
from dataclasses import dataclass, asdict |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Tuple, Optional |
|
|
import urllib.request |
|
|
import subprocess |
|
|
import shutil |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
from diffusers import StableDiffusionPipeline, DDPMScheduler |
|
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel |
|
|
|
|
|
|
|
|
from geovocab2.train.model.core.geo_david_collective import GeoDavidCollective |
|
|
from geovocab2.data.prompt.symbolic_tree import SynthesisSystem |
|
|
|
|
|
|
|
|
from huggingface_hub import snapshot_download, HfApi, create_repo, hf_hub_download |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class BaseConfig: |
|
|
run_name: str = "sd15_flowmatch_david_hf" |
|
|
out_dir: str = "./runs/sd15_flowmatch_david_hf" |
|
|
ckpt_dir: str = "./checkpoints_sd15_flow_david_hf" |
|
|
save_every: int = 1 |
|
|
|
|
|
|
|
|
num_samples: int = 200_000 |
|
|
batch_size: int = 32 |
|
|
num_workers: int = 2 |
|
|
seed: int = 42 |
|
|
|
|
|
|
|
|
model_id: str = "runwayml/stable-diffusion-v1-5" |
|
|
active_blocks: Tuple[str, ...] = ("down_0","down_1","down_2","down_3","mid","up_0","up_1","up_2","up_3") |
|
|
pooling: str = "mean" |
|
|
|
|
|
|
|
|
epochs: int = 10 |
|
|
lr: float = 1e-4 |
|
|
weight_decay: float = 1e-3 |
|
|
grad_clip: float = 1.0 |
|
|
amp: bool = True |
|
|
|
|
|
global_flow_weight: float = 1.0 |
|
|
block_penalty_weight: float = 0.2 |
|
|
use_local_flow_heads: bool = False |
|
|
local_flow_weight: float = 1.0 |
|
|
|
|
|
|
|
|
use_kd: bool = True |
|
|
kd_weight: float = 0.25 |
|
|
|
|
|
|
|
|
david_repo_id: str = "AbstractPhil/geo-david-collective-sd15-base-e40" |
|
|
david_cache_dir: str = "./_hf_david_cache" |
|
|
david_state_key: Optional[str] = None |
|
|
|
|
|
|
|
|
alpha_timestep: float = 0.5 |
|
|
beta_pattern: float = 0.25 |
|
|
delta_incoherence: float = 0.25 |
|
|
lambda_min: float = 0.5 |
|
|
lambda_max: float = 3.0 |
|
|
|
|
|
|
|
|
block_weights: Dict[str, float] = None |
|
|
|
|
|
|
|
|
num_train_timesteps: int = 1000 |
|
|
|
|
|
|
|
|
sample_steps: int = 30 |
|
|
guidance_scale: float = 7.5 |
|
|
|
|
|
|
|
|
hf_repo_id: Optional[str] = "AbstractPhil/sd15-flow-matching" |
|
|
upload_every_epoch: bool = True |
|
|
continue_training: bool = True |
|
|
|
|
|
def __post_init__(self): |
|
|
Path(self.out_dir).mkdir(parents=True, exist_ok=True) |
|
|
Path(self.ckpt_dir).mkdir(parents=True, exist_ok=True) |
|
|
Path(self.david_cache_dir).mkdir(parents=True, exist_ok=True) |
|
|
if self.block_weights is None: |
|
|
self.block_weights = {'down_0':0.7,'down_1':0.9,'down_2':1.0,'down_3':1.1,'mid':1.2,'up_0':1.1,'up_1':1.0,'up_2':0.9,'up_3':0.7} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SymbolicPromptDataset(Dataset): |
|
|
def __init__(self, n:int, seed:int=42): |
|
|
self.n = n |
|
|
random.seed(seed) |
|
|
self.sys = SynthesisSystem(seed=seed) |
|
|
|
|
|
def __len__(self): return self.n |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
r = self.sys.synthesize(complexity=random.choice([1,2,3,4,5])) |
|
|
prompt = r['text'] |
|
|
t = random.randint(0, 999) |
|
|
return {"prompt": prompt, "t": t} |
|
|
|
|
|
def collate(batch: List[dict]): |
|
|
prompts = [b["prompt"] for b in batch] |
|
|
t = torch.tensor([b["t"] for b in batch], dtype=torch.long) |
|
|
t_bins = t // 10 |
|
|
return {"prompts": prompts, "t": t, "t_bins": t_bins} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HookBank: |
|
|
def __init__(self, unet: UNet2DConditionModel, active: Tuple[str, ...]): |
|
|
self.active = set(active) |
|
|
self.bank: Dict[str, torch.Tensor] = {} |
|
|
self.hooks: List[torch.utils.hooks.RemovableHandle] = [] |
|
|
self._register(unet) |
|
|
|
|
|
def _register(self, unet: UNet2DConditionModel): |
|
|
def mk(name): |
|
|
def h(m, i, o): |
|
|
out = o[0] if isinstance(o,(tuple,list)) else o |
|
|
self.bank[name] = out |
|
|
return h |
|
|
for i, blk in enumerate(unet.down_blocks): |
|
|
nm = f"down_{i}" |
|
|
if nm in self.active: self.hooks.append(blk.register_forward_hook(mk(nm))) |
|
|
if "mid" in self.active: |
|
|
self.hooks.append(unet.mid_block.register_forward_hook(mk("mid"))) |
|
|
for i, blk in enumerate(unet.up_blocks): |
|
|
nm = f"up_{i}" |
|
|
if nm in self.active: self.hooks.append(blk.register_forward_hook(mk(nm))) |
|
|
|
|
|
def clear(self): self.bank.clear() |
|
|
def close(self): |
|
|
for h in self.hooks: h.remove() |
|
|
self.hooks.clear() |
|
|
|
|
|
def spatial_pool(x: torch.Tensor, name: str, policy: str) -> torch.Tensor: |
|
|
if policy == "mean": return x.mean(dim=(2,3)) |
|
|
if policy == "max": return x.amax(dim=(2,3)) |
|
|
if policy == "adaptive": |
|
|
return x.mean(dim=(2,3)) if (name.startswith("down") or name=="mid") else x.amax(dim=(2,3)) |
|
|
raise ValueError(f"Unknown pooling: {policy}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SD15Teacher(nn.Module): |
|
|
def __init__(self, cfg: BaseConfig, device: str): |
|
|
super().__init__() |
|
|
self.pipe = StableDiffusionPipeline.from_pretrained(cfg.model_id, torch_dtype=torch.float16, safety_checker=None).to(device) |
|
|
self.unet: UNet2DConditionModel = self.pipe.unet |
|
|
self.text_encoder = self.pipe.text_encoder |
|
|
self.tokenizer = self.pipe.tokenizer |
|
|
self.hooks = HookBank(self.unet, cfg.active_blocks) |
|
|
self.sched = DDPMScheduler(num_train_timesteps=cfg.num_train_timesteps) |
|
|
self.device = device |
|
|
for p in self.parameters(): p.requires_grad_(False) |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode(self, prompts: List[str]) -> torch.Tensor: |
|
|
tok = self.tokenizer(prompts, padding="max_length", max_length=self.tokenizer.model_max_length, |
|
|
truncation=True, return_tensors="pt") |
|
|
return self.text_encoder(tok.input_ids.to(self.device))[0] |
|
|
|
|
|
@torch.no_grad() |
|
|
def forward_eps_and_feats(self, x_t: torch.Tensor, t: torch.LongTensor, ehs: torch.Tensor): |
|
|
self.hooks.clear() |
|
|
eps_hat = self.unet(x_t, t, encoder_hidden_states=ehs).sample |
|
|
feats = {k: v.detach().float() for k, v in self.hooks.bank.items()} |
|
|
return eps_hat.float(), feats |
|
|
|
|
|
def alpha_sigma(self, t: torch.LongTensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
ac = self.sched.alphas_cumprod.to(self.device)[t] |
|
|
alpha = ac.sqrt().view(-1,1,1,1).float() |
|
|
sigma = (1.0 - ac).sqrt().view(-1,1,1,1).float() |
|
|
return alpha, sigma |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StudentUNet(nn.Module): |
|
|
def __init__(self, teacher_unet: UNet2DConditionModel, active_blocks: Tuple[str,...], use_local_heads: bool): |
|
|
super().__init__() |
|
|
self.unet = UNet2DConditionModel.from_config(teacher_unet.config) |
|
|
self.unet.load_state_dict(teacher_unet.state_dict(), strict=True) |
|
|
self.hooks = HookBank(self.unet, active_blocks) |
|
|
self.use_local_heads = use_local_heads |
|
|
self.local_heads = nn.ModuleDict() |
|
|
|
|
|
def _ensure_heads(self, feats: Dict[str, torch.Tensor]): |
|
|
if not self.use_local_heads: return |
|
|
if len(self.local_heads) == len(feats): return |
|
|
|
|
|
|
|
|
target_dtype = next(self.unet.parameters()).dtype |
|
|
|
|
|
for name, f in feats.items(): |
|
|
c = f.shape[1] |
|
|
if name not in self.local_heads: |
|
|
head = nn.Conv2d(c, 4, kernel_size=1) |
|
|
|
|
|
head = head.to(dtype=target_dtype, device=f.device) |
|
|
self.local_heads[name] = head |
|
|
|
|
|
def forward(self, x_t: torch.Tensor, t: torch.LongTensor, ehs: torch.Tensor): |
|
|
self.hooks.clear() |
|
|
v_hat = self.unet(x_t, t, encoder_hidden_states=ehs).sample |
|
|
feats = {k: v for k, v in self.hooks.bank.items()} |
|
|
self._ensure_heads(feats) |
|
|
return v_hat, feats |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DavidLoader: |
|
|
""" |
|
|
Downloads HF repo (config + safetensors), instantiates GeoDavidCollective with HF config, |
|
|
loads weights, returns a frozen model + the parsed HF config. |
|
|
""" |
|
|
def __init__(self, cfg: BaseConfig, device: str): |
|
|
self.cfg = cfg |
|
|
self.device = device |
|
|
self.repo_dir = snapshot_download(repo_id=cfg.david_repo_id, local_dir=cfg.david_cache_dir, local_dir_use_symlinks=False) |
|
|
self.config_path = os.path.join(self.repo_dir, "config.json") |
|
|
self.weights_path = os.path.join(self.repo_dir, "model.safetensors") |
|
|
with open(self.config_path, "r") as f: |
|
|
self.hf_config = json.load(f) |
|
|
|
|
|
self.gdc = GeoDavidCollective( |
|
|
block_configs=self.hf_config["block_configs"], |
|
|
num_timestep_bins=int(self.hf_config["num_timestep_bins"]), |
|
|
num_patterns_per_bin=int(self.hf_config["num_patterns_per_bin"]), |
|
|
block_weights=self.hf_config.get("block_weights", {k:1.0 for k in self.hf_config["block_configs"].keys()}), |
|
|
loss_config=self.hf_config.get("loss_config", {}) |
|
|
).to(device).eval() |
|
|
|
|
|
state = load_file(self.weights_path) |
|
|
self.gdc.load_state_dict(state, strict=False) |
|
|
for p in self.gdc.parameters(): p.requires_grad_(False) |
|
|
|
|
|
print(f"✓ David loaded from HF: {self.repo_dir}") |
|
|
print(f" blocks={len(self.hf_config['block_configs'])} bins={self.hf_config['num_timestep_bins']} patterns={self.hf_config['num_patterns_per_bin']}") |
|
|
|
|
|
if "block_weights" in self.hf_config: |
|
|
cfg.block_weights = self.hf_config["block_weights"] |
|
|
|
|
|
class DavidAssessor(nn.Module): |
|
|
""" |
|
|
Uses David to score STUDENT pooled features (per block) and timesteps. |
|
|
Produces: |
|
|
e_t[name] : timestep CE error proxy (scalar) |
|
|
e_p[name] : pattern CE error proxy if logits present, else 0 |
|
|
coh[name] : coherence proxy (avg Cantor alpha if provided, else 1) |
|
|
""" |
|
|
def __init__(self, gdc: GeoDavidCollective, pooling: str): |
|
|
super().__init__() |
|
|
self.gdc = gdc |
|
|
self.pooling = pooling |
|
|
|
|
|
def _pool(self, feats: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
|
return {k: spatial_pool(v, k, self.pooling) for k, v in feats.items()} |
|
|
|
|
|
@torch.no_grad() |
|
|
def forward(self, feats_student: Dict[str, torch.Tensor], t: torch.LongTensor |
|
|
) -> Tuple[Dict[str,float], Dict[str,float], Dict[str,float]]: |
|
|
Zs = self._pool(feats_student) |
|
|
outs = self.gdc(Zs, t.float()) |
|
|
e_t, e_p, coh = {}, {}, {} |
|
|
|
|
|
|
|
|
ts_key = None |
|
|
for key in ["timestep_logits", "logits_timestep", "timestep_head_logits"]: |
|
|
if key in outs: ts_key = key; break |
|
|
|
|
|
pt_key = None |
|
|
for key in ["pattern_logits", "logits_pattern", "pattern_head_logits"]: |
|
|
if key in outs: pt_key = key; break |
|
|
|
|
|
t_bins = (t // 10).to(next(self.gdc.parameters()).device) |
|
|
if ts_key is not None: |
|
|
|
|
|
ts_logits = outs[ts_key] |
|
|
if isinstance(ts_logits, dict): |
|
|
for name, L in ts_logits.items(): |
|
|
ce = F.cross_entropy(L, t_bins, reduction="mean") |
|
|
e_t[name] = float(ce.item()) |
|
|
else: |
|
|
|
|
|
ce = F.cross_entropy(ts_logits, t_bins, reduction="mean") |
|
|
for name in Zs.keys(): |
|
|
e_t[name] = float(ce.item()) |
|
|
else: |
|
|
for name in Zs.keys(): e_t[name] = 0.0 |
|
|
|
|
|
if pt_key is not None: |
|
|
pt_logits = outs[pt_key] |
|
|
|
|
|
if isinstance(pt_logits, dict): |
|
|
for name, L in pt_logits.items(): |
|
|
P = L.softmax(-1) |
|
|
ent = -(P * (P.clamp_min(1e-9)).log()).sum(-1).mean() |
|
|
e_p[name] = float(ent.item() / math.log(P.shape[-1])) |
|
|
else: |
|
|
P = pt_logits.softmax(-1) |
|
|
ent = -(P * (P.clamp_min(1e-9)).log()).sum(-1).mean() |
|
|
for name in Zs.keys(): |
|
|
e_p[name] = float(ent.item() / math.log(P.shape[-1])) |
|
|
else: |
|
|
for name in Zs.keys(): e_p[name] = 0.0 |
|
|
|
|
|
|
|
|
alphas = {} |
|
|
try: |
|
|
alphas = self.gdc.get_cantor_alphas() |
|
|
except Exception: |
|
|
alphas = {} |
|
|
avg_alpha = float(sum(alphas.values())/max(len(alphas),1)) if alphas else 1.0 |
|
|
for name in Zs.keys(): |
|
|
coh[name] = avg_alpha |
|
|
|
|
|
return e_t, e_p, coh |
|
|
|
|
|
class BlockPenaltyFusion: |
|
|
def __init__(self, cfg: BaseConfig): self.cfg = cfg |
|
|
def lambdas(self, e_t:Dict[str,float], e_p:Dict[str,float], coh:Dict[str,float]) -> Dict[str,float]: |
|
|
lam = {} |
|
|
for name, base in self.cfg.block_weights.items(): |
|
|
val = base * (1.0 |
|
|
+ self.cfg.alpha_timestep * float(e_t.get(name,0.0)) |
|
|
+ self.cfg.beta_pattern * float(e_p.get(name,0.0)) |
|
|
+ self.cfg.delta_incoherence * (1.0 - float(coh.get(name,1.0)))) |
|
|
lam[name] = float(max(self.cfg.lambda_min, min(self.cfg.lambda_max, val))) |
|
|
return lam |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FlowMatchDavidTrainer: |
|
|
def __init__(self, cfg: BaseConfig, device: str = "cuda"): |
|
|
self.cfg = cfg |
|
|
self.device = device |
|
|
self.start_epoch = 0 |
|
|
self.start_gstep = 0 |
|
|
|
|
|
|
|
|
self.dataset = SymbolicPromptDataset(cfg.num_samples, cfg.seed) |
|
|
self.loader = DataLoader(self.dataset, batch_size=cfg.batch_size, shuffle=True, |
|
|
num_workers=cfg.num_workers, pin_memory=True, collate_fn=collate) |
|
|
|
|
|
|
|
|
self.teacher = SD15Teacher(cfg, device).eval() |
|
|
self.student = StudentUNet(self.teacher.unet, cfg.active_blocks, cfg.use_local_flow_heads).to(device) |
|
|
|
|
|
|
|
|
self.david_loader = DavidLoader(cfg, device) |
|
|
self.david = self.david_loader.gdc |
|
|
|
|
|
self.assessor = DavidAssessor(self.david, cfg.pooling) |
|
|
self.fusion = BlockPenaltyFusion(cfg) |
|
|
|
|
|
|
|
|
self.opt = torch.optim.AdamW(self.student.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) |
|
|
self.sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=cfg.epochs * len(self.loader)) |
|
|
self.scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp) |
|
|
|
|
|
|
|
|
if cfg.continue_training: |
|
|
self._load_latest_from_hf() |
|
|
|
|
|
|
|
|
self.writer = SummaryWriter(log_dir=os.path.join(cfg.out_dir, cfg.run_name)) |
|
|
|
|
|
def _load_latest_from_hf(self): |
|
|
"""Download and load the latest checkpoint from HuggingFace.""" |
|
|
if not self.cfg.hf_repo_id: |
|
|
print("⚠️ continue_training=True but no hf_repo_id specified") |
|
|
return |
|
|
|
|
|
try: |
|
|
api = HfApi() |
|
|
print(f"\n🔍 Searching for latest checkpoint in {self.cfg.hf_repo_id}...") |
|
|
|
|
|
|
|
|
try: |
|
|
repo_info = api.repo_info(repo_id=self.cfg.hf_repo_id, repo_type="model") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Could not access repo: {e}") |
|
|
print(" Starting training from scratch") |
|
|
return |
|
|
|
|
|
|
|
|
files = api.list_repo_files(repo_id=self.cfg.hf_repo_id, repo_type="model") |
|
|
|
|
|
if not files: |
|
|
print("ℹ️ Repo is empty, starting from scratch") |
|
|
return |
|
|
|
|
|
print(f"📂 Found {len(files)} files in repo:") |
|
|
for f in files: |
|
|
print(f" - {f}") |
|
|
|
|
|
|
|
|
|
|
|
epochs = [] |
|
|
|
|
|
for f in files: |
|
|
if not f.endswith('.safetensors'): |
|
|
continue |
|
|
|
|
|
|
|
|
match = re.search(r'_e(\d+)\.safetensors$', f) |
|
|
if match: |
|
|
epoch_num = int(match.group(1)) |
|
|
epochs.append((epoch_num, f)) |
|
|
print(f"✓ Found checkpoint: {f} (epoch {epoch_num})") |
|
|
|
|
|
if not epochs: |
|
|
print("ℹ️ No checkpoint files found (looking for *_e<num>.safetensors)") |
|
|
return |
|
|
|
|
|
|
|
|
latest_epoch, latest_file = max(epochs, key=lambda x: x[0]) |
|
|
print(f"\n📥 Downloading latest checkpoint: {latest_file} (epoch {latest_epoch})") |
|
|
|
|
|
|
|
|
local_path = hf_hub_download( |
|
|
repo_id=self.cfg.hf_repo_id, |
|
|
filename=latest_file, |
|
|
repo_type="model", |
|
|
cache_dir=self.cfg.ckpt_dir |
|
|
) |
|
|
print(f"✓ Downloaded to: {local_path}") |
|
|
|
|
|
|
|
|
print("📦 Loading checkpoint into pipeline...") |
|
|
pipe = StableDiffusionPipeline.from_single_file( |
|
|
local_path, |
|
|
torch_dtype=torch.float16, |
|
|
safety_checker=None, |
|
|
load_safety_checker=False |
|
|
) |
|
|
|
|
|
|
|
|
unet_state = pipe.unet.state_dict() |
|
|
|
|
|
|
|
|
missing, unexpected = self.student.unet.load_state_dict(unet_state, strict=False) |
|
|
print(f"✓ Loaded student UNet from epoch {latest_epoch}") |
|
|
if missing: |
|
|
print(f" Missing keys: {len(missing)}") |
|
|
if unexpected: |
|
|
print(f" Unexpected keys: {len(unexpected)}") |
|
|
|
|
|
|
|
|
self.start_epoch = latest_epoch |
|
|
self.start_gstep = latest_epoch * len(self.loader) |
|
|
|
|
|
print(f"🎯 Resuming training from epoch {self.start_epoch + 1}") |
|
|
|
|
|
|
|
|
del pipe |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ Failed to load checkpoint from HF: {e}") |
|
|
print(" Starting training from scratch") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
|
|
|
def _v_star(self, x_t, t, eps_hat): |
|
|
alpha, sigma = self.teacher.alpha_sigma(t) |
|
|
x0_hat = (x_t - sigma * eps_hat) / (alpha + 1e-8) |
|
|
return alpha * eps_hat - sigma * x0_hat |
|
|
|
|
|
def _down_like(self, tgt: torch.Tensor, ref: torch.Tensor) -> torch.Tensor: |
|
|
return F.interpolate(tgt, size=ref.shape[-2:], mode="bilinear", align_corners=False) |
|
|
|
|
|
def _kd_cos(self, s: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
|
|
s = F.normalize(s, dim=-1); t = F.normalize(t, dim=-1) |
|
|
return 1.0 - (s*t).sum(-1).mean() |
|
|
|
|
|
|
|
|
def train(self): |
|
|
cfg = self.cfg |
|
|
gstep = self.start_gstep |
|
|
|
|
|
for ep in range(self.start_epoch, cfg.epochs): |
|
|
self.student.train() |
|
|
pbar = tqdm(self.loader, desc=f"Epoch {ep+1}/{cfg.epochs}", |
|
|
dynamic_ncols=True, leave=True, position=0) |
|
|
acc = {"L":0.0, "Lf":0.0, "Lb":0.0} |
|
|
|
|
|
for it, batch in enumerate(pbar): |
|
|
prompts = batch["prompts"] |
|
|
t = batch["t"].to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
ehs = self.teacher.encode(prompts) |
|
|
|
|
|
|
|
|
x_t = torch.randn(len(prompts), 4, 64, 64, device=self.device, dtype=torch.float16) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
eps_hat, t_feats_spatial = self.teacher.forward_eps_and_feats(x_t.half(), t, ehs) |
|
|
v_star = self._v_star(x_t, t, eps_hat) |
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=cfg.amp): |
|
|
|
|
|
v_hat, s_feats_spatial = self.student(x_t, t, ehs) |
|
|
L_flow = F.mse_loss(v_hat, v_star) |
|
|
|
|
|
|
|
|
e_t, e_p, coh = self.assessor(s_feats_spatial, t) |
|
|
lam = self.fusion.lambdas(e_t, e_p, coh) |
|
|
|
|
|
|
|
|
L_blocks = torch.zeros((), device=self.device) |
|
|
for name, s_feat in s_feats_spatial.items(): |
|
|
|
|
|
L_kd = torch.zeros((), device=self.device) |
|
|
if cfg.use_kd: |
|
|
s_pool = spatial_pool(s_feat, name, cfg.pooling) |
|
|
t_pool = spatial_pool(t_feats_spatial[name], name, cfg.pooling) |
|
|
L_kd = self._kd_cos(s_pool, t_pool) |
|
|
|
|
|
L_lf = torch.zeros((), device=self.device) |
|
|
if cfg.use_local_flow_heads and name in self.student.local_heads: |
|
|
v_loc = self.student.local_heads[name](s_feat) |
|
|
v_ds = self._down_like(v_star, v_loc) |
|
|
L_lf = F.mse_loss(v_loc, v_ds) |
|
|
L_blocks = L_blocks + lam.get(name,1.0) * (cfg.kd_weight * L_kd + cfg.local_flow_weight * L_lf) |
|
|
|
|
|
L_total = cfg.global_flow_weight*L_flow + cfg.block_penalty_weight*L_blocks |
|
|
|
|
|
self.opt.zero_grad(set_to_none=True) |
|
|
if cfg.amp: |
|
|
self.scaler.scale(L_total).backward() |
|
|
nn.utils.clip_grad_norm_(self.student.parameters(), cfg.grad_clip) |
|
|
self.scaler.step(self.opt); self.scaler.update() |
|
|
else: |
|
|
L_total.backward() |
|
|
nn.utils.clip_grad_norm_(self.student.parameters(), cfg.grad_clip) |
|
|
self.opt.step() |
|
|
self.sched.step(); gstep += 1 |
|
|
|
|
|
acc["L"] += float(L_total.item()) |
|
|
acc["Lf"] += float(L_flow.item()) |
|
|
acc["Lb"] += float(L_blocks.item()) |
|
|
|
|
|
|
|
|
if it % 50 == 0: |
|
|
self.writer.add_scalar("train/total", float(L_total.item()), gstep) |
|
|
self.writer.add_scalar("train/flow", float(L_flow.item()), gstep) |
|
|
self.writer.add_scalar("train/blocks",float(L_blocks.item()), gstep) |
|
|
|
|
|
for k in list(lam.keys())[:4]: |
|
|
self.writer.add_scalar(f"lambda/{k}", lam[k], gstep) |
|
|
|
|
|
|
|
|
if it % 10 == 0 or it == len(self.loader) - 1: |
|
|
pbar.set_postfix({ |
|
|
"L": f"{float(L_total.item()):.4f}", |
|
|
"Lf": f"{float(L_flow.item()):.4f}", |
|
|
"Lb": f"{float(L_blocks.item()):.4f}" |
|
|
}, refresh=False) |
|
|
|
|
|
del x_t, eps_hat, v_star, v_hat, s_feats_spatial, t_feats_spatial |
|
|
|
|
|
pbar.close() |
|
|
|
|
|
n = len(self.loader) |
|
|
print(f"\n[Epoch {ep+1}] L={acc['L']/n:.4f} | L_flow={acc['Lf']/n:.4f} | L_blocks={acc['Lb']/n:.4f}") |
|
|
self.writer.add_scalar("epoch/total", acc['L']/n, ep+1) |
|
|
self.writer.add_scalar("epoch/flow", acc['Lf']/n, ep+1) |
|
|
self.writer.add_scalar("epoch/blocks",acc['Lb']/n, ep+1) |
|
|
|
|
|
if (ep+1) % cfg.save_every == 0: |
|
|
self._save(ep+1, gstep) |
|
|
|
|
|
self._save("final", gstep) |
|
|
self.writer.close() |
|
|
|
|
|
|
|
|
def _save(self, tag, gstep): |
|
|
"""Save and convert to ComfyUI format, then upload.""" |
|
|
|
|
|
pt_path = Path(self.cfg.ckpt_dir) / f"{self.cfg.run_name}_e{tag}.pt" |
|
|
torch.save({ |
|
|
"cfg": asdict(self.cfg), |
|
|
"student": self.student.state_dict(), |
|
|
"opt": self.opt.state_dict(), |
|
|
"sched": self.sched.state_dict(), |
|
|
"gstep": gstep |
|
|
}, pt_path) |
|
|
print(f"✓ Saved temp .pt: {pt_path}") |
|
|
|
|
|
|
|
|
safetensors_path = self._convert_to_comfyui(pt_path, tag) |
|
|
|
|
|
|
|
|
if self.cfg.upload_every_epoch and self.cfg.hf_repo_id and safetensors_path: |
|
|
self._upload_to_hf(safetensors_path, tag) |
|
|
|
|
|
|
|
|
pt_path.unlink() |
|
|
print(f"✓ Cleaned up temp .pt file") |
|
|
|
|
|
def _convert_to_comfyui(self, pt_path: Path, tag) -> Optional[Path]: |
|
|
"""Convert .pt to ComfyUI-compatible safetensors.""" |
|
|
try: |
|
|
temp_pipeline = Path(self.cfg.ckpt_dir) / f"temp_pipeline_e{tag}" |
|
|
output_safetensors = Path(self.cfg.ckpt_dir) / f"{self.cfg.run_name}_e{tag}.safetensors" |
|
|
|
|
|
|
|
|
converter_path = Path(self.cfg.ckpt_dir) / "convert_diffusers_to_original_stable_diffusion.py" |
|
|
if not converter_path.exists(): |
|
|
print("📥 Downloading official converter...") |
|
|
url = "https://raw.githubusercontent.com/huggingface/diffusers/main/scripts/convert_diffusers_to_original_stable_diffusion.py" |
|
|
urllib.request.urlretrieve(url, str(converter_path)) |
|
|
print("✓ Converter downloaded") |
|
|
|
|
|
|
|
|
print(f"📦 Creating diffusers pipeline from checkpoint...") |
|
|
checkpoint = torch.load(pt_path, map_location='cpu') |
|
|
student_state = checkpoint.get('student', checkpoint) |
|
|
|
|
|
|
|
|
print("📥 Loading base UNet...") |
|
|
unet = UNet2DConditionModel.from_pretrained( |
|
|
"runwayml/stable-diffusion-v1-5", |
|
|
subfolder="unet", |
|
|
torch_dtype=torch.float16 |
|
|
) |
|
|
unet.load_state_dict(student_state, strict=False) |
|
|
print("✓ Loaded student weights into UNet") |
|
|
|
|
|
|
|
|
print("📥 Loading base SD1.5 pipeline...") |
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
|
"runwayml/stable-diffusion-v1-5", |
|
|
torch_dtype=torch.float16, |
|
|
safety_checker=None |
|
|
) |
|
|
pipe.unet = unet |
|
|
print("✓ Replaced UNet with student") |
|
|
|
|
|
|
|
|
print(f"💾 Saving diffusers pipeline...") |
|
|
pipe.save_pretrained(str(temp_pipeline), safe_serialization=True) |
|
|
print(f"✓ Pipeline saved to {temp_pipeline}") |
|
|
|
|
|
|
|
|
print(f"🔄 Converting to ComfyUI format...") |
|
|
cmd = [ |
|
|
"python", str(converter_path), |
|
|
"--model_path", str(temp_pipeline), |
|
|
"--checkpoint_path", str(output_safetensors), |
|
|
"--half" |
|
|
] |
|
|
|
|
|
result = subprocess.run(cmd, capture_output=True, text=True) |
|
|
if result.returncode != 0: |
|
|
print(f"❌ Conversion failed: {result.stderr}") |
|
|
return None |
|
|
|
|
|
|
|
|
if output_safetensors.exists(): |
|
|
size_mb = output_safetensors.stat().st_size / 1e6 |
|
|
print(f"✓ Converted: {output_safetensors.name} ({size_mb:.1f}MB)") |
|
|
|
|
|
|
|
|
shutil.rmtree(temp_pipeline) |
|
|
print("✓ Cleaned up temp pipeline") |
|
|
|
|
|
return output_safetensors |
|
|
else: |
|
|
print(f"❌ Output file not created") |
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Conversion failed: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None |
|
|
|
|
|
def _upload_to_hf(self, path: Path, tag): |
|
|
"""Upload safetensors to HuggingFace.""" |
|
|
try: |
|
|
api = HfApi() |
|
|
|
|
|
|
|
|
try: |
|
|
create_repo(self.cfg.hf_repo_id, exist_ok=True, private=False, repo_type="model") |
|
|
print(f"✓ Repo ready: {self.cfg.hf_repo_id}") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
print(f"📤 Uploading to {self.cfg.hf_repo_id}...") |
|
|
api.upload_file( |
|
|
path_or_fileobj=str(path), |
|
|
path_in_repo=path.name, |
|
|
repo_id=self.cfg.hf_repo_id, |
|
|
repo_type="model", |
|
|
commit_message=f"Epoch {tag}" |
|
|
) |
|
|
print(f"✅ Uploaded: https://huggingface.co/{self.cfg.hf_repo_id}/{path.name}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ Upload failed: {e}") |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def sample(self, prompts: List[str], steps: Optional[int]=None, guidance: Optional[float]=None) -> torch.Tensor: |
|
|
steps = steps or self.cfg.sample_steps |
|
|
guidance = guidance if guidance is not None else self.cfg.guidance_scale |
|
|
cond_e = self.teacher.encode(prompts) |
|
|
uncond_e = self.teacher.encode([""]*len(prompts)) |
|
|
sched = self.teacher.sched |
|
|
sched.set_timesteps(steps, device=self.device) |
|
|
x_t = torch.randn(len(prompts), 4, 64, 64, device=self.device) |
|
|
|
|
|
for t_scalar in sched.timesteps: |
|
|
t = torch.full((x_t.shape[0],), t_scalar, device=self.device, dtype=torch.long) |
|
|
v_u, _ = self.student(x_t, t, uncond_e) |
|
|
v_c, _ = self.student(x_t, t, cond_e) |
|
|
v_hat = v_u + guidance*(v_c - v_u) |
|
|
|
|
|
alpha, sigma = self.teacher.alpha_sigma(t) |
|
|
denom = (alpha**2 + sigma**2) |
|
|
x0_hat = (alpha * x_t - sigma * v_hat) / (denom + 1e-8) |
|
|
eps_hat = (x_t - alpha * x0_hat) / (sigma + 1e-8) |
|
|
|
|
|
step = sched.step(model_output=eps_hat, timestep=t_scalar, sample=x_t) |
|
|
x_t = step.prev_sample |
|
|
|
|
|
imgs = self.teacher.pipe.vae.decode(x_t / 0.18215).sample |
|
|
return imgs.clamp(-1,1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
cfg = BaseConfig() |
|
|
print(json.dumps(asdict(cfg), indent=2)) |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
if device != "cuda": |
|
|
print("⚠️ A100 strongly recommended.") |
|
|
trainer = FlowMatchDavidTrainer(cfg, device=device) |
|
|
trainer.train() |
|
|
|
|
|
_ = trainer.sample(["a castle at sunset"], steps=10, guidance=7.0) |
|
|
print("✓ Inference sanity done.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |