Spaces:
Running
Running
| # engine/features.py | |
| import json | |
| import numpy as np | |
| import re | |
| from typing import Dict, List, Any | |
| # ------------------------------------------------------------------ | |
| # Load schema once | |
| # ------------------------------------------------------------------ | |
| _FEATURE_SCHEMA_PATH = "data/feature_schema.json" | |
| with open(_FEATURE_SCHEMA_PATH, "r", encoding="utf-8") as f: | |
| SCHEMA = json.load(f) | |
| FEATURES = SCHEMA["features"] | |
| # ------------------------------------------------------------------ | |
| # Helper mappings | |
| # ------------------------------------------------------------------ | |
| PNV_MAP = { | |
| "positive": 1.0, | |
| "negative": -1.0, | |
| "variable": 0.5, | |
| "unknown": 0.0, | |
| None: 0.0, | |
| } | |
| SHAPE_MAP = { | |
| "cocci": 1.0, | |
| "rods": 2.0, | |
| "short rods": 2.5, | |
| "spiral": 3.0, | |
| "yeast": 4.0, | |
| "variable": 0.5, | |
| "unknown": 0.0, | |
| } | |
| OXYGEN_MAP = { | |
| "aerobic": 1.0, | |
| "anaerobic": 2.0, | |
| "facultative anaerobe": 3.0, | |
| "microaerophilic": 4.0, | |
| "capnophilic": 5.0, | |
| "unknown": 0.0, | |
| } | |
| # CATEGORY LABELSETS — deterministic, fixed, ML-friendly | |
| CATEGORY_MAPS = { | |
| "Motility Type": [ | |
| "none", "tumbling", "peritrichous", "polar", "monotrichous", | |
| "lophotrichous", "amphitrichous", "axial", "gliding" | |
| ], | |
| "Pigment": [ | |
| "none", "pyocyanin", "pyoverdine", "green", "yellow", "pink", | |
| "red", "orange", "brown", "black", "violet", "cream" | |
| ], | |
| "Odor": [ | |
| "none", "grape", "fruity", "earthy", "musty", "putrid", | |
| "buttery", "yeasty", "medicinal", "fishy", "almond", "burnt", "mousy" | |
| ], | |
| "Colony Pattern": [ | |
| "none", "mucoid", "rough", "smooth", "filamentous", | |
| "spreading", "chalky", "corroding", "swarming", | |
| "sticky", "ground-glass", "molar-tooth" | |
| ], | |
| "TSI Pattern": [ | |
| "unknown", "a/a", "k/a", "k/k", | |
| "k/a+h2s", "a/a+gas" | |
| ], | |
| } | |
| # Make fast lookup: label → integer code | |
| CATEGORY_ENCODERS = { | |
| field: {lab: idx for idx, lab in enumerate(labels)} | |
| for field, labels in CATEGORY_MAPS.items() | |
| } | |
| # Temperature flags: a direct binary interpretation | |
| TEMP_FLAGS = {"4c", "25c", "30c", "37c", "42c"} | |
| # ------------------------------------------------------------------ | |
| # Normalisation helpers | |
| # ------------------------------------------------------------------ | |
| def _norm(x: Any) -> str: | |
| if not x: | |
| return "unknown" | |
| return str(x).strip().lower() | |
| def _map_pnv(x: Any) -> float: | |
| return PNV_MAP.get(_norm(x), 0.0) | |
| def _map_shape(x: Any) -> float: | |
| return SHAPE_MAP.get(_norm(x), 0.0) | |
| def _map_oxygen(x: Any) -> float: | |
| return OXYGEN_MAP.get(_norm(x), 0.0) | |
| def _extract_temperature_flags(value: str): | |
| """ | |
| Convert things like "25//37" → {"25c":1, "37c":1} | |
| """ | |
| flags = {k: 0.0 for k in TEMP_FLAGS} | |
| if not value: | |
| return flags | |
| s = value.lower() | |
| nums = re.findall(r"\b(\d{1,2})\s*c?\b", s) | |
| for n in nums: | |
| key = f"{n}c" | |
| if key in flags: | |
| flags[key] = 1.0 | |
| return flags | |
| def _growth_minmax(v: Any): | |
| """Convert '30//37' → (30,37).""" | |
| if not v: | |
| return (0.0, 0.0) | |
| if not isinstance(v, str): | |
| v = str(v) | |
| m = re.match(r"^\s*(\d+)\s*//\s*(\d+)\s*$", v) | |
| if not m: | |
| return (0.0, 0.0) | |
| return (float(m.group(1)), float(m.group(2))) | |
| def _media_flag(media_field: Any, medium: str) -> float: | |
| if not media_field: | |
| return 0.0 | |
| mf = str(media_field).lower() | |
| return 1.0 if medium.lower() in mf else 0.0 | |
| # ------------------------------------------------------------------ | |
| # CATEGORY mapping helper | |
| # ------------------------------------------------------------------ | |
| def _map_category(field: str, value: Any) -> float: | |
| """ | |
| Deterministic integer encoding. | |
| Unknown → 0 (first element) | |
| Multi-values like "yellow; orange" → choose first matching token. | |
| """ | |
| labels = CATEGORY_MAPS.get(field) | |
| if not labels: | |
| return 0.0 # should not happen | |
| enc = CATEGORY_ENCODERS[field] | |
| s = _norm(value) | |
| # Multi-list: pick first match | |
| parts = [p.strip() for p in re.split(r"[;/,]", s) if p.strip()] | |
| for p in parts: | |
| if p in enc: | |
| return float(enc[p]) | |
| # No match → return index for "none" or "unknown" | |
| fallback = "none" if "none" in enc else "unknown" | |
| return float(enc.get(fallback, 0)) | |
| # ------------------------------------------------------------------ | |
| # Main feature extractor | |
| # ------------------------------------------------------------------ | |
| def extract_feature_vector(fused_fields: Dict[str, Any]) -> np.ndarray: | |
| """ | |
| Convert fused fields into a fixed-length ML-ready numeric vector. | |
| ORDER must match feature_schema.json exactly. | |
| """ | |
| vec: List[float] = [] | |
| growth_temp = fused_fields.get("Growth Temperature") | |
| temp_flags = _extract_temperature_flags(growth_temp) | |
| for f in FEATURES: | |
| name = f["name"] | |
| kind = f["kind"] | |
| value = fused_fields.get(name, "Unknown") | |
| norm = _norm(value) | |
| # --------------------------- | |
| # pnv | |
| # --------------------------- | |
| if kind == "pnv": | |
| vec.append(_map_pnv(norm)) | |
| # --------------------------- | |
| # shape | |
| # --------------------------- | |
| elif kind == "shape": | |
| vec.append(_map_shape(norm)) | |
| # --------------------------- | |
| # oxygen requirement | |
| # --------------------------- | |
| elif kind == "oxygen": | |
| vec.append(_map_oxygen(norm)) | |
| # --------------------------- | |
| # category → integer encoding | |
| # --------------------------- | |
| elif kind == "category": | |
| vec.append(_map_category(name, value)) | |
| # --------------------------- | |
| # binary flag (temperatures) | |
| # --------------------------- | |
| elif kind == "binary": | |
| key = name.lower().replace("temperature_", "") | |
| vec.append(temp_flags.get(key, 0.0)) | |
| # --------------------------- | |
| # media flag | |
| # --------------------------- | |
| elif kind == "media_flag": | |
| medium = name.replace("Growth", "").strip() | |
| media_field = fused_fields.get("Media Grown On") | |
| vec.append(_media_flag(media_field, medium)) | |
| # --------------------------- | |
| # numeric_from_growth_temp (legacy support) | |
| # --------------------------- | |
| elif kind == "numeric_from_growth_temp": | |
| lo, hi = _growth_minmax(growth_temp) | |
| if "min" in name.lower(): | |
| vec.append(lo) | |
| elif "max" in name.lower(): | |
| vec.append(hi) | |
| else: | |
| vec.append(0.0) | |
| # --------------------------- | |
| # unknown kind | |
| # --------------------------- | |
| else: | |
| vec.append(0.0) | |
| return np.array(vec, dtype=float) |