| |
| """ |
| Anima self-distillation データセット生成 |
| ========================================= |
| |
| ComfyUI を spawn して、gen_prompts.txt × N seeds の組み合わせで画像生成。 |
| 各画像と同名 .txt に「生成プロンプト + quality タグ suffix」を保存する。 |
| 後段の clean_captions が quality タグを削除して学習用ペアを完成させる。 |
| |
| 使い方: |
| python generate_dataset.py \\ |
| --prompts /workspace/scripts/gen_prompts.txt \\ |
| --workflow /workspace/scripts/anima_workflow.json \\ |
| --out /dataset/raw \\ |
| --seeds-per-prompt 50 \\ |
| --comfy-dir /workspace/ComfyUI |
| """ |
| import argparse |
| import json |
| import os |
| import random |
| import shutil |
| import signal |
| import subprocess |
| import sys |
| import time |
| import urllib.request |
| import urllib.parse |
| import urllib.error |
| from pathlib import Path |
|
|
| COMFY_HOST = "127.0.0.1" |
| COMFY_PORT = 8188 |
|
|
| |
| |
| |
| QUALITY_PREFIX = "masterpiece, best quality, score_7, safe" |
|
|
| |
| ASPECT_RATIOS = [ |
| (1024, 1024), |
| (1152, 896), |
| (896, 1152), |
| (1216, 832), |
| (832, 1216), |
| (1344, 768), |
| (768, 1344), |
| ] |
|
|
|
|
| def http_post(path: str, payload: dict) -> dict: |
| req = urllib.request.Request( |
| f"http://{COMFY_HOST}:{COMFY_PORT}{path}", |
| data=json.dumps(payload).encode(), |
| headers={"Content-Type": "application/json"}, |
| ) |
| with urllib.request.urlopen(req, timeout=120) as r: |
| return json.loads(r.read()) |
|
|
|
|
| def http_get(path: str, binary: bool = False): |
| with urllib.request.urlopen( |
| f"http://{COMFY_HOST}:{COMFY_PORT}{path}", timeout=120 |
| ) as r: |
| data = r.read() |
| return data if binary else json.loads(data) |
|
|
|
|
| def wait_for_comfy(timeout: int = 300): |
| """ComfyUI server が起動するのを待つ""" |
| deadline = time.time() + timeout |
| while time.time() < deadline: |
| try: |
| http_get("/system_stats") |
| return |
| except (urllib.error.URLError, ConnectionError): |
| time.sleep(2) |
| raise SystemExit(f"ComfyUI server did not become ready within {timeout}s") |
|
|
|
|
| def submit_and_wait(workflow: dict, poll_interval: float = 0.5, timeout: int = 600): |
| """workflow を submit して完了まで待つ。返り値は history entry。""" |
| resp = http_post("/prompt", {"prompt": workflow}) |
| prompt_id = resp["prompt_id"] |
|
|
| deadline = time.time() + timeout |
| while time.time() < deadline: |
| h = http_get(f"/history/{prompt_id}") |
| if prompt_id in h and h[prompt_id].get("status", {}).get("completed"): |
| return h[prompt_id] |
| time.sleep(poll_interval) |
| raise RuntimeError(f"Generation timed out (prompt_id={prompt_id})") |
|
|
|
|
| def fetch_image(filename: str, subfolder: str = "", typ: str = "output") -> bytes: |
| qs = urllib.parse.urlencode({"filename": filename, "subfolder": subfolder, "type": typ}) |
| return http_get(f"/view?{qs}", binary=True) |
|
|
|
|
| def patch_workflow( |
| template: dict, prompt: str, width: int, height: int, seed: int, |
| override_steps: int = 0, override_cfg: float = 0.0, batch_size: int = 1, |
| ) -> dict: |
| """テンプレ workflow のプレースホルダを埋める。steps/cfg/batch_size の override 対応。 |
| batch_size>1 で EmptyLatentImage.batch_size を上書き、KSampler が一度に N 枚生成。""" |
| wf = json.loads(json.dumps(template)) |
| wf.pop("_comment", None) |
| wf["5"]["inputs"]["text"] = prompt |
| wf["7"]["inputs"]["width"] = width |
| wf["7"]["inputs"]["height"] = height |
| wf["7"]["inputs"]["batch_size"] = batch_size |
| wf["8"]["inputs"]["seed"] = seed |
| if override_steps > 0: |
| wf["8"]["inputs"]["steps"] = override_steps |
| if override_cfg > 0: |
| wf["8"]["inputs"]["cfg"] = override_cfg |
| return wf |
|
|
|
|
| def setup_model_symlinks(comfy_dir: Path, models_dir: Path, loras_dir: Path | None = None): |
| """/models/checkpoints/*.safetensors を ComfyUI の各 subdir に symlink。 |
| loras_dir 指定時は /models/loras/*.safetensors も lora subdir に symlink。 |
| 既知の text_encoder / vae 以外は diffusion_models 扱い (Phase A distilled 等も拾う)""" |
| text_encoder_names = {"qwen_3_06b_base.safetensors"} |
| vae_names = {"qwen_image_vae.safetensors"} |
|
|
| def classify(fname: str) -> str: |
| if fname in text_encoder_names: |
| return "text_encoders" |
| if fname in vae_names: |
| return "vae" |
| return "diffusion_models" |
|
|
| for src in models_dir.glob("*.safetensors"): |
| subdir = classify(src.name) |
| dst_dir = comfy_dir / "models" / subdir |
| dst_dir.mkdir(parents=True, exist_ok=True) |
| dst = dst_dir / src.name |
| if dst.is_symlink() or dst.exists(): |
| dst.unlink() |
| dst.symlink_to(src) |
| print(f"[symlink] {dst} -> {src}") |
|
|
| |
| if loras_dir and loras_dir.exists(): |
| dst_dir = comfy_dir / "models" / "loras" |
| dst_dir.mkdir(parents=True, exist_ok=True) |
| for src in loras_dir.glob("*.safetensors"): |
| dst = dst_dir / src.name |
| if dst.is_symlink() or dst.exists(): |
| dst.unlink() |
| dst.symlink_to(src) |
| print(f"[symlink] {dst} -> {src}") |
|
|
|
|
| def load_progress(out_dir: Path) -> set[str]: |
| """既に生成済みの (prompt_idx, seed) ペアを recover""" |
| done = set() |
| for p in out_dir.glob("*.png"): |
| done.add(p.stem) |
| return done |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--prompts", required=True, type=Path) |
| ap.add_argument("--workflow", required=True, type=Path) |
| ap.add_argument("--out", required=True, type=Path) |
| ap.add_argument("--comfy-dir", required=True, type=Path) |
| ap.add_argument("--models-dir", default=Path("/models/checkpoints"), type=Path) |
| ap.add_argument("--loras-dir", default=Path("/models/loras"), type=Path, |
| help="ComfyUI に symlink する LoRA ディレクトリ。存在しなければスキップ。") |
| ap.add_argument("--seeds-per-prompt", type=int, default=50) |
| ap.add_argument("--base-seed", type=int, default=42, |
| help="ベース seed。実 seed = base_seed * 10000 + prompt_idx * 100 + seed_idx") |
| ap.add_argument("--max-images", type=int, default=0, |
| help="0 = 制限なし (prompts * seeds_per_prompt 全部)") |
| ap.add_argument("--start-from", type=int, default=0, |
| help="prompt index を途中から再開する場合") |
| ap.add_argument("--prompt-end", type=int, default=0, |
| help="処理する prompt index の終端 (排他的)。0 = 末尾まで。" |
| "並列実行時は [start-from, prompt-end) を各 worker に割り当て。") |
| ap.add_argument("--quality-prefix", default=QUALITY_PREFIX, |
| help="Anima 公式推奨の prefix。先頭に付与される。") |
| ap.add_argument("--no-prefix", action="store_true", |
| help="quality-prefix を付与しない (比較用)") |
| ap.add_argument("--override-steps", type=int, default=0, |
| help="workflow の KSampler.steps を上書き (0 = workflow 既定)") |
| ap.add_argument("--override-cfg", type=float, default=0.0, |
| help="workflow の KSampler.cfg を上書き (0 = workflow 既定)") |
| ap.add_argument("--method-label", default="", |
| help="蒸留手法ラベル (例 z_image_traj_imitation, civitai_anima_turbo, base)。" |
| "metadata に含めて条件間で識別可能にする") |
| ap.add_argument("--fixed-aspect", default="", |
| help="例: '1024x1024'。指定で aspect ratio random shuffle を無効化。") |
| ap.add_argument("--batch-size", type=int, default=1, |
| help="1 submission あたり生成する画像数。B200 なら 8 程度まで OK。" |
| "同一 batch 内は同じ prompt/aspect、seed は base_seed + offset 0..N-1。") |
| args = ap.parse_args() |
|
|
| args.out.mkdir(parents=True, exist_ok=True) |
|
|
| |
| setup_model_symlinks(args.comfy_dir, args.models_dir, args.loras_dir) |
|
|
| |
| raw_lines = args.prompts.read_text(encoding="utf-8").splitlines() |
| prompts = [ |
| l.strip() for l in raw_lines |
| if l.strip() and not l.strip().startswith("#") |
| ] |
| print(f"[prompts] {len(prompts)} loaded from {args.prompts}") |
|
|
| |
| workflow_template = json.loads(args.workflow.read_text(encoding="utf-8")) |
|
|
| |
| def _extract_wf_meta(wf): |
| info = { |
| "sampler": None, "scheduler": None, |
| "lora_name": None, "lora_strength": None, |
| "sigma_shift": None, "unet_name": None, |
| } |
| for nid, node in wf.items(): |
| if not isinstance(node, dict): |
| continue |
| ct = node.get("class_type") |
| ins = node.get("inputs", {}) |
| if ct == "KSampler": |
| info["sampler"] = ins.get("sampler_name") |
| info["scheduler"] = ins.get("scheduler") |
| elif ct == "LoraLoaderModelOnly": |
| info["lora_name"] = ins.get("lora_name") |
| info["lora_strength"] = ins.get("strength_model") |
| elif ct == "ModelSamplingAuraFlow": |
| info["sigma_shift"] = ins.get("shift") |
| elif ct == "UNETLoader": |
| info["unet_name"] = ins.get("unet_name") |
| return info |
| wf_meta = _extract_wf_meta(workflow_template) |
| print(f"[wf-meta] {wf_meta}") |
|
|
| |
| print(f"[comfy] starting server in {args.comfy_dir}") |
| comfy_args = [ |
| sys.executable, "main.py", |
| "--listen", COMFY_HOST, |
| "--port", str(COMFY_PORT), |
| "--disable-auto-launch", |
| ] |
| |
| try: |
| import sageattention |
| comfy_args.append("--use-sage-attention") |
| print("[comfy] sageattention 検出 → --use-sage-attention で起動") |
| except ImportError: |
| print("[comfy] sageattention 未インストール、torch SDPA fallback") |
| |
| |
| server = subprocess.Popen( |
| comfy_args, |
| cwd=str(args.comfy_dir), |
| stdout=None, |
| stderr=subprocess.STDOUT, |
| ) |
| try: |
| wait_for_comfy(timeout=600) |
| print("[comfy] server ready") |
|
|
| |
| done = load_progress(args.out) |
| print(f"[recover] {len(done)} images already in {args.out}") |
|
|
| |
| |
| |
| |
| total = 0 |
| rng = random.Random(args.base_seed) |
| prompt_end = args.prompt_end if args.prompt_end > 0 else len(prompts) |
| bs = max(1, args.batch_size) |
| for p_idx, raw_prompt in enumerate(prompts): |
| if p_idx < args.start_from or p_idx >= prompt_end: |
| continue |
| |
| if args.no_prefix: |
| full_caption = raw_prompt |
| else: |
| full_caption = f"{args.quality_prefix}, {raw_prompt}" |
|
|
| |
| for chunk_start in range(0, args.seeds_per_prompt, bs): |
| chunk_end = min(chunk_start + bs, args.seeds_per_prompt) |
| cur_bs = chunk_end - chunk_start |
| stems = [f"p{p_idx:04d}_s{si:03d}" for si in range(chunk_start, chunk_end)] |
| |
| if all(s in done for s in stems): |
| continue |
| if args.max_images and total >= args.max_images: |
| print(f"[stop] max_images={args.max_images} reached") |
| return |
|
|
| if args.fixed_aspect: |
| w, h = args.fixed_aspect.split("x") |
| width, height = int(w), int(h) |
| else: |
| width, height = rng.choice(ASPECT_RATIOS) |
| |
| chunk_seed = args.base_seed * 10000 + p_idx * 100 + chunk_start |
| wf = patch_workflow( |
| workflow_template, full_caption, width, height, chunk_seed, |
| override_steps=args.override_steps, |
| override_cfg=args.override_cfg, |
| batch_size=cur_bs, |
| ) |
|
|
| t0 = time.time() |
| try: |
| hist = submit_and_wait(wf, timeout=600) |
| except Exception as e: |
| print(f"[fail] p{p_idx} s{chunk_start}..{chunk_end} batch={cur_bs}: {e}") |
| continue |
|
|
| |
| outputs = hist.get("outputs", {}).get("10", {}).get("images", []) |
| if len(outputs) != cur_bs: |
| print(f"[warn] p{p_idx} chunk={chunk_start} expected {cur_bs} got {len(outputs)}") |
| elapsed = time.time() - t0 |
|
|
| |
| for offset, img in enumerate(outputs): |
| s_idx = chunk_start + offset |
| stem = stems[offset] if offset < len(stems) else f"p{p_idx:04d}_s{s_idx:03d}" |
| if stem in done: |
| continue |
| seed = args.base_seed * 10000 + p_idx * 100 + s_idx |
| img_bytes = fetch_image( |
| img["filename"], img.get("subfolder", ""), img.get("type", "output") |
| ) |
| (args.out / f"{stem}.png").write_bytes(img_bytes) |
| (args.out / f"{stem}.txt").write_text(full_caption, encoding="utf-8") |
| total += 1 |
| |
| meta = { |
| "method_label": args.method_label or None, |
| "gen_time_s": round(elapsed / cur_bs, 3), |
| "batch_size": cur_bs, |
| "batch_wall_s": round(elapsed, 3), |
| "width": width, "height": height, |
| "steps": args.override_steps if args.override_steps > 0 else None, |
| "cfg": args.override_cfg if args.override_cfg > 0 else None, |
| "seed": int(seed), |
| "prompt_index": p_idx, "seed_index": s_idx, |
| "sampler": wf_meta["sampler"], |
| "scheduler": wf_meta["scheduler"], |
| "sigma_shift": wf_meta["sigma_shift"], |
| "lora_name": wf_meta["lora_name"], |
| "lora_strength": wf_meta["lora_strength"], |
| "unet_name": wf_meta["unet_name"], |
| "workflow_path": str(args.workflow), |
| } |
| (args.out / f"{stem}.json").write_text(json.dumps(meta), encoding="utf-8") |
| with open(args.out / "_times.jsonl", "a", encoding="utf-8") as _tf: |
| _tf.write(json.dumps({"stem": stem, **meta}) + "\n") |
| if total % 10 == 0 or total <= cur_bs: |
| per_img = elapsed / max(1, cur_bs) |
| print(f"[gen] {total} done | p{p_idx}/{len(prompts)} s{chunk_start}-{chunk_end-1} " |
| f"| {width}x{height} | bs={cur_bs} | {elapsed:.1f}s ({per_img:.2f}s/img)") |
|
|
| |
| summary_path = args.out / "_summary.json" |
| try: |
| import statistics |
| times_path = args.out / "_times.jsonl" |
| if times_path.exists(): |
| lines = [json.loads(l) for l in times_path.read_text(encoding="utf-8").splitlines() if l.strip()] |
| times = [r["gen_time_s"] for r in lines] |
| summary = { |
| "method_label": args.method_label or None, |
| "num_images": len(times), |
| "total_time_s": round(sum(times), 3), |
| "mean_time_s": round(statistics.mean(times), 3) if times else 0, |
| "median_time_s": round(statistics.median(times), 3) if times else 0, |
| "min_time_s": round(min(times), 3) if times else 0, |
| "max_time_s": round(max(times), 3) if times else 0, |
| "steps": args.override_steps if args.override_steps > 0 else None, |
| "cfg": args.override_cfg if args.override_cfg > 0 else None, |
| "sampler": wf_meta["sampler"], |
| "scheduler": wf_meta["scheduler"], |
| "sigma_shift": wf_meta["sigma_shift"], |
| "lora_name": wf_meta["lora_name"], |
| "lora_strength": wf_meta["lora_strength"], |
| "unet_name": wf_meta["unet_name"], |
| } |
| summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") |
| print(f"[summary] mean={summary['mean_time_s']}s " |
| f"median={summary['median_time_s']}s " |
| f"min={summary['min_time_s']}s max={summary['max_time_s']}s " |
| f"-> {summary_path}") |
| except Exception as e: |
| print(f"[summary] failed: {e}") |
|
|
| print(f"[done] generated {total} images in {args.out}") |
| finally: |
| print("[comfy] stopping server") |
| server.send_signal(signal.SIGTERM) |
| try: |
| server.wait(timeout=30) |
| except subprocess.TimeoutExpired: |
| server.kill() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|