from __future__ import annotations import argparse import json from collections import defaultdict from pathlib import Path from typing import Dict, List, Set, Tuple TARGETS = ["balance_sheet", "profit_and_loss", "cash_flow"] SCOPES = ["consolidated", "standalone"] def load_json(p: Path): with open(p, "r", encoding="utf-8") as fh: return json.load(fh) def to_set_pages(obj) -> Set[int]: """Normalize a GT or predicted pages value into a set of ints.""" if obj is None: return set() if isinstance(obj, (int, float)): return {int(obj)} if isinstance(obj, str): if obj.isdigit(): return {int(obj)} return set() if isinstance(obj, (list, tuple, set)): return set(int(x) for x in obj if isinstance(x, (int, float)) or (isinstance(x, str) and x.isdigit())) # fallback: attempt to parse iterable try: return set(int(x) for x in obj) except Exception: return set() def jaccard(a: Set[int], b: Set[int]) -> float: if not a and not b: return 1.0 if not a and b: return 0.0 inter = len(a & b) union = len(a | b) return inter / union if union > 0 else 0.0 def precision_recall_f1(tp: int, fp: int, fn: int) -> Tuple[float, float, float]: p = tp / (tp + fp) if (tp + fp) > 0 else 0.0 r = tp / (tp + fn) if (tp + fn) > 0 else 0.0 f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0.0 return p, r, f1 def evaluate_file(gt_path: Path, pred_path: Path) -> Dict: gt = load_json(gt_path) pred = load_json(pred_path) # Map possible GT key synonyms to canonical targets gt_key_map = {"pnl": "profit_and_loss", "profit_and_loss": "profit_and_loss"} per_stmt_scores = {} per_stmt_counts = {} # For confusion counts aggregated by (stmt, scope) counts = {(stmt, scope): {"tp": 0, "fp": 0, "fn": 0} for stmt in TARGETS for scope in SCOPES} for stmt in TARGETS: # GT: GT sometimes uses 'pnl' key raw_gt = None if stmt in gt: raw_gt = gt.get(stmt) elif stmt == "profit_and_loss" and "pnl" in gt: raw_gt = gt.get("pnl") # Normalize GT scopes -> sets gt_scopes: Dict[str, Set[int]] = {} if isinstance(raw_gt, dict): for scope in SCOPES: if scope in raw_gt and raw_gt[scope]: gt_scopes[scope] = to_set_pages(raw_gt[scope]) else: # If GT is list (no scope), treat as 'consolidated' single scope if isinstance(raw_gt, list): gt_scopes["consolidated"] = to_set_pages(raw_gt) # Predictions: predicted blocks per stmt pred_blocks = pred.get(stmt) or [] pred_by_scope: Dict[str, Set[int]] = {"consolidated": set(), "standalone": set(), "unknown": set()} for b in pred_blocks: if not isinstance(b, dict): continue scope = (b.get("scope") or "unknown").lower() # Try 'pages' first, then 'start_page' to 'end_page' range pages = to_set_pages(b.get("pages") or []) if not pages: sp = b.get("start_page") ep = b.get("end_page") if isinstance(sp, int) and isinstance(ep, int): pages = set(range(sp, ep + 1)) if scope not in pred_by_scope: pred_by_scope[scope] = set() pred_by_scope[scope] |= pages pred_any_scope = set().union(*pred_by_scope.values()) # Scoring logic per statement stmt_scores = [] if gt_scopes: # If GT has both scopes, score each separately and average if all(s in gt_scopes for s in SCOPES): for scope in SCOPES: gt_pages = gt_scopes.get(scope, set()) pred_pages = pred_by_scope.get(scope, set()) # Jaccard j = jaccard(gt_pages, pred_pages) stmt_scores.append(j) # Update TP/FP/FN counts (page-level) tp = len(gt_pages & pred_pages) fp = len(pred_pages - gt_pages) fn = len(gt_pages - pred_pages) counts[(stmt, scope)]["tp"] += tp counts[(stmt, scope)]["fp"] += fp counts[(stmt, scope)]["fn"] += fn else: # Single scope in GT: compare GT pages to any predicted pages (scope-agnostic) # choose the GT scope name gt_scope = next(iter(gt_scopes.keys())) gt_pages = gt_scopes[gt_scope] pred_pages = pred_any_scope j = jaccard(gt_pages, pred_pages) stmt_scores.append(j) # For counting, attribute predicted pages to the GT scope tp = len(gt_pages & pred_pages) fp = len(pred_pages - gt_pages) fn = len(gt_pages - pred_pages) counts[(stmt, gt_scope)]["tp"] += tp counts[(stmt, gt_scope)]["fp"] += fp counts[(stmt, gt_scope)]["fn"] += fn else: # No GT for this statement: treat as not-applicable; but penalize false positives # Any predicted pages here are false positives for both scopes (we count under 'consolidated') pred_count = len(pred_any_scope) if pred_count > 0: counts[(stmt, "consolidated")]["fp"] += pred_count stmt_scores.append(1.0) # neutral / perfect since nothing to predict per_stmt_scores[stmt] = sum(stmt_scores) / max(1, len(stmt_scores)) # store a copy of counts per scope for this statement per_stmt_counts[stmt] = {s: counts[(stmt, s)].copy() for s in SCOPES} if stmt_scores else {} return { "gt_path": str(gt_path), "pred_path": str(pred_path), "per_stmt_scores": per_stmt_scores, "counts": counts, } def main(): ap = argparse.ArgumentParser() ap.add_argument("--split", default="eval", help="Which split folder under dataset/ to use (default: eval)") args = ap.parse_args() base = Path("./dataset") split = base / args.split gt_dir = split / "GTs" pred_dir = split / "classifier_output" if not gt_dir.exists(): raise FileNotFoundError(f"GTs dir not found: {gt_dir}") if not pred_dir.exists(): raise FileNotFoundError(f"Predictions dir not found: {pred_dir}") gt_files = sorted([p for p in gt_dir.iterdir() if p.suffix.lower() == ".json"]) if not gt_files: print("No GT files found.") return total_counts = {(stmt, scope): {"tp": 0, "fp": 0, "fn": 0} for stmt in TARGETS for scope in SCOPES} per_file_scores = [] for gt_p in gt_files: stem = gt_p.stem pred_p = pred_dir / f"{stem}.json" if not pred_p.exists(): print(f"WARN: prediction missing for {stem}, skipping") continue res = evaluate_file(gt_p, pred_p) per_file_scores.append((stem, res["per_stmt_scores"])) # accumulate counts for k, v in res["counts"].items(): total_counts[k]["tp"] += v["tp"] total_counts[k]["fp"] += v["fp"] total_counts[k]["fn"] += v["fn"] # print per-file breakdown print(f"\nFile: {stem}") for stmt, score in res["per_stmt_scores"].items(): print(f" {stmt}: Jaccard={score:.3f}") # Aggregate metrics print("\n=== Aggregate metrics ===") stmt_scope_results: Dict[Tuple[str, str], Tuple[float, float, float]] = {} for stmt in TARGETS: for scope in SCOPES: tp = total_counts[(stmt, scope)]["tp"] fp = total_counts[(stmt, scope)]["fp"] fn = total_counts[(stmt, scope)]["fn"] p, r, f1 = precision_recall_f1(tp, fp, fn) stmt_scope_results[(stmt, scope)] = (p, r, f1) print(f"{stmt}/{scope}: TP={tp} FP={fp} FN={fn} P={p:.3f} R={r:.3f} F1={f1:.3f}") # Mean Jaccard across files and statements all_scores = [] for _, per in per_file_scores: for stmt in TARGETS: if stmt in per: all_scores.append(per[stmt]) mean_jaccard = sum(all_scores) / len(all_scores) if all_scores else 0.0 print(f"\nMean per-statement Jaccard (averaged over files and statements): {mean_jaccard:.3f}") if __name__ == "__main__": main()