EphAsad commited on
Commit
8e61f92
·
verified ·
1 Parent(s): 949ad27

Update engine/parser_fusion.py

Browse files
Files changed (1) hide show
  1. engine/parser_fusion.py +269 -267
engine/parser_fusion.py CHANGED
@@ -3,120 +3,118 @@
3
  # Tri-Parser Fusion — Stage 12B (Weighted, SOTA-style)
4
  #
5
  # This module combines:
6
- #   - Rule parser (parser_rules.parse_text_rules)
7
- #   - Extended parser (parser_ext.parse_text_extended)
8
- #   - LLM parser (parser_llm.parse_llm)    [optional]
9
  #
10
  # using per-field reliability weights learned in Stage 12A
11
  # and stored in:
12
- #   data/field_weights.json
13
  #
14
  # Behaviour:
15
- #   - For each field, gather predictions from available parsers.
16
- #   - For that field, load weights:
17
- #          field_weights[field]  (if present)
18
- #          else global weights
19
- #          else equal weights across available parsers
20
- #   - Discard parsers that:
21
- #          * did not predict the field
22
- #          * or only predicted "Unknown"
23
- #   - Group by predicted value and sum the weights of parsers
24
- #     that voted for each value.
25
- #   - Choose the value with highest total weight.
26
- #     Tie-break: prefer rules > extended > llm if needed.
27
  #
28
  # Output format:
29
- #   {
30
- #     "fused_fields": { field: value, ... },   # used by DB identifier AND genus ML
31
- #     "by_parser": {
32
- #       "rules": { ... },
33
- #       "extended": { ... },
34
- #       "llm": { ... }   # may be empty
35
- #     },
36
- #     "votes": {
37
- #       field_name: {
38
- #         "per_parser": {
39
- #           "rules": {"value": "Positive", "weight": 0.95},
40
- #           "extended": {"value": "Unknown", "weight": 0.03},
41
- #           ...
42
- #         },
43
- #         "summed": {
44
- #           "Positive": 0.97,
45
- #           "Negative": 0.02
46
- #         },
47
- #         "chosen": "Positive"
48
- #       },
49
- #       ...
50
- #     },
51
- #     "weights_meta": {
52
- #       "has_weights_file": True/False,
53
- #       "weights_path": "data/field_weights.json",
54
- #       "meta": { ... }  # from file if present
55
- #     }
56
- #   }
57
- # ------------------------------------------------------------
58
-
59
- from __future__ import annotations
60
 
61
  import json
62
  import os
63
- from typing import Any, Dict, Optional
64
 
65
  from engine.parser_rules import parse_text_rules
66
- from engine.parser_ext import parse_text_extended
67
 
68
  # Optional LLM parser
69
  try:
70
-     from engine.parser_llm import parse_llm as parse_text_llm  # type: ignore
71
-     HAS_LLM = True
72
  except Exception:
73
-     parse_text_llm = None  # type: ignore
74
-     HAS_LLM = False
75
 
76
  # Path to learned weights
77
- FIELD_WEIGHTS_PATH = os.path.join("data", "field_weights.json")
78
 
79
  UNKNOWN = "Unknown"
80
- PARSER_ORDER = ["rules", "extended", "llm"]  # used for tie-breaking
81
 
82
 
83
  # ------------------------------------------------------------
84
  # Weights loading and helpers
85
- # ------------------------------------------------------------
86
 
87
  def _load_field_weights(path: str = FIELD_WEIGHTS_PATH) -> Dict[str, Any]:
88
-     """
89
-     Load the JSON weights file produced by Stage 12A.
90
-
91
-     Expected structure:
92
-       {
93
-         "global": { "rules": 0.7, "extended": 0.2, "llm": 0.1 },
94
-         "fields": {
95
-           "DNase": {
96
-             "rules": 0.95,
97
-             "extended": 0.03,
98
-             "llm": 0.02,
99
-             "support": 123
100
-           },
101
-           ...
102
-         },
103
-         "meta": { ... }
104
-       }
105
-
106
-     If the file is missing or broken, we fall back to an empty dict,
107
-     which triggers equal-weight behaviour later.
108
-     """
109
-     if not os.path.exists(path):
110
-         return {}
111
-
112
-     try:
113
-         with open(path, "r", encoding="utf-8") as f:
114
-             obj = json.load(f)
115
-         if isinstance(obj, dict):
116
-             return obj
117
-         return {}
118
-     except Exception:
119
-         return {}
120
 
