BactKing / engine /bacteria_identifier.py
EphAsad's picture
Update engine/bacteria_identifier.py
3f68bef verified
# engine/bacteria_identifier.py
# ------------------------------------------------------------
# Core BactAI-D identification engine.
# - Scores genera from Excel DB (core phenotype fields)
# - Integrates optional extended-test reasoning
# - Provides blended confidence and narrative reasoning
# ------------------------------------------------------------
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Dict, List, Any, Optional, Tuple
import pandas as pd
try:
from engine.extended_reasoner import score_genera_from_extended
HAS_EXTENDED_REASONER = True
except Exception:
HAS_EXTENDED_REASONER = False
# ------------------------------------------------------------
# Helper
# ------------------------------------------------------------
def join_with_and(items: List[str]) -> str:
if not items:
return ""
if len(items) == 1:
return items[0]
return ", ".join(items[:-1]) + " and " + items[-1]
# ------------------------------------------------------------
# Identification Result
# ------------------------------------------------------------
@dataclass
class IdentificationResult:
genus: str
total_score: int
matched_fields: List[str] = field(default_factory=list)
mismatched_fields: List[str] = field(default_factory=list)
reasoning_factors: Dict[str, Any] = field(default_factory=dict)
total_fields_evaluated: int = 0
total_fields_possible: int = 0
extra_notes: str = ""
extended_score: float = 0.0 # 0.0–1.0
extended_explanation: str = ""
# ---------- Confidence metrics ----------
def confidence_percent(self) -> int:
"""Confidence based only on tests the user entered."""
if self.total_fields_evaluated <= 0:
return 0
pct = (self.total_score / max(1, self.total_fields_evaluated)) * 100
return max(0, min(100, int(round(pct))))
def true_confidence(self) -> int:
"""Confidence based on all possible fields in the DB."""
if self.total_fields_possible <= 0:
return 0
pct = (self.total_score / max(1, self.total_fields_possible)) * 100
return max(0, min(100, int(round(pct))))
def blended_confidence_percent(self) -> int:
"""
Blend core confidence with extended_score (0–1).
If no extended signal, return core confidence.
Simple blend: 70% core, 30% extended signal.
"""
core = self.confidence_percent()
if self.extended_score <= 0:
return core
ext_pct = max(0.0, min(1.0, float(self.extended_score))) * 100.0
blended = 0.7 * core + 0.3 * ext_pct
return max(0, min(100, int(round(blended))))
# ---------- Reasoning text ----------
def reasoning_paragraph(self, ranked_results: Optional[List["IdentificationResult"]] = None) -> str:
"""Generate a narrative explanation from core matches."""
if not self.matched_fields and not self.reasoning_factors:
return "No significant biochemical or morphological matches were found."
intro_options = [
"Based on the observed biochemical and morphological traits,",
"According to the provided test results,",
"From the available laboratory findings,",
"Considering the entered reactions and colony characteristics,",
]
import random
intro = random.choice(intro_options)
highlights = []
gram = self.reasoning_factors.get("Gram Stain")
if gram:
highlights.append(f"it is **Gram {str(gram).lower()}**")
shape = self.reasoning_factors.get("Shape")
if shape:
highlights.append(f"with a **{str(shape).lower()}** morphology")
catalase = self.reasoning_factors.get("Catalase")
if catalase:
highlights.append(f"and **catalase {str(catalase).lower()}** activity")
oxidase = self.reasoning_factors.get("Oxidase")
if oxidase:
highlights.append(f"and **oxidase {str(oxidase).lower()}** reaction")
oxy = self.reasoning_factors.get("Oxygen Requirement")
if oxy:
highlights.append(f"which prefers **{str(oxy).lower()}** conditions")
if len(highlights) > 1:
summary = ", ".join(highlights[:-1]) + " and " + highlights[-1]
else:
summary = "".join(highlights)
# Confidence text (core)
core_conf = self.confidence_percent()
if core_conf >= 70:
confidence_text = "The confidence in this identification is high."
elif core_conf >= 40:
confidence_text = "The confidence in this identification is moderate."
else:
confidence_text = "The confidence in this identification is low."
# Comparison vs other top results
comparison = ""
if ranked_results and len(ranked_results) > 1:
close_others = ranked_results[1:3]
other_names = [r.genus for r in close_others]
if other_names:
if self.total_score >= close_others[0].total_score:
comparison = (
f" It is **more likely** than {join_with_and(other_names)} "
f"based on stronger alignment in {join_with_and(self.matched_fields[:3])}."
)
else:
comparison = (
f" It is **less likely** than {join_with_and(other_names)} "
f"due to differences in {join_with_and(self.mismatched_fields[:3])}."
)
return f"{intro} {summary}, the isolate most closely resembles **{self.genus}**. {confidence_text}{comparison}"
# ------------------------------------------------------------
# Bacteria Identifier
# ------------------------------------------------------------
class BacteriaIdentifier:
"""
Main engine to match bacterial genus based on biochemical & morphological data.
"""
def __init__(self, db: pd.DataFrame):
self.db: pd.DataFrame = db.fillna("")
self.db_columns = list(self.db.columns)
# ---------- Field comparison ----------
def compare_field(self, db_val: Any, user_val: Any, field_name: str) -> int:
"""
Compare one test field between database and user input.
Returns:
+1 match
-1 mismatch
0 unknown / ignored
Return -999 to indicate a hard exclusion (stop comparing this genus).
"""
if user_val is None:
return 0
user_str = str(user_val).strip()
if user_str == "" or user_str.lower() == "unknown":
return 0 # ignore unknown/empty
db_str = str(db_val).strip()
db_l = db_str.lower()
user_l = user_str.lower()
hard_exclusions = {"Gram Stain", "Shape", "Spore Formation"}
# Split multi-value fields on ; or / or ,
db_options = [p.strip().lower() for p in re.split(r"[;/,]", db_str) if p.strip()]
user_options = [p.strip().lower() for p in re.split(r"[;/,]", user_str) if p.strip()]
# "variable" logic: if either is variable, don't penalize
if "variable" in db_options or "variable" in user_options:
return 0
# Growth Temperature as range "low//high", user enters single numeric or similar
if field_name == "Growth Temperature":
try:
if "//" in db_str:
low_s, high_s = db_str.split("//", 1)
low = float(low_s)
high = float(high_s)
# user may have given "37//37" or "37" etc.
if "//" in user_str:
ut = float(user_str.split("//", 1)[0])
else:
ut = float(user_str)
if low <= ut <= high:
return 1
else:
return -1
except Exception:
return 0
# Flexible overlap match
match_found = False
for u in user_options:
for d in db_options:
if not d or not u:
continue
if u == d:
match_found = True
break
if u in d or d in u:
match_found = True
break
if match_found:
break
if match_found:
return 1
if field_name in hard_exclusions:
return -999 # hard mismatch
return -1
# ---------- Next-test suggestions ----------
def suggest_next_tests(
self,
top_results: List[IdentificationResult],
user_input: Dict[str, Any],
max_tests: int = 3,
) -> List[str]:
"""
Suggest tests that best differentiate top matches and haven't
already been entered or marked 'Unknown' by the user.
"""
if not top_results:
return []
# Only consider first 3–5 genera
top_names = {r.genus for r in top_results[:5]}
varying_fields: List[str] = []
for field in self.db_columns:
if field == "Genus":
continue
# Skip fields user already filled with a known value
u_val = user_input.get(field, "")
if isinstance(u_val, str) and u_val.lower() not in {"", "unknown"}:
continue
# Check if this field differs meaningfully between top genera
values_for_field = set()
for _, row in self.db.iterrows():
g = row.get("Genus", "")
if g in top_names:
v = str(row.get(field, "")).strip().lower()
if v:
values_for_field.add(v)
if len(values_for_field) > 1:
varying_fields.append(field)
# simple deterministic: take first few
return varying_fields[:max_tests]
# ---------- Main identification routine ----------
def identify(self, user_input: Dict[str, Any]) -> List[IdentificationResult]:
"""
Compare user input to database and rank possible genera.
Integrates extended signals when available.
"""
results: List[IdentificationResult] = []
total_fields_possible = len([c for c in self.db_columns if c != "Genus"])
# Pre-compute extended scores if extended_reasoner is available
extended_scores: Dict[str, float] = {}
extended_explanation: str = ""
if HAS_EXTENDED_REASONER:
try:
ranked_ext, explanation = score_genera_from_extended(user_input)
extended_explanation = explanation or ""
for genus, score in ranked_ext:
extended_scores[str(genus)] = float(score)
except Exception:
extended_scores = {}
extended_explanation = ""
for _, row in self.db.iterrows():
genus = str(row.get("Genus", "")).strip()
if not genus:
continue
total_score = 0
matched_fields: List[str] = []
mismatched_fields: List[str] = []
reasoning_factors: Dict[str, Any] = {}
total_fields_evaluated = 0
hard_excluded = False
for field in self.db_columns:
if field == "Genus":
continue
db_val = row.get(field, "")
user_val = user_input.get(field, "")
score = self.compare_field(db_val, user_val, field)
if user_val is not None and str(user_val).strip() != "" and str(user_val).strip().lower() != "unknown":
total_fields_evaluated += 1
if score == -999:
hard_excluded = True
total_score = -999
break
elif score == 1:
total_score += 1
matched_fields.append(field)
reasoning_factors[field] = user_val
elif score == -1:
total_score -= 1
mismatched_fields.append(field)
if hard_excluded:
continue # skip this genus entirely
extra_notes = str(row.get("Extra Notes", "")).strip() if "Extra Notes" in row else ""
r = IdentificationResult(
genus=genus,
total_score=total_score,
matched_fields=matched_fields,
mismatched_fields=mismatched_fields,
reasoning_factors=reasoning_factors,
total_fields_evaluated=total_fields_evaluated,
total_fields_possible=total_fields_possible,
extra_notes=extra_notes,
)
# Attach extended score if available
if genus in extended_scores:
r.extended_score = extended_scores[genus]
r.extended_explanation = extended_explanation
results.append(r)
# Sort by core score descending
results.sort(key=lambda r: r.total_score, reverse=True)
# Suggest next tests for top few
if results:
next_tests = self.suggest_next_tests(results[:5], user_input)
next_tests_str = ", ".join(next_tests) if next_tests else ""
for r in results[:5]:
r.reasoning_factors["next_tests"] = next_tests_str
# Return top 10
return results[:10]