EphAsad commited on
Commit
341d967
·
verified ·
1 Parent(s): 584a0a9

Update engine/parser_fusion.py

Browse files
Files changed (1) hide show
  1. engine/parser_fusion.py +338 -333
engine/parser_fusion.py CHANGED
@@ -1,334 +1,339 @@
1
- # engine/parser_fusion.py
2
- # ------------------------------------------------------------
3
- # Tri-Parser Fusion — Stage 12B (Weighted, SOTA-style)
4
- #
5
- # This module combines:
6
- # - Rule parser (parser_rules.parse_text_rules)
7
- # - Extended parser (parser_ext.parse_text_extended)
8
- # - LLM parser (parser_llm.parse_llm) [optional]
9
- #
10
- # using per-field reliability weights learned in Stage 12A
11
- # and stored in:
12
- # data/field_weights.json
13
- #
14
- # Behaviour:
15
- # - For each field, gather predictions from available parsers.
16
- # - For that field, load weights:
17
- # field_weights[field] (if present)
18
- # else global weights
19
- # else equal weights across available parsers
20
- # - Discard parsers that:
21
- # * did not predict the field
22
- # * or only predicted "Unknown"
23
- # - Group by predicted value and sum the weights of parsers
24
- # that voted for each value.
25
- # - Choose the value with highest total weight.
26
- # Tie-break: prefer rules > extended > llm if needed.
27
- #
28
- # Output format:
29
- # {
30
- # "fused_fields": { field: value, ... }, # used by DB identifier AND genus ML
31
- # "by_parser": {
32
- # "rules": { ... },
33
- # "extended": { ... },
34
- # "llm": { ... } # may be empty
35
- # },
36
- # "votes": {
37
- # field_name: {
38
- # "per_parser": {
39
- # "rules": {"value": "Positive", "weight": 0.95},
40
- # "extended": {"value": "Unknown", "weight": 0.03},
41
- # ...
42
- # },
43
- # "summed": {
44
- # "Positive": 0.97,
45
- # "Negative": 0.02
46
- # },
47
- # "chosen": "Positive"
48
- # },
49
- # ...
50
- # },
51
- # "weights_meta": {
52
- # "has_weights_file": True/False,
53
- # "weights_path": "data/field_weights.json",
54
- # "meta": { ... } # from file if present
55
- # }
56
- # }
57
- # ------------------------------------------------------------
58
-
59
- from __future__ import annotations
60
-
61
- import json
62
- import os
63
- from typing import Any, Dict, Optional
64
-
65
- from engine.parser_rules import parse_text_rules
66
- from engine.parser_ext import parse_text_extended
67
-
68
- # Optional LLM parser
69
- try:
70
- from engine.parser_llm import parse_llm as parse_text_llm # type: ignore
71
- HAS_LLM = True
72
- except Exception:
73
- parse_text_llm = None # type: ignore
74
- HAS_LLM = False
75
-
76
- # Path to learned weights
77
- FIELD_WEIGHTS_PATH = os.path.join("data", "field_weights.json")
78
-
79
- UNKNOWN = "Unknown"
80
- PARSER_ORDER = ["rules", "extended", "llm"] # tie-breaking priority
81
-
82
-
83
- # ------------------------------------------------------------
84
- # Weights loading and helpers
85
- # ------------------------------------------------------------
86
-
87
- def _load_field_weights(path: str = FIELD_WEIGHTS_PATH) -> Dict[str, Any]:
88
- """
89
- Load the JSON weights file produced by Stage 12A.
90
-
91
- Expected structure:
92
- {
93
- "global": { "rules": 0.7, "extended": 0.2, "llm": 0.1 },
94
- "fields": {
95
- "DNase": {
96
- "rules": 0.95,
97
- "extended": 0.03,
98
- "llm": 0.02,
99
- "support": 123
100
- },
101
- ...
102
- },
103
- "meta": { ... }
104
- }
105
-
106
- If the file is missing or broken, fall back to empty dict,
107
- triggering equal-weight behaviour later.
108
- """
109
- if not os.path.exists(path):
110
- return {}
111
-
112
- try:
113
- with open(path, "r", encoding="utf-8") as f:
114
- obj = json.load(f)
115
- return obj if isinstance(obj, dict) else {}
116
- except Exception:
117
- return {}
118
-
119
-
120
- FIELD_WEIGHTS_RAW: Dict[str, Any] = _load_field_weights()
121
- HAS_WEIGHTS_FILE: bool = bool(FIELD_WEIGHTS_RAW)
122
-
123
-
124
- def _normalise_scores(scores: Dict[str, float]) -> Dict[str, float]:
125
- """
126
- Normalise parser -> score into weights summing to 1.
127
- If all scores are zero or dict is empty, return equal weights.
128
- """
129
- cleaned = {k: max(0.0, float(v)) for k, v in scores.items()}
130
- total = sum(cleaned.values())
131
-
132
- if total <= 0:
133
- n = len(cleaned) or 1
134
- return {k: 1.0 / n for k in cleaned}
135
-
136
- return {k: v / total for k, v in cleaned.items()}
137
-
138
-
139
- def _get_base_weights_for_parsers(include_llm: bool) -> Dict[str, float]:
140
- """
141
- Equal-weight distribution across available parsers.
142
- Used when no learned weights are available.
143
- """
144
- parsers = ["rules", "extended"]
145
- if include_llm:
146
- parsers.append("llm")
147
-
148
- n = len(parsers) or 1
149
- return {p: 1.0 / n for p in parsers}
150
-
151
-
152
- def _get_weights_for_field(field_name: str, include_llm: bool) -> Dict[str, float]:
153
- """
154
- Get weights for a specific field.
155
-
156
- Priority:
157
- 1) FIELD_WEIGHTS_RAW["fields"][field_name]
158
- 2) FIELD_WEIGHTS_RAW["global"]
159
- 3) Equal weights
160
-
161
- Always:
162
- - Drop 'llm' if include_llm == False
163
- - Normalise
164
- """
165
- if not FIELD_WEIGHTS_RAW:
166
- return _normalise_scores(_get_base_weights_for_parsers(include_llm))
167
-
168
- fields_block = FIELD_WEIGHTS_RAW.get("fields", {}) or {}
169
- global_block = FIELD_WEIGHTS_RAW.get("global", {}) or {}
170
-
171
- raw: Dict[str, float] = {}
172
-
173
- field_entry = fields_block.get(field_name)
174
- if isinstance(field_entry, dict):
175
- for k, v in field_entry.items():
176
- if k in ("rules", "extended", "llm"):
177
- raw[k] = float(v)
178
-
179
- if not raw and isinstance(global_block, dict):
180
- for k, v in global_block.items():
181
- if k in ("rules", "extended", "llm"):
182
- raw[k] = float(v)
183
-
184
- if not raw:
185
- raw = _get_base_weights_for_parsers(include_llm)
186
-
187
- if not include_llm:
188
- raw.pop("llm", None)
189
-
190
- if not raw:
191
- raw = _get_base_weights_for_parsers(include_llm=False)
192
-
193
- return _normalise_scores(raw)
194
-
195
-
196
- # ------------------------------------------------------------
197
- # Fusion logic
198
- # ------------------------------------------------------------
199
-
200
- def _clean_pred_value(val: Optional[str]) -> Optional[str]:
201
- """
202
- Treat None, empty string, or explicit "Unknown" as missing.
203
- """
204
- if val is None:
205
- return None
206
-
207
- s = str(val).strip()
208
- if not s:
209
- return None
210
-
211
- if s.lower() == UNKNOWN.lower():
212
- return None
213
-
214
- return s
215
-
216
-
217
- def parse_text_fused(text: str, use_llm: Optional[bool] = None) -> Dict[str, Any]:
218
- """
219
- Main tri-parser fusion entrypoint.
220
-
221
- Parameters
222
- ----------
223
- text : str
224
- use_llm : bool or None
225
- True -> include LLM
226
- False -> exclude LLM
227
- None -> include if available
228
-
229
- Returns
230
- -------
231
- Dict[str, Any]
232
- Full fusion output including votes and per-parser breakdowns.
233
- """
234
- original = text or ""
235
- include_llm = HAS_LLM if use_llm is None else bool(use_llm)
236
-
237
- rules_out = parse_text_rules(original) or {}
238
- ext_out = parse_text_extended(original) or {}
239
-
240
- rules_fields = dict(rules_out.get("parsed_fields", {}))
241
- ext_fields = dict(ext_out.get("parsed_fields", {}))
242
-
243
- llm_fields: Dict[str, Any] = {}
244
- if include_llm and parse_text_llm is not None:
245
- try:
246
- llm_out = parse_text_llm(original)
247
- if isinstance(llm_out, dict):
248
- if "parsed_fields" in llm_out:
249
- llm_fields = dict(llm_out.get("parsed_fields", {}))
250
- else:
251
- llm_fields = {str(k): v for k, v in llm_out.items()}
252
- except Exception:
253
- llm_fields = {}
254
- else:
255
- include_llm = False
256
-
257
- by_parser: Dict[str, Dict[str, Any]] = {
258
- "rules": rules_fields,
259
- "extended": ext_fields,
260
- "llm": llm_fields if include_llm else {},
261
- }
262
-
263
- candidate_fields = (
264
- set(rules_fields.keys())
265
- | set(ext_fields.keys())
266
- | set(llm_fields.keys())
267
- )
268
-
269
- fused_fields: Dict[str, Any] = {}
270
- votes_debug: Dict[str, Any] = {}
271
-
272
- for field in sorted(candidate_fields):
273
- weights = _get_weights_for_field(field, include_llm)
274
-
275
- parser_preds: Dict[str, Optional[str]] = {
276
- "rules": _clean_pred_value(rules_fields.get(field)),
277
- "extended": _clean_pred_value(ext_fields.get(field)),
278
- "llm": _clean_pred_value(llm_fields.get(field)) if include_llm else None,
279
- }
280
-
281
- per_parser_info: Dict[str, Any] = {}
282
- value_scores: Dict[str, float] = {}
283
-
284
- for parser_name in PARSER_ORDER:
285
- if parser_name == "llm" and not include_llm:
286
- continue
287
-
288
- pred = parser_preds.get(parser_name)
289
- w = float(weights.get(parser_name, 0.0))
290
-
291
- per_parser_info[parser_name] = {
292
- "value": pred if pred is not None else UNKNOWN,
293
- "weight": w,
294
- }
295
-
296
- if pred is not None:
297
- value_scores[pred] = value_scores.get(pred, 0.0) + w
298
-
299
- if not value_scores:
300
- fused_value = UNKNOWN
301
- else:
302
- max_score = max(value_scores.values())
303
- best_values = [v for v, s in value_scores.items() if s == max_score]
304
-
305
- if len(best_values) == 1:
306
- fused_value = best_values[0]
307
- else:
308
- fused_value = best_values[0]
309
- for parser_name in PARSER_ORDER:
310
- if parser_name == "llm" and not include_llm:
311
- continue
312
- if parser_preds.get(parser_name) in best_values:
313
- fused_value = parser_preds[parser_name] # type: ignore
314
- break
315
-
316
- fused_fields[field] = fused_value
317
- votes_debug[field] = {
318
- "per_parser": per_parser_info,
319
- "summed": value_scores,
320
- "chosen": fused_value,
321
- }
322
-
323
- weights_meta = {
324
- "has_weights_file": HAS_WEIGHTS_FILE,
325
- "weights_path": FIELD_WEIGHTS_PATH,
326
- "meta": FIELD_WEIGHTS_RAW.get("meta", {}) if HAS_WEIGHTS_FILE else {},
327
- }
328
-
329
- return {
330
- "fused_fields": fused_fields,
331
- "by_parser": by_parser,
332
- "votes": votes_debug,
333
- "weights_meta": weights_meta,
 
 
 
 
 
334
  }
 
