Prompt_Squirrel_RAG / scripts /analyze_caption_evident_audit.py
Food Desert
Add eval audit tools, caption-evident set, and logging
73f56cf
"""
Analyze caption-evident tag recall against retrieved tags.
Compares tags marked caption-evident to retrieved tags (optionally + structural),
with optional implication expansion on both sets.
"""
from __future__ import annotations
import argparse
import json
from collections import Counter
from pathlib import Path
from typing import Dict, Iterable, Set
import sys
_REPO_ROOT = Path(__file__).resolve().parents[1]
if str(_REPO_ROOT) not in sys.path:
sys.path.insert(0, str(_REPO_ROOT))
from psq_rag.retrieval.state import expand_tags_via_implications
def _load_evident(path: Path) -> Dict[int, Set[str]]:
by_id: Dict[int, Set[str]] = {}
with path.open("r", encoding="utf-8") as f:
for line in f:
row = json.loads(line)
sid = row.get("id")
if sid is None:
continue
tags = set(row.get("tags_ground_truth_expanded") or [])
if tags:
by_id[int(sid)] = tags
return by_id
def _load_eval_detail(path: Path) -> Dict[int, dict]:
rows = {}
with path.open("r", encoding="utf-8") as f:
for line in f:
row = json.loads(line)
if row.get("_meta"):
continue
rows[int(row["sample_id"])] = row
return rows
def _expand(tags: Iterable[str]) -> Set[str]:
expanded, _ = expand_tags_via_implications(set(tags))
return expanded
def main() -> int:
ap = argparse.ArgumentParser(description="Caption-evident audit vs retrieval.")
ap.add_argument("--evident", type=str, required=True,
help="Caption-evident JSONL (tags_ground_truth_expanded set to evident tags).")
ap.add_argument("--detail", type=str, required=True,
help="Eval detail JSONL (from eval_pipeline.py).")
ap.add_argument("--no-structural", action="store_true",
help="Do not count structural tags as retrieved.")
ap.add_argument("--expand-implications", action="store_true",
help="Expand both evident and retrieved tags via implications.")
args = ap.parse_args()
evident_by_id = _load_evident(Path(args.evident))
detail_by_id = _load_eval_detail(Path(args.detail))
hit_counter = Counter()
miss_counter = Counter()
present_counter = Counter()
print("ID,evident,retrieved,overlap,recall_evident,precision_evident,missing_evident,extra_not_evident,complete_overlap")
total_evident = total_retrieved = total_overlap = 0
for sid in sorted(evident_by_id):
ev = set(evident_by_id[sid])
detail = detail_by_id.get(sid)
if detail is None:
continue
retrieved = set(detail.get("retrieved_tags", []))
if not args.no_structural:
retrieved |= set(detail.get("structural_tags", []))
if args.expand_implications:
ev = _expand(ev)
retrieved = _expand(retrieved)
overlap = ev & retrieved
missing = ev - retrieved
extra = retrieved - ev
for t in ev:
present_counter[t] += 1
if t in retrieved:
hit_counter[t] += 1
else:
miss_counter[t] += 1
recall = len(overlap) / len(ev) if ev else 0.0
precision = len(overlap) / len(retrieved) if retrieved else 0.0
total_evident += len(ev)
total_retrieved += len(retrieved)
total_overlap += len(overlap)
complete = len(missing) == 0
print(f"{sid},{len(ev)},{len(retrieved)},{len(overlap)},{recall:.3f},{precision:.3f},{len(missing)},{len(extra)},{complete}")
print(f"TOTAL,{total_evident},{total_retrieved},{total_overlap},{(total_overlap/total_evident):.3f},{(total_overlap/total_retrieved):.3f},{total_evident-total_overlap},{total_retrieved-total_overlap},N/A")
print("\nMOST MISSED (caption-evident tags not retrieved):")
for tag, cnt in miss_counter.most_common(20):
present = present_counter[tag]
print(f" {tag:25s} missed {cnt}/{present} (present {present}/10)")
print("\nMOST FOUND (caption-evident tags retrieved):")
for tag, cnt in hit_counter.most_common(20):
present = present_counter[tag]
print(f" {tag:25s} found {cnt}/{present} (present {present}/10)")
always_found = [t for t, c in hit_counter.items() if c == present_counter[t]]
if always_found:
print("\nALWAYS FOUND WHEN EVIDENT:")
for t in sorted(always_found):
print(f" {t}")
return 0
if __name__ == "__main__":
raise SystemExit(main())