EphAsad commited on
Commit
a7f8964
·
verified ·
1 Parent(s): 24f813a

Update training/field_weight_trainer.py

Browse files
Files changed (1) hide show
  1. training/field_weight_trainer.py +136 -84
training/field_weight_trainer.py CHANGED
@@ -2,9 +2,22 @@
2
  # ------------------------------------------------------------
3
  # Stage 12A — Train Per-Field Parser Weights from Gold Tests
4
  #
5
- # LLM is evaluated in REPAIR MODE when enabled:
6
- # - rules + extended first
7
- # - LLM receives existing_fields
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  # ------------------------------------------------------------
9
 
10
  from __future__ import annotations
@@ -20,11 +33,11 @@ from typing import Any, Dict, List, Optional, Tuple
20
  from engine.parser_rules import parse_text_rules
21
  from engine.parser_ext import parse_text_extended
22
 
23
- # LLM parser (optional, repair-only)
24
  try:
25
  from engine.parser_llm import parse_llm as parse_text_llm_local
26
  except Exception:
27
- parse_text_llm_local = None
28
 
29
 
30
  # ------------------------------------------------------------
@@ -63,7 +76,9 @@ class FieldStats:
63
  if self.total() == 0:
64
  return 0.0
65
  denom = self.correct + self.wrong + missing_penalty * self.missing
66
- return self.correct / denom if denom > 0 else 0.0
 
 
67
 
68
 
69
  # ------------------------------------------------------------
