| |
| """ |
| Anima DRaFT+ / AlignProp training with HPSv2 reward |
| ==================================================== |
| |
| これは **品質向上** (速度向上ではない) 用の蒸留 LoRA fine-tuning。 |
| 既に蒸留された student LoRA を warm-start として、HPSv2 (Human Preference Score) |
| を maximize する方向に追加学習する。 |
| |
| Algorithm (DRaFT-K LV with KL regularization): |
| 1. caption → cond_pos (no_grad) |
| 2. init noise → student で N step rollout |
| - 前 N-K step: no_grad |
| - 後 K step: grad on (K=1 が paper の best) |
| 3. final x0 → VAE decode (grad on) → image |
| 4. reward = HPSv2(image, prompt) |
| 5. KL term: ||v_pred - v_pred_base||² (LoRA disable して frozen base v_pred と比較) |
| 6. loss = -reward + kl_coeff * KL |
| """ |
| from __future__ import annotations |
| import argparse |
| import copy |
| import json |
| import os |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
| from distill.anima_loader import build_anima, AnimaBundle |
| from distill.dmd2_trainer import attach_wide_lora |
| from distill.train_traj import ( |
| TextOnlyDataset, text_collate, load_warm_lora, save_lora_state, |
| ) |
| from distill.traj_scheduler import make_schedule |
|
|
|
|
| def vae_decode_with_grad(bundle: AnimaBundle, latents: torch.Tensor) -> torch.Tensor: |
| """grad を通す VAE decode (anima_loader の vae_decode は @no_grad なので別実装)。""" |
| vae_dtype = next(bundle.vae.model.parameters()).dtype |
| latents = latents.to(dtype=vae_dtype) |
| return bundle.vae.model.decode(latents, bundle.vae_scale) |
|
|
|
|
| def student_rollout_with_truncation( |
| student_v_fn, |
| base_v_fn_for_kl, |
| init_noise: torch.Tensor, |
| schedule_ts: torch.Tensor, |
| cond_pos: torch.Tensor, |
| student_cfg: float, |
| K: int, |
| capture_kl: bool = True, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """rollout、最後の K step で grad on。KL 項用に最後の v_pred も別途 base で評価。 |
| |
| Returns: (x0_final, kl_loss) |
| """ |
| B = init_noise.size(0) |
| device = init_noise.device |
| dtype = init_noise.dtype |
| N = len(schedule_ts) - 1 |
| truncate_idx = N - K |
|
|
| x = init_noise |
| kl_loss = torch.zeros((), device=device) |
|
|
| for i in range(N): |
| t_cur = schedule_ts[i].expand(B).to(dtype=dtype) |
| t_next = schedule_ts[i + 1] |
| is_grad_step = (i >= truncate_idx) |
| ctx = torch.enable_grad() if is_grad_step else torch.no_grad() |
| with ctx: |
| v = student_v_fn(x, t_cur, cond_pos) |
| if is_grad_step and capture_kl and i == N - 1: |
| |
| with torch.no_grad(): |
| v_base = base_v_fn_for_kl(x, t_cur, cond_pos) |
| kl_loss = ((v - v_base.detach()).float() ** 2).mean() |
| dt = (t_next - schedule_ts[i]).to(device=device, dtype=dtype) |
| x = x + dt * v |
| return x, kl_loss |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--dataset", required=True, type=str, |
| help="caption-only dir (画像不要)") |
| ap.add_argument("--out", required=True, type=str) |
| ap.add_argument("--warm-lora", required=True, type=str, |
| help="必須: 既に蒸留された student LoRA (例 ① Z-Image の出力)") |
| ap.add_argument("--hps-weights", default="/models/hpsv2/HPS_v2_compressed.pt", |
| help="HPSv2 weights path、無ければ OpenCLIP baseline fallback") |
| ap.add_argument("--total-steps", type=int, default=1500) |
| ap.add_argument("--batch-size", type=int, default=2) |
| ap.add_argument("--grad-accum", type=int, default=1) |
| ap.add_argument("--n-student-steps", type=int, default=8) |
| ap.add_argument("--K", type=int, default=1, help="gradient truncation depth, paper best=1") |
| ap.add_argument("--n-lv-samples", type=int, default=2, |
| help="DRaFT-LV: extra noise samples at last step, averaged") |
| ap.add_argument("--resolution", type=int, default=768) |
| ap.add_argument("--student-cfg", type=float, default=1.0) |
| ap.add_argument("--sigma-shift", type=float, default=3.0) |
| ap.add_argument("--lr", type=float, default=1e-4) |
| ap.add_argument("--kl-coeff", type=float, default=0.2, |
| help="NeMo DRaFT+ default、reward hacking 防止") |
| ap.add_argument("--lora-rank", type=int, default=32) |
| ap.add_argument("--grad-clip", type=float, default=1.0) |
| ap.add_argument("--log-every", type=int, default=5) |
| ap.add_argument("--sample-every", type=int, default=200) |
| ap.add_argument("--num-workers", type=int, default=2) |
| ap.add_argument("--seed", type=int, default=42) |
| args = ap.parse_args() |
|
|
| torch.manual_seed(args.seed) |
| device = torch.device("cuda") |
| dtype = torch.bfloat16 |
| out_dir = Path(args.out) |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| print("[load] Anima bundle") |
| bundle = build_anima(device=device, dtype=dtype) |
|
|
| |
| print("[setup] base = frozen deepcopy (for KL reference)") |
| base_transformer = copy.deepcopy(bundle.transformer).to(device=device, dtype=dtype).eval() |
| for p in base_transformer.parameters(): |
| p.requires_grad = False |
|
|
| |
| student_transformer = attach_wide_lora(bundle.transformer, rank=args.lora_rank) |
| student_transformer.to(device=device, dtype=dtype) |
| for n, p in student_transformer.named_parameters(): |
| p.requires_grad = ("lora_" in n) |
| student_params = [p for p in student_transformer.parameters() if p.requires_grad] |
| print(f"[setup] student trainable: {sum(p.numel() for p in student_params)/1e6:.1f}M") |
| bundle.transformer = student_transformer |
|
|
| |
| load_warm_lora(student_transformer, args.warm_lora) |
|
|
| |
| print(f"[setup] loading HPSv2 from {args.hps_weights}") |
| from distill.hps_reward import HPSv2Reward |
| hps = HPSv2Reward(args.hps_weights, device=device, dtype=torch.float32) |
|
|
| |
| sched = make_schedule(args.n_student_steps, args.sigma_shift, |
| device=device, dtype=torch.float32) |
| print(f"[schedule] N={sched.num_steps} timesteps={sched.timesteps.tolist()}") |
|
|
| |
| print(f"[data] {args.dataset}") |
| dataset = TextOnlyDataset(args.dataset) |
| print(f" {len(dataset)} captions") |
| loader = DataLoader( |
| dataset, batch_size=args.batch_size, shuffle=True, |
| num_workers=args.num_workers, collate_fn=text_collate, |
| drop_last=True, |
| ) |
|
|
| opt = torch.optim.AdamW(student_params, lr=args.lr, betas=(0.9, 0.999), |
| weight_decay=0.01, eps=1e-8) |
|
|
| def student_v_fn(x, t, cond): |
| return AnimaBundle.dit_forward(student_transformer, x, t, cond) |
|
|
| def base_v_fn(x, t, cond): |
| return AnimaBundle.dit_forward(base_transformer, x, t, cond) |
|
|
| H_lat = args.resolution // 8 |
| W_lat = args.resolution // 8 |
|
|
| print(f"[train] steps={args.total_steps} bs={args.batch_size} N={args.n_student_steps} " |
| f"K={args.K} lv={args.n_lv_samples} kl={args.kl_coeff}") |
| log_path = out_dir / "draftp_log.jsonl" |
| log_f = open(log_path, "a", buffering=1) |
| t0 = time.time() |
| data_iter = iter(loader) |
|
|
| def _next(): |
| nonlocal data_iter |
| try: |
| return next(data_iter) |
| except StopIteration: |
| data_iter = iter(loader) |
| return next(data_iter) |
|
|
| for step in range(args.total_steps): |
| student_transformer.train() |
| opt.zero_grad() |
| metrics = {} |
| for _ in range(args.grad_accum): |
| captions = _next() |
| with torch.no_grad(): |
| cond_pos = bundle.text_encode(captions) |
| B = cond_pos.size(0) |
| init_noise = torch.randn(B, 16, 1, H_lat, W_lat, device=device, dtype=dtype) |
|
|
| |
| x0_final, kl = student_rollout_with_truncation( |
| student_v_fn, base_v_fn, init_noise, sched.timesteps, |
| cond_pos, args.student_cfg, args.K, capture_kl=True, |
| ) |
|
|
| |
| img = vae_decode_with_grad(bundle, x0_final).squeeze(2) |
| reward = hps.score(img, captions) |
| r_mean = reward.mean() |
|
|
| |
| loss = (-r_mean + args.kl_coeff * kl) / args.grad_accum |
| loss.backward() |
| metrics = { |
| "reward_mean": float(r_mean.detach()), |
| "reward_std": float(reward.std().detach()), |
| "kl": float(kl.detach()), |
| "loss": float((-r_mean + args.kl_coeff * kl).detach()), |
| } |
| torch.nn.utils.clip_grad_norm_(student_params, args.grad_clip) |
| opt.step() |
|
|
| if step % args.log_every == 0: |
| metrics["step"] = step |
| metrics["elapsed"] = time.time() - t0 |
| log_f.write(json.dumps(metrics) + "\n") |
| msg = " ".join(f"{k}={v:.4f}" if isinstance(v, float) else f"{k}={v}" |
| for k, v in metrics.items() if k != "step") |
| print(f"[step {step}/{args.total_steps}] {msg}", flush=True) |
|
|
| if step > 0 and step % args.sample_every == 0: |
| save_lora_state(student_transformer, out_dir, f"draftp_step{step:05d}") |
| print(f"[save] draftp_step{step:05d}.safetensors", flush=True) |
| try: |
| import modal |
| modal.Volume.from_name("anima-outputs").commit() |
| except Exception: |
| pass |
|
|
| print("[done] saving final") |
| save_lora_state(student_transformer, out_dir, "draftp_final") |
| log_f.close() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|