121
 
122
  FIELD_WEIGHTS_RAW: Dict[str, Any] = _load_field_weights()
@@ -124,209 +122,213 @@ HAS_WEIGHTS_FILE: bool = bool(FIELD_WEIGHTS_RAW)
124
 
125
 
126
  def _normalise_scores(scores: Dict[str, float]) -> Dict[str, float]:
127
-     """
128
-     Normalise a dict of parser -> score into weights summing to 1.
129
-     If all scores are zero or dict is empty, return equal weights.
130
-     """
131
-     cleaned = {k: max(0.0, float(v)) for k, v in scores.items()}
132
-     total = sum(cleaned.values())
133
 
134
-     if total <= 0:
135
-         n = len(cleaned) or 1
136
-         return {k: 1.0 / n for k in cleaned}
137
 
138
-     return {k: v / total for k, v in cleaned.items()}
139
 
140
 
141
  def _get_base_weights_for_parsers(include_llm: bool) -> Dict[str, float]:
142
-     """
143
-     Get a naive equal-weight distribution across available parsers.
144
-     Used when no learned weights are available.
145
-     """
146
-     parsers = ["rules", "extended"]
147
-     if include_llm:
148
-         parsers.append("llm")
149
-     n = len(parsers) or 1
150
-     return {p: 1.0 / n for p in parsers}
 
151
 
152
 
153
  def _get_weights_for_field(field_name: str, include_llm: bool) -> Dict[str, float]:
154
-     """
155
-     Get weights for a specific field.
156
 
157
-     Priority:
158
-       1) If FIELD_WEIGHTS_RAW has a 'fields[field_name]' entry,
159
-          use that.
160
-       2) Else if FIELD_WEIGHTS_RAW has 'global', use that.
161
-       3) Else equal weights.
162
 
163
-     In all cases:
164
-       - Drop 'llm' if include_llm == False
165
-       - Normalise
166
-     """
167
-     if not FIELD_WEIGHTS_RAW:
168
-         base = _get_base_weights_for_parsers(include_llm)
169
-         return _normalise_scores(base)
170
 
171
-     fields_block = FIELD_WEIGHTS_RAW.get("fields", {}) or {}
172
-     global_block = FIELD_WEIGHTS_RAW.get("global", {}) or {}
173
 
174
-     raw: Dict[str, float] = {}
175
 
176
-     field_entry = fields_block.get(field_name)
177
-     if isinstance(field_entry, dict):
178
-         for k, v in field_entry.items():
179
-             if k in ("rules", "extended", "llm"):
180
-                 raw[k] = float(v)
181
 
182
-     if not raw and isinstance(global_block, dict):
183
-         for k, v in global_block.items():
184
-             if k in ("rules", "extended", "llm"):
185
-                 raw[k] = float(v)
186
 
187
-     if not raw:
188
-         raw = _get_base_weights_for_parsers(include_llm)
189
 
190
-     if not include_llm and "llm" in raw:
191
-         raw.pop("llm", None)
192
 
193
-     if not raw:
194
-         raw = _get_base_weights_for_parsers(include_llm=False)
195
 
196
-     return _normalise_scores(raw)
197
 
198
 
199
  # ------------------------------------------------------------
200
  # Fusion logic
201
- # ------------------------------------------------------------
202
 
203
  def _clean_pred_value(val: Optional[str]) -> Optional[str]:
204
-     """
205
-     Treat None, "", or explicit "Unknown" as missing for fusion.
206
-     """
207
-     if val is None:
208
-         return None
209
-     s = str(val).strip()
210
-     if not s:
211
-         return None
212
-     if s.lower() == UNKNOWN.lower():
213
-         return None
214
-     return s
 
 
 
215
 
216
 
217
  def parse_text_fused(text: str, use_llm: Optional[bool] = None) -> Dict[str, Any]:
