EphAsad commited on
Commit
1c100b5
·
verified ·
1 Parent(s): d358339

Update engine/parser_fusion.py

Browse files
Files changed (1) hide show
  1. engine/parser_fusion.py +480 -258
engine/parser_fusion.py CHANGED
@@ -1,334 +1,556 @@
1
- # engine/parser_fusion.py
2
  # ------------------------------------------------------------
3
- # Tri-Parser Fusion Stage 12B (Weighted, SOTA-style)
4
  #
5
- # This module combines:
6
- # - Rule parser (parser_rules.parse_text_rules)
7
- # - Extended parser (parser_ext.parse_text_extended)
8
- # - LLM parser (parser_llm.parse_llm) [optional]
9
- #
10
- # using per-field reliability weights learned in Stage 12A
11
- # and stored in:
12
- # data/field_weights.json
13
- #
14
- # Behaviour:
15
- # - For each field, gather predictions from available parsers.
16
- # - For that field, load weights:
17
- # field_weights[field] (if present)
18
- # else global weights
19
- # else equal weights across available parsers
20
- # - Discard parsers that:
21
- # * did not predict the field
22
- # * or only predicted "Unknown"
23
- # - Group by predicted value and sum the weights of parsers
24
- # that voted for each value.
25
- # - Choose the value with highest total weight.
26
- # Tie-break: prefer rules > extended > llm if needed.
27
- #
28
- # Output format:
29
- # {
30
- # "fused_fields": { field: value, ... }, # used by DB identifier AND genus ML
31
- # "by_parser": {
32
- # "rules": { ... },
33
- # "extended": { ... },
34
- # "llm": { ... } # may be empty
35
- # },
36
- # "votes": {
37
- # field_name: {
38
- # "per_parser": {
39
- # "rules": {"value": "Positive", "weight": 0.95},
40
- # "extended": {"value": "Unknown", "weight": 0.03},
41
- # ...
42
- # },
43
- # "summed": {
44
- # "Positive": 0.97,
45
- # "Negative": 0.02
46
- # },
47
- # "chosen": "Positive"
48
- # },
49
- # ...
50
- # },
51
- # "weights_meta": {
52
- # "has_weights_file": True/False,
53
- # "weights_path": "data/field_weights.json",
54
- # "meta": { ... } # from file if present
55
- # }
56
- # }
57
  # ------------------------------------------------------------
58
 
59
  from __future__ import annotations
60
 
61
  import json
62
  import os
63
- from typing import Any, Dict, Optional
 
 
 
 
 
64
 
65
- from engine.parser_rules import parse_text_rules
66
- from engine.parser_ext import parse_text_extended
67
 
68
- # Optional LLM parser
69
- try:
70
- from engine.parser_llm import parse_llm as parse_text_llm # type: ignore
71
- HAS_LLM = True
72
- except Exception:
73
- parse_text_llm = None # type: ignore
74
- HAS_LLM = False
 
 
75
 
76
- # Path to learned weights
77
- FIELD_WEIGHTS_PATH = os.path.join("data", "field_weights.json")
78
 
79
- UNKNOWN = "Unknown"
80
- PARSER_ORDER = ["rules", "extended", "llm"] # tie-breaking priority
 
 
 
 
 
 
 
 
 
81
 
82
 
83
  # ------------------------------------------------------------
84
- # Weights loading and helpers
85
  # ------------------------------------------------------------
86
 