1
+ # engine/parser_fusion.py
2
+ # ------------------------------------------------------------
3
+ # Tri-Parser Fusion — Stage 12B (Weighted, SOTA-style)
4
+ #
5
+ # This module combines:
6
+ # - Rule parser (parser_rules.parse_text_rules)
7
+ # - Extended parser (parser_ext.parse_text_extended)
8
+ # - LLM parser (parser_llm.parse_llm) [optional]
9
+ #
10
+ # using per-field reliability weights learned in Stage 12A
11
+ # and stored in:
12
+ # data/field_weights.json
13
+ #
14
+ # Behaviour:
15
+ # - For each field, gather predictions from available parsers.
16
+ # - For that field, load weights:
17
+ # field_weights[field] (if present)
18
+ # else global weights
19
+ # else equal weights across available parsers
20
+ # - Discard parsers that:
21
+ # * did not predict the field
22
+ # * or only predicted "Unknown"
23
+ # - Group by predicted value and sum the weights of parsers
24
+ # that voted for each value.
25
+ # - Choose the value with highest total weight.
26
+ # Tie-break: prefer rules > extended > llm if needed.
27
+ #
28
+ # Output format:
29
+ # {
30
+ # "fused_fields": { field: value, ... }, # used by DB identifier AND genus ML
31
+ # "by_parser": {
32
+ # "rules": { ... },
33
+ # "extended": { ... },
34
+ # "llm": { ... } # may be empty
35
+ # },
36
+ # "votes": {
37
+ # field_name: {
38
+ # "per_parser": {
39
+ # "rules": {"value": "Positive", "weight": 0.95},
40
+ # "extended": {"value": "Unknown", "weight": 0.03},
41
+ # ...
42
+ # },
43
+ # "summed": {
44
+ # "Positive": 0.97,
45
+ # "Negative": 0.02
46
+ # },
47
+ # "chosen": "Positive"
48
+ # },
49
+ # ...
50
+ # },
51
+ # "weights_meta": {
52
+ # "has_weights_file": True/False,
53
+ # "weights_path": "data/field_weights.json",
54
+ # "meta": { ... } # from file if present
55
+ # }
56
+ # }
57
+ # ------------------------------------------------------------
58
+
59
+ from __future__ import annotations
60
+
61
+ import json
62
+ import os
63
+ from typing import Any, Dict, Optional
64
+
65
+ from engine.parser_rules import parse_text_rules
66
+ from engine.parser_ext import parse_text_extended
67
+
68
+ # Optional LLM parser
69
+ try:
70
+ from engine.parser_llm import parse_llm as parse_text_llm # type: ignore
71
+ HAS_LLM = True
72
+ except Exception:
73
+ parse_text_llm = None # type: ignore
74
+ HAS_LLM = False
75
+
76
+ # Path to learned weights
77
+ FIELD_WEIGHTS_PATH = os.path.join("data", "field_weights.json")
78
+
79
+ UNKNOWN = "Unknown"
80
+ PARSER_ORDER = ["rules", "extended", "llm"] # tie-breaking priority
81
+
82
+
83
+ # ------------------------------------------------------------
84
+ # Weights loading and helpers
85
+ # ------------------------------------------------------------
86
+
87
+ def _load_field_weights(path: str = FIELD_WEIGHTS_PATH) -> Dict[str, Any]:
88
+ """
89
+ Load the JSON weights file produced by Stage 12A.
90
+
91
+ Expected structure:
92
+ {
93
+ "global": { "rules": 0.7, "extended": 0.2, "llm": 0.1 },
94
+ "fields": {
95
+ "DNase": {
96
+ "rules": 0.95,
97
+ "extended": 0.03,
98
+ "llm": 0.02,
99
+ "support": 123
100
+ },
101
+ ...
102
+ },
103
+ "meta": { ... }
104
+ }
105
+
106
+ If the file is missing or broken, fall back to empty dict,
107
+ triggering equal-weight behaviour later.
108
+ """
109
+ if not os.path.exists(path):
110
+ return {}
111
+
112
+ try:
113
+ with open(path, "r", encoding="utf-8") as f:
114
+ obj = json.load(f)
115
+ return obj if isinstance(obj, dict) else {}
116
+ except Exception:
117
+ return {}
118
+
119
+
120
+ FIELD_WEIGHTS_RAW: Dict[str, Any] = _load_field_weights()
121
+ HAS_WEIGHTS_FILE: bool = bool(FIELD_WEIGHTS_RAW)
122
+
123
+
124
+ def _normalise_scores(scores: Dict[str, float]) -> Dict[str, float]:
125
+ """
126
+ Normalise parser -> score into weights summing to 1.
127
+ If all scores are zero or dict is empty, return equal weights.
128
+ """
129
+ cleaned = {k: max(0.0, float(v)) for k, v in scores.items()}
130
+ total = sum(cleaned.values())
131
+
132
+ if total <= 0:
133
+ n = len(cleaned) or 1
134
+ return {k: 1.0 / n for k in cleaned}
135
+
136
+ return {k: v / total for k, v in cleaned.items()}
137
+
138
+
139
+ def _get_base_weights_for_parsers(include_llm: bool) -> Dict[str, float]:
140
+ """
141
+ Equal-weight distribution across available parsers.
142
+ Used when no learned weights are available.
143
+ """
144
+ parsers = ["rules", "extended"]
145
+ if include_llm:
146
+ parsers.append("llm")
147
+
148
+ n = len(parsers) or 1
149
+ return {p: 1.0 / n for p in parsers}
150
+
151
+
152
+ def _get_weights_for_field(field_name: str, include_llm: bool) -> Dict[str, float]:
153
+ """
154
+ Get weights for a specific field.
155
+
156
+ Priority:
157
+ 1) FIELD_WEIGHTS_RAW["fields"][field_name]
158
+ 2) FIELD_WEIGHTS_RAW["global"]
159
+ 3) Equal weights
160
+
161
+ Always:
162
+ - Drop 'llm' if include_llm == False
163
+ - Normalise
164
+ """
165
+ if not FIELD_WEIGHTS_RAW:
166
+ return _normalise_scores(_get_base_weights_for_parsers(include_llm))
167
+
168
+ fields_block = FIELD_WEIGHTS_RAW.get("fields", {}) or {}
169
+ global_block = FIELD_WEIGHTS_RAW.get("global", {}) or {}
170
+
171
+ raw: Dict[str, float] = {}
172
+
173
+ field_entry = fields_block.get(field_name)
174
+ if isinstance(field_entry, dict):
175
+ for k, v in field_entry.items():
176
+ if k in ("rules", "extended", "llm"):
177
+ raw[k] = float(v)
178
+
179
+ if not raw and isinstance(global_block, dict):
180
+ for k, v in global_block.items():
181
+ if k in ("rules", "extended", "llm"):
182
+ raw[k] = float(v)
183
+
184
+ if not raw:
185
+ raw = _get_base_weights_for_parsers(include_llm)
186
+
187
+ if not include_llm:
188
+ raw.pop("llm", None)
189
+
190
+ if not raw:
191
+ raw = _get_base_weights_for_parsers(include_llm=False)
192
+
193
+ return _normalise_scores(raw)
194
+
195
+
196
+ # ------------------------------------------------------------
197
+ # Fusion logic
198
+ # ------------------------------------------------------------
199
+
200
+ def _clean_pred_value(val: Optional[str]) -> Optional[str]:
201
+ """
202
+ Treat None, empty string, or explicit "Unknown" as missing.
203
+ """
204
+ if val is None:
205
+ return None
206
+
207
+ s = str(val).strip()
208
+ if not s:
209
+ return None
210
+
211
+ if s.lower() == UNKNOWN.lower():
212
+ return None
213
+
214
+ return s
215
+
216
+
217
+ def parse_text_fused(text: str, use_llm: Optional[bool] = None) -> Dict[str, Any]:
218
+ """
219
+ Main tri-parser fusion entrypoint.
220
+
221
+ Parameters
222
+ ----------
223
+ text : str
224
+ use_llm : bool or None
225
+ True -> include LLM
226
+ False -> exclude LLM
227
+ None -> include if available
228
+
229
+ Returns
230
+ -------
231
+ Dict[str, Any]
232
+ Full fusion output including votes and per-parser breakdowns.
233
+ """
234
+ original = text or ""
235
+ include_llm = HAS_LLM if use_llm is None else bool(use_llm)
236
+
237
+ rules_out = parse_text_rules(original) or {}
238
+ ext_out = parse_text_extended(original) or {}
239
+
240
+ rules_fields = dict(rules_out.get("parsed_fields", {}))
241
+ ext_fields = dict(ext_out.get("parsed_fields", {}))
242
+
243
+ llm_fields: Dict[str, Any] = {}
244
+ if include_llm and parse_text_llm is not None:
245
+ try:
246
+ merged_existing = {}
247
+ merged_existing.update(rules_fields)
248
+ merged_existing.update(ext_fields)
249
+
250
+ llm_out = parse_text_llm(original, existing_fields=merged_existing)
251
+
252
+ if isinstance(llm_out, dict):
253
+ if "parsed_fields" in llm_out:
254
+ llm_fields = dict(llm_out.get("parsed_fields", {}))
255
+ else:
256
+ llm_fields = {str(k): v for k, v in llm_out.items()}
257
+ except Exception:
258
+ llm_fields = {}
259
+ else:
260
+ include_llm = False
261
+
262
+ by_parser: Dict[str, Dict[str, Any]] = {
263
+ "rules": rules_fields,
264
+ "extended": ext_fields,
265
+ "llm": llm_fields if include_llm else {},
266
+ }
267
+
268
+ candidate_fields = (
269
+ set(rules_fields.keys())
270
+ | set(ext_fields.keys())
271
+ | set(llm_fields.keys())
272
+ )
273
+
274
+ fused_fields: Dict[str, Any] = {}
275
+ votes_debug: Dict[str, Any] = {}
276
+
277
+ for field in sorted(candidate_fields):
278
+ weights = _get_weights_for_field(field, include_llm)
279
+
280
+ parser_preds: Dict[str, Optional[str]] = {
281
+ "rules": _clean_pred_value(rules_fields.get(field)),
282
+ "extended": _clean_pred_value(ext_fields.get(field)),
283
+ "llm": _clean_pred_value(llm_fields.get(field)) if include_llm else None,
284
+ }
285
+
286
+ per_parser_info: Dict[str, Any] = {}
287
+ value_scores: Dict[str, float] = {}
288
+
289
+ for parser_name in PARSER_ORDER:
290
+ if parser_name == "llm" and not include_llm:
291
+ continue
292
+
293
+ pred = parser_preds.get(parser_name)
294
+ w = float(weights.get(parser_name, 0.0))
295
+
296
+ per_parser_info[parser_name] = {
297
+ "value": pred if pred is not None else UNKNOWN,
298
+ "weight": w,
299
+ }
300
+
301
+ if pred is not None:
302
+ value_scores[pred] = value_scores.get(pred, 0.0) + w
303
+
304
+ if not value_scores:
305
+ fused_value = UNKNOWN
306
+ else:
307
+ max_score = max(value_scores.values())
308
+ best_values = [v for v, s in value_scores.items() if s == max_score]
309
+
310
+ if len(best_values) == 1:
311
+ fused_value = best_values[0]
312
+ else:
313
+ fused_value = best_values[0]
314
+ for parser_name in PARSER_ORDER:
315
+ if parser_name == "llm" and not include_llm:
316
+ continue
317
+ if parser_preds.get(parser_name) in best_values:
318
+ fused_value = parser_preds[parser_name] # type: ignore
319
+ break
320
+
321
+ fused_fields[field] = fused_value
322
+ votes_debug[field] = {
323
+ "per_parser": per_parser_info,
324
+ "summed": value_scores,
325
+ "chosen": fused_value,
326
+ }
327
+
328
+ weights_meta = {
329
+ "has_weights_file": HAS_WEIGHTS_FILE,
330
+ "weights_path": FIELD_WEIGHTS_PATH,
331
+ "meta": FIELD_WEIGHTS_RAW.get("meta", {}) if HAS_WEIGHTS_FILE else {},
332
+ }
333
+
334
+ return {
335
+ "fused_fields": fused_fields,
336
+ "by_parser": by_parser,
337
+ "votes": votes_debug,
338
+ "weights_meta": weights_meta,
339
  }