EphAsad commited on
Commit
7927663
Β·
verified Β·
1 Parent(s): d71e54e

Update engine/features.py

Browse files
Files changed (1) hide show
  1. engine/features.py +38 -22
engine/features.py CHANGED
@@ -4,7 +4,6 @@ import numpy as np
4
  import re
5
  from typing import Dict, List, Any
6
 
7
-
8
  # ------------------------------------------------------------------
9
  # Load schema once
10
  # ------------------------------------------------------------------
@@ -16,7 +15,6 @@ with open(_FEATURE_SCHEMA_PATH, "r", encoding="utf-8") as f:
16
 
17
  FEATURES = SCHEMA["features"]
18
 
19
-
20
  # ------------------------------------------------------------------
21
  # Helper mappings
22
  # ------------------------------------------------------------------
@@ -26,7 +24,7 @@ PNV_MAP = {
26
  "negative": -1.0,
27
  "variable": 0.5,
28
  "unknown": 0.0,
29
- None: 0.0
30
  }
31
 
32
  SHAPE_MAP = {
@@ -36,7 +34,7 @@ SHAPE_MAP = {
36
  "spiral": 3.0,
37
  "yeast": 4.0,
38
  "variable": 0.5,
39
- "unknown": 0.0
40
  }
41
 
42
  OXYGEN_MAP = {
@@ -45,14 +43,14 @@ OXYGEN_MAP = {
45
  "facultative anaerobe": 3.0,
46
  "microaerophilic": 4.0,
47
  "capnophilic": 5.0,
48
- "unknown": 0.0
49
  }
50
 
51
-
52
  # ------------------------------------------------------------------
53
  # Normalisation helpers
54
  # ------------------------------------------------------------------
55
 
 
56
  def _norm(s: Any) -> str:
57
  if not s:
58
  return "unknown"
@@ -71,26 +69,29 @@ def _map_oxygen(x: Any) -> float:
71
  return OXYGEN_MAP.get(_norm(x), 0.0)
72
 
73
 
74
- def _growth_minmax(v: str):
75
  """
76
  Convert '30//37' β†’ (30, 37)
77
- If missing, return (0, 0)
78
  """
79
  if not v:
80
  return (0.0, 0.0)
 
 
81
  m = re.match(r"^\s*(\d+)\s*//\s*(\d+)\s*$", v)
82
  if not m:
83
  return (0.0, 0.0)
84
  return (float(m.group(1)), float(m.group(2)))
85
 
86
 
87
- def _media_flag(media_field: str, medium: str) -> float:
88
  """
89
- Return 1.0 if medium appears in media list, else 0.0.
 
90
  """
91
  if not media_field:
92
  return 0.0
93
- mf = media_field.lower()
94
  return 1.0 if medium.lower() in mf else 0.0
95
 
96
 
@@ -98,18 +99,26 @@ def _media_flag(media_field: str, medium: str) -> float:
98
  # Main public function
99
  # ------------------------------------------------------------------
100
 
 
101
  def extract_feature_vector(fused_fields: Dict[str, Any]) -> np.ndarray:
102
  """
103
  Convert fused tri-fusion fields into a fixed-length numeric vector.
104
  Ordered exactly according to feature_schema.json.
105
- Unknowns β†’ 0.0.
 
106
  """
107
  vec: List[float] = []
108
 
 
 
 
 
 
109
  for f in FEATURES:
110
  name = f["name"]
111
  kind = f["kind"]
112
 
 
113
  value = fused_fields.get(name, "Unknown")
114
 
115
  if kind == "pnv":
@@ -122,21 +131,28 @@ def extract_feature_vector(fused_fields: Dict[str, Any]) -> np.ndarray:
122
  vec.append(_map_oxygen(value))
123
 
124
  elif kind == "numeric_from_growth_temp":
125
- low, high = _growth_minmax(value)
126
- vec.append(low)
127
- vec.append(high)
128
- # IMPORTANT: skip the next schema feature
129
- # (schema should include two entries but model expects two values)
130
- continue
 
 
 
 
 
131
 
132
  elif kind == "media_flag":
133
- # Each media entry in schema specifies the medium name
134
- # e.g. "MacConkey Growth"
 
135
  medium = name.replace("Growth", "").strip()
136
- vec.append(_media_flag(fused_fields.get("Media Grown On"), medium))
 
137
 
138
  else:
139
- # Unknown kind β†’ default numeric 0
140
  vec.append(0.0)
141
 
142
  return np.array(vec, dtype=float)
 
4
  import re
5
  from typing import Dict, List, Any
6
 
 
7
  # ------------------------------------------------------------------
8
  # Load schema once
9
  # ------------------------------------------------------------------
 
15
 
16
  FEATURES = SCHEMA["features"]
17
 
 
18
  # ------------------------------------------------------------------
19
  # Helper mappings
20
  # ------------------------------------------------------------------
 
24
  "negative": -1.0,
25
  "variable": 0.5,
26
  "unknown": 0.0,
27
+ None: 0.0,
28
  }
29
 
30
  SHAPE_MAP = {
 
34
  "spiral": 3.0,
35
  "yeast": 4.0,
36
  "variable": 0.5,
37
+ "unknown": 0.0,
38
  }
39
 
40
  OXYGEN_MAP = {
 
43
  "facultative anaerobe": 3.0,
44
  "microaerophilic": 4.0,
45
  "capnophilic": 5.0,
46
+ "unknown": 0.0,
47
  }
48
 
 
49
  # ------------------------------------------------------------------
50
  # Normalisation helpers
51
  # ------------------------------------------------------------------
52
 
53
+
54
  def _norm(s: Any) -> str:
55
  if not s:
56
  return "unknown"
 
69
  return OXYGEN_MAP.get(_norm(x), 0.0)
70
 
71
 
72
+ def _growth_minmax(v: Any) -> tuple[float, float]:
73
  """
74
  Convert '30//37' β†’ (30, 37)
75
+ If missing or malformed, return (0, 0).
76
  """
77
  if not v:
78
  return (0.0, 0.0)
79
+ if not isinstance(v, str):
80
+ v = str(v)
81
  m = re.match(r"^\s*(\d+)\s*//\s*(\d+)\s*$", v)
82
  if not m:
83
  return (0.0, 0.0)
84
  return (float(m.group(1)), float(m.group(2)))
85
 
86
 
87
+ def _media_flag(media_field: Any, medium: str) -> float:
88
  """
89
+ Return 1.0 if 'medium' appears in 'Media Grown On' field, else 0.0.
90
+ e.g. medium='MacConkey Agar' and media_field='Blood Agar; MacConkey Agar'
91
  """
92
  if not media_field:
93
  return 0.0
94
+ mf = str(media_field).lower()
95
  return 1.0 if medium.lower() in mf else 0.0
96
 
97
 
 
99
  # Main public function
100
  # ------------------------------------------------------------------
101
 
102
+
103
  def extract_feature_vector(fused_fields: Dict[str, Any]) -> np.ndarray:
104
  """
105
  Convert fused tri-fusion fields into a fixed-length numeric vector.
106
  Ordered exactly according to feature_schema.json.
107
+
108
+ Missing/unknown β†’ 0.0.
109
  """
110
  vec: List[float] = []
111
 
112
+ # We iterate over schema in order and push one numeric value per feature
113
+ # (except Growth Temp Min/Max, which split the single 'Growth Temperature'
114
+ # field into two numeric components).
115
+ growth_value = fused_fields.get("Growth Temperature")
116
+
117
  for f in FEATURES:
118
  name = f["name"]
119
  kind = f["kind"]
120
 
121
+ # Default: look up by the same name in fused_fields
122
  value = fused_fields.get(name, "Unknown")
123
 
124
  if kind == "pnv":
 
131
  vec.append(_map_oxygen(value))
132
 
133
  elif kind == "numeric_from_growth_temp":
134
+ # We assume:
135
+ # - "Growth Temp Min" β†’ min
136
+ # - "Growth Temp Max" β†’ max
137
+ low, high = _growth_minmax(growth_value)
138
+ if "min" in name.lower():
139
+ vec.append(low)
140
+ elif "max" in name.lower():
141
+ vec.append(high)
142
+ else:
143
+ # Fallback: just append 0 if schema name is unexpected
144
+ vec.append(0.0)
145
 
146
  elif kind == "media_flag":
147
+ # For media flags we derive the medium name from the feature name.
148
+ # e.g. "MacConkey Growth" β†’ medium="MacConkey"
149
+ # "Nutrient Growth" β†’ medium="Nutrient"
150
  medium = name.replace("Growth", "").strip()
151
+ media_field = fused_fields.get("Media Grown On")
152
+ vec.append(_media_flag(media_field, medium))
153
 
154
  else:
155
+ # Unknown kind β†’ default numeric 0.0
156
  vec.append(0.0)
157
 
158
  return np.array(vec, dtype=float)