# engine/bacteria_identifier.py # ------------------------------------------------------------ # Core BactAI-D identification engine. # - Scores genera from Excel DB (core phenotype fields) # - Integrates optional extended-test reasoning # - Provides blended confidence and narrative reasoning # ------------------------------------------------------------ from __future__ import annotations import re from dataclasses import dataclass, field from typing import Dict, List, Any, Optional, Tuple import pandas as pd try: from engine.extended_reasoner import score_genera_from_extended HAS_EXTENDED_REASONER = True except Exception: HAS_EXTENDED_REASONER = False # ------------------------------------------------------------ # Helper # ------------------------------------------------------------ def join_with_and(items: List[str]) -> str: if not items: return "" if len(items) == 1: return items[0] return ", ".join(items[:-1]) + " and " + items[-1] # ------------------------------------------------------------ # Identification Result # ------------------------------------------------------------ @dataclass class IdentificationResult: genus: str total_score: int matched_fields: List[str] = field(default_factory=list) mismatched_fields: List[str] = field(default_factory=list) reasoning_factors: Dict[str, Any] = field(default_factory=dict) total_fields_evaluated: int = 0 total_fields_possible: int = 0 extra_notes: str = "" extended_score: float = 0.0 # 0.0–1.0 extended_explanation: str = "" # ---------- Confidence metrics ---------- def confidence_percent(self) -> int: """Confidence based only on tests the user entered.""" if self.total_fields_evaluated <= 0: return 0 pct = (self.total_score / max(1, self.total_fields_evaluated)) * 100 return max(0, min(100, int(round(pct)))) def true_confidence(self) -> int: """Confidence based on all possible fields in the DB.""" if self.total_fields_possible <= 0: return 0 pct = (self.total_score / max(1, self.total_fields_possible)) * 100 return max(0, min(100, int(round(pct)))) def blended_confidence_percent(self) -> int: """ Blend core confidence with extended_score (0–1). If no extended signal, return core confidence. Simple blend: 70% core, 30% extended signal. """ core = self.confidence_percent() if self.extended_score <= 0: return core ext_pct = max(0.0, min(1.0, float(self.extended_score))) * 100.0 blended = 0.7 * core + 0.3 * ext_pct return max(0, min(100, int(round(blended)))) # ---------- Reasoning text ---------- def reasoning_paragraph(self, ranked_results: Optional[List["IdentificationResult"]] = None) -> str: """Generate a narrative explanation from core matches.""" if not self.matched_fields and not self.reasoning_factors: return "No significant biochemical or morphological matches were found." intro_options = [ "Based on the observed biochemical and morphological traits,", "According to the provided test results,", "From the available laboratory findings,", "Considering the entered reactions and colony characteristics,", ] import random intro = random.choice(intro_options) highlights = [] gram = self.reasoning_factors.get("Gram Stain") if gram: highlights.append(f"it is **Gram {str(gram).lower()}**") shape = self.reasoning_factors.get("Shape") if shape: highlights.append(f"with a **{str(shape).lower()}** morphology") catalase = self.reasoning_factors.get("Catalase") if catalase: highlights.append(f"and **catalase {str(catalase).lower()}** activity") oxidase = self.reasoning_factors.get("Oxidase") if oxidase: highlights.append(f"and **oxidase {str(oxidase).lower()}** reaction") oxy = self.reasoning_factors.get("Oxygen Requirement") if oxy: highlights.append(f"which prefers **{str(oxy).lower()}** conditions") if len(highlights) > 1: summary = ", ".join(highlights[:-1]) + " and " + highlights[-1] else: summary = "".join(highlights) # Confidence text (core) core_conf = self.confidence_percent() if core_conf >= 70: confidence_text = "The confidence in this identification is high." elif core_conf >= 40: confidence_text = "The confidence in this identification is moderate." else: confidence_text = "The confidence in this identification is low." # Comparison vs other top results comparison = "" if ranked_results and len(ranked_results) > 1: close_others = ranked_results[1:3] other_names = [r.genus for r in close_others] if other_names: if self.total_score >= close_others[0].total_score: comparison = ( f" It is **more likely** than {join_with_and(other_names)} " f"based on stronger alignment in {join_with_and(self.matched_fields[:3])}." ) else: comparison = ( f" It is **less likely** than {join_with_and(other_names)} " f"due to differences in {join_with_and(self.mismatched_fields[:3])}." ) return f"{intro} {summary}, the isolate most closely resembles **{self.genus}**. {confidence_text}{comparison}" # ------------------------------------------------------------ # Bacteria Identifier # ------------------------------------------------------------ class BacteriaIdentifier: """ Main engine to match bacterial genus based on biochemical & morphological data. """ def __init__(self, db: pd.DataFrame): self.db: pd.DataFrame = db.fillna("") self.db_columns = list(self.db.columns) # ---------- Field comparison ---------- def compare_field(self, db_val: Any, user_val: Any, field_name: str) -> int: """ Compare one test field between database and user input. Returns: +1 match -1 mismatch 0 unknown / ignored Return -999 to indicate a hard exclusion (stop comparing this genus). """ if user_val is None: return 0 user_str = str(user_val).strip() if user_str == "" or user_str.lower() == "unknown": return 0 # ignore unknown/empty db_str = str(db_val).strip() db_l = db_str.lower() user_l = user_str.lower() hard_exclusions = {"Gram Stain", "Shape", "Spore Formation"} # Split multi-value fields on ; or / or , db_options = [p.strip().lower() for p in re.split(r"[;/,]", db_str) if p.strip()] user_options = [p.strip().lower() for p in re.split(r"[;/,]", user_str) if p.strip()] # "variable" logic: if either is variable, don't penalize if "variable" in db_options or "variable" in user_options: return 0 # Growth Temperature as range "low//high", user enters single numeric or similar if field_name == "Growth Temperature": try: if "//" in db_str: low_s, high_s = db_str.split("//", 1) low = float(low_s) high = float(high_s) # user may have given "37//37" or "37" etc. if "//" in user_str: ut = float(user_str.split("//", 1)[0]) else: ut = float(user_str) if low <= ut <= high: return 1 else: return -1 except Exception: return 0 # Flexible overlap match match_found = False for u in user_options: for d in db_options: if not d or not u: continue if u == d: match_found = True break if u in d or d in u: match_found = True break if match_found: break if match_found: return 1 if field_name in hard_exclusions: return -999 # hard mismatch return -1 # ---------- Next-test suggestions ---------- def suggest_next_tests( self, top_results: List[IdentificationResult], user_input: Dict[str, Any], max_tests: int = 3, ) -> List[str]: """ Suggest tests that best differentiate top matches and haven't already been entered or marked 'Unknown' by the user. """ if not top_results: return [] # Only consider first 3–5 genera top_names = {r.genus for r in top_results[:5]} varying_fields: List[str] = [] for field in self.db_columns: if field == "Genus": continue # Skip fields user already filled with a known value u_val = user_input.get(field, "") if isinstance(u_val, str) and u_val.lower() not in {"", "unknown"}: continue # Check if this field differs meaningfully between top genera values_for_field = set() for _, row in self.db.iterrows(): g = row.get("Genus", "") if g in top_names: v = str(row.get(field, "")).strip().lower() if v: values_for_field.add(v) if len(values_for_field) > 1: varying_fields.append(field) # simple deterministic: take first few return varying_fields[:max_tests] # ---------- Main identification routine ---------- def identify(self, user_input: Dict[str, Any]) -> List[IdentificationResult]: """ Compare user input to database and rank possible genera. Integrates extended signals when available. """ results: List[IdentificationResult] = [] total_fields_possible = len([c for c in self.db_columns if c != "Genus"]) # Pre-compute extended scores if extended_reasoner is available extended_scores: Dict[str, float] = {} extended_explanation: str = "" if HAS_EXTENDED_REASONER: try: ranked_ext, explanation = score_genera_from_extended(user_input) extended_explanation = explanation or "" for genus, score in ranked_ext: extended_scores[str(genus)] = float(score) except Exception: extended_scores = {} extended_explanation = "" for _, row in self.db.iterrows(): genus = str(row.get("Genus", "")).strip() if not genus: continue total_score = 0 matched_fields: List[str] = [] mismatched_fields: List[str] = [] reasoning_factors: Dict[str, Any] = {} total_fields_evaluated = 0 hard_excluded = False for field in self.db_columns: if field == "Genus": continue db_val = row.get(field, "") user_val = user_input.get(field, "") score = self.compare_field(db_val, user_val, field) if user_val is not None and str(user_val).strip() != "" and str(user_val).strip().lower() != "unknown": total_fields_evaluated += 1 if score == -999: hard_excluded = True total_score = -999 break elif score == 1: total_score += 1 matched_fields.append(field) reasoning_factors[field] = user_val elif score == -1: total_score -= 1 mismatched_fields.append(field) if hard_excluded: continue # skip this genus entirely extra_notes = str(row.get("Extra Notes", "")).strip() if "Extra Notes" in row else "" r = IdentificationResult( genus=genus, total_score=total_score, matched_fields=matched_fields, mismatched_fields=mismatched_fields, reasoning_factors=reasoning_factors, total_fields_evaluated=total_fields_evaluated, total_fields_possible=total_fields_possible, extra_notes=extra_notes, ) # Attach extended score if available if genus in extended_scores: r.extended_score = extended_scores[genus] r.extended_explanation = extended_explanation results.append(r) # Sort by core score descending results.sort(key=lambda r: r.total_score, reverse=True) # Suggest next tests for top few if results: next_tests = self.suggest_next_tests(results[:5], user_input) next_tests_str = ", ".join(next_tests) if next_tests else "" for r in results[:5]: r.reasoning_factors["next_tests"] = next_tests_str # Return top 10 return results[:10]