FridayCodehhr commited on
Commit
4a76722
·
verified ·
1 Parent(s): 0f273d1

Upload 2 files

Browse files
Files changed (2) hide show
  1. evaluate_eval.py +231 -0
  2. script.py +66 -0
evaluate_eval.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ from collections import defaultdict
6
+ from pathlib import Path
7
+ from typing import Dict, List, Set, Tuple
8
+
9
+
10
+ TARGETS = ["balance_sheet", "profit_and_loss", "cash_flow"]
11
+ SCOPES = ["consolidated", "standalone"]
12
+
13
+
14
+ def load_json(p: Path):
15
+ with open(p, "r", encoding="utf-8") as fh:
16
+ return json.load(fh)
17
+
18
+
19
+ def to_set_pages(obj) -> Set[int]:
20
+ """Normalize a GT or predicted pages value into a set of ints."""
21
+ if obj is None:
22
+ return set()
23
+ if isinstance(obj, (int, float)):
24
+ return {int(obj)}
25
+ if isinstance(obj, str):
26
+ if obj.isdigit():
27
+ return {int(obj)}
28
+ return set()
29
+ if isinstance(obj, (list, tuple, set)):
30
+ return set(int(x) for x in obj if isinstance(x, (int, float)) or (isinstance(x, str) and x.isdigit()))
31
+ # fallback: attempt to parse iterable
32
+ try:
33
+ return set(int(x) for x in obj)
34
+ except Exception:
35
+ return set()
36
+
37
+
38
+ def jaccard(a: Set[int], b: Set[int]) -> float:
39
+ if not a and not b:
40
+ return 1.0
41
+ if not a and b:
42
+ return 0.0
43
+ inter = len(a & b)
44
+ union = len(a | b)
45
+ return inter / union if union > 0 else 0.0
46
+
47
+
48
+ def precision_recall_f1(tp: int, fp: int, fn: int) -> Tuple[float, float, float]:
49
+ p = tp / (tp + fp) if (tp + fp) > 0 else 0.0
50
+ r = tp / (tp + fn) if (tp + fn) > 0 else 0.0
51
+ f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0.0
52
+ return p, r, f1
53
+
54
+
55
+ def evaluate_file(gt_path: Path, pred_path: Path) -> Dict:
56
+ gt = load_json(gt_path)
57
+ pred = load_json(pred_path)
58
+
59
+ # Map possible GT key synonyms to canonical targets
60
+ gt_key_map = {"pnl": "profit_and_loss", "profit_and_loss": "profit_and_loss"}
61
+
62
+ per_stmt_scores = {}
63
+ per_stmt_counts = {}
64
+
65
+ # For confusion counts aggregated by (stmt, scope)
66
+ counts = {(stmt, scope): {"tp": 0, "fp": 0, "fn": 0} for stmt in TARGETS for scope in SCOPES}
67
+
68
+ for stmt in TARGETS:
69
+ # GT: GT sometimes uses 'pnl' key
70
+ raw_gt = None
71
+ if stmt in gt:
72
+ raw_gt = gt.get(stmt)
73
+ elif stmt == "profit_and_loss" and "pnl" in gt:
74
+ raw_gt = gt.get("pnl")
75
+
76
+ # Normalize GT scopes -> sets
77
+ gt_scopes: Dict[str, Set[int]] = {}
78
+ if isinstance(raw_gt, dict):
79
+ for scope in SCOPES:
80
+ if scope in raw_gt and raw_gt[scope]:
81
+ gt_scopes[scope] = to_set_pages(raw_gt[scope])
82
+ else:
83
+ # If GT is list (no scope), treat as 'consolidated' single scope
84
+ if isinstance(raw_gt, list):
85
+ gt_scopes["consolidated"] = to_set_pages(raw_gt)
86
+
87
+ # Predictions: predicted blocks per stmt
88
+ pred_blocks = pred.get(stmt) or []
89
+ pred_by_scope: Dict[str, Set[int]] = {"consolidated": set(), "standalone": set(), "unknown": set()}
90
+ for b in pred_blocks:
91
+ if not isinstance(b, dict):
92
+ continue
93
+ scope = (b.get("scope") or "unknown").lower()
94
+
95
+ # Try 'pages' first, then 'start_page' to 'end_page' range
96
+ pages = to_set_pages(b.get("pages") or [])
97
+ if not pages:
98
+ sp = b.get("start_page")
99
+ ep = b.get("end_page")
100
+ if isinstance(sp, int) and isinstance(ep, int):
101
+ pages = set(range(sp, ep + 1))
102
+
103
+ if scope not in pred_by_scope:
104
+ pred_by_scope[scope] = set()
105
+ pred_by_scope[scope] |= pages
106
+
107
+ pred_any_scope = set().union(*pred_by_scope.values())
108
+
109
+ # Scoring logic per statement
110
+ stmt_scores = []
111
+ if gt_scopes:
112
+ # If GT has both scopes, score each separately and average
113
+ if all(s in gt_scopes for s in SCOPES):
114
+ for scope in SCOPES:
115
+ gt_pages = gt_scopes.get(scope, set())
116
+ pred_pages = pred_by_scope.get(scope, set())
117
+
118
+ # Jaccard
119
+ j = jaccard(gt_pages, pred_pages)
120
+ stmt_scores.append(j)
121
+
122
+ # Update TP/FP/FN counts (page-level)
123
+ tp = len(gt_pages & pred_pages)
124
+ fp = len(pred_pages - gt_pages)
125
+ fn = len(gt_pages - pred_pages)
126
+ counts[(stmt, scope)]["tp"] += tp
127
+ counts[(stmt, scope)]["fp"] += fp
128
+ counts[(stmt, scope)]["fn"] += fn
129
+ else:
130
+ # Single scope in GT: compare GT pages to any predicted pages (scope-agnostic)
131
+ # choose the GT scope name
132
+ gt_scope = next(iter(gt_scopes.keys()))
133
+ gt_pages = gt_scopes[gt_scope]
134
+ pred_pages = pred_any_scope
135
+ j = jaccard(gt_pages, pred_pages)
136
+ stmt_scores.append(j)
137
+
138
+ # For counting, attribute predicted pages to the GT scope
139
+ tp = len(gt_pages & pred_pages)
140
+ fp = len(pred_pages - gt_pages)
141
+ fn = len(gt_pages - pred_pages)
142
+ counts[(stmt, gt_scope)]["tp"] += tp
143
+ counts[(stmt, gt_scope)]["fp"] += fp
144
+ counts[(stmt, gt_scope)]["fn"] += fn
145
+ else:
146
+ # No GT for this statement: treat as not-applicable; but penalize false positives
147
+ # Any predicted pages here are false positives for both scopes (we count under 'consolidated')
148
+ pred_count = len(pred_any_scope)
149
+ if pred_count > 0:
150
+ counts[(stmt, "consolidated")]["fp"] += pred_count
151
+ stmt_scores.append(1.0) # neutral / perfect since nothing to predict
152
+
153
+ per_stmt_scores[stmt] = sum(stmt_scores) / max(1, len(stmt_scores))
154
+ # store a copy of counts per scope for this statement
155
+ per_stmt_counts[stmt] = {s: counts[(stmt, s)].copy() for s in SCOPES} if stmt_scores else {}
156
+
157
+ return {
158
+ "gt_path": str(gt_path),
159
+ "pred_path": str(pred_path),
160
+ "per_stmt_scores": per_stmt_scores,
161
+ "counts": counts,
162
+ }
163
+
164
+
165
+ def main():
166
+ ap = argparse.ArgumentParser()
167
+ ap.add_argument("--split", default="eval", help="Which split folder under dataset/ to use (default: eval)")
168
+ args = ap.parse_args()
169
+
170
+ base = Path("./dataset")
171
+ split = base / args.split
172
+ gt_dir = split / "GTs"
173
+ pred_dir = split / "classifier_output"
174
+
175
+ if not gt_dir.exists():
176
+ raise FileNotFoundError(f"GTs dir not found: {gt_dir}")
177
+ if not pred_dir.exists():
178
+ raise FileNotFoundError(f"Predictions dir not found: {pred_dir}")
179
+
180
+ gt_files = sorted([p for p in gt_dir.iterdir() if p.suffix.lower() == ".json"])
181
+ if not gt_files:
182
+ print("No GT files found.")
183
+ return
184
+
185
+ total_counts = {(stmt, scope): {"tp": 0, "fp": 0, "fn": 0} for stmt in TARGETS for scope in SCOPES}
186
+ per_file_scores = []
187
+
188
+ for gt_p in gt_files:
189
+ stem = gt_p.stem
190
+ pred_p = pred_dir / f"{stem}.json"
191
+ if not pred_p.exists():
192
+ print(f"WARN: prediction missing for {stem}, skipping")
193
+ continue
194
+ res = evaluate_file(gt_p, pred_p)
195
+ per_file_scores.append((stem, res["per_stmt_scores"]))
196
+
197
+ # accumulate counts
198
+ for k, v in res["counts"].items():
199
+ total_counts[k]["tp"] += v["tp"]
200
+ total_counts[k]["fp"] += v["fp"]
201
+ total_counts[k]["fn"] += v["fn"]
202
+
203
+ # print per-file breakdown
204
+ print(f"\nFile: {stem}")
205
+ for stmt, score in res["per_stmt_scores"].items():
206
+ print(f" {stmt}: Jaccard={score:.3f}")
207
+
208
+ # Aggregate metrics
209
+ print("\n=== Aggregate metrics ===")
210
+ stmt_scope_results: Dict[Tuple[str, str], Tuple[float, float, float]] = {}
211
+ for stmt in TARGETS:
212
+ for scope in SCOPES:
213
+ tp = total_counts[(stmt, scope)]["tp"]
214
+ fp = total_counts[(stmt, scope)]["fp"]
215
+ fn = total_counts[(stmt, scope)]["fn"]
216
+ p, r, f1 = precision_recall_f1(tp, fp, fn)
217
+ stmt_scope_results[(stmt, scope)] = (p, r, f1)
218
+ print(f"{stmt}/{scope}: TP={tp} FP={fp} FN={fn} P={p:.3f} R={r:.3f} F1={f1:.3f}")
219
+
220
+ # Mean Jaccard across files and statements
221
+ all_scores = []
222
+ for _, per in per_file_scores:
223
+ for stmt in TARGETS:
224
+ if stmt in per:
225
+ all_scores.append(per[stmt])
226
+ mean_jaccard = sum(all_scores) / len(all_scores) if all_scores else 0.0
227
+ print(f"\nMean per-statement Jaccard (averaged over files and statements): {mean_jaccard:.3f}")
228
+
229
+
230
+ if __name__ == "__main__":
231
+ main()
script.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ from pathlib import Path
3
+ import sys
4
+ import shutil
5
+ import tqdm
6
+
7
+ BASE = Path(__file__).resolve().parents[0]
8
+ DATASET_DIR = BASE / "dataset"
9
+ GPT_DIR = BASE / "gpt"
10
+
11
+
12
+ def find_split_dir() -> Path:
13
+ name = "eval" # eval or test
14
+ p = DATASET_DIR / name
15
+ if p.exists() and p.is_dir():
16
+ return p
17
+ raise FileNotFoundError(f"No split directory found under {DATASET_DIR}. Expected one of: val, eval, validation")
18
+
19
+
20
+ def run_for_pdf(pdf_path: Path, out_path: Path) -> int:
21
+ # Ensure output parent exists
22
+ out_path.parent.mkdir(parents=True, exist_ok=True)
23
+
24
+ cmd = [sys.executable, "main.py", "--pdf", str(pdf_path), "--out", str(out_path)]
25
+ print(f"Running: {' '.join(cmd)} (cwd={GPT_DIR})")
26
+ proc = subprocess.run(cmd, cwd=str(GPT_DIR), stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
27
+ if proc.returncode != 0:
28
+ print(f"ERROR: gpt/main.py failed for {pdf_path.name} (rc={proc.returncode})")
29
+ print(proc.stdout)
30
+ print(proc.stderr)
31
+ else:
32
+ print(f"OK: saved -> {out_path}")
33
+ return proc.returncode
34
+
35
+
36
+ def main():
37
+ split_dir = find_split_dir()
38
+ pdf_dir = split_dir / "PDFs"
39
+ if not pdf_dir.exists():
40
+ raise FileNotFoundError(f"PDFs directory not found: {pdf_dir}")
41
+
42
+ out_dir = split_dir / "classifier_output"
43
+ out_dir.mkdir(parents=True, exist_ok=True)
44
+
45
+ pdf_files = sorted([p for p in pdf_dir.iterdir() if p.suffix.lower() == ".pdf"])
46
+ if not pdf_files:
47
+ print(f"No PDF files found in {pdf_dir}")
48
+ return
49
+
50
+ print(f"Found {len(pdf_files)} PDFs in {pdf_dir}; outputs -> {out_dir}")
51
+
52
+ failures = 0
53
+ for pdf in tqdm.tqdm(pdf_files, total=len(pdf_files)):
54
+ stem = pdf.stem
55
+ if stem in list([i.stem for i in out_dir.iterdir()]):
56
+ continue
57
+ out_path = out_dir / f"{stem}.json"
58
+ rc = run_for_pdf(pdf, out_path)
59
+ if rc != 0:
60
+ failures += 1
61
+
62
+ print(f"\nDone. Processed: {len(pdf_files)} failures: {failures}")
63
+
64
+
65
+ if __name__ == "__main__":
66
+ main()