218
-     """
219
-     Main tri-fusion entrypoint.
220
-
221
-     Parameters
222
-     ----------
223
-     text : str
224
-     use_llm : bool or None
225
-         If True include LLM.
226
-         If False skip LLM.
227
-         If None include if HAS_LLM.
228
-
229
-     Returns:
230
-       full fusion output including votes + per-parser summaries.
231
-     """
232
-     original = text or ""
233
-     include_llm = HAS_LLM if use_llm is None else bool(use_llm)
234
-
235
-     rules_out = parse_text_rules(original) or {}
236
-     ext_out = parse_text_extended(original) or {}
237
-
238
-     rules_fields = dict(rules_out.get("parsed_fields", {}))
239
-     ext_fields = dict(ext_out.get("parsed_fields", {}))
240
-
241
-     llm_fields: Dict[str, Any] = {}
242
-     if include_llm and parse_text_llm is not None:
243
-         try:
244
-             llm_out = parse_text_llm(original)
245
-             if isinstance(llm_out, dict):
246
-                 if "parsed_fields" in llm_out:
247
-                     llm_fields = dict(llm_out.get("parsed_fields", {}))
248
-                 else:
249
-                     llm_fields = {str(k): v for k, v in llm_out.items()}
250
-         except Exception:
251
-             llm_fields = {}
252
-     else:
253
-         include_llm = False
254
-
255
-     by_parser: Dict[str, Dict[str, Any]] = {
256
-         "rules": rules_fields,
257
-         "extended": ext_fields,
258
-         "llm": llm_fields if include_llm else {},
259
-     }
260
-
261
-     candidate_fields = set(rules_fields.keys()) | set(ext_fields.keys()) | set(llm_fields.keys())
262
-
263
-     fused_fields: Dict[str, Any] = {}
264
-     votes_debug: Dict[str, Any] = {}
265
-
266
-     for field in sorted(candidate_fields):
267
-         weights = _get_weights_for_field(field, include_llm=include_llm)
268
-
269
-         parser_preds: Dict[str, Optional[str]] = {
270
-             "rules": _clean_pred_value(rules_fields.get(field)),
271
-             "extended": _clean_pred_value(ext_fields.get(field)),
272
-             "llm": _clean_pred_value(llm_fields.get(field)) if include_llm else None,
273
-         }
274
-
275
-         per_parser_info: Dict[str, Any] = {}
276
-         value_scores: Dict[str, float] = {}
277
-
278
-         for parser_name in PARSER_ORDER:
279
-             if parser_name == "llm" and not include_llm:
280
-                 continue
281
-
282
-             pred = parser_preds.get(parser_name)
283
-             w = float(weights.get(parser_name, 0.0))
284
-
285
-             per_parser_info[parser_name] = {
286
-                 "value": pred if pred is not None else UNKNOWN,
287
-                 "weight": w,
288
-             }
289
-
290
-             if pred is None:
291
-                 continue
292
-
293
-             value_scores[pred] = value_scores.get(pred, 0.0) + w
294
-
295
-         if not value_scores:
296
-             fused_value = UNKNOWN
297
-         else:
298
-             max_score = max(value_scores.values())
299
-             best_values = [v for v, s in value_scores.items() if s == max_score]
300
-
301
-             if len(best_values) == 1:
302
-                 fused_value = best_values[0]
303
-             else:
304
-                 fused_value = best_values[0]
305
-                 for parser_name in PARSER_ORDER:
306
-                     if parser_name == "llm" and not include_llm:
307
-                         continue
308
-                     pred = parser_preds.get(parser_name)
309
-                     if pred in best_values:
310
-                         fused_value = pred
311
-                         break
312
-
313
-         fused_fields[field] = fused_value
314
-
315
-         votes_debug[field] = {
316
-             "per_parser": per_parser_info,
317
-             "summed": value_scores,
318
-             "chosen": fused_value,
319
-         }
320
-
321
-     weights_meta = {
322
-         "has_weights_file": HAS_WEIGHTS_FILE,
323
-         "weights_path": FIELD_WEIGHTS_PATH,
324
-         "meta": FIELD_WEIGHTS_RAW.get("meta", {}) if HAS_WEIGHTS_FILE else {},
325
-     }
326
-
327
-     return {
328
-         "fused_fields": fused_fields,
329
-         "by_parser": by_parser,
330
-         "votes": votes_debug,
331
-         "weights_meta": weights_meta,
332
-     }
 
 
 
