rapid-anima / scripts /distill /train_draftp.py
darask0's picture
Initial commit: rapid-anima distillation codebase
77cc641 verified
Raw
History Blame Contribute Delete
10 kB
#!/usr/bin/env python3
"""
Anima DRaFT+ / AlignProp training with HPSv2 reward
====================================================
これは **品質向上** (速度向上ではない) 用の蒸留 LoRA fine-tuning。
既に蒸留された student LoRA を warm-start として、HPSv2 (Human Preference Score)
を maximize する方向に追加学習する。
Algorithm (DRaFT-K LV with KL regularization):
1. caption → cond_pos (no_grad)
2. init noise → student で N step rollout
- 前 N-K step: no_grad
- 後 K step: grad on (K=1 が paper の best)
3. final x0 → VAE decode (grad on) → image
4. reward = HPSv2(image, prompt)
5. KL term: ||v_pred - v_pred_base||² (LoRA disable して frozen base v_pred と比較)
6. loss = -reward + kl_coeff * KL
"""
from __future__ import annotations
import argparse
import copy
import json
import os
import sys
import time
from pathlib import Path
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from distill.anima_loader import build_anima, AnimaBundle
from distill.dmd2_trainer import attach_wide_lora
from distill.train_traj import (
TextOnlyDataset, text_collate, load_warm_lora, save_lora_state,
)
from distill.traj_scheduler import make_schedule
def vae_decode_with_grad(bundle: AnimaBundle, latents: torch.Tensor) -> torch.Tensor:
"""grad を通す VAE decode (anima_loader の vae_decode は @no_grad なので別実装)。"""
vae_dtype = next(bundle.vae.model.parameters()).dtype
latents = latents.to(dtype=vae_dtype)
return bundle.vae.model.decode(latents, bundle.vae_scale)
def student_rollout_with_truncation(
student_v_fn,
base_v_fn_for_kl,
init_noise: torch.Tensor,
schedule_ts: torch.Tensor, # (N+1,)
cond_pos: torch.Tensor,
student_cfg: float,
K: int,
capture_kl: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""rollout、最後の K step で grad on。KL 項用に最後の v_pred も別途 base で評価。
Returns: (x0_final, kl_loss)
"""
B = init_noise.size(0)
device = init_noise.device
dtype = init_noise.dtype
N = len(schedule_ts) - 1
truncate_idx = N - K
x = init_noise
kl_loss = torch.zeros((), device=device)
for i in range(N):
t_cur = schedule_ts[i].expand(B).to(dtype=dtype)
t_next = schedule_ts[i + 1]
is_grad_step = (i >= truncate_idx)
ctx = torch.enable_grad() if is_grad_step else torch.no_grad()
with ctx:
v = student_v_fn(x, t_cur, cond_pos)
if is_grad_step and capture_kl and i == N - 1:
# 最後の step で KL term: student LoRA v vs frozen base v
with torch.no_grad():
v_base = base_v_fn_for_kl(x, t_cur, cond_pos)
kl_loss = ((v - v_base.detach()).float() ** 2).mean()
dt = (t_next - schedule_ts[i]).to(device=device, dtype=dtype)
x = x + dt * v
return x, kl_loss
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--dataset", required=True, type=str,
help="caption-only dir (画像不要)")
ap.add_argument("--out", required=True, type=str)
ap.add_argument("--warm-lora", required=True, type=str,
help="必須: 既に蒸留された student LoRA (例 ① Z-Image の出力)")
ap.add_argument("--hps-weights", default="/models/hpsv2/HPS_v2_compressed.pt",
help="HPSv2 weights path、無ければ OpenCLIP baseline fallback")
ap.add_argument("--total-steps", type=int, default=1500)
ap.add_argument("--batch-size", type=int, default=2)
ap.add_argument("--grad-accum", type=int, default=1)
ap.add_argument("--n-student-steps", type=int, default=8)
ap.add_argument("--K", type=int, default=1, help="gradient truncation depth, paper best=1")
ap.add_argument("--n-lv-samples", type=int, default=2,
help="DRaFT-LV: extra noise samples at last step, averaged")
ap.add_argument("--resolution", type=int, default=768)
ap.add_argument("--student-cfg", type=float, default=1.0)
ap.add_argument("--sigma-shift", type=float, default=3.0)
ap.add_argument("--lr", type=float, default=1e-4)
ap.add_argument("--kl-coeff", type=float, default=0.2,
help="NeMo DRaFT+ default、reward hacking 防止")
ap.add_argument("--lora-rank", type=int, default=32)
ap.add_argument("--grad-clip", type=float, default=1.0)
ap.add_argument("--log-every", type=int, default=5)
ap.add_argument("--sample-every", type=int, default=200)
ap.add_argument("--num-workers", type=int, default=2)
ap.add_argument("--seed", type=int, default=42)
args = ap.parse_args()
torch.manual_seed(args.seed)
device = torch.device("cuda")
dtype = torch.bfloat16
out_dir = Path(args.out)
out_dir.mkdir(parents=True, exist_ok=True)
print("[load] Anima bundle")
bundle = build_anima(device=device, dtype=dtype)
# base = frozen deepcopy (KL term の reference として使う)
print("[setup] base = frozen deepcopy (for KL reference)")
base_transformer = copy.deepcopy(bundle.transformer).to(device=device, dtype=dtype).eval()
for p in base_transformer.parameters():
p.requires_grad = False
# student = wide LoRA
student_transformer = attach_wide_lora(bundle.transformer, rank=args.lora_rank)
student_transformer.to(device=device, dtype=dtype)
for n, p in student_transformer.named_parameters():
p.requires_grad = ("lora_" in n)
student_params = [p for p in student_transformer.parameters() if p.requires_grad]
print(f"[setup] student trainable: {sum(p.numel() for p in student_params)/1e6:.1f}M")
bundle.transformer = student_transformer
# warm-start 必須
load_warm_lora(student_transformer, args.warm_lora)
# HPSv2 reward
print(f"[setup] loading HPSv2 from {args.hps_weights}")
from distill.hps_reward import HPSv2Reward
hps = HPSv2Reward(args.hps_weights, device=device, dtype=torch.float32)
# schedule (固定 N step)
sched = make_schedule(args.n_student_steps, args.sigma_shift,
device=device, dtype=torch.float32)
print(f"[schedule] N={sched.num_steps} timesteps={sched.timesteps.tolist()}")
# dataset (caption-only)
print(f"[data] {args.dataset}")
dataset = TextOnlyDataset(args.dataset)
print(f" {len(dataset)} captions")
loader = DataLoader(
dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, collate_fn=text_collate,
drop_last=True,
)
opt = torch.optim.AdamW(student_params, lr=args.lr, betas=(0.9, 0.999),
weight_decay=0.01, eps=1e-8)
def student_v_fn(x, t, cond):
return AnimaBundle.dit_forward(student_transformer, x, t, cond)
def base_v_fn(x, t, cond):
return AnimaBundle.dit_forward(base_transformer, x, t, cond)
H_lat = args.resolution // 8
W_lat = args.resolution // 8
print(f"[train] steps={args.total_steps} bs={args.batch_size} N={args.n_student_steps} "
f"K={args.K} lv={args.n_lv_samples} kl={args.kl_coeff}")
log_path = out_dir / "draftp_log.jsonl"
log_f = open(log_path, "a", buffering=1)
t0 = time.time()
data_iter = iter(loader)
def _next():
nonlocal data_iter
try:
return next(data_iter)
except StopIteration:
data_iter = iter(loader)
return next(data_iter)
for step in range(args.total_steps):
student_transformer.train()
opt.zero_grad()
metrics = {}
for _ in range(args.grad_accum):
captions = _next()
with torch.no_grad():
cond_pos = bundle.text_encode(captions)
B = cond_pos.size(0)
init_noise = torch.randn(B, 16, 1, H_lat, W_lat, device=device, dtype=dtype)
# DRaFT-LV: rollout once + n_lv_samples 個の last-step alternative を試して平均
x0_final, kl = student_rollout_with_truncation(
student_v_fn, base_v_fn, init_noise, sched.timesteps,
cond_pos, args.student_cfg, args.K, capture_kl=True,
)
# VAE decode (grad on)
img = vae_decode_with_grad(bundle, x0_final).squeeze(2) # (B,3,H,W) in [-1,1]
reward = hps.score(img, captions) # (B,)
r_mean = reward.mean()
# DRaFT+ loss: -reward + kl_coeff * kl
loss = (-r_mean + args.kl_coeff * kl) / args.grad_accum
loss.backward()
metrics = {
"reward_mean": float(r_mean.detach()),
"reward_std": float(reward.std().detach()),
"kl": float(kl.detach()),
"loss": float((-r_mean + args.kl_coeff * kl).detach()),
}
torch.nn.utils.clip_grad_norm_(student_params, args.grad_clip)
opt.step()
if step % args.log_every == 0:
metrics["step"] = step
metrics["elapsed"] = time.time() - t0
log_f.write(json.dumps(metrics) + "\n")
msg = " ".join(f"{k}={v:.4f}" if isinstance(v, float) else f"{k}={v}"
for k, v in metrics.items() if k != "step")
print(f"[step {step}/{args.total_steps}] {msg}", flush=True)
if step > 0 and step % args.sample_every == 0:
save_lora_state(student_transformer, out_dir, f"draftp_step{step:05d}")
print(f"[save] draftp_step{step:05d}.safetensors", flush=True)
try:
import modal
modal.Volume.from_name("anima-outputs").commit()
except Exception:
pass
print("[done] saving final")
save_lora_state(student_transformer, out_dir, "draftp_final")
log_f.close()
if __name__ == "__main__":
main()