|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import json |
|
|
import os |
|
|
import random |
|
|
import re |
|
|
from typing import Dict, Any, List, Optional |
|
|
|
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_MODEL = os.getenv( |
|
|
"BACTAI_LLM_PARSER_MODEL", |
|
|
"EphAsad/BactAID-v2", |
|
|
) |
|
|
|
|
|
|
|
|
MAX_FEWSHOT_EXAMPLES = int(os.getenv("BACTAI_LLM_FEWSHOT", "0")) |
|
|
|
|
|
MAX_NEW_TOKENS = int(os.getenv("BACTAI_LLM_MAX_NEW_TOKENS", "256")) |
|
|
|
|
|
DEBUG_LLM = os.getenv("BACTAI_LLM_DEBUG", "0").strip().lower() in { |
|
|
"1", "true", "yes", "y", "on" |
|
|
} |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
_tokenizer: Optional[AutoTokenizer] = None |
|
|
_model: Optional[AutoModelForSeq2SeqLM] = None |
|
|
_GOLD_EXAMPLES: Optional[List[Dict[str, Any]]] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ALL_FIELDS: List[str] = [ |
|
|
"Gram Stain", |
|
|
"Shape", |
|
|
"Motility", |
|
|
"Capsule", |
|
|
"Spore Formation", |
|
|
"Haemolysis", |
|
|
"Haemolysis Type", |
|
|
"Media Grown On", |
|
|
"Colony Morphology", |
|
|
"Oxygen Requirement", |
|
|
"Growth Temperature", |
|
|
"Catalase", |
|
|
"Oxidase", |
|
|
"Indole", |
|
|
"Urease", |
|
|
"Citrate", |
|
|
"Methyl Red", |
|
|
"VP", |
|
|
"H2S", |
|
|
"DNase", |
|
|
"ONPG", |
|
|
"Coagulase", |
|
|
"Gelatin Hydrolysis", |
|
|
"Esculin Hydrolysis", |
|
|
"Nitrate Reduction", |
|
|
"NaCl Tolerant (>=6%)", |
|
|
"Lipase Test", |
|
|
"Lysine Decarboxylase", |
|
|
"Ornithine Decarboxylase", |
|
|
"Ornitihine Decarboxylase", |
|
|
"Arginine dihydrolase", |
|
|
"Glucose Fermentation", |
|
|
"Lactose Fermentation", |
|
|
"Sucrose Fermentation", |
|
|
"Maltose Fermentation", |
|
|
"Mannitol Fermentation", |
|
|
"Sorbitol Fermentation", |
|
|
"Xylose Fermentation", |
|
|
"Rhamnose Fermentation", |
|
|
"Arabinose Fermentation", |
|
|
"Raffinose Fermentation", |
|
|
"Trehalose Fermentation", |
|
|
"Inositol Fermentation", |
|
|
"Gas Production", |
|
|
"TSI Pattern", |
|
|
"Colony Pattern", |
|
|
"Pigment", |
|
|
"Motility Type", |
|
|
"Odor", |
|
|
] |
|
|
|
|
|
SUGAR_FIELDS = [ |
|
|
"Glucose Fermentation", |
|
|
"Lactose Fermentation", |
|
|
"Sucrose Fermentation", |
|
|
"Maltose Fermentation", |
|
|
"Mannitol Fermentation", |
|
|
"Sorbitol Fermentation", |
|
|
"Xylose Fermentation", |
|
|
"Rhamnose Fermentation", |
|
|
"Arabinose Fermentation", |
|
|
"Raffinose Fermentation", |
|
|
"Trehalose Fermentation", |
|
|
"Inositol Fermentation", |
|
|
] |
|
|
|
|
|
PNV_FIELDS = { |
|
|
f for f in ALL_FIELDS |
|
|
if f not in { |
|
|
"Media Grown On", |
|
|
"Colony Morphology", |
|
|
"Growth Temperature", |
|
|
"Gram Stain", |
|
|
"Shape", |
|
|
"Oxygen Requirement", |
|
|
"Haemolysis Type", |
|
|
"TSI Pattern", |
|
|
"Colony Pattern", |
|
|
"Motility Type", |
|
|
"Odor", |
|
|
"Pigment", |
|
|
"Gas Production", |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FIELD_ALIASES: Dict[str, str] = { |
|
|
"Gram": "Gram Stain", |
|
|
"Gram stain": "Gram Stain", |
|
|
"Gram Stain Result": "Gram Stain", |
|
|
|
|
|
"NaCl tolerance": "NaCl Tolerant (>=6%)", |
|
|
"NaCl Tolerant": "NaCl Tolerant (>=6%)", |
|
|
"Salt tolerance": "NaCl Tolerant (>=6%)", |
|
|
"Salt tolerant": "NaCl Tolerant (>=6%)", |
|
|
"6.5% NaCl": "NaCl Tolerant (>=6%)", |
|
|
"6% NaCl": "NaCl Tolerant (>=6%)", |
|
|
|
|
|
"Growth temp": "Growth Temperature", |
|
|
"Growth temperature": "Growth Temperature", |
|
|
"Temperature growth": "Growth Temperature", |
|
|
|
|
|
"Catalase test": "Catalase", |
|
|
"Oxidase test": "Oxidase", |
|
|
"Indole test": "Indole", |
|
|
"Urease test": "Urease", |
|
|
"Citrate test": "Citrate", |
|
|
|
|
|
"Glucose fermentation": "Glucose Fermentation", |
|
|
"Lactose fermentation": "Lactose Fermentation", |
|
|
"Sucrose fermentation": "Sucrose Fermentation", |
|
|
"Maltose fermentation": "Maltose Fermentation", |
|
|
"Mannitol fermentation": "Mannitol Fermentation", |
|
|
"Sorbitol fermentation": "Sorbitol Fermentation", |
|
|
"Xylose fermentation": "Xylose Fermentation", |
|
|
"Rhamnose fermentation": "Rhamnose Fermentation", |
|
|
"Arabinose fermentation": "Arabinose Fermentation", |
|
|
"Raffinose fermentation": "Raffinose Fermentation", |
|
|
"Trehalose fermentation": "Trehalose Fermentation", |
|
|
"Inositol fermentation": "Inositol Fermentation", |
|
|
|
|
|
|
|
|
"Voges–Proskauer Test": "VP", |
|
|
"Voges-Proskauer Test": "VP", |
|
|
"Voges–Proskauer": "VP", |
|
|
"Voges-Proskauer": "VP", |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _norm_str(s: Any) -> str: |
|
|
return str(s).strip() if s is not None else "" |
|
|
|
|
|
|
|
|
def _normalise_pnv_value(raw: Any) -> str: |
|
|
s = _norm_str(raw).lower() |
|
|
if not s: |
|
|
return "Unknown" |
|
|
|
|
|
|
|
|
if any(x in s for x in {"positive", "pos", "+", "yes", "present", "detected", "reactive"}): |
|
|
return "Positive" |
|
|
|
|
|
|
|
|
if any(x in s for x in {"negative", "neg", "-", "no", "none", "absent", "not detected", "no growth"}): |
|
|
return "Negative" |
|
|
|
|
|
|
|
|
if any(x in s for x in {"variable", "mixed", "inconsistent"}): |
|
|
return "Variable" |
|
|
|
|
|
return "Unknown" |
|
|
|
|
|
|
|
|
def _normalise_gram(raw: Any) -> str: |
|
|
s = _norm_str(raw).lower() |
|
|
if "positive" in s: |
|
|
return "Positive" |
|
|
if "negative" in s: |
|
|
return "Negative" |
|
|
if "variable" in s: |
|
|
return "Variable" |
|
|
return "Unknown" |
|
|
|
|
|
|
|
|
def _merge_ornithine_variants(fields: Dict[str, str]) -> Dict[str, str]: |
|
|
v = fields.get("Ornithine Decarboxylase") or fields.get("Ornitihine Decarboxylase") |
|
|
if v and v != "Unknown": |
|
|
fields["Ornithine Decarboxylase"] = v |
|
|
fields["Ornitihine Decarboxylase"] = v |
|
|
return fields |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_NON_FERMENTER_PATTERNS = re.compile( |
|
|
r"\b(" |
|
|
r"non[-\s]?fermenter|" |
|
|
r"non[-\s]?fermentative|" |
|
|
r"asaccharolytic|" |
|
|
r"does not ferment (sugars|carbohydrates)|" |
|
|
r"no carbohydrate fermentation" |
|
|
r")\b", |
|
|
re.IGNORECASE, |
|
|
) |
|
|
|
|
|
|
|
|
def _apply_global_sugar_logic(fields: Dict[str, str], original_text: str) -> Dict[str, str]: |
|
|
if not _NON_FERMENTER_PATTERNS.search(original_text): |
|
|
return fields |
|
|
|
|
|
for sugar in SUGAR_FIELDS: |
|
|
if fields.get(sugar) in {"Positive", "Variable"}: |
|
|
continue |
|
|
fields[sugar] = "Negative" |
|
|
|
|
|
return fields |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_project_root() -> str: |
|
|
return os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
|
|
|
|
|
|
|
def _load_gold_examples() -> List[Dict[str, Any]]: |
|
|
global _GOLD_EXAMPLES |
|
|
if _GOLD_EXAMPLES is not None: |
|
|
return _GOLD_EXAMPLES |
|
|
|
|
|
path = os.path.join(_get_project_root(), "data", "llm_gold_examples.json") |
|
|
try: |
|
|
with open(path, "r", encoding="utf-8") as f: |
|
|
data = json.load(f) |
|
|
_GOLD_EXAMPLES = data if isinstance(data, list) else [] |
|
|
except Exception: |
|
|
_GOLD_EXAMPLES = [] |
|
|
|
|
|
return _GOLD_EXAMPLES |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PROMPT_HEADER = """ |
|
|
You are a microbiology phenotype parser. |
|
|
|
|
|
Task: |
|
|
- Extract ONLY explicitly stated results from the input text. |
|
|
- Do NOT invent results. |
|
|
- If not stated, omit the field or use "Unknown". |
|
|
|
|
|
Output format: |
|
|
- Prefer "Field: Value" lines, one per line. |
|
|
- You may also output JSON if instructed. |
|
|
|
|
|
Use the exact schema keys where possible. |
|
|
""" |
|
|
|
|
|
PROMPT_FOOTER = """ |
|
|
Input: |
|
|
\"\"\"<<PHENOTYPE>>\"\"\" |
|
|
|
|
|
Output: |
|
|
""" |
|
|
|
|
|
|
|
|
def _build_prompt(text: str) -> str: |
|
|
|
|
|
blocks: List[str] = [PROMPT_HEADER] |
|
|
|
|
|
if MAX_FEWSHOT_EXAMPLES > 0: |
|
|
examples = _load_gold_examples() |
|
|
n = min(MAX_FEWSHOT_EXAMPLES, len(examples)) |
|
|
sampled = random.sample(examples, n) if n > 0 else [] |
|
|
for ex in sampled: |
|
|
inp = _norm_str(ex.get("input", "")) |
|
|
exp = ex.get("expected", {}) |
|
|
if not isinstance(exp, dict): |
|
|
exp = {} |
|
|
|
|
|
kv_lines = "\n".join([f"{k}: {v}" for k, v in exp.items()]) |
|
|
blocks.append(f'Example Input:\n"""{inp}"""\nExample Output:\n{kv_lines}\n') |
|
|
|
|
|
blocks.append(PROMPT_FOOTER.replace("<<PHENOTYPE>>", text)) |
|
|
return "\n".join(blocks) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_model() -> None: |
|
|
global _model, _tokenizer |
|
|
if _model is not None and _tokenizer is not None: |
|
|
return |
|
|
|
|
|
_tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL) |
|
|
_model = AutoModelForSeq2SeqLM.from_pretrained(DEFAULT_MODEL).to(DEVICE) |
|
|
_model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_JSON_OBJECT_RE = re.compile(r"\{[\s\S]*?\}") |
|
|
|
|
|
|
|
|
def _extract_first_json_object(text: str) -> Dict[str, Any]: |
|
|
m = _JSON_OBJECT_RE.search(text) |
|
|
if not m: |
|
|
return {} |
|
|
try: |
|
|
return json.loads(m.group(0)) |
|
|
except Exception: |
|
|
return {} |
|
|
|
|
|
|
|
|
|
|
|
_KV_LINE_RE = re.compile(r"^\s*([^:\n]{2,120})\s*:\s*(.*?)\s*$") |
|
|
|
|
|
|
|
|
def _extract_kv_pairs(text: str) -> Dict[str, Any]: |
|
|
""" |
|
|
Parse outputs like: |
|
|
Gram Stain: Positive |
|
|
Shape: Cocci |
|
|
... |
|
|
""" |
|
|
out: Dict[str, Any] = {} |
|
|
for line in (text or "").splitlines(): |
|
|
line = line.strip() |
|
|
if not line: |
|
|
continue |
|
|
m = _KV_LINE_RE.match(line) |
|
|
if not m: |
|
|
continue |
|
|
k = _norm_str(m.group(1)) |
|
|
v = _norm_str(m.group(2)) |
|
|
if not k: |
|
|
continue |
|
|
out[k] = v |
|
|
return out |
|
|
|
|
|
|
|
|
def _apply_field_aliases(fields_raw: Dict[str, Any]) -> Dict[str, Any]: |
|
|
out: Dict[str, Any] = {} |
|
|
for k, v in fields_raw.items(): |
|
|
key = _norm_str(k) |
|
|
if not key: |
|
|
continue |
|
|
mapped = FIELD_ALIASES.get(key, key) |
|
|
out[mapped] = v |
|
|
return out |
|
|
|
|
|
|
|
|
def _clean_and_normalise(fields_raw: Dict[str, Any], original_text: str) -> Dict[str, str]: |
|
|
""" |
|
|
Keep only allowed fields and normalise values into your contract. |
|
|
""" |
|
|
cleaned: Dict[str, str] = {} |
|
|
|
|
|
|
|
|
for field in ALL_FIELDS: |
|
|
if field not in fields_raw: |
|
|
continue |
|
|
|
|
|
raw_val = fields_raw[field] |
|
|
|
|
|
if field == "Gram Stain": |
|
|
cleaned[field] = _normalise_gram(raw_val) |
|
|
elif field in PNV_FIELDS: |
|
|
cleaned[field] = _normalise_pnv_value(raw_val) |
|
|
else: |
|
|
cleaned[field] = _norm_str(raw_val) or "Unknown" |
|
|
|
|
|
cleaned = _merge_ornithine_variants(cleaned) |
|
|
cleaned = _apply_global_sugar_logic(cleaned, original_text) |
|
|
return cleaned |
|
|
|
|
|
|
|
|
def _merge_guard_fill_only_missing( |
|
|
llm_fields: Dict[str, str], |
|
|
existing_fields: Optional[Dict[str, Any]], |
|
|
) -> Dict[str, str]: |
|
|
""" |
|
|
Merge guard: |
|
|
- If an existing field is present and not Unknown -> do NOT overwrite. |
|
|
- If existing is missing/Unknown -> allow llm value (if not Unknown). |
|
|
""" |
|
|
if not existing_fields or not isinstance(existing_fields, dict): |
|
|
return llm_fields |
|
|
|
|
|
out = dict(existing_fields) |
|
|
for k, v in llm_fields.items(): |
|
|
if k not in ALL_FIELDS: |
|
|
continue |
|
|
existing_val = _norm_str(out.get(k, "")) |
|
|
existing_norm = _normalise_pnv_value(existing_val) if k in PNV_FIELDS else existing_val |
|
|
|
|
|
|
|
|
fillable = (not existing_val) or (existing_val == "Unknown") or (existing_norm == "Unknown") |
|
|
if not fillable: |
|
|
continue |
|
|
|
|
|
|
|
|
if _norm_str(v) and v != "Unknown": |
|
|
out[k] = v |
|
|
|
|
|
|
|
|
final: Dict[str, str] = {} |
|
|
for k, v in out.items(): |
|
|
if k in ALL_FIELDS: |
|
|
final[k] = _norm_str(v) or "Unknown" |
|
|
return final |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_llm(text: str, existing_fields: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: |
|
|
""" |
|
|
Parse phenotype text using local seq2seq model. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
text : str |
|
|
phenotype description |
|
|
|
|
|
existing_fields : dict | None |
|
|
Optional pre-parsed fields (e.g., from rules/ext). |
|
|
If provided, LLM will ONLY fill missing/Unknown fields. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
dict: |
|
|
{ |
|
|
"parsed_fields": { ... }, |
|
|
"source": "llm_parser", |
|
|
"raw": <original text>, |
|
|
"decoded": <model output> (only when DEBUG on) |
|
|
} |
|
|
""" |
|
|
original = text or "" |
|
|
if not original.strip(): |
|
|
return { |
|
|
"parsed_fields": (existing_fields or {}) if isinstance(existing_fields, dict) else {}, |
|
|
"source": "llm_parser", |
|
|
"raw": original, |
|
|
} |
|
|
|
|
|
_load_model() |
|
|
assert _tokenizer is not None and _model is not None |
|
|
|
|
|
prompt = _build_prompt(original) |
|
|
inputs = _tokenizer(prompt, return_tensors="pt", truncation=True).to(DEVICE) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = _model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=MAX_NEW_TOKENS, |
|
|
do_sample=False, |
|
|
temperature=0.0, |
|
|
) |
|
|
|
|
|
decoded = _tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
|
|
if DEBUG_LLM: |
|
|
print("=== LLM PROMPT (truncated) ===") |
|
|
print(prompt[:1500] + ("..." if len(prompt) > 1500 else "")) |
|
|
print("=== LLM RAW OUTPUT ===") |
|
|
print(decoded) |
|
|
print("======================") |
|
|
|
|
|
|
|
|
parsed_obj = _extract_first_json_object(decoded) |
|
|
fields_raw = {} |
|
|
|
|
|
if isinstance(parsed_obj, dict) and parsed_obj: |
|
|
if "parsed_fields" in parsed_obj and isinstance(parsed_obj.get("parsed_fields"), dict): |
|
|
fields_raw = dict(parsed_obj["parsed_fields"]) |
|
|
else: |
|
|
|
|
|
fields_raw = dict(parsed_obj) |
|
|
|
|
|
|
|
|
if not fields_raw: |
|
|
fields_raw = _extract_kv_pairs(decoded) |
|
|
|
|
|
|
|
|
fields_raw = _apply_field_aliases(fields_raw) |
|
|
cleaned = _clean_and_normalise(fields_raw, original) |
|
|
|
|
|
|
|
|
if existing_fields is not None: |
|
|
cleaned = _merge_guard_fill_only_missing(cleaned, existing_fields) |
|
|
|
|
|
out = { |
|
|
"parsed_fields": cleaned, |
|
|
"source": "llm_parser", |
|
|
"raw": original, |
|
|
} |
|
|
if DEBUG_LLM: |
|
|
out["decoded"] = decoded |
|
|
return out |