Update training/gold_tester.py
Browse files- training/gold_tester.py +81 -161
training/gold_tester.py
CHANGED
|
@@ -1,169 +1,89 @@
|
|
| 1 |
# training/gold_tester.py
|
| 2 |
-
# ----------------------------------------------------
|
| 3 |
-
#
|
| 4 |
-
#
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
from
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
| 10 |
from engine.parser_rules import parse_text_rules
|
|
|
|
| 11 |
|
| 12 |
-
REPORTS_DIR = "reports"
|
| 13 |
-
PROPOSALS_PATH = os.path.join("data", "extended_proposals.jsonl")
|
| 14 |
-
GOLD_PATH = os.path.join("training", "gold_tests.json")
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
with open(GOLD_PATH, "r", encoding="utf-8") as f:
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
else:
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
cases_with_misses = 0
|
| 59 |
-
|
| 60 |
-
for case in tests:
|
| 61 |
-
name, text, expected = case.get("name", ""), case.get("input", ""), case.get("expected", {})
|
| 62 |
-
|
| 63 |
-
# normalize expected key aliases
|
| 64 |
-
expected_norm = {}
|
| 65 |
-
for k, v in expected.items():
|
| 66 |
-
k2 = "DNase" if k.lower() == "dnase" else k
|
| 67 |
-
expected_norm[k2] = v
|
| 68 |
-
expected = expected_norm
|
| 69 |
-
|
| 70 |
-
out = parse_text_rules(text)
|
| 71 |
-
parsed = out.get("parsed_fields", {})
|
| 72 |
-
|
| 73 |
-
# normalize parser output
|
| 74 |
-
normalized_pred = {}
|
| 75 |
-
for field, val in parsed.items():
|
| 76 |
-
if field not in SCHEMA:
|
| 77 |
-
unknown_fields[field] += 1
|
| 78 |
-
append_proposal({
|
| 79 |
-
"type": "unknown_field",
|
| 80 |
-
"field": field,
|
| 81 |
-
"value": val,
|
| 82 |
-
"case_name": name,
|
| 83 |
-
"timestamp": ts
|
| 84 |
-
})
|
| 85 |
-
continue
|
| 86 |
-
normalized_pred[field] = normalize_value(field, val)
|
| 87 |
-
if is_enum_field(field):
|
| 88 |
-
allowed = SCHEMA[field].get("allowed", [])
|
| 89 |
-
if normalized_pred[field] not in allowed + [UNKNOWN]:
|
| 90 |
-
unknown_values[(field, normalized_pred[field])] += 1
|
| 91 |
-
append_proposal({
|
| 92 |
-
"type": "unknown_value",
|
| 93 |
-
"field": field,
|
| 94 |
-
"value": normalized_pred[field],
|
| 95 |
-
"allowed": allowed,
|
| 96 |
-
"case_name": name,
|
| 97 |
-
"timestamp": ts
|
| 98 |
-
})
|
| 99 |
-
|
| 100 |
-
# audit expected fields not in schema
|
| 101 |
-
for ef in expected.keys():
|
| 102 |
-
if ef not in SCHEMA:
|
| 103 |
-
expected_unknowns[ef] += 1
|
| 104 |
-
append_proposal({
|
| 105 |
-
"type": "expected_field_not_in_schema",
|
| 106 |
-
"field": ef,
|
| 107 |
-
"case_name": name,
|
| 108 |
-
"timestamp": ts
|
| 109 |
-
})
|
| 110 |
-
|
| 111 |
-
correct, total, errors = compare_records(normalized_pred, expected)
|
| 112 |
-
if errors:
|
| 113 |
-
cases_with_misses += 1
|
| 114 |
-
|
| 115 |
-
for f in expected.keys():
|
| 116 |
-
per_field_counts[f] += 1
|
| 117 |
-
if f in normalized_pred and normalized_pred[f] != UNKNOWN:
|
| 118 |
-
per_field_cov[f] += 1
|
| 119 |
-
if f not in errors:
|
| 120 |
-
per_field_correct[f] += 1
|
| 121 |
-
|
| 122 |
-
detailed_rows.append({
|
| 123 |
-
"name": name,
|
| 124 |
-
"parsed": json.dumps(normalized_pred, ensure_ascii=False),
|
| 125 |
-
"expected": json.dumps(expected, ensure_ascii=False),
|
| 126 |
-
"correct_fields": correct,
|
| 127 |
-
"total_fields": total
|
| 128 |
-
})
|
| 129 |
-
|
| 130 |
-
# --- aggregate metrics ---
|
| 131 |
-
per_field_metrics = []
|
| 132 |
-
for f, tot in per_field_counts.items():
|
| 133 |
-
acc = per_field_correct[f] / tot if tot else 0.0
|
| 134 |
-
cov = per_field_cov[f] / tot if tot else 0.0
|
| 135 |
-
per_field_metrics.append({"field": f, "accuracy": round(acc, 4), "coverage": round(cov, 4), "n": tot})
|
| 136 |
-
per_field_metrics.sort(key=lambda x: x["field"])
|
| 137 |
-
|
| 138 |
-
micro_acc = sum(per_field_correct.values()) / sum(per_field_counts.values()) if per_field_counts else 0.0
|
| 139 |
-
|
| 140 |
-
os.makedirs(REPORTS_DIR, exist_ok=True)
|
| 141 |
-
report = {
|
| 142 |
"mode": mode,
|
| 143 |
-
"
|
| 144 |
-
"
|
| 145 |
-
"
|
| 146 |
-
"
|
| 147 |
-
"
|
| 148 |
-
"
|
| 149 |
-
"unknown_values": {f"{k[0]}::{k[1]}": v for k, v in unknown_values.items()},
|
| 150 |
-
"expected_unknown_fields": dict(expected_unknowns),
|
| 151 |
-
"proposals_path": PROPOSALS_PATH
|
| 152 |
}
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
csv_cases = os.path.join(REPORTS_DIR, f"gold_cases_{mode}_{ts}.csv")
|
| 156 |
-
|
| 157 |
-
with open(json_path, "w", encoding="utf-8") as f:
|
| 158 |
-
json.dump(report, f, indent=2, ensure_ascii=False)
|
| 159 |
-
with open(csv_fields, "w", newline="", encoding="utf-8") as f:
|
| 160 |
-
import csv
|
| 161 |
-
w = csv.DictWriter(f, fieldnames=["field", "accuracy", "coverage", "n"])
|
| 162 |
-
w.writeheader()
|
| 163 |
-
w.writerows(per_field_metrics)
|
| 164 |
-
with open(csv_cases, "w", newline="", encoding="utf-8") as f:
|
| 165 |
-
w = csv.DictWriter(f, fieldnames=["name", "parsed", "expected", "correct_fields", "total_fields"])
|
| 166 |
-
w.writeheader()
|
| 167 |
-
w.writerows(detailed_rows)
|
| 168 |
-
|
| 169 |
-
return {"summary": report, "paths": {"json_report": json_path, "csv_fields": csv_fields, "csv_cases": csv_cases}}
|
|
|
|
| 1 |
# training/gold_tester.py
|
| 2 |
+
# ------------------------------------------------------------
|
| 3 |
+
# Stage 10A: Evaluate parsers on gold tests.
|
| 4 |
+
# This MUST NOT crash during import.
|
| 5 |
+
# ------------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
from typing import Dict, Any, List
|
| 12 |
+
|
| 13 |
from engine.parser_rules import parse_text_rules
|
| 14 |
+
from engine.parser_ext import parse_text_extended
|
| 15 |
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
+
GOLD_PATH = "training/gold_tests.json"
|
| 18 |
+
REPORT_DIR = "reports"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _load_gold_tests() -> List[Dict[str, Any]]:
|
| 22 |
+
if not os.path.exists(GOLD_PATH):
|
| 23 |
+
return []
|
| 24 |
with open(GOLD_PATH, "r", encoding="utf-8") as f:
|
| 25 |
+
try:
|
| 26 |
+
data = json.load(f)
|
| 27 |
+
return data if isinstance(data, list) else []
|
| 28 |
+
except Exception:
|
| 29 |
+
return []
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def run_gold_tests(mode: str = "rules") -> Dict[str, Any]:
|
| 33 |
+
gold_tests = _load_gold_tests()
|
| 34 |
+
if not gold_tests:
|
| 35 |
+
return {
|
| 36 |
+
"summary": {
|
| 37 |
+
"mode": mode,
|
| 38 |
+
"tests": 0,
|
| 39 |
+
"total_correct": 0,
|
| 40 |
+
"total_fields": 0,
|
| 41 |
+
"overall_accuracy": 0.0,
|
| 42 |
+
"proposals_path": "data/extended_proposals.jsonl",
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
os.makedirs(REPORT_DIR, exist_ok=True)
|
| 47 |
+
|
| 48 |
+
wrong_cases = []
|
| 49 |
+
total_correct = 0
|
| 50 |
+
total_fields = 0
|
| 51 |
+
|
| 52 |
+
for idx, test in enumerate(gold_tests):
|
| 53 |
+
text = test.get("input", "")
|
| 54 |
+
expected = test.get("expected", {})
|
| 55 |
+
|
| 56 |
+
if mode == "rules":
|
| 57 |
+
parsed = parse_text_rules(text).get("parsed_fields", {})
|
| 58 |
+
elif mode == "rules+extended":
|
| 59 |
+
rule_fields = parse_text_rules(text).get("parsed_fields", {})
|
| 60 |
+
ext_fields = parse_text_extended(text).get("parsed_fields", {})
|
| 61 |
+
parsed = {**rule_fields, **ext_fields}
|
| 62 |
else:
|
| 63 |
+
parsed = {}
|
| 64 |
+
|
| 65 |
+
# Compare field-by-field
|
| 66 |
+
correct_count = 0
|
| 67 |
+
for key, val in expected.items():
|
| 68 |
+
total_fields += 1
|
| 69 |
+
if key in parsed and str(parsed[key]).strip().lower() == str(val).strip().lower():
|
| 70 |
+
correct_count += 1
|
| 71 |
+
|
| 72 |
+
total_correct += correct_count
|
| 73 |
+
|
| 74 |
+
if correct_count < len(expected):
|
| 75 |
+
wrong_cases.append(idx)
|
| 76 |
+
|
| 77 |
+
accuracy = total_correct / total_fields if total_fields else 0.0
|
| 78 |
+
|
| 79 |
+
summary = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
"mode": mode,
|
| 81 |
+
"tests": len(gold_tests),
|
| 82 |
+
"total_correct": total_correct,
|
| 83 |
+
"total_fields": total_fields,
|
| 84 |
+
"overall_accuracy": accuracy,
|
| 85 |
+
"wrong_cases": wrong_cases,
|
| 86 |
+
"proposals_path": "data/extended_proposals.jsonl",
|
|
|
|
|
|
|
|
|
|
| 87 |
}
|
| 88 |
+
|
| 89 |
+
return {"summary": summary}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|