87
- def _load_field_weights(path: str = FIELD_WEIGHTS_PATH) -> Dict[str, Any]:
88
- """
89
- Load the JSON weights file produced by Stage 12A.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- Expected structure:
92
- {
93
- "global": { "rules": 0.7, "extended": 0.2, "llm": 0.1 },
94
- "fields": {
95
- "DNase": {
96
- "rules": 0.95,
97
- "extended": 0.03,
98
- "llm": 0.02,
99
- "support": 123
100
- },
101
- ...
102
- },
103
- "meta": { ... }
104
- }
105
 
106
- If the file is missing or broken, fall back to empty dict,
107
- triggering equal-weight behaviour later.
108
- """
109
- if not os.path.exists(path):
110
- return {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
 
 
 
 
 
 
 
112
  try:
113
  with open(path, "r", encoding="utf-8") as f:
114
- obj = json.load(f)
115
- return obj if isinstance(obj, dict) else {}
116
  except Exception:
117
- return {}
118
 
 
119
 
120
- FIELD_WEIGHTS_RAW: Dict[str, Any] = _load_field_weights()
121
- HAS_WEIGHTS_FILE: bool = bool(FIELD_WEIGHTS_RAW)
122
 
 
 
 
123
 
124
- def _normalise_scores(scores: Dict[str, float]) -> Dict[str, float]:
125
- """
126
- Normalise parser -> score into weights summing to 1.
127
- If all scores are zero or dict is empty, return equal weights.
128
- """
129
- cleaned = {k: max(0.0, float(v)) for k, v in scores.items()}
130
- total = sum(cleaned.values())
131
 
132
- if total <= 0:
133
- n = len(cleaned) or 1
134
- return {k: 1.0 / n for k in cleaned}
 
135
 
136
- return {k: v / total for k, v in cleaned.items()}
 
 
137
 
 
 
138
 
139
- def _get_base_weights_for_parsers(include_llm: bool) -> Dict[str, float]:
140
- """
141
- Equal-weight distribution across available parsers.
142
- Used when no learned weights are available.
143
- """
144
- parsers = ["rules", "extended"]
145
- if include_llm:
146
- parsers.append("llm")
147
 
148
- n = len(parsers) or 1
149
- return {p: 1.0 / n for p in parsers}
150
 
151
 
152
- def _get_weights_for_field(field_name: str, include_llm: bool) -> Dict[str, float]:
153
- """
154
- Get weights for a specific field.
155
 
156
- Priority:
157
- 1) FIELD_WEIGHTS_RAW["fields"][field_name]
158
- 2) FIELD_WEIGHTS_RAW["global"]
159
- 3) Equal weights
 
 
 
 
 
 
 
 
160
 
161
- Always:
162
- - Drop 'llm' if include_llm == False
163
- - Normalise
164
- """
165
- if not FIELD_WEIGHTS_RAW:
166
- return _normalise_scores(_get_base_weights_for_parsers(include_llm))
167
 
168
- fields_block = FIELD_WEIGHTS_RAW.get("fields", {}) or {}
169
- global_block = FIELD_WEIGHTS_RAW.get("global", {}) or {}
170
 
171
- raw: Dict[str, float] = {}
 
 
172
 
173
- field_entry = fields_block.get(field_name)
174
- if isinstance(field_entry, dict):
175
- for k, v in field_entry.items():
176
- if k in ("rules", "extended", "llm"):
177
- raw[k] = float(v)
178
 
179
- if not raw and isinstance(global_block, dict):
180
- for k, v in global_block.items():
181
- if k in ("rules", "extended", "llm"):
182
- raw[k] = float(v)
183
 
184
- if not raw:
185
- raw = _get_base_weights_for_parsers(include_llm)
186
 
187
- if not include_llm:
188
- raw.pop("llm", None)
 
 
 
189
 
190
- if not raw:
191
- raw = _get_base_weights_for_parsers(include_llm=False)
192
 
193
- return _normalise_scores(raw)
 
 
 
 
 
 
 
 
194
 
 
 
195
 
196
- # ------------------------------------------------------------
197
- # Fusion logic
198
- # ------------------------------------------------------------
199
 
200
- def _clean_pred_value(val: Optional[str]) -> Optional[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  """
202
- Treat None, empty string, or explicit "Unknown" as missing.
203
  """
204
- if val is None:
205
- return None
206
 
207
- s = str(val).strip()
208
- if not s:
209
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
- if s.lower() == UNKNOWN.lower():
212
- return None
 
 
 
 
213
 
214
- return s
215
 
 
 
 
216
 
217
- def parse_text_fused(text: str, use_llm: Optional[bool] = None) -> Dict[str, Any]:
218
  """
219
- Main tri-parser fusion entrypoint.
220
 
221
  Parameters
222
  ----------
223
  text : str
224
- use_llm : bool or None
225
- True -> include LLM
226
- False -> exclude LLM
227
- None -> include if available
 
228
 
229
  Returns
230
  -------
231
- Dict[str, Any]
232
- Full fusion output including votes and per-parser breakdowns.
 
 
 
 
 
233
  """
234
  original = text or ""
235
- include_llm = HAS_LLM if use_llm is None else bool(use_llm)
236
-
237
- rules_out = parse_text_rules(original) or {}
238
- ext_out = parse_text_extended(original) or {}
239
-
240
- rules_fields = dict(rules_out.get("parsed_fields", {}))
241
- ext_fields = dict(ext_out.get("parsed_fields", {}))
242
-
243
- llm_fields: Dict[str, Any] = {}
244
- if include_llm and parse_text_llm is not None:
245
- try:
246
- llm_out = parse_text_llm(original)
247
- if isinstance(llm_out, dict):
248
- if "parsed_fields" in llm_out:
249
- llm_fields = dict(llm_out.get("parsed_fields", {}))
250
- else:
251
- llm_fields = {str(k): v for k, v in llm_out.items()}
252
- except Exception:
253
- llm_fields = {}
254
- else:
255
- include_llm = False
256
-
257
- by_parser: Dict[str, Dict[str, Any]] = {
258
- "rules": rules_fields,
259
- "extended": ext_fields,
260
- "llm": llm_fields if include_llm else {},
261
- }
262
 
263
- candidate_fields = (
264
- set(rules_fields.keys())
265
- | set(ext_fields.keys())
266
- | set(llm_fields.keys())
267
- )
268
 
269
- fused_fields: Dict[str, Any] = {}
270
- votes_debug: Dict[str, Any] = {}
271
 
272
- for field in sorted(candidate_fields):
273
- weights = _get_weights_for_field(field, include_llm)
 
 
 
 
 
274
 
275
- parser_preds: Dict[str, Optional[str]] = {
276
- "rules": _clean_pred_value(rules_fields.get(field)),
277
- "extended": _clean_pred_value(ext_fields.get(field)),
278
- "llm": _clean_pred_value(llm_fields.get(field)) if include_llm else None,
279
- }
280
 
281
- per_parser_info: Dict[str, Any] = {}
282
- value_scores: Dict[str, float] = {}
 
 
 
 
283
 
284
- for parser_name in PARSER_ORDER:
285
- if parser_name == "llm" and not include_llm:
286
- continue
287
 
288
- pred = parser_preds.get(parser_name)
289
- w = float(weights.get(parser_name, 0.0))
 
 
 
 
290
 
291
- per_parser_info[parser_name] = {
292
- "value": pred if pred is not None else UNKNOWN,
293
- "weight": w,
294
- }
295
 
296
- if pred is not None:
297
- value_scores[pred] = value_scores.get(pred, 0.0) + w
 
298
 
299
- if not value_scores:
300
- fused_value = UNKNOWN
301
- else:
302
- max_score = max(value_scores.values())
303
- best_values = [v for v, s in value_scores.items() if s == max_score]
304
-
305
- if len(best_values) == 1:
306
- fused_value = best_values[0]
307
- else:
308
- fused_value = best_values[0]
309
- for parser_name in PARSER_ORDER:
310
- if parser_name == "llm" and not include_llm:
311
- continue
312
- if parser_preds.get(parser_name) in best_values:
313
- fused_value = parser_preds[parser_name] # type: ignore
314
- break
315
-
316
- fused_fields[field] = fused_value
317
- votes_debug[field] = {
318
- "per_parser": per_parser_info,
319
- "summed": value_scores,
320
- "chosen": fused_value,
321
- }
322
 
323
- weights_meta = {
324
- "has_weights_file": HAS_WEIGHTS_FILE,
325
- "weights_path": FIELD_WEIGHTS_PATH,
326
- "meta": FIELD_WEIGHTS_RAW.get("meta", {}) if HAS_WEIGHTS_FILE else {},
327
  }
328
-
329
- return {
330
- "fused_fields": fused_fields,
331
- "by_parser": by_parser,
332
- "votes": votes_debug,
333
- "weights_meta": weights_meta,
334
- }
 
1
+ # engine/parser_llm.py
2
  # ------------------------------------------------------------
3
+ # Local LLM parser for BactAI-D (T5 fine-tune, CPU-friendly)
4
  #
5
+ # UPDATED (EphBactAID integration):
6
+ # - Default model now points to your HF fine-tune: EphAsad/EphBactAID
7
+ # - Few-shot disabled by default (your fine-tune no longer needs it)
8
+ # - Robust output parsing:
9
+ # * Supports JSON output (legacy)
10
+ # * Supports "Key: Value" pairs output (your fine-tune style)
11
+ # - Merge guard (optional): LLM fills ONLY missing/Unknown fields
12
+ # - Validation/normalisation kept (PNV/Gram, sugar logic, aliases, ornithine sync)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # ------------------------------------------------------------
14
 
15
  from __future__ import annotations
16
 
17
  import json
18
  import os
19
+ import random
20
+ import re
21
+ from typing import Dict, Any, List, Optional
22
+
23
+ import torch
24
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
25
 
 
 
26
 
27
+ # ------------------------------------------------------------
28
+ # Model configuration
29
+ # ------------------------------------------------------------
30
+
31
+ # ✅ Your fine-tuned model (can be overridden via env var)
32
+ DEFAULT_MODEL = os.getenv(
33
+ "BACTAI_LLM_PARSER_MODEL",
34
+ "EphAsad/EphBactAID",
35
+ )
36
 
37
+ # Few-shot OFF by default now (fine-tune doesn't need it)
38
+ MAX_FEWSHOT_EXAMPLES = int(os.getenv("BACTAI_LLM_FEWSHOT", "0"))
39
 
40
+ MAX_NEW_TOKENS = int(os.getenv("BACTAI_LLM_MAX_NEW_TOKENS", "256"))
41
+
42
+ DEBUG_LLM = os.getenv("BACTAI_LLM_DEBUG", "0").strip().lower() in {
43
+ "1", "true", "yes", "y", "on"
44
+ }
45
+
46
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
47
+
48
+ _tokenizer: Optional[AutoTokenizer] = None
49
+ _model: Optional[AutoModelForSeq2SeqLM] = None
50
+ _GOLD_EXAMPLES: Optional[List[Dict[str, Any]]] = None
51
 
52
 
53
  # ------------------------------------------------------------
54
+ # Allowed fields
55
  # ------------------------------------------------------------
56
 
57
+ ALL_FIELDS: List[str] = [
58
+ "Gram Stain",
59
+ "Shape",
60
+ "Motility",
61
+ "Capsule",
62
+ "Spore Formation",
63
+ "Haemolysis",
64
+ "Haemolysis Type",
65
+ "Media Grown On",
66
+ "Colony Morphology",
67
+ "Oxygen Requirement",
68
+ "Growth Temperature",
69
+ "Catalase",
70
+ "Oxidase",
71
+ "Indole",
72
+ "Urease",
73
+ "Citrate",
74
+ "Methyl Red",
75
+ "VP",
76
+ "H2S",
77
+ "DNase",
78
+ "ONPG",
79
+ "Coagulase",
80
+ "Gelatin Hydrolysis",
81
+ "Esculin Hydrolysis",
82
+ "Nitrate Reduction",
83
+ "NaCl Tolerant (>=6%)",
84
+ "Lipase Test",
85
+ "Lysine Decarboxylase",
86
+ "Ornithine Decarboxylase",
87
+ "Ornitihine Decarboxylase",
88
+ "Arginine dihydrolase",
89
+ "Glucose Fermentation",
90
+ "Lactose Fermentation",
91
+ "Sucrose Fermentation",
92
+ "Maltose Fermentation",
93
+ "Mannitol Fermentation",
94
+ "Sorbitol Fermentation",
95
+ "Xylose Fermentation",
96
+ "Rhamnose Fermentation",
97
+ "Arabinose Fermentation",
98
+ "Raffinose Fermentation",
99
+ "Trehalose Fermentation",
100
+ "Inositol Fermentation",
101
+ "Gas Production",
102
+ "TSI Pattern",
103
+ "Colony Pattern",
104
+ "Pigment",
105
+ "Motility Type",
106
+ "Odor",
107
+ ]
108
+
109
+ SUGAR_FIELDS = [
110
+ "Glucose Fermentation",
111
+ "Lactose Fermentation",
112
+ "Sucrose Fermentation",
113
+ "Maltose Fermentation",
114
+ "Mannitol Fermentation",
115
+ "Sorbitol Fermentation",
116
+ "Xylose Fermentation",
117
+ "Rhamnose Fermentation",
118
+ "Arabinose Fermentation",
119
+ "Raffinose Fermentation",
120
+ "Trehalose Fermentation",
121
+ "Inositol Fermentation",
122
+ ]
123
+
124
+ PNV_FIELDS = {
125
+ f for f in ALL_FIELDS
126
+ if f not in {
127
+ "Media Grown On",
128
+ "Colony Morphology",
129
+ "Growth Temperature",
130
+ "Gram Stain",
131
+ "Shape",
132
+ "Oxygen Requirement",
133
+ "Haemolysis Type",
134
+ "TSI Pattern",
135
+ "Colony Pattern",
136
+ "Motility Type",
137
+ "Odor",
138
+ "Pigment",
139
+ "Gas Production",
140
+ }
141
+ }
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
+ # ------------------------------------------------------------
145
+ # Field alias mapping (CRITICAL)
146
+ # ------------------------------------------------------------
147
+
148
+ FIELD_ALIASES: Dict[str, str] = {
149
+ "Gram": "Gram Stain",
150
+ "Gram stain": "Gram Stain",
151
+ "Gram Stain Result": "Gram Stain",
152
+
153
+ "NaCl tolerance": "NaCl Tolerant (>=6%)",
154
+ "NaCl Tolerant": "NaCl Tolerant (>=6%)",
155
+ "Salt tolerance": "NaCl Tolerant (>=6%)",
156
+ "Salt tolerant": "NaCl Tolerant (>=6%)",
157
+ "6.5% NaCl": "NaCl Tolerant (>=6%)",
158
+ "6% NaCl": "NaCl Tolerant (>=6%)",
159
+
160
+ "Growth temp": "Growth Temperature",
161
+ "Growth temperature": "Growth Temperature",
162
+ "Temperature growth": "Growth Temperature",
163
+
164
+ "Catalase test": "Catalase",
165
+ "Oxidase test": "Oxidase",
166
+ "Indole test": "Indole",
167
+ "Urease test": "Urease",
168
+ "Citrate test": "Citrate",
169
+
170
+ "Glucose fermentation": "Glucose Fermentation",
171
+ "Lactose fermentation": "Lactose Fermentation",
172
+ "Sucrose fermentation": "Sucrose Fermentation",
173
+ "Maltose fermentation": "Maltose Fermentation",
174
+ "Mannitol fermentation": "Mannitol Fermentation",
175
+ "Sorbitol fermentation": "Sorbitol Fermentation",
176
+ "Xylose fermentation": "Xylose Fermentation",
177
+ "Rhamnose fermentation": "Rhamnose Fermentation",
178
+ "Arabinose fermentation": "Arabinose Fermentation",
179
+ "Raffinose fermentation": "Raffinose Fermentation",
180
+ "Trehalose fermentation": "Trehalose Fermentation",
181
+ "Inositol fermentation": "Inositol Fermentation",
182
+
183
+ # common variants from outputs
184
+ "Voges–Proskauer Test": "VP",
185
+ "Voges-Proskauer Test": "VP",
186
+ "Voges–Proskauer": "VP",
187
+ "Voges-Proskauer": "VP",
188
+ }
189
+
190
+
191
+ # ------------------------------------------------------------
192
+ # Normalisation helpers
193
+ # ------------------------------------------------------------
194
+
195
+ def _norm_str(s: Any) -> str:
196
+ return str(s).strip() if s is not None else ""
197
+
198
+
199
+ def _normalise_pnv_value(raw: Any) -> str:
200
+ s = _norm_str(raw).lower()
201
+ if not s:
202
+ return "Unknown"
203
+
204
+ # positive
205
+ if any(x in s for x in {"positive", "pos", "+", "yes", "present", "detected", "reactive"}):
206
+ return "Positive"
207
+
208
+ # negative
209
+ if any(x in s for x in {"negative", "neg", "-", "no", "none", "absent", "not detected", "no growth"}):
210
+ return "Negative"
211
+
212
+ # variable
213
+ if any(x in s for x in {"variable", "mixed", "inconsistent"}):
214
+ return "Variable"
215
+
216
+ return "Unknown"
217
+
218
+
219
+ def _normalise_gram(raw: Any) -> str:
220
+ s = _norm_str(raw).lower()
221
+ if "positive" in s:
222
+ return "Positive"
223
+ if "negative" in s:
224
+ return "Negative"
225
+ if "variable" in s:
226
+ return "Variable"
227
+ return "Unknown"
228
+
229
+
230
+ def _merge_ornithine_variants(fields: Dict[str, str]) -> Dict[str, str]:
231
+ v = fields.get("Ornithine Decarboxylase") or fields.get("Ornitihine Decarboxylase")
232
+ if v and v != "Unknown":
233
+ fields["Ornithine Decarboxylase"] = v
234
+ fields["Ornitihine Decarboxylase"] = v
235
+ return fields
236
+
237
+
238
+ # ------------------------------------------------------------
239
+ # Sugar logic
240
+ # ------------------------------------------------------------
241
+
242
+ _NON_FERMENTER_PATTERNS = re.compile(
243
+ r"\b("
244
+ r"non[-\s]?fermenter|"
245
+ r"non[-\s]?fermentative|"
246
+ r"asaccharolytic|"
247
+ r"does not ferment (sugars|carbohydrates)|"
248
+ r"no carbohydrate fermentation"
249
+ r")\b",
250
+ re.IGNORECASE,
251
+ )
252
+
253
+
254
+ def _apply_global_sugar_logic(fields: Dict[str, str], original_text: str) -> Dict[str, str]:
255
+ if not _NON_FERMENTER_PATTERNS.search(original_text):
256
+ return fields
257
+
258
+ for sugar in SUGAR_FIELDS:
259
+ if fields.get(sugar) in {"Positive", "Variable"}:
260
+ continue
261
+ fields[sugar] = "Negative"
262
+
263
+ return fields
264
+
265
+
266
+ # ------------------------------------------------------------
267
+ # Gold examples (kept for backwards compat; now optional)
268
+ # ------------------------------------------------------------
269
+
270
+ def _get_project_root() -> str:
271
+ return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
272
 
273
+
274
+ def _load_gold_examples() -> List[Dict[str, Any]]:
275
+ global _GOLD_EXAMPLES
276
+ if _GOLD_EXAMPLES is not None:
277
+ return _GOLD_EXAMPLES
278
+
279
+ path = os.path.join(_get_project_root(), "data", "llm_gold_examples.json")
280
  try:
281
  with open(path, "r", encoding="utf-8") as f:
282
+ data = json.load(f)
283
+ _GOLD_EXAMPLES = data if isinstance(data, list) else []
284
  except Exception:
285
+ _GOLD_EXAMPLES = []
286
 
287
+ return _GOLD_EXAMPLES
288
 
 
 
289
 
290
+ # ------------------------------------------------------------
291
+ # Prompt (supports both JSON + KV outputs; fine-tune usually KV)
292
+ # ------------------------------------------------------------
293
 
294
+ PROMPT_HEADER = """
295
+ You are a microbiology phenotype parser.
 
 
 
 
 
296
 
297
+ Task:
298
+ - Extract ONLY explicitly stated results from the input text.
299
+ - Do NOT invent results.
300
+ - If not stated, omit the field or use "Unknown".
301
 
302
+ Output format:
303
+ - Prefer "Field: Value" lines, one per line.
304
+ - You may also output JSON if instructed.
305
 
306
+ Use the exact schema keys where possible.
307
+ """
308
 
309
+ PROMPT_FOOTER = """
310
+ Input:
311
+ \"\"\"<<PHENOTYPE>>\"\"\"
 
 
 
 
 
312
 
313
+ Output:
314
+ """
315
 
316
 
317
+ def _build_prompt(text: str) -> str:
318
+ # Few-shot disabled by default; but we keep the capability for testing.
319
+ blocks: List[str] = [PROMPT_HEADER]
320
 
321
+ if MAX_FEWSHOT_EXAMPLES > 0:
322
+ examples = _load_gold_examples()
323
+ n = min(MAX_FEWSHOT_EXAMPLES, len(examples))
324
+ sampled = random.sample(examples, n) if n > 0 else []
325
+ for ex in sampled:
326
+ inp = _norm_str(ex.get("input", ""))
327
+ exp = ex.get("expected", {})
328
+ if not isinstance(exp, dict):
329
+ exp = {}
330
+ # Show KV style to match your fine-tune
331
+ kv_lines = "\n".join([f"{k}: {v}" for k, v in exp.items()])
332
+ blocks.append(f'Example Input:\n"""{inp}"""\nExample Output:\n{kv_lines}\n')
333
 
334
+ blocks.append(PROMPT_FOOTER.replace("<<PHENOTYPE>>", text))
335
+ return "\n".join(blocks)
 
 
 
 
336
 
 
 
337
 
338
+ # ------------------------------------------------------------
339
+ # Model loader
340
+ # ------------------------------------------------------------
341
 
342
+ def _load_model() -> None:
343
+ global _model, _tokenizer
344
+ if _model is not None and _tokenizer is not None:
345
+ return
 
346
 
347
+ _tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL)
348
+ _model = AutoModelForSeq2SeqLM.from_pretrained(DEFAULT_MODEL).to(DEVICE)
349
+ _model.eval()
 
350
 
 
 
351
 
352
+ # ------------------------------------------------------------
353
+ # Output parsing helpers (JSON + KV)
354
+ # ------------------------------------------------------------
355
+
356
+ _JSON_OBJECT_RE = re.compile(r"\{[\s\S]*?\}")
357
 
 
 
358
 
359
+ def _extract_first_json_object(text: str) -> Dict[str, Any]:
360
+ m = _JSON_OBJECT_RE.search(text)
361
+ if not m:
362
+ return {}
363
+ try:
364
+ return json.loads(m.group(0))
365
+ except Exception:
366
+ return {}
367
+
368
 
369
+ # Match "Key: Value" (including keys with symbols like >=6%)
370
+ _KV_LINE_RE = re.compile(r"^\s*([^:\n]{2,120})\s*:\s*(.*?)\s*$")
371
 
 
 
 
372
 
373
+ def _extract_kv_pairs(text: str) -> Dict[str, Any]:
374
+ """
375
+ Parse outputs like:
376
+ Gram Stain: Positive
377
+ Shape: Cocci
378
+ ...
379
+ """
380
+ out: Dict[str, Any] = {}
381
+ for line in (text or "").splitlines():
382
+ line = line.strip()
383
+ if not line:
384
+ continue
385
+ m = _KV_LINE_RE.match(line)
386
+ if not m:
387
+ continue
388
+ k = _norm_str(m.group(1))
389
+ v = _norm_str(m.group(2))
390
+ if not k:
391
+ continue
392
+ out[k] = v
393
+ return out
394
+
395
+
396
+ def _apply_field_aliases(fields_raw: Dict[str, Any]) -> Dict[str, Any]:
397
+ out: Dict[str, Any] = {}
398
+ for k, v in fields_raw.items():
399
+ key = _norm_str(k)
400
+ if not key:
401
+ continue
402
+ mapped = FIELD_ALIASES.get(key, key)
403
+ out[mapped] = v
404
+ return out
405
+
406
+
407
+ def _clean_and_normalise(fields_raw: Dict[str, Any], original_text: str) -> Dict[str, str]:
408
  """
409
+ Keep only allowed fields and normalise values into your contract.
410
  """
411
+ cleaned: Dict[str, str] = {}
 
412
 
413
+ # Only accept keys that match schema (or aliases already applied)
414
+ for field in ALL_FIELDS:
415
+ if field not in fields_raw:
416
+ continue
417
+
418
+ raw_val = fields_raw[field]
419
+
420
+ if field == "Gram Stain":
421
+ cleaned[field] = _normalise_gram(raw_val)
422
+ elif field in PNV_FIELDS:
423
+ cleaned[field] = _normalise_pnv_value(raw_val)
424
+ else:
425
+ cleaned[field] = _norm_str(raw_val) or "Unknown"
426
+
427
+ cleaned = _merge_ornithine_variants(cleaned)
428
+ cleaned = _apply_global_sugar_logic(cleaned, original_text)
429
+ return cleaned
430
+
431
+
432
+ def _merge_guard_fill_only_missing(
433
+ llm_fields: Dict[str, str],
434
+ existing_fields: Optional[Dict[str, Any]],
435
+ ) -> Dict[str, str]:
436
+ """
437
+ Merge guard:
438
+ - If an existing field is present and not Unknown -> do NOT overwrite.
439
+ - If existing is missing/Unknown -> allow llm value (if not Unknown).
440
+ """
441
+ if not existing_fields or not isinstance(existing_fields, dict):
442
+ return llm_fields
443
+
444
+ out = dict(existing_fields) # start with existing
445
+ for k, v in llm_fields.items():
446
+ if k not in ALL_FIELDS:
447
+ continue
448
+ existing_val = _norm_str(out.get(k, ""))
449
+ existing_norm = _normalise_pnv_value(existing_val) if k in PNV_FIELDS else existing_val
450
+
451
+ # Treat empty/Unknown as fillable
452
+ fillable = (not existing_val) or (existing_val == "Unknown") or (existing_norm == "Unknown")
453
+ if not fillable:
454
+ continue
455
+
456
+ # Only fill if LLM has something meaningful
457
+ if _norm_str(v) and v != "Unknown":
458
+ out[k] = v
459
 
460
+ # Ensure we return only schema keys and strings
461
+ final: Dict[str, str] = {}
462
+ for k, v in out.items():
463
+ if k in ALL_FIELDS:
464
+ final[k] = _norm_str(v) or "Unknown"
465
+ return final
466
 
 
467
 
468
+ # ------------------------------------------------------------
469
+ # PUBLIC API
470
+ # ------------------------------------------------------------
471
 
472
+ def parse_llm(text: str, existing_fields: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
473
  """
474
+ Parse phenotype text using local seq2seq model.
475
 
476
  Parameters
477
  ----------
478
  text : str
479
+ phenotype description
480
+
481
+ existing_fields : dict | None
482
+ Optional pre-parsed fields (e.g., from rules/ext).
483
+ If provided, LLM will ONLY fill missing/Unknown fields.
484
 
485
  Returns
486
  -------
487
+ dict:
488
+ {
489
+ "parsed_fields": { ... },
490
+ "source": "llm_parser",
491
+ "raw": <original text>,
492
+ "decoded": <model output> (only when DEBUG on)
493
+ }
494
  """
495
  original = text or ""
496
+ if not original.strip():
497
+ return {
498
+ "parsed_fields": (existing_fields or {}) if isinstance(existing_fields, dict) else {},
499
+ "source": "llm_parser",
500
+ "raw": original,
501
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
 
503
+ _load_model()
504
+ assert _tokenizer is not None and _model is not None
 
 
 
505
 
506
+ prompt = _build_prompt(original)
507
+ inputs = _tokenizer(prompt, return_tensors="pt", truncation=True).to(DEVICE)
508
 
509
+ with torch.no_grad():
510
+ output = _model.generate(
511
+ **inputs,
512
+ max_new_tokens=MAX_NEW_TOKENS,
513
+ do_sample=False,
514
+ temperature=0.0,
515
+ )
516
 
517
+ decoded = _tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
518
 
519
+ if DEBUG_LLM:
520
+ print("=== LLM PROMPT (truncated) ===")
521
+ print(prompt[:1500] + ("..." if len(prompt) > 1500 else ""))
522
+ print("=== LLM RAW OUTPUT ===")
523
+ print(decoded)
524
+ print("======================")
525
 
526
+ # 1) Try JSON extraction (legacy)
527
+ parsed_obj = _extract_first_json_object(decoded)
528
+ fields_raw = {}
529
 
530
+ if isinstance(parsed_obj, dict) and parsed_obj:
531
+ if "parsed_fields" in parsed_obj and isinstance(parsed_obj.get("parsed_fields"), dict):
532
+ fields_raw = dict(parsed_obj["parsed_fields"])
533
+ else:
534
+ # in case model returned a flat JSON dict
535
+ fields_raw = dict(parsed_obj)
536
 
537
+ # 2) Fallback to KV parsing (your fine-tune style)
538
+ if not fields_raw:
539
+ fields_raw = _extract_kv_pairs(decoded)
 
540
 
541
+ # 3) Alias map + normalise
542
+ fields_raw = _apply_field_aliases(fields_raw)
543
+ cleaned = _clean_and_normalise(fields_raw, original)
544
 
545
+ # 4) Merge guard (optional) - fill only missing/Unknown
546
+ if existing_fields is not None:
547
+ cleaned = _merge_guard_fill_only_missing(cleaned, existing_fields)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
548
 
549
+ out = {
550
+ "parsed_fields": cleaned,
551
+ "source": "llm_parser",
552
+ "raw": original,
553
  }
554
+ if DEBUG_LLM:
555
+ out["decoded"] = decoded
556
+ return out