3
  # Tri-Parser Fusion — Stage 12B (Weighted, SOTA-style)
4
  #
5
  # This module combines:
6
+ # - Rule parser (parser_rules.parse_text_rules)
7
+ # - Extended parser (parser_ext.parse_text_extended)
8
+ # - LLM parser (parser_llm.parse_llm) [optional]
9
  #
10
  # using per-field reliability weights learned in Stage 12A
11
  # and stored in:
12
+ # data/field_weights.json
13
  #
14
  # Behaviour:
15
+ # - For each field, gather predictions from available parsers.
16
+ # - For that field, load weights:
17
+ # field_weights[field] (if present)
18
+ # else global weights
19
+ # else equal weights across available parsers
20
+ # - Discard parsers that:
21
+ # * did not predict the field
22
+ # * or only predicted "Unknown"
23
+ # - Group by predicted value and sum the weights of parsers
24
+ # that voted for each value.
25
+ # - Choose the value with highest total weight.
26
+ # Tie-break: prefer rules > extended > llm if needed.
27
  #
28
  # Output format:
29
+ # {
30
+ # "fused_fields": { field: value, ... }, # used by DB identifier AND genus ML
31
+ # "by_parser": {
32
+ # "rules": { ... },
33
+ # "extended": { ... },
34
+ # "llm": { ... } # may be empty
35
+ # },
36
+ # "votes": {
37
+ # field_name: {
38
+ # "per_parser": {
39
+ # "rules": {"value": "Positive", "weight": 0.95},
40
+ # "extended": {"value": "Unknown", "weight": 0.03},
41
+ # ...
42
+ # },
43
+ # "summed": {
44
+ # "Positive": 0.97,
45
+ # "Negative": 0.02
46
+ # },
47
+ # "chosen": "Positive"
48
+ # },
49
+ # ...
50
+ # },
51
+ # "weights_meta": {
52
+ # "has_weights_file": True/False,
53
+ # "weights_path": "data/field_weights.json",
54
+ # "meta": { ... } # from file if present
55
+ # }
56
+ # }
57
+ # ------------------------------------------------------------
58
+
59
+ from __future__ import annotations
60
 
61
  import json
62
  import os
63
+ from typing import Any, Dict, Optional
64
 
65
  from engine.parser_rules import parse_text_rules
66
+ from engine.parser_ext import parse_text_extended
67
 
68
  # Optional LLM parser
69
  try:
70
+ from engine.parser_llm import parse_llm as parse_text_llm # type: ignore
71
+ HAS_LLM = True
72
  except Exception:
73
+ parse_text_llm = None # type: ignore
74
+ HAS_LLM = False
75
 
76
  # Path to learned weights
77
+ FIELD_WEIGHTS_PATH = os.path.join("data", "field_weights.json")
78
 
79
  UNKNOWN = "Unknown"
80
+ PARSER_ORDER = ["rules", "extended", "llm"] # tie-breaking priority
81
 
82
 
83
  # ------------------------------------------------------------
84
  # Weights loading and helpers
85
+ # ------------------------------------------------------------
86
 
87
  def _load_field_weights(path: str = FIELD_WEIGHTS_PATH) -> Dict[str, Any]:
88
+ """
89
+ Load the JSON weights file produced by Stage 12A.
90
+
91
+ Expected structure:
92
+ {
93
+ "global": { "rules": 0.7, "extended": 0.2, "llm": 0.1 },
94
+ "fields": {
95
+ "DNase": {
96
+ "rules": 0.95,
97
+ "extended": 0.03,
98
+ "llm": 0.02,
99
+ "support": 123
100
+ },
101
+ ...
102
+ },
103
+ "meta": { ... }
104
+ }
105
+
106
+ If the file is missing or broken, fall back to empty dict,
107
+ triggering equal-weight behaviour later.
108
+ """
109
+ if not os.path.exists(path):
110
+ return {}
111
+
112
+ try:
113
+ with open(path, "r", encoding="utf-8") as f:
114
+ obj = json.load(f)
115
+ return obj if isinstance(obj, dict) else {}
116
+ except Exception:
117
+ return {}
 
 
118
 
