from __future__ import annotations import argparse import csv import json from datetime import datetime from pathlib import Path from typing import Any, Dict, List, Sequence, Set from scripts.eval_pipeline import run_eval REPO_ROOT = Path(__file__).resolve().parents[1] DEFAULT_EVAL_PATH = REPO_ROOT / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000_caption_evident_n30.jsonl" def _canon_tag(tag: str) -> str: t = " ".join(str(tag or "").strip().split()).lower() return t.replace(" ", "_").replace("\\(", "(").replace("\\)", ")") def _parse_tag_set(text: str) -> Set[str]: out: Set[str] = set() for raw in (text or "").split(","): t = _canon_tag(raw) if t: out.add(t) return out def _set_metrics(pred_sets: Sequence[Set[str]], gold_sets: Sequence[Set[str]]) -> Dict[str, float]: n = len(pred_sets) if n == 0: return { "set_precision": 0.0, "set_recall": 0.0, "set_f1": 0.0, "avg_pred_tags": 0.0, "avg_gold_tags": 0.0, } p_vals: List[float] = [] r_vals: List[float] = [] f_vals: List[float] = [] pred_sizes: List[int] = [] gold_sizes: List[int] = [] for pset, gset in zip(pred_sets, gold_sets): pred_sizes.append(len(pset)) gold_sizes.append(len(gset)) if not pset or not gset: p_vals.append(0.0 if pset or gset else 1.0) r_vals.append(0.0 if pset or gset else 1.0) f_vals.append(0.0 if pset or gset else 1.0) continue tp = len(pset & gset) p = tp / len(pset) r = tp / len(gset) f1 = (2 * p * r / (p + r)) if (p + r) > 0 else 0.0 p_vals.append(p) r_vals.append(r) f_vals.append(f1) return { "set_precision": sum(p_vals) / n, "set_recall": sum(r_vals) / n, "set_f1": sum(f_vals) / n, "avg_pred_tags": sum(pred_sizes) / n, "avg_gold_tags": sum(gold_sizes) / n, } def _summarize(results) -> Dict[str, float]: valid = [r for r in results if r.error is None] if not valid: return { "n_valid": 0, "n_errors": len(results), "ret_R": 0.0, "P": 0.0, "R": 0.0, "F1": 0.0, "leaf_F1": 0.0, "t1": 0.0, "t2": 0.0, "t3": 0.0, "t_total": 0.0, "rw_P": 0.0, "rw_R": 0.0, "rw_F1": 0.0, "rw_avg_pred": 0.0, "rw_avg_gt": 0.0, } n = len(valid) avg = lambda xs: sum(xs) / n pred_sets = [] gold_sets = [] for r in valid: phrase_text = ", ".join((r.rewrite_phrases or [])) pred_sets.append(_parse_tag_set(phrase_text)) gold_sets.append({_canon_tag(t) for t in (r.ground_truth_tags or set()) if t}) rewrite = _set_metrics(pred_sets, gold_sets) t1 = avg([r.stage1_time for r in valid]) t2 = avg([r.stage2_time for r in valid]) t3 = avg([r.stage3_time for r in valid]) return { "n_valid": n, "n_errors": len(results) - n, "ret_R": avg([r.retrieval_recall for r in valid]), "P": avg([r.selection_precision for r in valid]), "R": avg([r.selection_recall for r in valid]), "F1": avg([r.selection_f1 for r in valid]), "leaf_F1": avg([r.leaf_f1 for r in valid]), "t1": t1, "t2": t2, "t3": t3, "t_total": t1 + t2 + t3, "rw_P": rewrite["set_precision"], "rw_R": rewrite["set_recall"], "rw_F1": rewrite["set_f1"], "rw_avg_pred": rewrite["avg_pred_tags"], "rw_avg_gt": rewrite["avg_gold_tags"], } def main() -> int: ap = argparse.ArgumentParser(description="Run n30 rewrite ablation: LLM vs T5, heuristic phrase append off/on") ap.add_argument("--eval-path", type=Path, default=DEFAULT_EVAL_PATH) ap.add_argument("--caption-field", type=str, default="caption_cogvlm") ap.add_argument("--n", type=int, default=30) ap.add_argument("--seed", type=int, default=42) ap.add_argument("--workers", type=int, default=1) ap.add_argument("--mode", type=str, default="chunked_map_union", choices=["single_shot", "chunked_map_union"]) ap.add_argument("--chunk-size", type=int, default=60) ap.add_argument("--per-phrase-k", type=int, default=2) ap.add_argument("--per-phrase-final-k", type=int, default=1) ap.add_argument("--min-why", type=str, default="strong_implied") ap.add_argument("--infer-structural", action="store_true", default=True) ap.add_argument("--no-infer-structural", dest="infer_structural", action="store_false") ap.add_argument("--infer-probe", action="store_true", default=True) ap.add_argument("--no-infer-probe", dest="infer_probe", action="store_false") ap.add_argument("--t5-model-dir", type=str, default="models/finetune/t5-rewrite") ap.add_argument("--t5-num-beams", type=int, default=4) ap.add_argument("--t5-max-new-tokens", type=int, default=128) args = ap.parse_args() eval_path = args.eval_path if args.eval_path.is_absolute() else (REPO_ROOT / args.eval_path).resolve() if not eval_path.is_file(): raise FileNotFoundError(f"Eval path not found: {eval_path}") configs = [ {"rewrite_source": "llm", "append_heuristic_phrases": False}, {"rewrite_source": "llm", "append_heuristic_phrases": True}, {"rewrite_source": "t5", "append_heuristic_phrases": False}, {"rewrite_source": "t5", "append_heuristic_phrases": True}, ] rows: List[Dict[str, Any]] = [] details: Dict[str, Any] = {} for cfg in configs: name = f"{cfg['rewrite_source']}_heur_{'on' if cfg['append_heuristic_phrases'] else 'off'}" print("\n" + "=" * 80) print(f"Running config: {name}") print("=" * 80) results = run_eval( n_samples=args.n, caption_field=args.caption_field, skip_rewrite=False, allow_nsfw=False, mode=args.mode, chunk_size=args.chunk_size, per_phrase_k=args.per_phrase_k, per_phrase_final_k=args.per_phrase_final_k, temperature=0.0, max_tokens=512, verbose=False, shuffle=True, seed=args.seed, workers=args.workers, min_why=None if args.min_why == "none" else args.min_why, eval_path=str(eval_path), expand_implications=False, infer_structural=args.infer_structural, infer_probe=args.infer_probe, rewrite_source=cfg["rewrite_source"], t5_model_dir=args.t5_model_dir, t5_num_beams=args.t5_num_beams, t5_max_new_tokens=args.t5_max_new_tokens, append_heuristic_phrases=cfg["append_heuristic_phrases"], ) summary = _summarize(results) summary.update(cfg) rows.append(summary) details[name] = { "summary": summary, "errors": [ { "id": r.sample_id, "error": r.error, "issues": r.issues, } for r in results if r.error ], } print(json.dumps(summary, ensure_ascii=False, indent=2)) out_dir = REPO_ROOT / "data" / "eval_results" out_dir.mkdir(parents=True, exist_ok=True) ts = datetime.now().strftime("%Y%m%d_%H%M%S") json_out = out_dir / f"rewrite_ablation_n{args.n}_{ts}.json" csv_out = out_dir / f"rewrite_ablation_n{args.n}_{ts}.csv" payload = { "meta": { "timestamp": datetime.now().isoformat(), "eval_path": str(eval_path), "caption_field": args.caption_field, "n": args.n, "seed": args.seed, "workers": args.workers, "mode": args.mode, "chunk_size": args.chunk_size, "per_phrase_k": args.per_phrase_k, "per_phrase_final_k": args.per_phrase_final_k, "min_why": args.min_why, "infer_structural": args.infer_structural, "infer_probe": args.infer_probe, "t5_model_dir": args.t5_model_dir, "t5_num_beams": args.t5_num_beams, "t5_max_new_tokens": args.t5_max_new_tokens, }, "rows": rows, "details": details, } with json_out.open("w", encoding="utf-8") as f: json.dump(payload, f, ensure_ascii=False, indent=2) fieldnames = [ "rewrite_source", "append_heuristic_phrases", "n_valid", "n_errors", "rw_P", "rw_R", "rw_F1", "rw_avg_pred", "rw_avg_gt", "ret_R", "P", "R", "F1", "leaf_F1", "t1", "t2", "t3", "t_total", ] with csv_out.open("w", encoding="utf-8", newline="") as f: w = csv.DictWriter(f, fieldnames=fieldnames) w.writeheader() for row in rows: w.writerow(row) print(f"\nSaved ablation JSON: {json_out}") print(f"Saved ablation CSV: {csv_out}") return 0 if __name__ == "__main__": raise SystemExit(main())