EphAsad commited on
Commit
584a0a9
·
verified ·
1 Parent(s): e6d653e

Update engine/parser_llm.py

Browse files
Files changed (1) hide show
  1. engine/parser_llm.py +556 -435
engine/parser_llm.py CHANGED
@@ -1,435 +1,556 @@
1
- # engine/parser_llm.py
2
- # ------------------------------------------------------------
3
- # Local LLM parser for BactAI-D (Flan-T5, CPU-friendly)
4
- # Third parser head: repair & recovery
5
- #
6
- # Drop-in patched version:
7
- # - Few-shot examples increased (configurable via env)
8
- # - Field alias mapping (prevents silent field drops)
9
- # - Non-greedy JSON extraction (prevents regex over-capture)
10
- # - Improved P/N/V normalization (Flan phrasing coverage)
11
- # - Prompt refined for "extract/clarify" (reduces Unknown collapse)
12
- # - Debug prints (toggle via env var)
13
- # - Sugar logic scaffold preserved
14
- # ------------------------------------------------------------
15
-
16
- from __future__ import annotations
17
-
18
- import json
19
- import os
20
- import random
21
- import re
22
- from typing import Dict, Any, List, Optional
23
-
24
- import torch
25
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
26
-
27
-
28
- # ------------------------------------------------------------
29
- # Model configuration
30
- # ------------------------------------------------------------
31
-
32
- DEFAULT_MODEL = os.getenv(
33
- "BACTAI_LLM_PARSER_MODEL",
34
- "google/flan-t5-base",
35
- )
36
-
37
- MAX_FEWSHOT_EXAMPLES = int(os.getenv("BACTAI_LLM_FEWSHOT", "25"))
38
- MAX_NEW_TOKENS = int(os.getenv("BACTAI_LLM_MAX_NEW_TOKENS", "128"))
39
-
40
- DEBUG_LLM = os.getenv("BACTAI_LLM_DEBUG", "1").strip().lower() in {
41
- "1", "true", "yes", "y", "on"
42
- }
43
-
44
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
45
-
46
- _tokenizer: Optional[AutoTokenizer] = None
47
- _model: Optional[AutoModelForSeq2SeqLM] = None
48
- _GOLD_EXAMPLES: Optional[List[Dict[str, Any]]] = None
49
-
50
-
51
- # ------------------------------------------------------------
52
- # Allowed fields
53
- # ------------------------------------------------------------
54
-
55
- ALL_FIELDS: List[str] = [
56
- "Gram Stain",
57
- "Shape",
58
- "Motility",
59
- "Capsule",
60
- "Spore Formation",
61
- "Haemolysis",
62
- "Haemolysis Type",
63
- "Media Grown On",
64
- "Colony Morphology",
65
- "Oxygen Requirement",
66
- "Growth Temperature",
67
- "Catalase",
68
- "Oxidase",
69
- "Indole",
70
- "Urease",
71
- "Citrate",
72
- "Methyl Red",
73
- "VP",
74
- "H2S",
75
- "DNase",
76
- "ONPG",
77
- "Coagulase",
78
- "Gelatin Hydrolysis",
79
- "Esculin Hydrolysis",
80
- "Nitrate Reduction",
81
- "NaCl Tolerant (>=6%)",
82
- "Lipase Test",
83
- "Lysine Decarboxylase",
84
- "Ornithine Decarboxylase",
85
- "Ornitihine Decarboxylase",
86
- "Arginine dihydrolase",
87
- "Glucose Fermentation",
88
- "Lactose Fermentation",
89
- "Sucrose Fermentation",
90
- "Maltose Fermentation",
91
- "Mannitol Fermentation",
92
- "Sorbitol Fermentation",
93
- "Xylose Fermentation",
94
- "Rhamnose Fermentation",
95
- "Arabinose Fermentation",
96
- "Raffinose Fermentation",
97
- "Trehalose Fermentation",
98
- "Inositol Fermentation",
99
- "Gas Production",
100
- "TSI Pattern",
101
- "Colony Pattern",
102
- "Pigment",
103
- "Motility Type",
104
- "Odor",
105
- ]
106
-
107
- SUGAR_FIELDS = [
108
- "Glucose Fermentation",
109
- "Lactose Fermentation",
110
- "Sucrose Fermentation",
111
- "Maltose Fermentation",
112
- "Mannitol Fermentation",
113
- "Sorbitol Fermentation",
114
- "Xylose Fermentation",
115
- "Rhamnose Fermentation",
116
- "Arabinose Fermentation",
117
- "Raffinose Fermentation",
118
- "Trehalose Fermentation",
119
- "Inositol Fermentation",
120
- ]
121
-
122
- PNV_FIELDS = {
123
- f for f in ALL_FIELDS
124
- if f not in {
125
- "Media Grown On",
126
- "Colony Morphology",
127
- "Growth Temperature",
128
- "Gram Stain",
129
- "Shape",
130
- "Oxygen Requirement",
131
- "Haemolysis Type",
132
- }
133
- }
134
-
135
-
136
- # ------------------------------------------------------------
137
- # Field alias mapping (CRITICAL)
138
- # ------------------------------------------------------------
139
-
140
- FIELD_ALIASES: Dict[str, str] = {
141
- "Gram": "Gram Stain",
142
- "Gram stain": "Gram Stain",
143
- "Gram Stain Result": "Gram Stain",
144
-
145
- "NaCl tolerance": "NaCl Tolerant (>=6%)",
146
- "NaCl Tolerant": "NaCl Tolerant (>=6%)",
147
- "Salt tolerance": "NaCl Tolerant (>=6%)",
148
- "Salt tolerant": "NaCl Tolerant (>=6%)",
149
- "6.5% NaCl": "NaCl Tolerant (>=6%)",
150
- "6% NaCl": "NaCl Tolerant (>=6%)",
151
-
152
- "Growth temp": "Growth Temperature",
153
- "Growth temperature": "Growth Temperature",
154
- "Temperature growth": "Growth Temperature",
155
-
156
- "Catalase test": "Catalase",
157
- "Oxidase test": "Oxidase",
158
- "Indole test": "Indole",
159
- "Urease test": "Urease",
160
- "Citrate test": "Citrate",
161
-
162
- "Glucose fermentation": "Glucose Fermentation",
163
- "Lactose fermentation": "Lactose Fermentation",
164
- "Sucrose fermentation": "Sucrose Fermentation",
165
- "Maltose fermentation": "Maltose Fermentation",
166
- "Mannitol fermentation": "Mannitol Fermentation",
167
- "Sorbitol fermentation": "Sorbitol Fermentation",
168
- "Xylose fermentation": "Xylose Fermentation",
169
- "Rhamnose fermentation": "Rhamnose Fermentation",
170
- "Arabinose fermentation": "Arabinose Fermentation",
171
- "Raffinose fermentation": "Raffinose Fermentation",
172
- "Trehalose fermentation": "Trehalose Fermentation",
173
- "Inositol fermentation": "Inositol Fermentation",
174
- }
175
-
176
-
177
- # ------------------------------------------------------------
178
- # Normalisation helpers
179
- # ------------------------------------------------------------
180
-
181
- def _norm_str(s: Any) -> str:
182
- return str(s).strip() if s is not None else ""
183
-
184
-
185
- def _normalise_pnv_value(raw: Any) -> str:
186
- s = _norm_str(raw).lower()
187
- if not s:
188
- return "Unknown"
189
-
190
- if any(x in s for x in {"positive", "pos", "+", "yes", "present", "detected", "reactive"}):
191
- return "Positive"
192
-
193
- if any(x in s for x in {"negative", "neg", "-", "no", "none", "absent", "not detected", "no growth"}):
194
- return "Negative"
195
-
196
- if any(x in s for x in {"variable", "mixed", "inconsistent"}):
197
- return "Variable"
198
-
199
- return "Unknown"
200
-
201
-
202
- def _normalise_gram(raw: Any) -> str:
203
- s = _norm_str(raw).lower()
204
- if "positive" in s:
205
- return "Positive"
206
- if "negative" in s:
207
- return "Negative"
208
- if "variable" in s:
209
- return "Variable"
210
- return "Unknown"
211
-
212
-
213
- def _merge_ornithine_variants(fields: Dict[str, str]) -> Dict[str, str]:
214
- v = fields.get("Ornithine Decarboxylase") or fields.get("Ornitihine Decarboxylase")
215
- if v and v != "Unknown":
216
- fields["Ornithine Decarboxylase"] = v
217
- fields["Ornitihine Decarboxylase"] = v
218
- return fields
219
-
220
-
221
- # ------------------------------------------------------------
222
- # Sugar logic
223
- # ------------------------------------------------------------
224
-
225
- _NON_FERMENTER_PATTERNS = re.compile(
226
- r"\b("
227
- r"non[-\s]?fermenter|"
228
- r"non[-\s]?fermentative|"
229
- r"asaccharolytic|"
230
- r"does not ferment (sugars|carbohydrates)|"
231
- r"no carbohydrate fermentation"
232
- r")\b",
233
- re.IGNORECASE,
234
- )
235
-
236
-
237
- def _apply_global_sugar_logic(fields: Dict[str, str], original_text: str) -> Dict[str, str]:
238
- if not _NON_FERMENTER_PATTERNS.search(original_text):
239
- return fields
240
-
241
- for sugar in SUGAR_FIELDS:
242
- if fields.get(sugar) in {"Positive", "Variable"}:
243
- continue
244
- fields[sugar] = "Negative"
245
-
246
- return fields
247
-
248
-
249
- # ------------------------------------------------------------
250
- # Gold examples
251
- # ------------------------------------------------------------
252
-
253
- def _get_project_root() -> str:
254
- return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
255
-
256
-
257
- def _load_gold_examples() -> List[Dict[str, Any]]:
258
- global _GOLD_EXAMPLES
259
- if _GOLD_EXAMPLES is not None:
260
- return _GOLD_EXAMPLES
261
-
262
- path = os.path.join(_get_project_root(), "data", "llm_gold_examples.json")
263
- try:
264
- with open(path, "r", encoding="utf-8") as f:
265
- data = json.load(f)
266
- _GOLD_EXAMPLES = data if isinstance(data, list) else []
267
- except Exception:
268
- _GOLD_EXAMPLES = []
269
-
270
- return _GOLD_EXAMPLES
271
-
272
-
273
- # ------------------------------------------------------------
274
- # Prompt
275
- # ------------------------------------------------------------
276
-
277
- PROMPT_HEADER = """
278
- You are a microbiology expert assisting an automated phenotype parser.
279
-
280
- Your task is to EXTRACT OR CLARIFY phenotypic and biochemical test results
281
- from the input text.
282
-
283
- Rules:
284
- - Return ONLY valid JSON
285
- - Do NOT invent results
286
- - If a result is unclear or not stated, use "Unknown"
287
- - Prefer explicit statements over assumptions
288
-
289
- Output format:
290
- {
291
- "parsed_fields": {
292
- "Field Name": "Value",
293
- ...
294
- }
295
- }
296
- """
297
-
298
- PROMPT_FOOTER = """
299
- Now process the following phenotype description.
300
-
301
- Input:
302
- \"\"\"<<PHENOTYPE>>\"\"\"
303
-
304
- Return ONLY the JSON object.
305
- """
306
-
307
-
308
- def _build_prompt(text: str) -> str:
309
- examples = _load_gold_examples()
310
- n = min(MAX_FEWSHOT_EXAMPLES, len(examples))
311
- sampled = random.sample(examples, n) if n > 0 else []
312
-
313
- blocks: List[str] = [PROMPT_HEADER]
314
-
315
- for ex in sampled:
316
- inp = _norm_str(ex.get("input", ""))
317
- exp = ex.get("expected", {})
318
- if not isinstance(exp, dict):
319
- exp = {}
320
-
321
- blocks.append(
322
- f'Input:\n"""{inp}"""\nOutput:\n'
323
- f'{json.dumps({"parsed_fields": exp}, ensure_ascii=False)}\n'
324
- )
325
-
326
- blocks.append(PROMPT_FOOTER.replace("<<PHENOTYPE>>", text))
327
- return "\n".join(blocks)
328
-
329
-
330
- # ------------------------------------------------------------
331
- # Model loader
332
- # ------------------------------------------------------------
333
-
334
- def _load_model() -> None:
335
- global _model, _tokenizer
336
- if _model is not None and _tokenizer is not None:
337
- return
338
-
339
- _tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL)
340
- _model = AutoModelForSeq2SeqLM.from_pretrained(DEFAULT_MODEL).to(DEVICE)
341
- _model.eval()
342
-
343
-
344
- # ------------------------------------------------------------
345
- # JSON extraction (non-greedy)
346
- # ------------------------------------------------------------
347
-
348
- _JSON_OBJECT_RE = re.compile(r"\{[\s\S]*?\}")
349
-
350
-
351
- def _extract_first_json_object(text: str) -> Dict[str, Any]:
352
- m = _JSON_OBJECT_RE.search(text)
353
- if not m:
354
- return {}
355
- try:
356
- return json.loads(m.group(0))
357
- except Exception:
358
- return {}
359
-
360
-
361
- def _apply_field_aliases(fields_raw: Dict[str, Any]) -> Dict[str, Any]:
362
- out: Dict[str, Any] = {}
363
- for k, v in fields_raw.items():
364
- key = _norm_str(k)
365
- if not key:
366
- continue
367
- mapped = FIELD_ALIASES.get(key, key)
368
- out[mapped] = v
369
- return out
370
-
371
-
372
- # ------------------------------------------------------------
373
- # PUBLIC API
374
- # ------------------------------------------------------------
375
-
376
- def parse_llm(text: str) -> Dict[str, Any]:
377
- original = text or ""
378
- if not original.strip():
379
- return {
380
- "parsed_fields": {},
381
- "source": "llm_parser",
382
- "raw": original,
383
- }
384
-
385
- _load_model()
386
- assert _tokenizer is not None and _model is not None
387
-
388
- prompt = _build_prompt(original)
389
- inputs = _tokenizer(prompt, return_tensors="pt", truncation=True).to(DEVICE)
390
-
391
- with torch.no_grad():
392
- output = _model.generate(
393
- **inputs,
394
- max_new_tokens=MAX_NEW_TOKENS,
395
- do_sample=False,
396
- temperature=0.0,
397
- )
398
-
399
- decoded = _tokenizer.decode(output[0], skip_special_tokens=True)
400
-
401
- if DEBUG_LLM:
402
- print("=== LLM RAW OUTPUT ===")
403
- print(decoded)
404
- print("======================")
405
-
406
- parsed_obj = _extract_first_json_object(decoded)
407
- fields_raw = parsed_obj.get("parsed_fields", {}) if isinstance(parsed_obj, dict) else {}
408
- if not isinstance(fields_raw, dict):
409
- fields_raw = {}
410
-
411
- fields_raw = _apply_field_aliases(fields_raw)
412
-
413
- cleaned: Dict[str, str] = {}
414
-
415
- for field in ALL_FIELDS:
416
- if field not in fields_raw:
417
- continue
418
-
419
- raw_val = fields_raw[field]
420
-
421
- if field == "Gram Stain":
422
- cleaned[field] = _normalise_gram(raw_val)
423
- elif field in PNV_FIELDS:
424
- cleaned[field] = _normalise_pnv_value(raw_val)
425
- else:
426
- cleaned[field] = _norm_str(raw_val) or "Unknown"
427
-
428
- cleaned = _merge_ornithine_variants(cleaned)
429
- cleaned = _apply_global_sugar_logic(cleaned, original)
430
-
431
- return {
432
- "parsed_fields": cleaned,
433
- "source": "llm_parser",
434
- "raw": original,
435
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # engine/parser_llm.py
2
+ # ------------------------------------------------------------
3
+ # Local LLM parser for BactAI-D (T5 fine-tune, CPU-friendly)
4
+ #
5
+ # UPDATED (EphBactAID integration):
6
+ # - Default model now points to your HF fine-tune: EphAsad/EphBactAID
7
+ # - Few-shot disabled by default (your fine-tune no longer needs it)
8
+ # - Robust output parsing:
9
+ # * Supports JSON output (legacy)
10
+ # * Supports "Key: Value" pairs output (your fine-tune style)
11
+ # - Merge guard (optional): LLM fills ONLY missing/Unknown fields
12
+ # - Validation/normalisation kept (PNV/Gram, sugar logic, aliases, ornithine sync)
13
+ # ------------------------------------------------------------
14
+
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ import os
19
+ import random
20
+ import re
21
+ from typing import Dict, Any, List, Optional
22
+
23
+ import torch
24
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
25
+
26
+
27
+ # ------------------------------------------------------------
28
+ # Model configuration
29
+ # ------------------------------------------------------------
30
+
31
+ # ✅ Your fine-tuned model (can be overridden via env var)
32
+ DEFAULT_MODEL = os.getenv(
33
+ "BACTAI_LLM_PARSER_MODEL",
34
+ "EphAsad/EphBactAID",
35
+ )
36
+
37
+ # Few-shot OFF by default now (fine-tune doesn't need it)
38
+ MAX_FEWSHOT_EXAMPLES = int(os.getenv("BACTAI_LLM_FEWSHOT", "0"))
39
+
40
+ MAX_NEW_TOKENS = int(os.getenv("BACTAI_LLM_MAX_NEW_TOKENS", "256"))
41
+
42
+ DEBUG_LLM = os.getenv("BACTAI_LLM_DEBUG", "0").strip().lower() in {
43
+ "1", "true", "yes", "y", "on"
44
+ }
45
+
46
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
47
+
48
+ _tokenizer: Optional[AutoTokenizer] = None
49
+ _model: Optional[AutoModelForSeq2SeqLM] = None
50
+ _GOLD_EXAMPLES: Optional[List[Dict[str, Any]]] = None
51
+
52
+
53
+ # ------------------------------------------------------------
54
+ # Allowed fields
55
+ # ------------------------------------------------------------
56
+
57
+ ALL_FIELDS: List[str] = [
58
+ "Gram Stain",
59
+ "Shape",
60
+ "Motility",
61
+ "Capsule",
62
+ "Spore Formation",
63
+ "Haemolysis",
64
+ "Haemolysis Type",
65
+ "Media Grown On",
66
+ "Colony Morphology",
67
+ "Oxygen Requirement",
68
+ "Growth Temperature",
69
+ "Catalase",
70
+ "Oxidase",
71
+ "Indole",
72
+ "Urease",
73
+ "Citrate",
74
+ "Methyl Red",
75
+ "VP",
76
+ "H2S",
77
+ "DNase",
78
+ "ONPG",
79
+ "Coagulase",
80
+ "Gelatin Hydrolysis",
81
+ "Esculin Hydrolysis",
82
+ "Nitrate Reduction",
83
+ "NaCl Tolerant (>=6%)",
84
+ "Lipase Test",
85
+ "Lysine Decarboxylase",
86
+ "Ornithine Decarboxylase",
87
+ "Ornitihine Decarboxylase",
88
+ "Arginine dihydrolase",
89
+ "Glucose Fermentation",
90
+ "Lactose Fermentation",
91
+ "Sucrose Fermentation",
92
+ "Maltose Fermentation",
93
+ "Mannitol Fermentation",
94
+ "Sorbitol Fermentation",
95
+ "Xylose Fermentation",
96
+ "Rhamnose Fermentation",
97
+ "Arabinose Fermentation",
98
+ "Raffinose Fermentation",
99
+ "Trehalose Fermentation",
100
+ "Inositol Fermentation",
101
+ "Gas Production",
102
+ "TSI Pattern",
103
+ "Colony Pattern",
104
+ "Pigment",
105
+ "Motility Type",
106
+ "Odor",
107
+ ]
108
+
109
+ SUGAR_FIELDS = [
110
+ "Glucose Fermentation",
111
+ "Lactose Fermentation",
112
+ "Sucrose Fermentation",
113
+ "Maltose Fermentation",
114
+ "Mannitol Fermentation",
115
+ "Sorbitol Fermentation",
116
+ "Xylose Fermentation",
117
+ "Rhamnose Fermentation",
118
+ "Arabinose Fermentation",
119
+ "Raffinose Fermentation",
120
+ "Trehalose Fermentation",
121
+ "Inositol Fermentation",
122
+ ]
123
+
124
+ PNV_FIELDS = {
125
+ f for f in ALL_FIELDS
126
+ if f not in {
127
+ "Media Grown On",
128
+ "Colony Morphology",
129
+ "Growth Temperature",
130
+ "Gram Stain",
131
+ "Shape",
132
+ "Oxygen Requirement",
133
+ "Haemolysis Type",
134
+ "TSI Pattern",
135
+ "Colony Pattern",
136
+ "Motility Type",
137
+ "Odor",
138
+ "Pigment",
139
+ "Gas Production",
140
+ }
141
+ }
142
+
143
+
144
+ # ------------------------------------------------------------
145
+ # Field alias mapping (CRITICAL)
146
+ # ------------------------------------------------------------
147
+
148
+ FIELD_ALIASES: Dict[str, str] = {
149
+ "Gram": "Gram Stain",
150
+ "Gram stain": "Gram Stain",
151
+ "Gram Stain Result": "Gram Stain",
152
+
153
+ "NaCl tolerance": "NaCl Tolerant (>=6%)",
154
+ "NaCl Tolerant": "NaCl Tolerant (>=6%)",
155
+ "Salt tolerance": "NaCl Tolerant (>=6%)",
156
+ "Salt tolerant": "NaCl Tolerant (>=6%)",
157
+ "6.5% NaCl": "NaCl Tolerant (>=6%)",
158
+ "6% NaCl": "NaCl Tolerant (>=6%)",
159
+
160
+ "Growth temp": "Growth Temperature",
161
+ "Growth temperature": "Growth Temperature",
162
+ "Temperature growth": "Growth Temperature",
163
+
164
+ "Catalase test": "Catalase",
165
+ "Oxidase test": "Oxidase",
166
+ "Indole test": "Indole",
167
+ "Urease test": "Urease",
168
+ "Citrate test": "Citrate",
169
+
170
+ "Glucose fermentation": "Glucose Fermentation",
171
+ "Lactose fermentation": "Lactose Fermentation",
172
+ "Sucrose fermentation": "Sucrose Fermentation",
173
+ "Maltose fermentation": "Maltose Fermentation",
174
+ "Mannitol fermentation": "Mannitol Fermentation",
175
+ "Sorbitol fermentation": "Sorbitol Fermentation",
176
+ "Xylose fermentation": "Xylose Fermentation",
177
+ "Rhamnose fermentation": "Rhamnose Fermentation",
178
+ "Arabinose fermentation": "Arabinose Fermentation",
179
+ "Raffinose fermentation": "Raffinose Fermentation",
180
+ "Trehalose fermentation": "Trehalose Fermentation",
181
+ "Inositol fermentation": "Inositol Fermentation",
182
+
183
+ # common variants from outputs
184
+ "Voges–Proskauer Test": "VP",
185
+ "Voges-Proskauer Test": "VP",
186
+ "Voges–Proskauer": "VP",
187
+ "Voges-Proskauer": "VP",
188
+ }
189
+
190
+
191
+ # ------------------------------------------------------------
192
+ # Normalisation helpers
193
+ # ------------------------------------------------------------
194
+
195
+ def _norm_str(s: Any) -> str:
196
+ return str(s).strip() if s is not None else ""
197
+
198
+
199
+ def _normalise_pnv_value(raw: Any) -> str:
200
+ s = _norm_str(raw).lower()
201
+ if not s:
202
+ return "Unknown"
203
+
204
+ # positive
205
+ if any(x in s for x in {"positive", "pos", "+", "yes", "present", "detected", "reactive"}):
206
+ return "Positive"
207
+
208
+ # negative
209
+ if any(x in s for x in {"negative", "neg", "-", "no", "none", "absent", "not detected", "no growth"}):
210
+ return "Negative"
211
+
212
+ # variable
213
+ if any(x in s for x in {"variable", "mixed", "inconsistent"}):
214
+ return "Variable"
215
+
216
+ return "Unknown"
217
+
218
+
219
+ def _normalise_gram(raw: Any) -> str:
220
+ s = _norm_str(raw).lower()
221
+ if "positive" in s:
222
+ return "Positive"
223
+ if "negative" in s:
224
+ return "Negative"
225
+ if "variable" in s:
226
+ return "Variable"
227
+ return "Unknown"
228
+
229
+
230
+ def _merge_ornithine_variants(fields: Dict[str, str]) -> Dict[str, str]:
231
+ v = fields.get("Ornithine Decarboxylase") or fields.get("Ornitihine Decarboxylase")
232
+ if v and v != "Unknown":
233
+ fields["Ornithine Decarboxylase"] = v
234
+ fields["Ornitihine Decarboxylase"] = v
235
+ return fields
236
+
237
+
238
+ # ------------------------------------------------------------
239
+ # Sugar logic
240
+ # ------------------------------------------------------------
241
+
242
+ _NON_FERMENTER_PATTERNS = re.compile(
243
+ r"\b("
244
+ r"non[-\s]?fermenter|"
245
+ r"non[-\s]?fermentative|"
246
+ r"asaccharolytic|"
247
+ r"does not ferment (sugars|carbohydrates)|"
248
+ r"no carbohydrate fermentation"
249
+ r")\b",
250
+ re.IGNORECASE,
251
+ )
252
+
253
+
254
+ def _apply_global_sugar_logic(fields: Dict[str, str], original_text: str) -> Dict[str, str]:
255
+ if not _NON_FERMENTER_PATTERNS.search(original_text):
256
+ return fields
257
+
258
+ for sugar in SUGAR_FIELDS:
259
+ if fields.get(sugar) in {"Positive", "Variable"}:
260
+ continue
261
+ fields[sugar] = "Negative"
262
+
263
+ return fields
264
+
265
+
266
+ # ------------------------------------------------------------
267
+ # Gold examples (kept for backwards compat; now optional)
268
+ # ------------------------------------------------------------
269
+
270
+ def _get_project_root() -> str:
271
+ return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
272
+
273
+
274
+ def _load_gold_examples() -> List[Dict[str, Any]]:
275
+ global _GOLD_EXAMPLES
276
+ if _GOLD_EXAMPLES is not None:
277
+ return _GOLD_EXAMPLES
278
+
279
+ path = os.path.join(_get_project_root(), "data", "llm_gold_examples.json")
280
+ try:
281
+ with open(path, "r", encoding="utf-8") as f:
282
+ data = json.load(f)
283
+ _GOLD_EXAMPLES = data if isinstance(data, list) else []
284
+ except Exception:
285
+ _GOLD_EXAMPLES = []
286
+
287
+ return _GOLD_EXAMPLES
288
+
289
+
290
+ # ------------------------------------------------------------
291
+ # Prompt (supports both JSON + KV outputs; fine-tune usually KV)
292
+ # ------------------------------------------------------------
293
+
294
+ PROMPT_HEADER = """
295
+ You are a microbiology phenotype parser.
296
+
297
+ Task:
298
+ - Extract ONLY explicitly stated results from the input text.
299
+ - Do NOT invent results.
300
+ - If not stated, omit the field or use "Unknown".
301
+
302
+ Output format:
303
+ - Prefer "Field: Value" lines, one per line.
304
+ - You may also output JSON if instructed.
305
+
306
+ Use the exact schema keys where possible.
307
+ """
308
+
309
+ PROMPT_FOOTER = """
310
+ Input:
311
+ \"\"\"<<PHENOTYPE>>\"\"\"
312
+
313
+ Output:
314
+ """
315
+
316
+
317
+ def _build_prompt(text: str) -> str:
318
+ # Few-shot disabled by default; but we keep the capability for testing.
319
+ blocks: List[str] = [PROMPT_HEADER]
320
+
321
+ if MAX_FEWSHOT_EXAMPLES > 0:
322
+ examples = _load_gold_examples()
323
+ n = min(MAX_FEWSHOT_EXAMPLES, len(examples))
324
+ sampled = random.sample(examples, n) if n > 0 else []
325
+ for ex in sampled:
326
+ inp = _norm_str(ex.get("input", ""))
327
+ exp = ex.get("expected", {})
328
+ if not isinstance(exp, dict):
329
+ exp = {}
330
+ # Show KV style to match your fine-tune
331
+ kv_lines = "\n".join([f"{k}: {v}" for k, v in exp.items()])
332
+ blocks.append(f'Example Input:\n"""{inp}"""\nExample Output:\n{kv_lines}\n')
333
+
334
+ blocks.append(PROMPT_FOOTER.replace("<<PHENOTYPE>>", text))
335
+ return "\n".join(blocks)
336
+
337
+
338
+ # ------------------------------------------------------------
339
+ # Model loader
340
+ # ------------------------------------------------------------
341
+
342
+ def _load_model() -> None:
343
+ global _model, _tokenizer
344
+ if _model is not None and _tokenizer is not None:
345
+ return
346
+
347
+ _tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL)
348
+ _model = AutoModelForSeq2SeqLM.from_pretrained(DEFAULT_MODEL).to(DEVICE)
349
+ _model.eval()
350
+
351
+
352
+ # ------------------------------------------------------------
353
+ # Output parsing helpers (JSON + KV)
354
+ # ------------------------------------------------------------
355
+
356
+ _JSON_OBJECT_RE = re.compile(r"\{[\s\S]*?\}")
357
+
358
+
359
+ def _extract_first_json_object(text: str) -> Dict[str, Any]:
360
+ m = _JSON_OBJECT_RE.search(text)
361
+ if not m:
362
+ return {}
363
+ try:
364
+ return json.loads(m.group(0))
365
+ except Exception:
366
+ return {}
367
+
368
+
369
+ # Match "Key: Value" (including keys with symbols like >=6%)
370
+ _KV_LINE_RE = re.compile(r"^\s*([^:\n]{2,120})\s*:\s*(.*?)\s*$")
371
+
372
+
373
+ def _extract_kv_pairs(text: str) -> Dict[str, Any]:
374
+ """
375
+ Parse outputs like:
376
+ Gram Stain: Positive
377
+ Shape: Cocci
378
+ ...
379
+ """
380
+ out: Dict[str, Any] = {}
381
+ for line in (text or "").splitlines():
382
+ line = line.strip()
383
+ if not line:
384
+ continue
385
+ m = _KV_LINE_RE.match(line)
386
+ if not m:
387
+ continue
388
+ k = _norm_str(m.group(1))
389
+ v = _norm_str(m.group(2))
390
+ if not k:
391
+ continue
392
+ out[k] = v
393
+ return out
394
+
395
+
396
+ def _apply_field_aliases(fields_raw: Dict[str, Any]) -> Dict[str, Any]:
397
+ out: Dict[str, Any] = {}
398
+ for k, v in fields_raw.items():
399
+ key = _norm_str(k)
400
+ if not key:
401
+ continue
402
+ mapped = FIELD_ALIASES.get(key, key)
403
+ out[mapped] = v
404
+ return out
405
+
406
+
407
+ def _clean_and_normalise(fields_raw: Dict[str, Any], original_text: str) -> Dict[str, str]:
408
+ """
409
+ Keep only allowed fields and normalise values into your contract.
410
+ """
411
+ cleaned: Dict[str, str] = {}
412
+
413
+ # Only accept keys that match schema (or aliases already applied)
414
+ for field in ALL_FIELDS:
415
+ if field not in fields_raw:
416
+ continue
417
+
418
+ raw_val = fields_raw[field]
419
+
420
+ if field == "Gram Stain":
421
+ cleaned[field] = _normalise_gram(raw_val)
422
+ elif field in PNV_FIELDS:
423
+ cleaned[field] = _normalise_pnv_value(raw_val)
424
+ else:
425
+ cleaned[field] = _norm_str(raw_val) or "Unknown"
426
+
427
+ cleaned = _merge_ornithine_variants(cleaned)
428
+ cleaned = _apply_global_sugar_logic(cleaned, original_text)
429
+ return cleaned
430
+
431
+
432
+ def _merge_guard_fill_only_missing(
433
+ llm_fields: Dict[str, str],
434
+ existing_fields: Optional[Dict[str, Any]],
435
+ ) -> Dict[str, str]:
436
+ """
437
+ Merge guard:
438
+ - If an existing field is present and not Unknown -> do NOT overwrite.
439
+ - If existing is missing/Unknown -> allow llm value (if not Unknown).
440
+ """
441
+ if not existing_fields or not isinstance(existing_fields, dict):
442
+ return llm_fields
443
+
444
+ out = dict(existing_fields) # start with existing
445
+ for k, v in llm_fields.items():
446
+ if k not in ALL_FIELDS:
447
+ continue
448
+ existing_val = _norm_str(out.get(k, ""))
449
+ existing_norm = _normalise_pnv_value(existing_val) if k in PNV_FIELDS else existing_val
450
+
451
+ # Treat empty/Unknown as fillable
452
+ fillable = (not existing_val) or (existing_val == "Unknown") or (existing_norm == "Unknown")
453
+ if not fillable:
454
+ continue
455
+
456
+ # Only fill if LLM has something meaningful
457
+ if _norm_str(v) and v != "Unknown":
458
+ out[k] = v
459
+
460
+ # Ensure we return only schema keys and strings
461
+ final: Dict[str, str] = {}
462
+ for k, v in out.items():
463
+ if k in ALL_FIELDS:
464
+ final[k] = _norm_str(v) or "Unknown"
465
+ return final
466
+
467
+
468
+ # ------------------------------------------------------------
469
+ # PUBLIC API
470
+ # ------------------------------------------------------------
471
+
472
+ def parse_llm(text: str, existing_fields: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
473
+ """
474
+ Parse phenotype text using local seq2seq model.
475
+
476
+ Parameters
477
+ ----------
478
+ text : str
479
+ phenotype description
480
+
481
+ existing_fields : dict | None
482
+ Optional pre-parsed fields (e.g., from rules/ext).
483
+ If provided, LLM will ONLY fill missing/Unknown fields.
484
+
485
+ Returns
486
+ -------
487
+ dict:
488
+ {
489
+ "parsed_fields": { ... },
490
+ "source": "llm_parser",
491
+ "raw": <original text>,
492
+ "decoded": <model output> (only when DEBUG on)
493
+ }
494
+ """
495
+ original = text or ""
496
+ if not original.strip():
497
+ return {
498
+ "parsed_fields": (existing_fields or {}) if isinstance(existing_fields, dict) else {},
499
+ "source": "llm_parser",
500
+ "raw": original,
501
+ }
502
+
503
+ _load_model()
504
+ assert _tokenizer is not None and _model is not None
505
+
506
+ prompt = _build_prompt(original)
507
+ inputs = _tokenizer(prompt, return_tensors="pt", truncation=True).to(DEVICE)
508
+
509
+ with torch.no_grad():
510
+ output = _model.generate(
511
+ **inputs,
512
+ max_new_tokens=MAX_NEW_TOKENS,
513
+ do_sample=False,
514
+ temperature=0.0,
515
+ )
516
+
517
+ decoded = _tokenizer.decode(output[0], skip_special_tokens=True)
518
+
519
+ if DEBUG_LLM:
520
+ print("=== LLM PROMPT (truncated) ===")
521
+ print(prompt[:1500] + ("..." if len(prompt) > 1500 else ""))
522
+ print("=== LLM RAW OUTPUT ===")
523
+ print(decoded)
524
+ print("======================")
525
+
526
+ # 1) Try JSON extraction (legacy)
527
+ parsed_obj = _extract_first_json_object(decoded)
528
+ fields_raw = {}
529
+
530
+ if isinstance(parsed_obj, dict) and parsed_obj:
531
+ if "parsed_fields" in parsed_obj and isinstance(parsed_obj.get("parsed_fields"), dict):
532
+ fields_raw = dict(parsed_obj["parsed_fields"])
533
+ else:
534
+ # in case model returned a flat JSON dict
535
+ fields_raw = dict(parsed_obj)
536
+
537
+ # 2) Fallback to KV parsing (your fine-tune style)
538
+ if not fields_raw:
539
+ fields_raw = _extract_kv_pairs(decoded)
540
+
541
+ # 3) Alias map + normalise
542
+ fields_raw = _apply_field_aliases(fields_raw)
543
+ cleaned = _clean_and_normalise(fields_raw, original)
544
+
545
+ # 4) Merge guard (optional) - fill only missing/Unknown
546
+ if existing_fields is not None:
547
+ cleaned = _merge_guard_fill_only_missing(cleaned, existing_fields)
548
+
549
+ out = {
550
+ "parsed_fields": cleaned,
551
+ "source": "llm_parser",
552
+ "raw": original,
553
+ }
554
+ if DEBUG_LLM:
555
+ out["decoded"] = decoded
556
+ return out