119
 
120
  FIELD_WEIGHTS_RAW: Dict[str, Any] = _load_field_weights()
 
122
 
123
 
124
  def _normalise_scores(scores: Dict[str, float]) -> Dict[str, float]:
125
+ """
126
+ Normalise parser -> score into weights summing to 1.
127
+ If all scores are zero or dict is empty, return equal weights.
128
+ """
129
+ cleaned = {k: max(0.0, float(v)) for k, v in scores.items()}
130
+ total = sum(cleaned.values())
131
 
132
+ if total <= 0:
133
+ n = len(cleaned) or 1
134
+ return {k: 1.0 / n for k in cleaned}
135
 
136
+ return {k: v / total for k, v in cleaned.items()}
137
 
138
 
139
  def _get_base_weights_for_parsers(include_llm: bool) -> Dict[str, float]:
140
+ """
141
+ Equal-weight distribution across available parsers.
142
+ Used when no learned weights are available.
143
+ """
144
+ parsers = ["rules", "extended"]
145
+ if include_llm:
146
+ parsers.append("llm")
147
+
148
+ n = len(parsers) or 1
149
+ return {p: 1.0 / n for p in parsers}
150
 
151
 
152
  def _get_weights_for_field(field_name: str, include_llm: bool) -> Dict[str, float]:
153
+ """
154
+ Get weights for a specific field.
155
 
156
+ Priority:
157
+ 1) FIELD_WEIGHTS_RAW["fields"][field_name]
158
+ 2) FIELD_WEIGHTS_RAW["global"]
159
+ 3) Equal weights
 
160
 
161
+ Always:
162
+ - Drop 'llm' if include_llm == False
163
+ - Normalise
164
+ """
165
+ if not FIELD_WEIGHTS_RAW:
166
+ return _normalise_scores(_get_base_weights_for_parsers(include_llm))
 
167
 
168
+ fields_block = FIELD_WEIGHTS_RAW.get("fields", {}) or {}
169
+ global_block = FIELD_WEIGHTS_RAW.get("global", {}) or {}
170
 
171
+ raw: Dict[str, float] = {}
172
 
173
+ field_entry = fields_block.get(field_name)
174
+ if isinstance(field_entry, dict):
175
+ for k, v in field_entry.items():
176
+ if k in ("rules", "extended", "llm"):
177
+ raw[k] = float(v)
178
 
179
+ if not raw and isinstance(global_block, dict):
180
+ for k, v in global_block.items():
181
+ if k in ("rules", "extended", "llm"):
182
+ raw[k] = float(v)
183
 
184
+ if not raw:
185
+ raw = _get_base_weights_for_parsers(include_llm)
186
 
187
+ if not include_llm:
188
+ raw.pop("llm", None)
189
 
190
+ if not raw:
191
+ raw = _get_base_weights_for_parsers(include_llm=False)
192
 
193
+ return _normalise_scores(raw)
194
 
195
 
196
  # ------------------------------------------------------------
197
  # Fusion logic
198
+ # ------------------------------------------------------------
199
 
200
  def _clean_pred_value(val: Optional[str]) -> Optional[str]:
201
+ """
202
+ Treat None, empty string, or explicit "Unknown" as missing.
203
+ """
204
+ if val is None:
205
+ return None
206
+
207
+ s = str(val).strip()
208
+ if not s:
209
+ return None
210
+
211
+ if s.lower() == UNKNOWN.lower():
212
+ return None
213
+
214
+ return s
215
 
216
 
217
  def parse_text_fused(text: str, use_llm: Optional[bool] = None) -> Dict[str, Any]:
