Spaces:
Running
Running
File size: 4,573 Bytes
73f56cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | """
Analyze caption-evident tag recall against retrieved tags.
Compares tags marked caption-evident to retrieved tags (optionally + structural),
with optional implication expansion on both sets.
"""
from __future__ import annotations
import argparse
import json
from collections import Counter
from pathlib import Path
from typing import Dict, Iterable, Set
import sys
_REPO_ROOT = Path(__file__).resolve().parents[1]
if str(_REPO_ROOT) not in sys.path:
sys.path.insert(0, str(_REPO_ROOT))
from psq_rag.retrieval.state import expand_tags_via_implications
def _load_evident(path: Path) -> Dict[int, Set[str]]:
by_id: Dict[int, Set[str]] = {}
with path.open("r", encoding="utf-8") as f:
for line in f:
row = json.loads(line)
sid = row.get("id")
if sid is None:
continue
tags = set(row.get("tags_ground_truth_expanded") or [])
if tags:
by_id[int(sid)] = tags
return by_id
def _load_eval_detail(path: Path) -> Dict[int, dict]:
rows = {}
with path.open("r", encoding="utf-8") as f:
for line in f:
row = json.loads(line)
if row.get("_meta"):
continue
rows[int(row["sample_id"])] = row
return rows
def _expand(tags: Iterable[str]) -> Set[str]:
expanded, _ = expand_tags_via_implications(set(tags))
return expanded
def main() -> int:
ap = argparse.ArgumentParser(description="Caption-evident audit vs retrieval.")
ap.add_argument("--evident", type=str, required=True,
help="Caption-evident JSONL (tags_ground_truth_expanded set to evident tags).")
ap.add_argument("--detail", type=str, required=True,
help="Eval detail JSONL (from eval_pipeline.py).")
ap.add_argument("--no-structural", action="store_true",
help="Do not count structural tags as retrieved.")
ap.add_argument("--expand-implications", action="store_true",
help="Expand both evident and retrieved tags via implications.")
args = ap.parse_args()
evident_by_id = _load_evident(Path(args.evident))
detail_by_id = _load_eval_detail(Path(args.detail))
hit_counter = Counter()
miss_counter = Counter()
present_counter = Counter()
print("ID,evident,retrieved,overlap,recall_evident,precision_evident,missing_evident,extra_not_evident,complete_overlap")
total_evident = total_retrieved = total_overlap = 0
for sid in sorted(evident_by_id):
ev = set(evident_by_id[sid])
detail = detail_by_id.get(sid)
if detail is None:
continue
retrieved = set(detail.get("retrieved_tags", []))
if not args.no_structural:
retrieved |= set(detail.get("structural_tags", []))
if args.expand_implications:
ev = _expand(ev)
retrieved = _expand(retrieved)
overlap = ev & retrieved
missing = ev - retrieved
extra = retrieved - ev
for t in ev:
present_counter[t] += 1
if t in retrieved:
hit_counter[t] += 1
else:
miss_counter[t] += 1
recall = len(overlap) / len(ev) if ev else 0.0
precision = len(overlap) / len(retrieved) if retrieved else 0.0
total_evident += len(ev)
total_retrieved += len(retrieved)
total_overlap += len(overlap)
complete = len(missing) == 0
print(f"{sid},{len(ev)},{len(retrieved)},{len(overlap)},{recall:.3f},{precision:.3f},{len(missing)},{len(extra)},{complete}")
print(f"TOTAL,{total_evident},{total_retrieved},{total_overlap},{(total_overlap/total_evident):.3f},{(total_overlap/total_retrieved):.3f},{total_evident-total_overlap},{total_retrieved-total_overlap},N/A")
print("\nMOST MISSED (caption-evident tags not retrieved):")
for tag, cnt in miss_counter.most_common(20):
present = present_counter[tag]
print(f" {tag:25s} missed {cnt}/{present} (present {present}/10)")
print("\nMOST FOUND (caption-evident tags retrieved):")
for tag, cnt in hit_counter.most_common(20):
present = present_counter[tag]
print(f" {tag:25s} found {cnt}/{present} (present {present}/10)")
always_found = [t for t, c in hit_counter.items() if c == present_counter[t]]
if always_found:
print("\nALWAYS FOUND WHEN EVIDENT:")
for t in sorted(always_found):
print(f" {t}")
return 0
if __name__ == "__main__":
raise SystemExit(main())
|