EphAsad commited on
Commit
693f9e0
·
verified ·
1 Parent(s): 63b6104

Update training/gold_tester.py

Browse files
Files changed (1) hide show
  1. training/gold_tester.py +81 -161
training/gold_tester.py CHANGED
@@ -1,169 +1,89 @@
1
  # training/gold_tester.py
2
- # ----------------------------------------------------
3
- # Enhanced tester: audits expected fields not in schema,
4
- # adds DNase/Dnase alias and range-aware Growth Temperature matching.
5
-
6
- import json, os, time, csv
7
- from collections import Counter
8
- from typing import Dict, List, Tuple
9
- from engine.schema import SCHEMA, UNKNOWN, normalize_value, is_enum_field
 
 
 
10
  from engine.parser_rules import parse_text_rules
 
11
 
12
- REPORTS_DIR = "reports"
13
- PROPOSALS_PATH = os.path.join("data", "extended_proposals.jsonl")
14
- GOLD_PATH = os.path.join("training", "gold_tests.json")
15
 
16
- # --- helpers ---
17
- def load_gold() -> List[Dict]:
 
 
 
 
 
18
  with open(GOLD_PATH, "r", encoding="utf-8") as f:
19
- return json.load(f)
20
-
21
- def _range_overlap(a: str, b: str) -> bool:
22
- try:
23
- la, ha = [float(x) for x in a.split("//")]
24
- lb, hb = [float(x) for x in b.split("//")]
25
- return not (ha < lb or hb < la)
26
- except Exception:
27
- return False
28
-
29
- def compare_records(pred: Dict[str, str], exp: Dict[str, str]) -> Tuple[int, int, Dict[str, Tuple[str, str]]]:
30
- correct, total, errors = 0, 0, {}
31
- for field, exp_val in exp.items():
32
- total += 1
33
- p = pred.get(field, UNKNOWN)
34
- if field == "Growth Temperature":
35
- if p != UNKNOWN and exp_val != UNKNOWN and _range_overlap(p, exp_val):
36
- correct += 1
37
- continue
38
- if p == exp_val:
39
- correct += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  else:
41
- errors[field] = (p, exp_val)
42
- return correct, total, errors
43
-
44
- def append_proposal(record: Dict):
45
- os.makedirs(os.path.dirname(PROPOSALS_PATH), exist_ok=True)
46
- with open(PROPOSALS_PATH, "a", encoding="utf-8") as f:
47
- f.write(json.dumps(record, ensure_ascii=False) + "\n")
48
-
49
- # --- main ---
50
- def run_gold_tests(mode: str = "rules") -> Dict:
51
- tests = load_gold()
52
- ts = time.strftime("%Y%m%d_%H%M%S")
53
-
54
- per_field_counts, per_field_correct, per_field_cov = Counter(), Counter(), Counter()
55
- unknown_fields, unknown_values = Counter(), Counter()
56
- expected_unknowns = Counter()
57
- detailed_rows = []
58
- cases_with_misses = 0
59
-
60
- for case in tests:
61
- name, text, expected = case.get("name", ""), case.get("input", ""), case.get("expected", {})
62
-
63
- # normalize expected key aliases
64
- expected_norm = {}
65
- for k, v in expected.items():
66
- k2 = "DNase" if k.lower() == "dnase" else k
67
- expected_norm[k2] = v
68
- expected = expected_norm
69
-
70
- out = parse_text_rules(text)
71
- parsed = out.get("parsed_fields", {})
72
-
73
- # normalize parser output
74
- normalized_pred = {}
75
- for field, val in parsed.items():
76
- if field not in SCHEMA:
77
- unknown_fields[field] += 1
78
- append_proposal({
79
- "type": "unknown_field",
80
- "field": field,
81
- "value": val,
82
- "case_name": name,
83
- "timestamp": ts
84
- })
85
- continue
86
- normalized_pred[field] = normalize_value(field, val)
87
- if is_enum_field(field):
88
- allowed = SCHEMA[field].get("allowed", [])
89
- if normalized_pred[field] not in allowed + [UNKNOWN]:
90
- unknown_values[(field, normalized_pred[field])] += 1
91
- append_proposal({
92
- "type": "unknown_value",
93
- "field": field,
94
- "value": normalized_pred[field],
95
- "allowed": allowed,
96
- "case_name": name,
97
- "timestamp": ts
98
- })
99
-
100
- # audit expected fields not in schema
101
- for ef in expected.keys():
102
- if ef not in SCHEMA:
103
- expected_unknowns[ef] += 1
104
- append_proposal({
105
- "type": "expected_field_not_in_schema",
106
- "field": ef,
107
- "case_name": name,
108
- "timestamp": ts
109
- })
110
-
111
- correct, total, errors = compare_records(normalized_pred, expected)
112
- if errors:
113
- cases_with_misses += 1
114
-
115
- for f in expected.keys():
116
- per_field_counts[f] += 1
117
- if f in normalized_pred and normalized_pred[f] != UNKNOWN:
118
- per_field_cov[f] += 1
119
- if f not in errors:
120
- per_field_correct[f] += 1
121
-
122
- detailed_rows.append({
123
- "name": name,
124
- "parsed": json.dumps(normalized_pred, ensure_ascii=False),
125
- "expected": json.dumps(expected, ensure_ascii=False),
126
- "correct_fields": correct,
127
- "total_fields": total
128
- })
129
-
130
- # --- aggregate metrics ---
131
- per_field_metrics = []
132
- for f, tot in per_field_counts.items():
133
- acc = per_field_correct[f] / tot if tot else 0.0
134
- cov = per_field_cov[f] / tot if tot else 0.0
135
- per_field_metrics.append({"field": f, "accuracy": round(acc, 4), "coverage": round(cov, 4), "n": tot})
136
- per_field_metrics.sort(key=lambda x: x["field"])
137
-
138
- micro_acc = sum(per_field_correct.values()) / sum(per_field_counts.values()) if per_field_counts else 0.0
139
-
140
- os.makedirs(REPORTS_DIR, exist_ok=True)
141
- report = {
142
  "mode": mode,
143
- "timestamp": ts,
144
- "num_tests": len(tests),
145
- "micro_accuracy": round(micro_acc, 4),
146
- "cases_with_misses": cases_with_misses,
147
- "per_field": per_field_metrics,
148
- "unknown_fields": dict(unknown_fields),
149
- "unknown_values": {f"{k[0]}::{k[1]}": v for k, v in unknown_values.items()},
150
- "expected_unknown_fields": dict(expected_unknowns),
151
- "proposals_path": PROPOSALS_PATH
152
  }