@@ -88,64 +103,58 @@ def _extract_text_and_expected(test_obj: Dict[str, Any]) -> Tuple[str, Dict[str,
88
  or test_obj.get("raw")
89
  or ""
90
  )
91
- text = text if isinstance(text, str) else str(text)
 
92
 
93
  expected: Dict[str, str] = {}
94
 
95
- for key in ("expected", "expected_core", "expected_extended"):
96
- block = test_obj.get(key)
97
- if isinstance(block, dict):
98
- for k, v in block.items():
99
- expected[str(k)] = str(v)
 
 
 
 
 
 
 
100
 
101
  return text, expected
102
 
103
 
104
  # ------------------------------------------------------------
105
- # Parser Execution (LLM = repair-only)
106
  # ------------------------------------------------------------
107
 
108
  def _get_parser_predictions(text: str, include_llm: bool = True) -> Dict[str, Dict[str, str]]:
109
  results: Dict[str, Dict[str, str]] = {}
110
 
111
- # 1) Rules
112
  r = parse_text_rules(text)
113
- rules_fields = dict(r.get("parsed_fields", {}))
114
- results["rules"] = rules_fields
115
 
116
- # 2) Extended
117
  e = parse_text_extended(text)
118
- ext_fields = dict(e.get("parsed_fields", {}))
119
- results["extended"] = ext_fields
120
 
121
- # 3) LLM (repair mode)
122
- llm_fields: Dict[str, str] = {}
123
  if include_llm and parse_text_llm_local is not None:
124
  try:
125
- merged_existing = {}
126
- merged_existing.update(rules_fields)
127
- merged_existing.update(ext_fields)
128
-
129
- llm_out = parse_text_llm_local(
130
- text,
131
- existing_fields=merged_existing,
132
- )
133
-
134
- if isinstance(llm_out, dict):
135
- llm_fields = dict(llm_out.get("parsed_fields", {}))
136
  except Exception:
137
- llm_fields = {}
 
138
 
139
- results["llm"] = llm_fields
140
  return results
141
 
142
 
143
  def _outcome_for_field(expected_val: str, predicted_val: Optional[str]) -> ParserOutcome:
144
  if predicted_val is None:
145
- return ParserOutcome(None, False, False, True)
146
  if predicted_val == expected_val:
147
- return ParserOutcome(predicted_val, True, False, False)
148
- return ParserOutcome(predicted_val, False, True, False)
149
 
150
 
151
  # ------------------------------------------------------------
@@ -158,6 +167,7 @@ def _compute_stats_from_gold(
158
  ):
159
  field_stats = defaultdict(lambda: defaultdict(FieldStats))
160
  global_stats = defaultdict(FieldStats)
 
161
  total_samples = 0
162
 
163
  for sample in gold_tests:
@@ -169,67 +179,85 @@ def _compute_stats_from_gold(
169
  preds = _get_parser_predictions(text, include_llm=include_llm)
170
 
171
  for field, expected_val in expected.items():
172
- for parser_name in ("rules", "extended", "llm"):
 
173
  if parser_name == "llm" and not include_llm:
174
  continue
175
 
176
  pred_val = preds.get(parser_name, {}).get(field)
 
177
  outcome = _outcome_for_field(expected_val, pred_val)
178
 
179
  fs = field_stats[field][parser_name]
180
- gs = global_stats[parser_name]
181
-
182
  if outcome.correct:
183
  fs.correct += 1
184
- gs.correct += 1
185
- elif outcome.wrong:
186
  fs.wrong += 1
187
- gs.wrong += 1
188
- else:
189
  fs.missing += 1
 
 
 
 
 
 
 
190
  gs.missing += 1
191
 
192
  return field_stats, global_stats, total_samples
193
 
194
 
195
- # ------------------------------------------------------------
196
- # Weight Construction
197
- # ------------------------------------------------------------
198
-
199
- def _normalise(weights: Dict[str, float]) -> Dict[str, float]:
200
- adjusted = {k: max(SMOOTHING, v) for k, v in weights.items()}
201
  total = sum(adjusted.values())
202
- return {k: v / total for k, v in adjusted.items()} if total > 0 else {}
 
 
 
203
 
204
 
205
- def _build_weights_json(field_stats, global_stats, total_samples, include_llm=True):
206
- raw_global = {
207
- p: stats.score(MISSING_PENALTY)
208
- for p, stats in global_stats.items()
209
- if include_llm or p != "llm"
210
- }
 
 
 
 
 
 
211
 
212
  global_weights = _normalise(raw_global)
 
 
213
  fields_block = {}
214
 
215
- for field, stats_dict in field_stats.items():
216
- raw_scores = {
217
- p: s.score(MISSING_PENALTY)
218
- for p, s in stats_dict.items()
219
- if include_llm or p != "llm"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  }
221
- support = sum(s.total() for s in stats_dict.values())
222
-
223
- weights = (
224
- _normalise(raw_scores)
225
- if support >= 5
226
- else _normalise({
227
- p: 0.7 * global_weights.get(p, 0.0) + 0.3 * raw_scores.get(p, 0.0)
228
- for p in global_weights
229
- })
230
- )
231
-
232
- fields_block[field] = {**weights, "support": support}
233
 
234
  return {
235
  "global": global_weights,
@@ -237,8 +265,8 @@ def _build_weights_json(field_stats, global_stats, total_samples, include_llm=Tr
237
  "meta": {
238
  "total_samples": total_samples,
239
  "missing_penalty": MISSING_PENALTY,
 
240
  "include_llm": include_llm,
241
- "llm_mode": "repair-only",
242
  },
243
  }
244
 
@@ -252,14 +280,28 @@ def train_field_weights(
252
  out_path: str = DEFAULT_OUT_PATH,
253
  include_llm: bool = False,
254
  ):
 
255
  gold = _load_gold_tests(gold_path)
256
- field_stats, global_stats, total = _compute_stats_from_gold(gold, include_llm)
257
- weights = _build_weights_json(field_stats, global_stats, total, include_llm)
 
 
 
 
 
 
 
 
258
 
259
- os.makedirs(os.path.dirname(out_path), exist_ok=True)
 
 
 
 
260
  with open(out_path, "w", encoding="utf-8") as f:
261
- json.dump(weights, f, indent=2)
262
 
 
263
  return weights
264
 
265
 
@@ -267,12 +309,22 @@ def train_field_weights(
267
  # CLI
268
  # ------------------------------------------------------------
269
 
270
- def main():
271
- p = argparse.ArgumentParser()
 
 
272
  p.add_argument("--include-llm", action="store_true")
273
- args = p.parse_args()
274
- train_field_weights(include_llm=args.include_llm)
 
 
 
 
 
 
 
 
275
 
276
 
277
  if __name__ == "__main__":
278
- main()
 
2
  # ------------------------------------------------------------
3
  # Stage 12A — Train Per-Field Parser Weights from Gold Tests
4
  #
5
+ # Produces:
6
+ # data/field_weights.json
7
+ #
8
+ # This script computes reliability scores for:
9
+ # - parser_rules
10
+ # - parser_ext
11
+ # - parser_llm
12
+ #
13
+ # and outputs:
14
+ # {
15
+ # "global": { ... },
16
+ # "fields": { field -> weights },
17
+ # "meta": { ... }
18
+ # }
19
+ #
20
+ # These weights are used by parser_fusion (Stage 12B).
21
  # ------------------------------------------------------------
22
 
23
  from __future__ import annotations
 
33
  from engine.parser_rules import parse_text_rules
34
  from engine.parser_ext import parse_text_extended
35
 
36
+ # LLM parser (optional)
37
  try:
38
  from engine.parser_llm import parse_llm as parse_text_llm_local
39
  except Exception:
40
+ parse_text_llm_local = None # gracefully degrade if LLM unavailable
41
 
42
 
43
  # ------------------------------------------------------------
 
76
  if self.total() == 0:
77
  return 0.0
78
  denom = self.correct + self.wrong + missing_penalty * self.missing
79
+ if denom == 0:
80
+ return 0.0
81
+ return self.correct / denom
82
 
83
 
84
  # ------------------------------------------------------------
 
103
  or test_obj.get("raw")
104
  or ""
105
  )
106
+ if not isinstance(text, str):
107
+ text = str(text)
108
 
109
  expected: Dict[str, str] = {}
110
 
111
+ if isinstance(test_obj.get("expected"), dict):
112
+ for k, v in test_obj["expected"].items():
113
+ expected[str(k)] = str(v)
114
+ return text, expected
115
+
116
+ if isinstance(test_obj.get("expected_core"), dict):
117
+ for k, v in test_obj["expected_core"].items():
118
+ expected[str(k)] = str(v)
119
+
120
+ if isinstance(test_obj.get("expected_extended"), dict):
121
+ for k, v in test_obj["expected_extended"].items():
122
+ expected[str(k)] = str(v)
123
 
124
  return text, expected
125
 
126
 
127
  # ------------------------------------------------------------
128
+ # Parser Execution
129
  # ------------------------------------------------------------
130
 
131
  def _get_parser_predictions(text: str, include_llm: bool = True) -> Dict[str, Dict[str, str]]:
132
  results: Dict[str, Dict[str, str]] = {}
133
 
 
134
  r = parse_text_rules(text)
135
+ results["rules"] = dict(r.get("parsed_fields", {}))
 
136
 
 
137
  e = parse_text_extended(text)
138
+ results["extended"] = dict(e.get("parsed_fields", {}))
 
139
 
140
+ llm_values: Dict[str, str] = {}
 
141
  if include_llm and parse_text_llm_local is not None:
142
  try:
143
+ llm_out = parse_text_llm_local(text)
144
+ llm_values = dict(llm_out.get("parsed_fields", {}))
 
 
 
 
 
 
 
 
 
145
  except Exception:
146
+ llm_values = {}
147
+ results["llm"] = llm_values
148
 
 
149
  return results
150
 
151
 
152
  def _outcome_for_field(expected_val: str, predicted_val: Optional[str]) -> ParserOutcome:
153
  if predicted_val is None:
154
+ return ParserOutcome(prediction=None, correct=False, wrong=False, missing=True)
155
  if predicted_val == expected_val:
156
+ return ParserOutcome(prediction=predicted_val, correct=True, wrong=False, missing=False)
157
+ return ParserOutcome(prediction=predicted_val, correct=False, wrong=True, missing=False)
158
 
159
 
160
  # ------------------------------------------------------------
 
167
  ):
168
  field_stats = defaultdict(lambda: defaultdict(FieldStats))
169
  global_stats = defaultdict(FieldStats)
170
+
171
  total_samples = 0
172
 
173
  for sample in gold_tests:
 
179
  preds = _get_parser_predictions(text, include_llm=include_llm)
180
 
181
  for field, expected_val in expected.items():
182
+ expected_val = str(expected_val)
183
+ for parser_name in ["rules", "extended", "llm"]:
184
  if parser_name == "llm" and not include_llm:
185
  continue
186
 
187
  pred_val = preds.get(parser_name, {}).get(field)
188
+
189
  outcome = _outcome_for_field(expected_val, pred_val)
190
 
191
  fs = field_stats[field][parser_name]
 
 
192
  if outcome.correct:
193
  fs.correct += 1
194
+ if outcome.wrong:
 
195
  fs.wrong += 1
196
+ if outcome.missing:
 
197
  fs.missing += 1
198
+
199
+ gs = global_stats[parser_name]
200
+ if outcome.correct:
201
+ gs.correct += 1
202
+ if outcome.wrong:
203
+ gs.wrong += 1
204
+ if outcome.missing:
205
  gs.missing += 1
206
 
207
  return field_stats, global_stats, total_samples
208
 
209
 
210
+ def _normalise(weights: Dict[str, float], smoothing: float = SMOOTHING) -> Dict[str, float]:
211
+ adjusted = {k: max(smoothing, v) for k, v in weights.items()}
 
 
 
 
212
  total = sum(adjusted.values())
213
+ if total <= 0:
214
+ n = len(adjusted)
215
+ return {k: 1.0 / n for k in adjusted}
216
+ return {k: v / total for k, v in adjusted.items()}
217
 
218
 
219
+ def _build_weights_json(
220
+ field_stats,
221
+ global_stats,
222
+ total_samples,
223
+ include_llm=True,
224
+ ):
225
+ # Global scores
226
+ raw_global = {}
227
+ for parser_name, stats in global_stats.items():
228
+ if parser_name == "llm" and not include_llm:
229
+ continue
230
+ raw_global[parser_name] = stats.score(MISSING_PENALTY)
231
 
232
  global_weights = _normalise(raw_global)
233
+
234
+ # Per-field
235
  fields_block = {}
236
 
237
+ for field_name, stats_dict in field_stats.items():
238
+ raw_scores = {}
239
+ total_support = 0
240
+
241
+ for parser_name, stats in stats_dict.items():
242
+ if parser_name == "llm" and not include_llm:
243
+ continue
244
+ raw_scores[parser_name] = stats.score(MISSING_PENALTY)
245
+ total_support += stats.total()
246
+
247
+ if total_support < 5:
248
+ # low support → blend global + local
249
+ local_norm = _normalise(raw_scores)
250
+ mixed = {}
251
+ for p in global_weights:
252
+ mixed[p] = 0.7 * global_weights[p] + 0.3 * local_norm.get(p, global_weights[p])
253
+ field_w = _normalise(mixed)
254
+ else:
255
+ field_w = _normalise(raw_scores)
256
+
257
+ fields_block[field_name] = {
258
+ **field_w,
259
+ "support": total_support,
260
  }
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
  return {
263
  "global": global_weights,
 
265
  "meta": {
266
  "total_samples": total_samples,
267
  "missing_penalty": MISSING_PENALTY,
268
+ "smoothing": SMOOTHING,
269
  "include_llm": include_llm,
 
270
  },
271
  }
272
 
 
280
  out_path: str = DEFAULT_OUT_PATH,
281
  include_llm: bool = False,
282
  ):
283
+ print(f"[12A] Loading gold tests: {gold_path}")
284
  gold = _load_gold_tests(gold_path)
285
+ print(f"[12A] {len(gold)} gold samples loaded")
286
+
287
+ field_stats, global_stats, total_samples = _compute_stats_from_gold(
288
+ gold, include_llm=include_llm
289
+ )
290
+
291
+ print("[12A] Computing weights...")
292
+ weights = _build_weights_json(
293
+ field_stats, global_stats, total_samples, include_llm=include_llm
294
+ )
295
 
296
+ out_dir = os.path.dirname(out_path)
297
+ if out_dir and not os.path.exists(out_dir):
298
+ os.makedirs(out_dir, exist_ok=True)
299
+
300
+ print(f"[12A] Writing: {out_path}")
301
  with open(out_path, "w", encoding="utf-8") as f:
302
+ json.dump(weights, f, indent=2, ensure_ascii=False)
303
 
304
+ print("[12A] Done.")
305
  return weights
306
 
307
 
 
309
  # CLI
310
  # ------------------------------------------------------------
311
 
312
+ def _parse_args(argv=None):
313
+ p = argparse.ArgumentParser(description="Stage 12A — Train parser weights")
314
+ p.add_argument("--gold", type=str, default=DEFAULT_GOLD_PATH)
315
+ p.add_argument("--out", type=str, default=DEFAULT_OUT_PATH)
316
  p.add_argument("--include-llm", action="store_true")
317
+ return p.parse_args(argv)
318
+
319
+
320
+ def main(argv=None):
321
+ args = _parse_args(argv)
322
+ train_field_weights(
323
+ gold_path=args.gold,
324
+ out_path=args.out,
325
+ include_llm=args.include_llm,
326
+ )
327
 
328
 
329
  if __name__ == "__main__":
330
+ main()