File size: 8,527 Bytes
4a76722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
from __future__ import annotations

import argparse
import json
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Set, Tuple


TARGETS = ["balance_sheet", "profit_and_loss", "cash_flow"]
SCOPES = ["consolidated", "standalone"]


def load_json(p: Path):
    with open(p, "r", encoding="utf-8") as fh:
        return json.load(fh)


def to_set_pages(obj) -> Set[int]:
    """Normalize a GT or predicted pages value into a set of ints."""
    if obj is None:
        return set()
    if isinstance(obj, (int, float)):
        return {int(obj)}
    if isinstance(obj, str):
        if obj.isdigit():
            return {int(obj)}
        return set()
    if isinstance(obj, (list, tuple, set)):
        return set(int(x) for x in obj if isinstance(x, (int, float)) or (isinstance(x, str) and x.isdigit()))
    # fallback: attempt to parse iterable
    try:
        return set(int(x) for x in obj)
    except Exception:
        return set()


def jaccard(a: Set[int], b: Set[int]) -> float:
    if not a and not b:
        return 1.0
    if not a and b:
        return 0.0
    inter = len(a & b)
    union = len(a | b)
    return inter / union if union > 0 else 0.0


def precision_recall_f1(tp: int, fp: int, fn: int) -> Tuple[float, float, float]:
    p = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    r = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0.0
    return p, r, f1


