""" Analyze caption-evident tag recall against retrieved tags. Compares tags marked caption-evident to retrieved tags (optionally + structural), with optional implication expansion on both sets. """ from __future__ import annotations import argparse import json from collections import Counter from pathlib import Path from typing import Dict, Iterable, Set import sys _REPO_ROOT = Path(__file__).resolve().parents[1] if str(_REPO_ROOT) not in sys.path: sys.path.insert(0, str(_REPO_ROOT)) from psq_rag.retrieval.state import expand_tags_via_implications def _load_evident(path: Path) -> Dict[int, Set[str]]: by_id: Dict[int, Set[str]] = {} with path.open("r", encoding="utf-8") as f: for line in f: row = json.loads(line) sid = row.get("id") if sid is None: continue tags = set(row.get("tags_ground_truth_expanded") or []) if tags: by_id[int(sid)] = tags return by_id def _load_eval_detail(path: Path) -> Dict[int, dict]: rows = {} with path.open("r", encoding="utf-8") as f: for line in f: row = json.loads(line) if row.get("_meta"): continue rows[int(row["sample_id"])] = row return rows def _expand(tags: Iterable[str]) -> Set[str]: expanded, _ = expand_tags_via_implications(set(tags)) return expanded def main() -> int: ap = argparse.ArgumentParser(description="Caption-evident audit vs retrieval.") ap.add_argument("--evident", type=str, required=True, help="Caption-evident JSONL (tags_ground_truth_expanded set to evident tags).") ap.add_argument("--detail", type=str, required=True, help="Eval detail JSONL (from eval_pipeline.py).") ap.add_argument("--no-structural", action="store_true", help="Do not count structural tags as retrieved.") ap.add_argument("--expand-implications", action="store_true", help="Expand both evident and retrieved tags via implications.") args = ap.parse_args() evident_by_id = _load_evident(Path(args.evident)) detail_by_id = _load_eval_detail(Path(args.detail)) hit_counter = Counter() miss_counter = Counter() present_counter = Counter() print("ID,evident,retrieved,overlap,recall_evident,precision_evident,missing_evident,extra_not_evident,complete_overlap") total_evident = total_retrieved = total_overlap = 0 for sid in sorted(evident_by_id): ev = set(evident_by_id[sid]) detail = detail_by_id.get(sid) if detail is None: continue retrieved = set(detail.get("retrieved_tags", [])) if not args.no_structural: retrieved |= set(detail.get("structural_tags", [])) if args.expand_implications: ev = _expand(ev) retrieved = _expand(retrieved) overlap = ev & retrieved missing = ev - retrieved extra = retrieved - ev for t in ev: present_counter[t] += 1 if t in retrieved: hit_counter[t] += 1 else: miss_counter[t] += 1 recall = len(overlap) / len(ev) if ev else 0.0 precision = len(overlap) / len(retrieved) if retrieved else 0.0 total_evident += len(ev) total_retrieved += len(retrieved) total_overlap += len(overlap) complete = len(missing) == 0 print(f"{sid},{len(ev)},{len(retrieved)},{len(overlap)},{recall:.3f},{precision:.3f},{len(missing)},{len(extra)},{complete}") print(f"TOTAL,{total_evident},{total_retrieved},{total_overlap},{(total_overlap/total_evident):.3f},{(total_overlap/total_retrieved):.3f},{total_evident-total_overlap},{total_retrieved-total_overlap},N/A") print("\nMOST MISSED (caption-evident tags not retrieved):") for tag, cnt in miss_counter.most_common(20): present = present_counter[tag] print(f" {tag:25s} missed {cnt}/{present} (present {present}/10)") print("\nMOST FOUND (caption-evident tags retrieved):") for tag, cnt in hit_counter.most_common(20): present = present_counter[tag] print(f" {tag:25s} found {cnt}/{present} (present {present}/10)") always_found = [t for t, c in hit_counter.items() if c == present_counter[t]] if always_found: print("\nALWAYS FOUND WHEN EVIDENT:") for t in sorted(always_found): print(f" {t}") return 0 if __name__ == "__main__": raise SystemExit(main())