153
- json_path = os.path.join(REPORTS_DIR, f"gold_report_{mode}_{ts}.json")
154
- csv_fields = os.path.join(REPORTS_DIR, f"gold_fields_{mode}_{ts}.csv")
155
- csv_cases = os.path.join(REPORTS_DIR, f"gold_cases_{mode}_{ts}.csv")
156
-
157
- with open(json_path, "w", encoding="utf-8") as f:
158
- json.dump(report, f, indent=2, ensure_ascii=False)
159
- with open(csv_fields, "w", newline="", encoding="utf-8") as f:
160
- import csv
161
- w = csv.DictWriter(f, fieldnames=["field", "accuracy", "coverage", "n"])
162
- w.writeheader()
163
- w.writerows(per_field_metrics)
164
- with open(csv_cases, "w", newline="", encoding="utf-8") as f:
165
- w = csv.DictWriter(f, fieldnames=["name", "parsed", "expected", "correct_fields", "total_fields"])
166
- w.writeheader()
167
- w.writerows(detailed_rows)
168
-
169
- return {"summary": report, "paths": {"json_report": json_path, "csv_fields": csv_fields, "csv_cases": csv_cases}}
 
1
  # training/gold_tester.py
2
+ # ------------------------------------------------------------
3
+ # Stage 10A: Evaluate parsers on gold tests.
4
+ # This MUST NOT crash during import.
5
+ # ------------------------------------------------------------
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import os
11
+ from typing import Dict, Any, List
12
+
13
  from engine.parser_rules import parse_text_rules
14
+ from engine.parser_ext import parse_text_extended
15
 
 
 
 
16
 
17
+ GOLD_PATH = "training/gold_tests.json"
18
+ REPORT_DIR = "reports"
19
+
20
+
21
+ def _load_gold_tests() -> List[Dict[str, Any]]:
22
+ if not os.path.exists(GOLD_PATH):
23
+ return []
24
  with open(GOLD_PATH, "r", encoding="utf-8") as f:
25
+ try:
26
+ data = json.load(f)
27
+ return data if isinstance(data, list) else []
28
+ except Exception:
29
+ return []
30
+
31
+
32
+ def run_gold_tests(mode: str = "rules") -> Dict[str, Any]:
33
+ gold_tests = _load_gold_tests()
34
+ if not gold_tests:
35
+ return {
36
+ "summary": {
37
+ "mode": mode,
38
+ "tests": 0,
39
+ "total_correct": 0,
40
+ "total_fields": 0,
41
+ "overall_accuracy": 0.0,
42
+ "proposals_path": "data/extended_proposals.jsonl",
43
+ }
44
+ }
45
+
46
+ os.makedirs(REPORT_DIR, exist_ok=True)
47
+
48
+ wrong_cases = []
49
+ total_correct = 0
50
+ total_fields = 0
51
+
52
+ for idx, test in enumerate(gold_tests):
53
+ text = test.get("input", "")
54
+ expected = test.get("expected", {})
55
+
56
+ if mode == "rules":
57
+ parsed = parse_text_rules(text).get("parsed_fields", {})
58
+ elif mode == "rules+extended":
59
+ rule_fields = parse_text_rules(text).get("parsed_fields", {})
60
+ ext_fields = parse_text_extended(text).get("parsed_fields", {})
61
+ parsed = {**rule_fields, **ext_fields}
62
  else:
63
+ parsed = {}
64
+
65
+ # Compare field-by-field
66
+ correct_count = 0
67
+ for key, val in expected.items():
68
+ total_fields += 1
69
+ if key in parsed and str(parsed[key]).strip().lower() == str(val).strip().lower():
70
+ correct_count += 1
71
+
72
+ total_correct += correct_count
73
+
74
+ if correct_count < len(expected):
75
+ wrong_cases.append(idx)
76
+
77
+ accuracy = total_correct / total_fields if total_fields else 0.0
78
+
79
+ summary = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  "mode": mode,
81
+ "tests": len(gold_tests),
82
+ "total_correct": total_correct,
83
+ "total_fields": total_fields,
84
+ "overall_accuracy": accuracy,
85
+ "wrong_cases": wrong_cases,
86
+ "proposals_path": "data/extended_proposals.jsonl",
 
 
 
87
  }
88
+
89
+ return {"summary": summary}