def evaluate_file(gt_path: Path, pred_path: Path) -> Dict:
    gt = load_json(gt_path)
    pred = load_json(pred_path)

    # Map possible GT key synonyms to canonical targets
    gt_key_map = {"pnl": "profit_and_loss", "profit_and_loss": "profit_and_loss"}

    per_stmt_scores = {}
    per_stmt_counts = {}

    # For confusion counts aggregated by (stmt, scope)
    counts = {(stmt, scope): {"tp": 0, "fp": 0, "fn": 0} for stmt in TARGETS for scope in SCOPES}

    for stmt in TARGETS:
        # GT: GT sometimes uses 'pnl' key
        raw_gt = None
        if stmt in gt:
            raw_gt = gt.get(stmt)
        elif stmt == "profit_and_loss" and "pnl" in gt:
            raw_gt = gt.get("pnl")

        # Normalize GT scopes -> sets
        gt_scopes: Dict[str, Set[int]] = {}
        if isinstance(raw_gt, dict):
            for scope in SCOPES:
                if scope in raw_gt and raw_gt[scope]:
                    gt_scopes[scope] = to_set_pages(raw_gt[scope])
        else:
            # If GT is list (no scope), treat as 'consolidated' single scope
            if isinstance(raw_gt, list):
                gt_scopes["consolidated"] = to_set_pages(raw_gt)

        # Predictions: predicted blocks per stmt
        pred_blocks = pred.get(stmt) or []
        pred_by_scope: Dict[str, Set[int]] = {"consolidated": set(), "standalone": set(), "unknown": set()}
        for b in pred_blocks:
            if not isinstance(b, dict):
                continue
            scope = (b.get("scope") or "unknown").lower()
            
            # Try 'pages' first, then 'start_page' to 'end_page' range
            pages = to_set_pages(b.get("pages") or [])
            if not pages:
                sp = b.get("start_page")
                ep = b.get("end_page")
                if isinstance(sp, int) and isinstance(ep, int):
                    pages = set(range(sp, ep + 1))
            
            if scope not in pred_by_scope:
                pred_by_scope[scope] = set()
            pred_by_scope[scope] |= pages

        pred_any_scope = set().union(*pred_by_scope.values())

        # Scoring logic per statement
        stmt_scores = []
        if gt_scopes:
            # If GT has both scopes, score each separately and average
            if all(s in gt_scopes for s in SCOPES):
                for scope in SCOPES:
                    gt_pages = gt_scopes.get(scope, set())
                    pred_pages = pred_by_scope.get(scope, set())

                    # Jaccard
                    j = jaccard(gt_pages, pred_pages)
                    stmt_scores.append(j)

                    # Update TP/FP/FN counts (page-level)
                    tp = len(gt_pages & pred_pages)
                    fp = len(pred_pages - gt_pages)
                    fn = len(gt_pages - pred_pages)
                    counts[(stmt, scope)]["tp"] += tp
                    counts[(stmt, scope)]["fp"] += fp
                    counts[(stmt, scope)]["fn"] += fn
            else:
                # Single scope in GT: compare GT pages to any predicted pages (scope-agnostic)
                # choose the GT scope name
                gt_scope = next(iter(gt_scopes.keys()))
                gt_pages = gt_scopes[gt_scope]
                pred_pages = pred_any_scope
                j = jaccard(gt_pages, pred_pages)
                stmt_scores.append(j)

                # For counting, attribute predicted pages to the GT scope
                tp = len(gt_pages & pred_pages)
                fp = len(pred_pages - gt_pages)
                fn = len(gt_pages - pred_pages)
                counts[(stmt, gt_scope)]["tp"] += tp
                counts[(stmt, gt_scope)]["fp"] += fp
                counts[(stmt, gt_scope)]["fn"] += fn
        else:
            # No GT for this statement: treat as not-applicable; but penalize false positives
            # Any predicted pages here are false positives for both scopes (we count under 'consolidated')
            pred_count = len(pred_any_scope)
            if pred_count > 0:
                counts[(stmt, "consolidated")]["fp"] += pred_count
            stmt_scores.append(1.0)  # neutral / perfect since nothing to predict

        per_stmt_scores[stmt] = sum(stmt_scores) / max(1, len(stmt_scores))
        # store a copy of counts per scope for this statement
        per_stmt_counts[stmt] = {s: counts[(stmt, s)].copy() for s in SCOPES} if stmt_scores else {}

    return {
        "gt_path": str(gt_path),
        "pred_path": str(pred_path),
        "per_stmt_scores": per_stmt_scores,
        "counts": counts,
    }


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--split", default="eval", help="Which split folder under dataset/ to use (default: eval)")
    args = ap.parse_args()

    base = Path("./dataset")
    split = base / args.split
    gt_dir = split / "GTs"
    pred_dir = split / "classifier_output"

    if not gt_dir.exists():
        raise FileNotFoundError(f"GTs dir not found: {gt_dir}")
    if not pred_dir.exists():
        raise FileNotFoundError(f"Predictions dir not found: {pred_dir}")

    gt_files = sorted([p for p in gt_dir.iterdir() if p.suffix.lower() == ".json"])
    if not gt_files:
        print("No GT files found.")
        return

    total_counts = {(stmt, scope): {"tp": 0, "fp": 0, "fn": 0} for stmt in TARGETS for scope in SCOPES}
    per_file_scores = []

    for gt_p in gt_files:
        stem = gt_p.stem
        pred_p = pred_dir / f"{stem}.json"
        if not pred_p.exists():
            print(f"WARN: prediction missing for {stem}, skipping")
            continue
        res = evaluate_file(gt_p, pred_p)
        per_file_scores.append((stem, res["per_stmt_scores"]))

        # accumulate counts
        for k, v in res["counts"].items():
            total_counts[k]["tp"] += v["tp"]
            total_counts[k]["fp"] += v["fp"]
            total_counts[k]["fn"] += v["fn"]

        # print per-file breakdown
        print(f"\nFile: {stem}")
        for stmt, score in res["per_stmt_scores"].items():
            print(f"  {stmt}: Jaccard={score:.3f}")

    # Aggregate metrics
    print("\n=== Aggregate metrics ===")
    stmt_scope_results: Dict[Tuple[str, str], Tuple[float, float, float]] = {}
    for stmt in TARGETS:
        for scope in SCOPES:
            tp = total_counts[(stmt, scope)]["tp"]
            fp = total_counts[(stmt, scope)]["fp"]
            fn = total_counts[(stmt, scope)]["fn"]
            p, r, f1 = precision_recall_f1(tp, fp, fn)
            stmt_scope_results[(stmt, scope)] = (p, r, f1)
            print(f"{stmt}/{scope}: TP={tp} FP={fp} FN={fn}  P={p:.3f} R={r:.3f} F1={f1:.3f}")

    # Mean Jaccard across files and statements
    all_scores = []
    for _, per in per_file_scores:
        for stmt in TARGETS:
            if stmt in per:
                all_scores.append(per[stmt])
    mean_jaccard = sum(all_scores) / len(all_scores) if all_scores else 0.0
    print(f"\nMean per-statement Jaccard (averaged over files and statements): {mean_jaccard:.3f}")


if __name__ == "__main__":
    main()