Spaces:
Running
Running
| """ | |
| Analyze caption-evident tag recall against retrieved tags. | |
| Compares tags marked caption-evident to retrieved tags (optionally + structural), | |
| with optional implication expansion on both sets. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from collections import Counter | |
| from pathlib import Path | |
| from typing import Dict, Iterable, Set | |
| import sys | |
| _REPO_ROOT = Path(__file__).resolve().parents[1] | |
| if str(_REPO_ROOT) not in sys.path: | |
| sys.path.insert(0, str(_REPO_ROOT)) | |
| from psq_rag.retrieval.state import expand_tags_via_implications | |
| def _load_evident(path: Path) -> Dict[int, Set[str]]: | |
| by_id: Dict[int, Set[str]] = {} | |
| with path.open("r", encoding="utf-8") as f: | |
| for line in f: | |
| row = json.loads(line) | |
| sid = row.get("id") | |
| if sid is None: | |
| continue | |
| tags = set(row.get("tags_ground_truth_expanded") or []) | |
| if tags: | |
| by_id[int(sid)] = tags | |
| return by_id | |
| def _load_eval_detail(path: Path) -> Dict[int, dict]: | |
| rows = {} | |
| with path.open("r", encoding="utf-8") as f: | |
| for line in f: | |
| row = json.loads(line) | |
| if row.get("_meta"): | |
| continue | |
| rows[int(row["sample_id"])] = row | |
| return rows | |
| def _expand(tags: Iterable[str]) -> Set[str]: | |
| expanded, _ = expand_tags_via_implications(set(tags)) | |
| return expanded | |
| def main() -> int: | |
| ap = argparse.ArgumentParser(description="Caption-evident audit vs retrieval.") | |
| ap.add_argument("--evident", type=str, required=True, | |
| help="Caption-evident JSONL (tags_ground_truth_expanded set to evident tags).") | |
| ap.add_argument("--detail", type=str, required=True, | |
| help="Eval detail JSONL (from eval_pipeline.py).") | |
| ap.add_argument("--no-structural", action="store_true", | |
| help="Do not count structural tags as retrieved.") | |
| ap.add_argument("--expand-implications", action="store_true", | |
| help="Expand both evident and retrieved tags via implications.") | |
| args = ap.parse_args() | |
| evident_by_id = _load_evident(Path(args.evident)) | |
| detail_by_id = _load_eval_detail(Path(args.detail)) | |
| hit_counter = Counter() | |
| miss_counter = Counter() | |
| present_counter = Counter() | |
| print("ID,evident,retrieved,overlap,recall_evident,precision_evident,missing_evident,extra_not_evident,complete_overlap") | |
| total_evident = total_retrieved = total_overlap = 0 | |
| for sid in sorted(evident_by_id): | |
| ev = set(evident_by_id[sid]) | |
| detail = detail_by_id.get(sid) | |
| if detail is None: | |
| continue | |
| retrieved = set(detail.get("retrieved_tags", [])) | |
| if not args.no_structural: | |
| retrieved |= set(detail.get("structural_tags", [])) | |
| if args.expand_implications: | |
| ev = _expand(ev) | |
| retrieved = _expand(retrieved) | |
| overlap = ev & retrieved | |
| missing = ev - retrieved | |
| extra = retrieved - ev | |
| for t in ev: | |
| present_counter[t] += 1 | |
| if t in retrieved: | |
| hit_counter[t] += 1 | |
| else: | |
| miss_counter[t] += 1 | |
| recall = len(overlap) / len(ev) if ev else 0.0 | |
| precision = len(overlap) / len(retrieved) if retrieved else 0.0 | |
| total_evident += len(ev) | |
| total_retrieved += len(retrieved) | |
| total_overlap += len(overlap) | |
| complete = len(missing) == 0 | |
| print(f"{sid},{len(ev)},{len(retrieved)},{len(overlap)},{recall:.3f},{precision:.3f},{len(missing)},{len(extra)},{complete}") | |
| print(f"TOTAL,{total_evident},{total_retrieved},{total_overlap},{(total_overlap/total_evident):.3f},{(total_overlap/total_retrieved):.3f},{total_evident-total_overlap},{total_retrieved-total_overlap},N/A") | |
| print("\nMOST MISSED (caption-evident tags not retrieved):") | |
| for tag, cnt in miss_counter.most_common(20): | |
| present = present_counter[tag] | |
| print(f" {tag:25s} missed {cnt}/{present} (present {present}/10)") | |
| print("\nMOST FOUND (caption-evident tags retrieved):") | |
| for tag, cnt in hit_counter.most_common(20): | |
| present = present_counter[tag] | |
| print(f" {tag:25s} found {cnt}/{present} (present {present}/10)") | |
| always_found = [t for t, c in hit_counter.items() if c == present_counter[t]] | |
| if always_found: | |
| print("\nALWAYS FOUND WHEN EVIDENT:") | |
| for t in sorted(always_found): | |
| print(f" {t}") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |