File size: 4,573 Bytes
73f56cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
Analyze caption-evident tag recall against retrieved tags.

Compares tags marked caption-evident to retrieved tags (optionally + structural),
with optional implication expansion on both sets.
"""
from __future__ import annotations

import argparse
import json
from collections import Counter
from pathlib import Path
from typing import Dict, Iterable, Set
import sys

_REPO_ROOT = Path(__file__).resolve().parents[1]
if str(_REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(_REPO_ROOT))

from psq_rag.retrieval.state import expand_tags_via_implications


def _load_evident(path: Path) -> Dict[int, Set[str]]:
    by_id: Dict[int, Set[str]] = {}
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            row = json.loads(line)
            sid = row.get("id")
            if sid is None:
                continue
            tags = set(row.get("tags_ground_truth_expanded") or [])
            if tags:
                by_id[int(sid)] = tags
    return by_id


def _load_eval_detail(path: Path) -> Dict[int, dict]:
    rows = {}
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            row = json.loads(line)
            if row.get("_meta"):
                continue
            rows[int(row["sample_id"])] = row
    return rows


def _expand(tags: Iterable[str]) -> Set[str]:
    expanded, _ = expand_tags_via_implications(set(tags))
    return expanded


def main() -> int:
    ap = argparse.ArgumentParser(description="Caption-evident audit vs retrieval.")
    ap.add_argument("--evident", type=str, required=True,
                    help="Caption-evident JSONL (tags_ground_truth_expanded set to evident tags).")
    ap.add_argument("--detail", type=str, required=True,
                    help="Eval detail JSONL (from eval_pipeline.py).")
    ap.add_argument("--no-structural", action="store_true",
                    help="Do not count structural tags as retrieved.")
    ap.add_argument("--expand-implications", action="store_true",
                    help="Expand both evident and retrieved tags via implications.")
    args = ap.parse_args()

    evident_by_id = _load_evident(Path(args.evident))
    detail_by_id = _load_eval_detail(Path(args.detail))

    hit_counter = Counter()
    miss_counter = Counter()
    present_counter = Counter()

    print("ID,evident,retrieved,overlap,recall_evident,precision_evident,missing_evident,extra_not_evident,complete_overlap")

    total_evident = total_retrieved = total_overlap = 0

    for sid in sorted(evident_by_id):
        ev = set(evident_by_id[sid])
        detail = detail_by_id.get(sid)
        if detail is None:
            continue
        retrieved = set(detail.get("retrieved_tags", []))
        if not args.no_structural:
            retrieved |= set(detail.get("structural_tags", []))

        if args.expand_implications:
            ev = _expand(ev)
            retrieved = _expand(retrieved)

        overlap = ev & retrieved
        missing = ev - retrieved
        extra = retrieved - ev

        for t in ev:
            present_counter[t] += 1
            if t in retrieved:
                hit_counter[t] += 1
            else:
                miss_counter[t] += 1

        recall = len(overlap) / len(ev) if ev else 0.0
        precision = len(overlap) / len(retrieved) if retrieved else 0.0
        total_evident += len(ev)
        total_retrieved += len(retrieved)
        total_overlap += len(overlap)
        complete = len(missing) == 0
        print(f"{sid},{len(ev)},{len(retrieved)},{len(overlap)},{recall:.3f},{precision:.3f},{len(missing)},{len(extra)},{complete}")

    print(f"TOTAL,{total_evident},{total_retrieved},{total_overlap},{(total_overlap/total_evident):.3f},{(total_overlap/total_retrieved):.3f},{total_evident-total_overlap},{total_retrieved-total_overlap},N/A")

    print("\nMOST MISSED (caption-evident tags not retrieved):")
    for tag, cnt in miss_counter.most_common(20):
        present = present_counter[tag]
        print(f"  {tag:25s} missed {cnt}/{present} (present {present}/10)")

    print("\nMOST FOUND (caption-evident tags retrieved):")
    for tag, cnt in hit_counter.most_common(20):
        present = present_counter[tag]
        print(f"  {tag:25s} found {cnt}/{present} (present {present}/10)")

    always_found = [t for t, c in hit_counter.items() if c == present_counter[t]]
    if always_found:
        print("\nALWAYS FOUND WHEN EVIDENT:")
        for t in sorted(always_found):
            print(f"  {t}")

    return 0


if __name__ == "__main__":
    raise SystemExit(main())