218
+ """
219
+ Main tri-parser fusion entrypoint.
220
+
221
+ Parameters
222
+ ----------
223
+ text : str
224
+ use_llm : bool or None
225
+ True -> include LLM
226
+ False -> exclude LLM
227
+ None -> include if available
228
+
229
+ Returns
230
+ -------
231
+ Dict[str, Any]
232
+ Full fusion output including votes and per-parser breakdowns.
233
+ """
234
+ original = text or ""
235
+ include_llm = HAS_LLM if use_llm is None else bool(use_llm)
236
+
237
+ rules_out = parse_text_rules(original) or {}
238
+ ext_out = parse_text_extended(original) or {}
239
+
240
+ rules_fields = dict(rules_out.get("parsed_fields", {}))
241
+ ext_fields = dict(ext_out.get("parsed_fields", {}))
242
+
243
+ llm_fields: Dict[str, Any] = {}
244
+ if include_llm and parse_text_llm is not None:
245
+ try:
246
+ llm_out = parse_text_llm(original)
247
+ if isinstance(llm_out, dict):
248
+ if "parsed_fields" in llm_out:
249
+ llm_fields = dict(llm_out.get("parsed_fields", {}))
250
+ else:
251
+ llm_fields = {str(k): v for k, v in llm_out.items()}
252
+ except Exception:
253
+ llm_fields = {}
254
+ else:
255
+ include_llm = False
256
+
257
+ by_parser: Dict[str, Dict[str, Any]] = {
258
+ "rules": rules_fields,
259
+ "extended": ext_fields,
260
+ "llm": llm_fields if include_llm else {},
261
+ }
262
+
263
+ candidate_fields = (
264
+ set(rules_fields.keys())
265
+ | set(ext_fields.keys())
266
+ | set(llm_fields.keys())
267
+ )
268
+
269
+ fused_fields: Dict[str, Any] = {}
270
+ votes_debug: Dict[str, Any] = {}
271
+
272
+ for field in sorted(candidate_fields):
273
+ weights = _get_weights_for_field(field, include_llm)
274
+
275
+ parser_preds: Dict[str, Optional[str]] = {
276
+ "rules": _clean_pred_value(rules_fields.get(field)),
277
+ "extended": _clean_pred_value(ext_fields.get(field)),
278
+ "llm": _clean_pred_value(llm_fields.get(field)) if include_llm else None,
279
+ }
280
+
281
+ per_parser_info: Dict[str, Any] = {}
282
+ value_scores: Dict[str, float] = {}
283
+
284
+ for parser_name in PARSER_ORDER:
285
+ if parser_name == "llm" and not include_llm:
286
+ continue
287
+
288
+ pred = parser_preds.get(parser_name)
289
+ w = float(weights.get(parser_name, 0.0))
290
+
291
+ per_parser_info[parser_name] = {
292
+ "value": pred if pred is not None else UNKNOWN,
293
+ "weight": w,
294
+ }
295
+
296
+ if pred is not None:
297
+ value_scores[pred] = value_scores.get(pred, 0.0) + w
298
+
299
+ if not value_scores:
300
+ fused_value = UNKNOWN
301
+ else:
302
+ max_score = max(value_scores.values())
303
+ best_values = [v for v, s in value_scores.items() if s == max_score]
304
+
305
+ if len(best_values) == 1:
306
+ fused_value = best_values[0]
307
+ else:
308
+ fused_value = best_values[0]
309
+ for parser_name in PARSER_ORDER:
310
+ if parser_name == "llm" and not include_llm:
311
+ continue
312
+ if parser_preds.get(parser_name) in best_values:
313
+ fused_value = parser_preds[parser_name] # type: ignore
314
+ break
315
+
316
+ fused_fields[field] = fused_value
317
+ votes_debug[field] = {
318
+ "per_parser": per_parser_info,
319
+ "summed": value_scores,
320
+ "chosen": fused_value,
321
+ }
322
+
323
+ weights_meta = {
324
+ "has_weights_file": HAS_WEIGHTS_FILE,
325
+ "weights_path": FIELD_WEIGHTS_PATH,
326
+ "meta": FIELD_WEIGHTS_RAW.get("meta", {}) if HAS_WEIGHTS_FILE else {},
327
+ }
328
+
329
+ return {
330
+ "fused_fields": fused_fields,
331
+ "by_parser": by_parser,
332
+ "votes": votes_debug,
333
+ "weights_meta": weights_meta,
334
+ }