#!/usr/bin/env python3 """ Anima DRaFT+ / AlignProp training with HPSv2 reward ==================================================== これは **品質向上** (速度向上ではない) 用の蒸留 LoRA fine-tuning。 既に蒸留された student LoRA を warm-start として、HPSv2 (Human Preference Score) を maximize する方向に追加学習する。 Algorithm (DRaFT-K LV with KL regularization): 1. caption → cond_pos (no_grad) 2. init noise → student で N step rollout - 前 N-K step: no_grad - 後 K step: grad on (K=1 が paper の best) 3. final x0 → VAE decode (grad on) → image 4. reward = HPSv2(image, prompt) 5. KL term: ||v_pred - v_pred_base||² (LoRA disable して frozen base v_pred と比較) 6. loss = -reward + kl_coeff * KL """ from __future__ import annotations import argparse import copy import json import os import sys import time from pathlib import Path import torch import torch.nn.functional as F from torch.utils.data import DataLoader sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from distill.anima_loader import build_anima, AnimaBundle from distill.dmd2_trainer import attach_wide_lora from distill.train_traj import ( TextOnlyDataset, text_collate, load_warm_lora, save_lora_state, ) from distill.traj_scheduler import make_schedule def vae_decode_with_grad(bundle: AnimaBundle, latents: torch.Tensor) -> torch.Tensor: """grad を通す VAE decode (anima_loader の vae_decode は @no_grad なので別実装)。""" vae_dtype = next(bundle.vae.model.parameters()).dtype latents = latents.to(dtype=vae_dtype) return bundle.vae.model.decode(latents, bundle.vae_scale) def student_rollout_with_truncation( student_v_fn, base_v_fn_for_kl, init_noise: torch.Tensor, schedule_ts: torch.Tensor, # (N+1,) cond_pos: torch.Tensor, student_cfg: float, K: int, capture_kl: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """rollout、最後の K step で grad on。KL 項用に最後の v_pred も別途 base で評価。 Returns: (x0_final, kl_loss) """ B = init_noise.size(0) device = init_noise.device dtype = init_noise.dtype N = len(schedule_ts) - 1 truncate_idx = N - K x = init_noise kl_loss = torch.zeros((), device=device) for i in range(N): t_cur = schedule_ts[i].expand(B).to(dtype=dtype) t_next = schedule_ts[i + 1] is_grad_step = (i >= truncate_idx) ctx = torch.enable_grad() if is_grad_step else torch.no_grad() with ctx: v = student_v_fn(x, t_cur, cond_pos) if is_grad_step and capture_kl and i == N - 1: # 最後の step で KL term: student LoRA v vs frozen base v with torch.no_grad(): v_base = base_v_fn_for_kl(x, t_cur, cond_pos) kl_loss = ((v - v_base.detach()).float() ** 2).mean() dt = (t_next - schedule_ts[i]).to(device=device, dtype=dtype) x = x + dt * v return x, kl_loss def main(): ap = argparse.ArgumentParser() ap.add_argument("--dataset", required=True, type=str, help="caption-only dir (画像不要)") ap.add_argument("--out", required=True, type=str) ap.add_argument("--warm-lora", required=True, type=str, help="必須: 既に蒸留された student LoRA (例 ① Z-Image の出力)") ap.add_argument("--hps-weights", default="/models/hpsv2/HPS_v2_compressed.pt", help="HPSv2 weights path、無ければ OpenCLIP baseline fallback") ap.add_argument("--total-steps", type=int, default=1500) ap.add_argument("--batch-size", type=int, default=2) ap.add_argument("--grad-accum", type=int, default=1) ap.add_argument("--n-student-steps", type=int, default=8) ap.add_argument("--K", type=int, default=1, help="gradient truncation depth, paper best=1") ap.add_argument("--n-lv-samples", type=int, default=2, help="DRaFT-LV: extra noise samples at last step, averaged") ap.add_argument("--resolution", type=int, default=768) ap.add_argument("--student-cfg", type=float, default=1.0) ap.add_argument("--sigma-shift", type=float, default=3.0) ap.add_argument("--lr", type=float, default=1e-4) ap.add_argument("--kl-coeff", type=float, default=0.2, help="NeMo DRaFT+ default、reward hacking 防止") ap.add_argument("--lora-rank", type=int, default=32) ap.add_argument("--grad-clip", type=float, default=1.0) ap.add_argument("--log-every", type=int, default=5) ap.add_argument("--sample-every", type=int, default=200) ap.add_argument("--num-workers", type=int, default=2) ap.add_argument("--seed", type=int, default=42) args = ap.parse_args() torch.manual_seed(args.seed) device = torch.device("cuda") dtype = torch.bfloat16 out_dir = Path(args.out) out_dir.mkdir(parents=True, exist_ok=True) print("[load] Anima bundle") bundle = build_anima(device=device, dtype=dtype) # base = frozen deepcopy (KL term の reference として使う) print("[setup] base = frozen deepcopy (for KL reference)") base_transformer = copy.deepcopy(bundle.transformer).to(device=device, dtype=dtype).eval() for p in base_transformer.parameters(): p.requires_grad = False # student = wide LoRA student_transformer = attach_wide_lora(bundle.transformer, rank=args.lora_rank) student_transformer.to(device=device, dtype=dtype) for n, p in student_transformer.named_parameters(): p.requires_grad = ("lora_" in n) student_params = [p for p in student_transformer.parameters() if p.requires_grad] print(f"[setup] student trainable: {sum(p.numel() for p in student_params)/1e6:.1f}M") bundle.transformer = student_transformer # warm-start 必須 load_warm_lora(student_transformer, args.warm_lora) # HPSv2 reward print(f"[setup] loading HPSv2 from {args.hps_weights}") from distill.hps_reward import HPSv2Reward hps = HPSv2Reward(args.hps_weights, device=device, dtype=torch.float32) # schedule (固定 N step) sched = make_schedule(args.n_student_steps, args.sigma_shift, device=device, dtype=torch.float32) print(f"[schedule] N={sched.num_steps} timesteps={sched.timesteps.tolist()}") # dataset (caption-only) print(f"[data] {args.dataset}") dataset = TextOnlyDataset(args.dataset) print(f" {len(dataset)} captions") loader = DataLoader( dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=text_collate, drop_last=True, ) opt = torch.optim.AdamW(student_params, lr=args.lr, betas=(0.9, 0.999), weight_decay=0.01, eps=1e-8) def student_v_fn(x, t, cond): return AnimaBundle.dit_forward(student_transformer, x, t, cond) def base_v_fn(x, t, cond): return AnimaBundle.dit_forward(base_transformer, x, t, cond) H_lat = args.resolution // 8 W_lat = args.resolution // 8 print(f"[train] steps={args.total_steps} bs={args.batch_size} N={args.n_student_steps} " f"K={args.K} lv={args.n_lv_samples} kl={args.kl_coeff}") log_path = out_dir / "draftp_log.jsonl" log_f = open(log_path, "a", buffering=1) t0 = time.time() data_iter = iter(loader) def _next(): nonlocal data_iter try: return next(data_iter) except StopIteration: data_iter = iter(loader) return next(data_iter) for step in range(args.total_steps): student_transformer.train() opt.zero_grad() metrics = {} for _ in range(args.grad_accum): captions = _next() with torch.no_grad(): cond_pos = bundle.text_encode(captions) B = cond_pos.size(0) init_noise = torch.randn(B, 16, 1, H_lat, W_lat, device=device, dtype=dtype) # DRaFT-LV: rollout once + n_lv_samples 個の last-step alternative を試して平均 x0_final, kl = student_rollout_with_truncation( student_v_fn, base_v_fn, init_noise, sched.timesteps, cond_pos, args.student_cfg, args.K, capture_kl=True, ) # VAE decode (grad on) img = vae_decode_with_grad(bundle, x0_final).squeeze(2) # (B,3,H,W) in [-1,1] reward = hps.score(img, captions) # (B,) r_mean = reward.mean() # DRaFT+ loss: -reward + kl_coeff * kl loss = (-r_mean + args.kl_coeff * kl) / args.grad_accum loss.backward() metrics = { "reward_mean": float(r_mean.detach()), "reward_std": float(reward.std().detach()), "kl": float(kl.detach()), "loss": float((-r_mean + args.kl_coeff * kl).detach()), } torch.nn.utils.clip_grad_norm_(student_params, args.grad_clip) opt.step() if step % args.log_every == 0: metrics["step"] = step metrics["elapsed"] = time.time() - t0 log_f.write(json.dumps(metrics) + "\n") msg = " ".join(f"{k}={v:.4f}" if isinstance(v, float) else f"{k}={v}" for k, v in metrics.items() if k != "step") print(f"[step {step}/{args.total_steps}] {msg}", flush=True) if step > 0 and step % args.sample_every == 0: save_lora_state(student_transformer, out_dir, f"draftp_step{step:05d}") print(f"[save] draftp_step{step:05d}.safetensors", flush=True) try: import modal modal.Volume.from_name("anima-outputs").commit() except Exception: pass print("[done] saving final") save_lora_state(student_transformer, out_dir, "draftp_final") log_f.close() if __name__ == "__main__": main()