Spaces:
Running
Running
| # engine/parser_llm.py | |
| # ------------------------------------------------------------ | |
| # Local LLM parser for BactAI-D (T5 fine-tune, CPU-friendly) | |
| # | |
| # UPDATED (EphBactAID integration): | |
| # - Default model now points to your HF fine-tune: EphAsad/EphBactAID | |
| # - Few-shot disabled by default (your fine-tune no longer needs it) | |
| # - Robust output parsing: | |
| # * Supports JSON output (legacy) | |
| # * Supports "Key: Value" pairs output (your fine-tune style) | |
| # - Merge guard (optional): LLM fills ONLY missing/Unknown fields | |
| # - Validation/normalisation kept (PNV/Gram, sugar logic, aliases, ornithine sync) | |
| # ------------------------------------------------------------ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import random | |
| import re | |
| from typing import Dict, Any, List, Optional | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| # ------------------------------------------------------------ | |
| # Model configuration | |
| # ------------------------------------------------------------ | |
| # ✅ Your fine-tuned model (can be overridden via env var) | |
| DEFAULT_MODEL = os.getenv( | |
| "BACTAI_LLM_PARSER_MODEL", | |
| "EphAsad/EphBactAID", | |
| ) | |
| # ✅ Few-shot OFF by default now (fine-tune doesn't need it) | |
| MAX_FEWSHOT_EXAMPLES = int(os.getenv("BACTAI_LLM_FEWSHOT", "0")) | |
| MAX_NEW_TOKENS = int(os.getenv("BACTAI_LLM_MAX_NEW_TOKENS", "256")) | |
| DEBUG_LLM = os.getenv("BACTAI_LLM_DEBUG", "0").strip().lower() in { | |
| "1", "true", "yes", "y", "on" | |
| } | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| _tokenizer: Optional[AutoTokenizer] = None | |
| _model: Optional[AutoModelForSeq2SeqLM] = None | |
| _GOLD_EXAMPLES: Optional[List[Dict[str, Any]]] = None | |
| # ------------------------------------------------------------ | |
| # Allowed fields | |
| # ------------------------------------------------------------ | |
| ALL_FIELDS: List[str] = [ | |
| "Gram Stain", | |
| "Shape", | |
| "Motility", | |
| "Capsule", | |
| "Spore Formation", | |
| "Haemolysis", | |
| "Haemolysis Type", | |
| "Media Grown On", | |
| "Colony Morphology", | |
| "Oxygen Requirement", | |
| "Growth Temperature", | |
| "Catalase", | |
| "Oxidase", | |
| "Indole", | |
| "Urease", | |
| "Citrate", | |
| "Methyl Red", | |
| "VP", | |
| "H2S", | |
| "DNase", | |
| "ONPG", | |
| "Coagulase", | |
| "Gelatin Hydrolysis", | |
| "Esculin Hydrolysis", | |
| "Nitrate Reduction", | |
| "NaCl Tolerant (>=6%)", | |
| "Lipase Test", | |
| "Lysine Decarboxylase", | |
| "Ornithine Decarboxylase", | |
| "Ornitihine Decarboxylase", | |
| "Arginine dihydrolase", | |
| "Glucose Fermentation", | |
| "Lactose Fermentation", | |
| "Sucrose Fermentation", | |
| "Maltose Fermentation", | |
| "Mannitol Fermentation", | |
| "Sorbitol Fermentation", | |
| "Xylose Fermentation", | |
| "Rhamnose Fermentation", | |
| "Arabinose Fermentation", | |
| "Raffinose Fermentation", | |
| "Trehalose Fermentation", | |
| "Inositol Fermentation", | |
| "Gas Production", | |
| "TSI Pattern", | |
| "Colony Pattern", | |
| "Pigment", | |
| "Motility Type", | |
| "Odor", | |
| ] | |
| SUGAR_FIELDS = [ | |
| "Glucose Fermentation", | |
| "Lactose Fermentation", | |
| "Sucrose Fermentation", | |
| "Maltose Fermentation", | |
| "Mannitol Fermentation", | |
| "Sorbitol Fermentation", | |
| "Xylose Fermentation", | |
| "Rhamnose Fermentation", | |
| "Arabinose Fermentation", | |
| "Raffinose Fermentation", | |
| "Trehalose Fermentation", | |
| "Inositol Fermentation", | |
| ] | |
| PNV_FIELDS = { | |
| f for f in ALL_FIELDS | |
| if f not in { | |
| "Media Grown On", | |
| "Colony Morphology", | |
| "Growth Temperature", | |
| "Gram Stain", | |
| "Shape", | |
| "Oxygen Requirement", | |
| "Haemolysis Type", | |
| "TSI Pattern", | |
| "Colony Pattern", | |
| "Motility Type", | |
| "Odor", | |
| "Pigment", | |
| "Gas Production", | |
| } | |
| } | |
| # ------------------------------------------------------------ | |
| # Field alias mapping (CRITICAL) | |
| # ------------------------------------------------------------ | |
| FIELD_ALIASES: Dict[str, str] = { | |
| "Gram": "Gram Stain", | |
| "Gram stain": "Gram Stain", | |
| "Gram Stain Result": "Gram Stain", | |
| "NaCl tolerance": "NaCl Tolerant (>=6%)", | |
| "NaCl Tolerant": "NaCl Tolerant (>=6%)", | |
| "Salt tolerance": "NaCl Tolerant (>=6%)", | |
| "Salt tolerant": "NaCl Tolerant (>=6%)", | |
| "6.5% NaCl": "NaCl Tolerant (>=6%)", | |
| "6% NaCl": "NaCl Tolerant (>=6%)", | |
| "Growth temp": "Growth Temperature", | |
| "Growth temperature": "Growth Temperature", | |
| "Temperature growth": "Growth Temperature", | |
| "Catalase test": "Catalase", | |
| "Oxidase test": "Oxidase", | |
| "Indole test": "Indole", | |
| "Urease test": "Urease", | |
| "Citrate test": "Citrate", | |
| "Glucose fermentation": "Glucose Fermentation", | |
| "Lactose fermentation": "Lactose Fermentation", | |
| "Sucrose fermentation": "Sucrose Fermentation", | |
| "Maltose fermentation": "Maltose Fermentation", | |
| "Mannitol fermentation": "Mannitol Fermentation", | |
| "Sorbitol fermentation": "Sorbitol Fermentation", | |
| "Xylose fermentation": "Xylose Fermentation", | |
| "Rhamnose fermentation": "Rhamnose Fermentation", | |
| "Arabinose fermentation": "Arabinose Fermentation", | |
| "Raffinose fermentation": "Raffinose Fermentation", | |
| "Trehalose fermentation": "Trehalose Fermentation", | |
| "Inositol fermentation": "Inositol Fermentation", | |
| # common variants from outputs | |
| "Voges–Proskauer Test": "VP", | |
| "Voges-Proskauer Test": "VP", | |
| "Voges–Proskauer": "VP", | |
| "Voges-Proskauer": "VP", | |
| } | |
| # ------------------------------------------------------------ | |
| # Normalisation helpers | |
| # ------------------------------------------------------------ | |
| def _norm_str(s: Any) -> str: | |
| return str(s).strip() if s is not None else "" | |
| def _normalise_pnv_value(raw: Any) -> str: | |
| s = _norm_str(raw).lower() | |
| if not s: | |
| return "Unknown" | |
| # positive | |
| if any(x in s for x in {"positive", "pos", "+", "yes", "present", "detected", "reactive"}): | |
| return "Positive" | |
| # negative | |
| if any(x in s for x in {"negative", "neg", "-", "no", "none", "absent", "not detected", "no growth"}): | |
| return "Negative" | |
| # variable | |
| if any(x in s for x in {"variable", "mixed", "inconsistent"}): | |
| return "Variable" | |
| return "Unknown" | |
| def _normalise_gram(raw: Any) -> str: | |
| s = _norm_str(raw).lower() | |
| if "positive" in s: | |
| return "Positive" | |
| if "negative" in s: | |
| return "Negative" | |
| if "variable" in s: | |
| return "Variable" | |
| return "Unknown" | |
| def _merge_ornithine_variants(fields: Dict[str, str]) -> Dict[str, str]: | |
| v = fields.get("Ornithine Decarboxylase") or fields.get("Ornitihine Decarboxylase") | |
| if v and v != "Unknown": | |
| fields["Ornithine Decarboxylase"] = v | |
| fields["Ornitihine Decarboxylase"] = v | |
| return fields | |
| # ------------------------------------------------------------ | |
| # Sugar logic | |
| # ------------------------------------------------------------ | |
| _NON_FERMENTER_PATTERNS = re.compile( | |
| r"\b(" | |
| r"non[-\s]?fermenter|" | |
| r"non[-\s]?fermentative|" | |
| r"asaccharolytic|" | |
| r"does not ferment (sugars|carbohydrates)|" | |
| r"no carbohydrate fermentation" | |
| r")\b", | |
| re.IGNORECASE, | |
| ) | |
| def _apply_global_sugar_logic(fields: Dict[str, str], original_text: str) -> Dict[str, str]: | |
| if not _NON_FERMENTER_PATTERNS.search(original_text): | |
| return fields | |
| for sugar in SUGAR_FIELDS: | |
| if fields.get(sugar) in {"Positive", "Variable"}: | |
| continue | |
| fields[sugar] = "Negative" | |
| return fields | |
| # ------------------------------------------------------------ | |
| # Gold examples (kept for backwards compat; now optional) | |
| # ------------------------------------------------------------ | |
| def _get_project_root() -> str: | |
| return os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| def _load_gold_examples() -> List[Dict[str, Any]]: | |
| global _GOLD_EXAMPLES | |
| if _GOLD_EXAMPLES is not None: | |
| return _GOLD_EXAMPLES | |
| path = os.path.join(_get_project_root(), "data", "llm_gold_examples.json") | |
| try: | |
| with open(path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| _GOLD_EXAMPLES = data if isinstance(data, list) else [] | |
| except Exception: | |
| _GOLD_EXAMPLES = [] | |
| return _GOLD_EXAMPLES | |
| # ------------------------------------------------------------ | |
| # Prompt (supports both JSON + KV outputs; fine-tune usually KV) | |
| # ------------------------------------------------------------ | |
| PROMPT_HEADER = """ | |
| You are a microbiology phenotype parser. | |
| Task: | |
| - Extract ONLY explicitly stated results from the input text. | |
| - Do NOT invent results. | |
| - If not stated, omit the field or use "Unknown". | |
| Output format: | |
| - Prefer "Field: Value" lines, one per line. | |
| - You may also output JSON if instructed. | |
| Use the exact schema keys where possible. | |
| """ | |
| PROMPT_FOOTER = """ | |
| Input: | |
| \"\"\"<<PHENOTYPE>>\"\"\" | |
| Output: | |
| """ | |
| def _build_prompt(text: str) -> str: | |
| # Few-shot disabled by default; but we keep the capability for testing. | |
| blocks: List[str] = [PROMPT_HEADER] | |
| if MAX_FEWSHOT_EXAMPLES > 0: | |
| examples = _load_gold_examples() | |
| n = min(MAX_FEWSHOT_EXAMPLES, len(examples)) | |
| sampled = random.sample(examples, n) if n > 0 else [] | |
| for ex in sampled: | |
| inp = _norm_str(ex.get("input", "")) | |
| exp = ex.get("expected", {}) | |
| if not isinstance(exp, dict): | |
| exp = {} | |
| # Show KV style to match your fine-tune | |
| kv_lines = "\n".join([f"{k}: {v}" for k, v in exp.items()]) | |
| blocks.append(f'Example Input:\n"""{inp}"""\nExample Output:\n{kv_lines}\n') | |
| blocks.append(PROMPT_FOOTER.replace("<<PHENOTYPE>>", text)) | |
| return "\n".join(blocks) | |
| # ------------------------------------------------------------ | |
| # Model loader | |
| # ------------------------------------------------------------ | |
| def _load_model() -> None: | |
| global _model, _tokenizer | |
| if _model is not None and _tokenizer is not None: | |
| return | |
| _tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL) | |
| _model = AutoModelForSeq2SeqLM.from_pretrained(DEFAULT_MODEL).to(DEVICE) | |
| _model.eval() | |
| # ------------------------------------------------------------ | |
| # Output parsing helpers (JSON + KV) | |
| # ------------------------------------------------------------ | |
| _JSON_OBJECT_RE = re.compile(r"\{[\s\S]*?\}") | |
| def _extract_first_json_object(text: str) -> Dict[str, Any]: | |
| m = _JSON_OBJECT_RE.search(text) | |
| if not m: | |
| return {} | |
| try: | |
| return json.loads(m.group(0)) | |
| except Exception: | |
| return {} | |
| # Match "Key: Value" (including keys with symbols like >=6%) | |
| _KV_LINE_RE = re.compile(r"^\s*([^:\n]{2,120})\s*:\s*(.*?)\s*$") | |
| def _extract_kv_pairs(text: str) -> Dict[str, Any]: | |
| """ | |
| Parse outputs like: | |
| Gram Stain: Positive | |
| Shape: Cocci | |
| ... | |
| """ | |
| out: Dict[str, Any] = {} | |
| for line in (text or "").splitlines(): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| m = _KV_LINE_RE.match(line) | |
| if not m: | |
| continue | |
| k = _norm_str(m.group(1)) | |
| v = _norm_str(m.group(2)) | |
| if not k: | |
| continue | |
| out[k] = v | |
| return out | |
| def _apply_field_aliases(fields_raw: Dict[str, Any]) -> Dict[str, Any]: | |
| out: Dict[str, Any] = {} | |
| for k, v in fields_raw.items(): | |
| key = _norm_str(k) | |
| if not key: | |
| continue | |
| mapped = FIELD_ALIASES.get(key, key) | |
| out[mapped] = v | |
| return out | |
| def _clean_and_normalise(fields_raw: Dict[str, Any], original_text: str) -> Dict[str, str]: | |
| """ | |
| Keep only allowed fields and normalise values into your contract. | |
| """ | |
| cleaned: Dict[str, str] = {} | |
| # Only accept keys that match schema (or aliases already applied) | |
| for field in ALL_FIELDS: | |
| if field not in fields_raw: | |
| continue | |
| raw_val = fields_raw[field] | |
| if field == "Gram Stain": | |
| cleaned[field] = _normalise_gram(raw_val) | |
| elif field in PNV_FIELDS: | |
| cleaned[field] = _normalise_pnv_value(raw_val) | |
| else: | |
| cleaned[field] = _norm_str(raw_val) or "Unknown" | |
| cleaned = _merge_ornithine_variants(cleaned) | |
| cleaned = _apply_global_sugar_logic(cleaned, original_text) | |
| return cleaned | |
| def _merge_guard_fill_only_missing( | |
| llm_fields: Dict[str, str], | |
| existing_fields: Optional[Dict[str, Any]], | |
| ) -> Dict[str, str]: | |
| """ | |
| Merge guard: | |
| - If an existing field is present and not Unknown -> do NOT overwrite. | |
| - If existing is missing/Unknown -> allow llm value (if not Unknown). | |
| """ | |
| if not existing_fields or not isinstance(existing_fields, dict): | |
| return llm_fields | |
| out = dict(existing_fields) # start with existing | |
| for k, v in llm_fields.items(): | |
| if k not in ALL_FIELDS: | |
| continue | |
| existing_val = _norm_str(out.get(k, "")) | |
| existing_norm = _normalise_pnv_value(existing_val) if k in PNV_FIELDS else existing_val | |
| # Treat empty/Unknown as fillable | |
| fillable = (not existing_val) or (existing_val == "Unknown") or (existing_norm == "Unknown") | |
| if not fillable: | |
| continue | |
| # Only fill if LLM has something meaningful | |
| if _norm_str(v) and v != "Unknown": | |
| out[k] = v | |
| # Ensure we return only schema keys and strings | |
| final: Dict[str, str] = {} | |
| for k, v in out.items(): | |
| if k in ALL_FIELDS: | |
| final[k] = _norm_str(v) or "Unknown" | |
| return final | |
| # ------------------------------------------------------------ | |
| # PUBLIC API | |
| # ------------------------------------------------------------ | |
| def parse_llm(text: str, existing_fields: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: | |
| """ | |
| Parse phenotype text using local seq2seq model. | |
| Parameters | |
| ---------- | |
| text : str | |
| phenotype description | |
| existing_fields : dict | None | |
| Optional pre-parsed fields (e.g., from rules/ext). | |
| If provided, LLM will ONLY fill missing/Unknown fields. | |
| Returns | |
| ------- | |
| dict: | |
| { | |
| "parsed_fields": { ... }, | |
| "source": "llm_parser", | |
| "raw": <original text>, | |
| "decoded": <model output> (only when DEBUG on) | |
| } | |
| """ | |
| original = text or "" | |
| if not original.strip(): | |
| return { | |
| "parsed_fields": (existing_fields or {}) if isinstance(existing_fields, dict) else {}, | |
| "source": "llm_parser", | |
| "raw": original, | |
| } | |
| _load_model() | |
| assert _tokenizer is not None and _model is not None | |
| prompt = _build_prompt(original) | |
| inputs = _tokenizer(prompt, return_tensors="pt", truncation=True).to(DEVICE) | |
| with torch.no_grad(): | |
| output = _model.generate( | |
| **inputs, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| do_sample=False, | |
| temperature=0.0, | |
| ) | |
| decoded = _tokenizer.decode(output[0], skip_special_tokens=True) | |
| if DEBUG_LLM: | |
| print("=== LLM PROMPT (truncated) ===") | |
| print(prompt[:1500] + ("..." if len(prompt) > 1500 else "")) | |
| print("=== LLM RAW OUTPUT ===") | |
| print(decoded) | |
| print("======================") | |
| # 1) Try JSON extraction (legacy) | |
| parsed_obj = _extract_first_json_object(decoded) | |
| fields_raw = {} | |
| if isinstance(parsed_obj, dict) and parsed_obj: | |
| if "parsed_fields" in parsed_obj and isinstance(parsed_obj.get("parsed_fields"), dict): | |
| fields_raw = dict(parsed_obj["parsed_fields"]) | |
| else: | |
| # in case model returned a flat JSON dict | |
| fields_raw = dict(parsed_obj) | |
| # 2) Fallback to KV parsing (your fine-tune style) | |
| if not fields_raw: | |
| fields_raw = _extract_kv_pairs(decoded) | |
| # 3) Alias map + normalise | |
| fields_raw = _apply_field_aliases(fields_raw) | |
| cleaned = _clean_and_normalise(fields_raw, original) | |
| # 4) Merge guard (optional) - fill only missing/Unknown | |
| if existing_fields is not None: | |
| cleaned = _merge_guard_fill_only_missing(cleaned, existing_fields) | |
| out = { | |
| "parsed_fields": cleaned, | |
| "source": "llm_parser", | |
| "raw": original, | |
| } | |
| if DEBUG_LLM: | |
| out["decoded"] = decoded | |
| return out |