Prompt_Squirrel_RAG / scripts /eval_rewrite_ablation.py
Food Desert
Roll out T5 rewrite updates, tooling, docs, and artifact ignore rules
34c53b5
from __future__ import annotations
import argparse
import csv
import json
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Sequence, Set
from scripts.eval_pipeline import run_eval
REPO_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_EVAL_PATH = REPO_ROOT / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000_caption_evident_n30.jsonl"
def _canon_tag(tag: str) -> str:
t = " ".join(str(tag or "").strip().split()).lower()
return t.replace(" ", "_").replace("\\(", "(").replace("\\)", ")")
def _parse_tag_set(text: str) -> Set[str]:
out: Set[str] = set()
for raw in (text or "").split(","):
t = _canon_tag(raw)
if t:
out.add(t)
return out
def _set_metrics(pred_sets: Sequence[Set[str]], gold_sets: Sequence[Set[str]]) -> Dict[str, float]:
n = len(pred_sets)
if n == 0:
return {
"set_precision": 0.0,
"set_recall": 0.0,
"set_f1": 0.0,
"avg_pred_tags": 0.0,
"avg_gold_tags": 0.0,
}
p_vals: List[float] = []
r_vals: List[float] = []
f_vals: List[float] = []
pred_sizes: List[int] = []
gold_sizes: List[int] = []
for pset, gset in zip(pred_sets, gold_sets):
pred_sizes.append(len(pset))
gold_sizes.append(len(gset))
if not pset or not gset:
p_vals.append(0.0 if pset or gset else 1.0)
r_vals.append(0.0 if pset or gset else 1.0)
f_vals.append(0.0 if pset or gset else 1.0)
continue
tp = len(pset & gset)
p = tp / len(pset)
r = tp / len(gset)
f1 = (2 * p * r / (p + r)) if (p + r) > 0 else 0.0
p_vals.append(p)
r_vals.append(r)
f_vals.append(f1)
return {
"set_precision": sum(p_vals) / n,
"set_recall": sum(r_vals) / n,
"set_f1": sum(f_vals) / n,
"avg_pred_tags": sum(pred_sizes) / n,
"avg_gold_tags": sum(gold_sizes) / n,
}
def _summarize(results) -> Dict[str, float]:
valid = [r for r in results if r.error is None]
if not valid:
return {
"n_valid": 0,
"n_errors": len(results),
"ret_R": 0.0,
"P": 0.0,
"R": 0.0,
"F1": 0.0,
"leaf_F1": 0.0,
"t1": 0.0,
"t2": 0.0,
"t3": 0.0,
"t_total": 0.0,
"rw_P": 0.0,
"rw_R": 0.0,
"rw_F1": 0.0,
"rw_avg_pred": 0.0,
"rw_avg_gt": 0.0,
}
n = len(valid)
avg = lambda xs: sum(xs) / n
pred_sets = []
gold_sets = []
for r in valid:
phrase_text = ", ".join((r.rewrite_phrases or []))
pred_sets.append(_parse_tag_set(phrase_text))
gold_sets.append({_canon_tag(t) for t in (r.ground_truth_tags or set()) if t})
rewrite = _set_metrics(pred_sets, gold_sets)
t1 = avg([r.stage1_time for r in valid])
t2 = avg([r.stage2_time for r in valid])
t3 = avg([r.stage3_time for r in valid])
return {
"n_valid": n,
"n_errors": len(results) - n,
"ret_R": avg([r.retrieval_recall for r in valid]),
"P": avg([r.selection_precision for r in valid]),
"R": avg([r.selection_recall for r in valid]),
"F1": avg([r.selection_f1 for r in valid]),
"leaf_F1": avg([r.leaf_f1 for r in valid]),
"t1": t1,
"t2": t2,
"t3": t3,
"t_total": t1 + t2 + t3,
"rw_P": rewrite["set_precision"],
"rw_R": rewrite["set_recall"],
"rw_F1": rewrite["set_f1"],
"rw_avg_pred": rewrite["avg_pred_tags"],
"rw_avg_gt": rewrite["avg_gold_tags"],
}
def main() -> int:
ap = argparse.ArgumentParser(description="Run n30 rewrite ablation: LLM vs T5, heuristic phrase append off/on")
ap.add_argument("--eval-path", type=Path, default=DEFAULT_EVAL_PATH)
ap.add_argument("--caption-field", type=str, default="caption_cogvlm")
ap.add_argument("--n", type=int, default=30)
ap.add_argument("--seed", type=int, default=42)
ap.add_argument("--workers", type=int, default=1)
ap.add_argument("--mode", type=str, default="chunked_map_union", choices=["single_shot", "chunked_map_union"])
ap.add_argument("--chunk-size", type=int, default=60)
ap.add_argument("--per-phrase-k", type=int, default=2)
ap.add_argument("--per-phrase-final-k", type=int, default=1)
ap.add_argument("--min-why", type=str, default="strong_implied")
ap.add_argument("--infer-structural", action="store_true", default=True)
ap.add_argument("--no-infer-structural", dest="infer_structural", action="store_false")
ap.add_argument("--infer-probe", action="store_true", default=True)
ap.add_argument("--no-infer-probe", dest="infer_probe", action="store_false")
ap.add_argument("--t5-model-dir", type=str, default="models/finetune/t5-rewrite")
ap.add_argument("--t5-num-beams", type=int, default=4)
ap.add_argument("--t5-max-new-tokens", type=int, default=128)
args = ap.parse_args()
eval_path = args.eval_path if args.eval_path.is_absolute() else (REPO_ROOT / args.eval_path).resolve()
if not eval_path.is_file():
raise FileNotFoundError(f"Eval path not found: {eval_path}")
configs = [
{"rewrite_source": "llm", "append_heuristic_phrases": False},
{"rewrite_source": "llm", "append_heuristic_phrases": True},
{"rewrite_source": "t5", "append_heuristic_phrases": False},
{"rewrite_source": "t5", "append_heuristic_phrases": True},
]
rows: List[Dict[str, Any]] = []
details: Dict[str, Any] = {}
for cfg in configs:
name = f"{cfg['rewrite_source']}_heur_{'on' if cfg['append_heuristic_phrases'] else 'off'}"
print("\n" + "=" * 80)
print(f"Running config: {name}")
print("=" * 80)
results = run_eval(
n_samples=args.n,
caption_field=args.caption_field,
skip_rewrite=False,
allow_nsfw=False,
mode=args.mode,
chunk_size=args.chunk_size,
per_phrase_k=args.per_phrase_k,
per_phrase_final_k=args.per_phrase_final_k,
temperature=0.0,
max_tokens=512,
verbose=False,
shuffle=True,
seed=args.seed,
workers=args.workers,
min_why=None if args.min_why == "none" else args.min_why,
eval_path=str(eval_path),
expand_implications=False,
infer_structural=args.infer_structural,
infer_probe=args.infer_probe,
rewrite_source=cfg["rewrite_source"],
t5_model_dir=args.t5_model_dir,
t5_num_beams=args.t5_num_beams,
t5_max_new_tokens=args.t5_max_new_tokens,
append_heuristic_phrases=cfg["append_heuristic_phrases"],
)
summary = _summarize(results)
summary.update(cfg)
rows.append(summary)
details[name] = {
"summary": summary,
"errors": [
{
"id": r.sample_id,
"error": r.error,
"issues": r.issues,
}
for r in results
if r.error
],
}
print(json.dumps(summary, ensure_ascii=False, indent=2))
out_dir = REPO_ROOT / "data" / "eval_results"
out_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
json_out = out_dir / f"rewrite_ablation_n{args.n}_{ts}.json"
csv_out = out_dir / f"rewrite_ablation_n{args.n}_{ts}.csv"
payload = {
"meta": {
"timestamp": datetime.now().isoformat(),
"eval_path": str(eval_path),
"caption_field": args.caption_field,
"n": args.n,
"seed": args.seed,
"workers": args.workers,
"mode": args.mode,
"chunk_size": args.chunk_size,
"per_phrase_k": args.per_phrase_k,
"per_phrase_final_k": args.per_phrase_final_k,
"min_why": args.min_why,
"infer_structural": args.infer_structural,
"infer_probe": args.infer_probe,
"t5_model_dir": args.t5_model_dir,
"t5_num_beams": args.t5_num_beams,
"t5_max_new_tokens": args.t5_max_new_tokens,
},
"rows": rows,
"details": details,
}
with json_out.open("w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
fieldnames = [
"rewrite_source",
"append_heuristic_phrases",
"n_valid",
"n_errors",
"rw_P",
"rw_R",
"rw_F1",
"rw_avg_pred",
"rw_avg_gt",
"ret_R",
"P",
"R",
"F1",
"leaf_F1",
"t1",
"t2",
"t3",
"t_total",
]
with csv_out.open("w", encoding="utf-8", newline="") as f:
w = csv.DictWriter(f, fieldnames=fieldnames)
w.writeheader()
for row in rows:
w.writerow(row)
print(f"\nSaved ablation JSON: {json_out}")
print(f"Saved ablation CSV: {csv_out}")
return 0
if __name__ == "__main__":
raise SystemExit(main())