Prompt_Squirrel_RAG / scripts /eval_t5_rewrite_only_models.py
Food Desert
Roll out T5 rewrite updates, tooling, docs, and artifact ignore rules
34c53b5
from __future__ import annotations
import argparse
import json
import os
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, Iterable, List, Sequence, Set
from psq_rag.llm.rewrite_local_t5 import local_t5_rewrite_prompt
from psq_rag.pipeline.preproc import extract_user_provided_tags_upto_3_words
REPO_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_SAMPLES = (
REPO_ROOT
/ "data"
/ "eval_samples"
/ "e621_sfw_sample_1000_seed123_buffer10000_caption_evident_n30.jsonl"
)
def _canon_tag(tag: str) -> str:
t = " ".join(str(tag or "").strip().split()).lower()
return t.replace(" ", "_").replace("\\(", "(").replace("\\)", ")")
def _parse_tag_set(text: str) -> Set[str]:
out: Set[str] = set()
for raw in (text or "").split(","):
t = _canon_tag(raw)
if t:
out.add(t)
return out
def _set_metrics(pred_sets: Sequence[Set[str]], gold_sets: Sequence[Set[str]]) -> Dict[str, float]:
n = len(pred_sets)
if n == 0:
return {
"n": 0,
"set_precision": 0.0,
"set_recall": 0.0,
"set_f1": 0.0,
"exact_set_match": 0.0,
"avg_pred_tags": 0.0,
"avg_gold_tags": 0.0,
}
p_vals: List[float] = []
r_vals: List[float] = []
f_vals: List[float] = []
exact = 0
pred_sizes: List[int] = []
gold_sizes: List[int] = []
for pset, gset in zip(pred_sets, gold_sets):
pred_sizes.append(len(pset))
gold_sizes.append(len(gset))
if pset == gset:
exact += 1
if not pset and not gset:
p_vals.append(1.0)
r_vals.append(1.0)
f_vals.append(1.0)
continue
if not pset or not gset:
p_vals.append(0.0)
r_vals.append(0.0)
f_vals.append(0.0)
continue
tp = len(pset & gset)
p = tp / len(pset)
r = tp / len(gset)
f = (2 * p * r / (p + r)) if (p + r) > 0 else 0.0
p_vals.append(p)
r_vals.append(r)
f_vals.append(f)
return {
"n": float(n),
"set_precision": sum(p_vals) / n,
"set_recall": sum(r_vals) / n,
"set_f1": sum(f_vals) / n,
"exact_set_match": exact / n,
"avg_pred_tags": sum(pred_sizes) / n,
"avg_gold_tags": sum(gold_sizes) / n,
}
@dataclass
class SampleRow:
sample_id: int
caption: str
gold_tags: Set[str]
def _load_rows(samples_path: Path, caption_field: str) -> List[SampleRow]:
rows: List[SampleRow] = []
with samples_path.open("r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
obj = json.loads(line)
if obj.get("_meta"):
continue
cap = str(obj.get(caption_field, "") or "").strip()
if not cap:
continue
gt: Set[str] = set()
expanded = obj.get("tags_ground_truth_expanded")
if isinstance(expanded, list):
for t in expanded:
c = _canon_tag(str(t))
if c:
gt.add(c)
if not gt:
cat = obj.get("tags_ground_truth_categorized")
if isinstance(cat, str):
try:
cat = json.loads(cat)
except Exception:
cat = {}
if isinstance(cat, dict):
for vals in cat.values():
if isinstance(vals, list):
for t in vals:
c = _canon_tag(str(t))
if c:
gt.add(c)
rows.append(
SampleRow(
sample_id=int(obj.get("id", -1)),
caption=cap,
gold_tags=gt,
)
)
return rows
def _eval_model(
model_dir: Path,
rows: Sequence[SampleRow],
num_beams: int,
max_new_tokens: int,
max_source_length: int,
) -> Dict[str, Dict[str, float]]:
pred_no_heur: List[Set[str]] = []
pred_with_heur: List[Set[str]] = []
gold_sets: List[Set[str]] = []
for r in rows:
out = local_t5_rewrite_prompt(
r.caption,
log=lambda _msg: None,
model_dir=str(model_dir),
num_beams=num_beams,
max_new_tokens=max_new_tokens,
max_source_length=max_source_length,
)
p0 = _parse_tag_set(out)
p1 = set(p0)
if out.strip():
for term in extract_user_provided_tags_upto_3_words(r.caption):
c = _canon_tag(term)
if c:
p1.add(c)
pred_no_heur.append(p0)
pred_with_heur.append(p1)
gold_sets.append(r.gold_tags)
return {
"t5_local_rewrite_no_heur": _set_metrics(pred_no_heur, gold_sets),
"t5_local_rewrite_with_heur": _set_metrics(pred_with_heur, gold_sets),
}
def _iter_model_dirs(raw_models: Iterable[str]) -> List[Path]:
dirs: List[Path] = []
for raw in raw_models:
p = Path(raw)
if not p.is_absolute():
p = (REPO_ROOT / p).resolve()
dirs.append(p)
return dirs
def main() -> int:
ap = argparse.ArgumentParser(description="Rewrite-only T5 evaluation on caption-evident n30 set.")
ap.add_argument("--samples", type=Path, default=DEFAULT_SAMPLES)
ap.add_argument("--caption-field", type=str, default="caption_cogvlm")
ap.add_argument("--limit", type=int, default=30)
ap.add_argument("--num-beams", type=int, default=4)
ap.add_argument("--max-new-tokens", type=int, default=128)
ap.add_argument("--max-source-length", type=int, default=160)
ap.add_argument("--model-dir", action="append", required=True,
help="Model directory; repeat for multiple models.")
ap.add_argument("--out-json", type=Path, default=REPO_ROOT / "data" / "analysis")
args = ap.parse_args()
samples_path = args.samples if args.samples.is_absolute() else (REPO_ROOT / args.samples).resolve()
if not samples_path.is_file():
raise FileNotFoundError(f"Samples file not found: {samples_path}")
model_dirs = _iter_model_dirs(args.model_dir)
for d in model_dirs:
if not d.is_dir():
raise FileNotFoundError(f"Model directory not found: {d}")
if not (d / "model.safetensors").is_file():
raise FileNotFoundError(f"model.safetensors missing in: {d}")
rows = _load_rows(samples_path, args.caption_field)
if args.limit > 0:
rows = rows[: min(args.limit, len(rows))]
result_rows = []
for d in model_dirs:
metrics = _eval_model(
d,
rows=rows,
num_beams=max(1, args.num_beams),
max_new_tokens=max(8, args.max_new_tokens),
max_source_length=max(16, args.max_source_length),
)
result_rows.append(
{
"model_dir": str(d),
"metrics": metrics,
}
)
no_h = metrics["t5_local_rewrite_no_heur"]
with_h = metrics["t5_local_rewrite_with_heur"]
print(
f"{d.name}: "
f"no_heur R={no_h['set_recall']:.4f} F1={no_h['set_f1']:.4f} "
f"| with_heur R={with_h['set_recall']:.4f} F1={with_h['set_f1']:.4f}"
)
out_base = args.out_json if args.out_json.is_absolute() else (REPO_ROOT / args.out_json).resolve()
if out_base.suffix.lower() == ".json":
out_path = out_base
out_path.parent.mkdir(parents=True, exist_ok=True)
else:
out_base.mkdir(parents=True, exist_ok=True)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
out_path = out_base / f"rewrite_only_compare_n30_t5_sweep_{ts}.json"
payload = {
"meta": {
"timestamp": datetime.now().isoformat(),
"samples_path": str(samples_path),
"caption_field": args.caption_field,
"n_samples": len(rows),
"num_beams": args.num_beams,
"max_new_tokens": args.max_new_tokens,
"max_source_length": args.max_source_length,
"cuda_visible_devices": os.environ.get("CUDA_VISIBLE_DEVICES"),
},
"rows": result_rows,
}
out_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"Saved: {out_path}")
return 0
if __name__ == "__main__":
raise SystemExit(main())