Prompt_Squirrel_RAG / scripts /eval_probe_reliability.py
Food Desert
Consolidate probe configs and eval artifacts on main
6e50f4d
"""Measure per-tag LLM reliability for probe tags (selection-only, no retrieval).
Process:
- Use caption as query text.
- Ask Stage 3 selector to choose among a fixed probe-tag candidate list.
- Compare selected tags to ground-truth tag presence.
This estimates whether a probe tag is worth asking the LLM about.
Outputs (overwrite by suffix):
- data/analysis/probe_reliability_<suffix>.csv
- data/analysis/probe_reliability_<suffix>.json
"""
from __future__ import annotations
import argparse
import csv
import json
import random
import sys
from collections import Counter, defaultdict
from pathlib import Path
from typing import Dict, List, Set, Tuple
REPO = Path(__file__).resolve().parents[1]
if str(REPO) not in sys.path:
sys.path.insert(0, str(REPO))
os_chdir = __import__("os").chdir
os_chdir(REPO)
EVAL_DATA_RAW = REPO / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl"
PROBE_SET_CSV = REPO / "data" / "simplified_probe_tags.csv"
OUT_DIR = REPO / "data" / "analysis"
def _flatten_ground_truth(tags_categorized_str: str) -> Set[str]:
if not tags_categorized_str:
return set()
try:
cats = json.loads(tags_categorized_str)
except Exception:
return set()
out: Set[str] = set()
if isinstance(cats, dict):
for vals in cats.values():
if isinstance(vals, list):
for t in vals:
if isinstance(t, str):
out.add(t.strip())
return out
def _metrics(tp: int, fp: int, fn: int) -> Tuple[float, float, float]:
p = tp / (tp + fp) if (tp + fp) > 0 else 0.0
r = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = (2 * p * r / (p + r)) if (p + r) > 0 else 0.0
return p, r, f1
def main() -> None:
ap = argparse.ArgumentParser(description="Evaluate per-tag probe reliability (selection-only).")
ap.add_argument("--probe-csv", type=Path, default=PROBE_SET_CSV)
ap.add_argument("--data", type=Path, default=EVAL_DATA_RAW)
ap.add_argument("--caption-field", default="caption_cogvlm")
ap.add_argument("--n", type=int, default=10, help="Number of samples.")
ap.add_argument("--seed", type=int, default=42)
ap.add_argument("--suffix", default="sanity10")
ap.add_argument("--retries", type=int, default=2)
ap.add_argument("--temperature", type=float, default=0.0)
ap.add_argument("--max-tokens", type=int, default=700)
ap.add_argument("--workers-note", default="sequential", help="for logging only; this script runs sequentially.")
ap.add_argument("--verbose", action="store_true")
args = ap.parse_args()
if not args.probe_csv.is_file():
raise FileNotFoundError(f"Probe CSV not found: {args.probe_csv}")
if not args.data.is_file():
raise FileNotFoundError(f"Eval data not found: {args.data}")
from psq_rag.llm.select import llm_select_indices, WHY_RANK
# Load probe tags from selected_initial list.
probe_rows = list(csv.DictReader(args.probe_csv.open("r", encoding="utf-8", newline="")))
probe_rows = [r for r in probe_rows if (r.get("selected_initial") or "0").strip() in {"1", "true", "True"}]
probe_tags = [r["tag"] for r in probe_rows if r.get("tag")]
if not probe_tags:
raise RuntimeError("No probe tags found with selected_initial=1.")
tag_meta = {r["tag"]: r for r in probe_rows}
# Load and sample data.
all_rows = []
with args.data.open("r", encoding="utf-8") as f:
for line in f:
row = json.loads(line)
cap = (row.get(args.caption_field) or "").strip()
if not cap:
continue
gt = _flatten_ground_truth(row.get("tags_ground_truth_categorized", ""))
if not gt:
continue
all_rows.append({"id": row.get("id"), "caption": cap, "gt": gt})
if not all_rows:
raise RuntimeError(f"No usable rows in {args.data}.")
rnd = random.Random(args.seed)
rnd.shuffle(all_rows)
samples = all_rows[: max(1, min(args.n, len(all_rows)))]
# Tag-level confusion by threshold.
thresholds = {
"explicit": {"max_rank": WHY_RANK["explicit"]},
"strong": {"max_rank": WHY_RANK["strong_implied"]}, # explicit + strong_implied
}
conf = {th: {t: {"tp": 0, "fp": 0, "fn": 0, "tn": 0} for t in probe_tags} for th in thresholds}
overall = {th: {"tp": 0, "fp": 0, "fn": 0, "tn": 0} for th in thresholds}
diag_rows = []
parse_fail_count = 0
call_exhaust_count = 0
def _log(msg: str) -> None:
if args.verbose:
print(msg)
for i, s in enumerate(samples):
caption = s["caption"]
gt = s["gt"]
# IMPORTANT: per_phrase_k controls per-call budget when candidate strings have no sources.
# Set it to len(probe_tags) so the model can choose all true tags if needed.
idxs, tag_why, diag = llm_select_indices(
query_text=caption,
candidates=probe_tags,
max_pick=len(probe_tags),
log=_log,
retries=args.retries,
mode="single_shot",
chunk_size=max(1, len(probe_tags)),
per_phrase_k=max(1, len(probe_tags)),
temperature=args.temperature,
max_tokens=args.max_tokens,
return_metadata=True,
return_diagnostics=True,
min_why=None,
)
# Map selected indices to tags.
selected_all = set()
for idx in idxs:
if 0 <= idx < len(probe_tags):
selected_all.add(probe_tags[idx])
if float(diag.get("attempt_failure_rate", 0.0)) > 0.0:
parse_fail_count += 1
if float(diag.get("call_exhaustion_rate", 0.0)) > 0.0:
call_exhaust_count += 1
diag_rows.append(
{
"sample_id": s["id"],
"selected_any": len(selected_all),
"attempt_failure_rate": float(diag.get("attempt_failure_rate", 0.0)),
"call_exhaustion_rate": float(diag.get("call_exhaustion_rate", 0.0)),
}
)
# Apply thresholds by why rank.
for th, cfg in thresholds.items():
max_rank = cfg["max_rank"]
selected = set()
for t in selected_all:
why = tag_why.get(t, "other")
if WHY_RANK.get(why, 999) <= max_rank:
selected.add(t)
for t in probe_tags:
gt_pos = t in gt
pred_pos = t in selected
if gt_pos and pred_pos:
conf[th][t]["tp"] += 1
overall[th]["tp"] += 1
elif (not gt_pos) and pred_pos:
conf[th][t]["fp"] += 1
overall[th]["fp"] += 1
elif gt_pos and (not pred_pos):
conf[th][t]["fn"] += 1
overall[th]["fn"] += 1
else:
conf[th][t]["tn"] += 1
overall[th]["tn"] += 1
# Per-tag reliability table.
out_rows = []
for t in probe_tags:
r = {"tag": t}
r["bundle"] = tag_meta[t].get("bundle", "")
r["needs_glossary"] = tag_meta[t].get("needs_glossary", "")
support_pos = conf["strong"][t]["tp"] + conf["strong"][t]["fn"]
support_neg = conf["strong"][t]["tn"] + conf["strong"][t]["fp"]
r["support_pos"] = str(support_pos)
r["support_neg"] = str(support_neg)
for th in ("explicit", "strong"):
tp = conf[th][t]["tp"]
fp = conf[th][t]["fp"]
fn = conf[th][t]["fn"]
p, rc, f1 = _metrics(tp, fp, fn)
r[f"tp_{th}"] = str(tp)
r[f"fp_{th}"] = str(fp)
r[f"fn_{th}"] = str(fn)
r[f"precision_{th}"] = f"{p:.6f}"
r[f"recall_{th}"] = f"{rc:.6f}"
r[f"f1_{th}"] = f"{f1:.6f}"
out_rows.append(r)
out_rows.sort(
key=lambda x: (float(x["f1_strong"]), int(x["support_pos"]), -int(x["needs_glossary"] or "0")),
reverse=True,
)
# Overall metrics.
overall_metrics = {}
for th in ("explicit", "strong"):
tp = overall[th]["tp"]
fp = overall[th]["fp"]
fn = overall[th]["fn"]
p, rc, f1 = _metrics(tp, fp, fn)
overall_metrics[th] = {
"tp": tp,
"fp": fp,
"fn": fn,
"precision": round(p, 6),
"recall": round(rc, 6),
"f1": round(f1, 6),
}
suffix = args.suffix.strip() or f"n{len(samples)}"
out_csv = OUT_DIR / f"probe_reliability_{suffix}.csv"
out_json = OUT_DIR / f"probe_reliability_{suffix}.json"
OUT_DIR.mkdir(parents=True, exist_ok=True)
with out_csv.open("w", encoding="utf-8", newline="") as f:
writer = csv.DictWriter(
f,
fieldnames=[
"tag",
"bundle",
"needs_glossary",
"support_pos",
"support_neg",
"tp_explicit",
"fp_explicit",
"fn_explicit",
"precision_explicit",
"recall_explicit",
"f1_explicit",
"tp_strong",
"fp_strong",
"fn_strong",
"precision_strong",
"recall_strong",
"f1_strong",
],
)
writer.writeheader()
writer.writerows(out_rows)
summary = {
"settings": {
"n": len(samples),
"seed": args.seed,
"caption_field": args.caption_field,
"probe_count": len(probe_tags),
"retries": args.retries,
"temperature": args.temperature,
"max_tokens": args.max_tokens,
"model_env": __import__("os").environ.get("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct"),
},
"overall_metrics": overall_metrics,
"diagnostics": {
"samples_with_attempt_failures": parse_fail_count,
"samples_with_call_exhaustion": call_exhaust_count,
"avg_attempt_failure_rate": sum(d["attempt_failure_rate"] for d in diag_rows) / len(diag_rows),
"avg_call_exhaustion_rate": sum(d["call_exhaustion_rate"] for d in diag_rows) / len(diag_rows),
},
"top_tags_by_f1_strong": out_rows[:20],
"outputs": {
"csv": str(out_csv),
"json": str(out_json),
},
}
with out_json.open("w", encoding="utf-8") as f:
json.dump(summary, f, indent=2, ensure_ascii=False)
print(f"Samples evaluated: {len(samples)}")
print(f"Probe tags evaluated: {len(probe_tags)}")
print(f"Overall strong: P={overall_metrics['strong']['precision']:.4f} "
f"R={overall_metrics['strong']['recall']:.4f} F1={overall_metrics['strong']['f1']:.4f}")
print(f"Diagnostics: attempt_fail_samples={parse_fail_count}, call_exhaust_samples={call_exhaust_count}")
print(f"Outputs: {out_csv}, {out_json}")
if __name__ == "__main__":
main()