EphAsad commited on
Commit
c9c2c73
·
verified ·
1 Parent(s): 384f774

Create features.py

Browse files
Files changed (1) hide show
  1. engine/features.py +142 -0
engine/features.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # engine/features.py
2
+ import json
3
+ import numpy as np
4
+ import re
5
+ from typing import Dict, List, Any
6
+
7
+
8
+ # ------------------------------------------------------------------
9
+ # Load schema once
10
+ # ------------------------------------------------------------------
11
+
12
+ _FEATURE_SCHEMA_PATH = "data/feature_schema.json"
13
+
14
+ with open(_FEATURE_SCHEMA_PATH, "r", encoding="utf-8") as f:
15
+ SCHEMA = json.load(f)
16
+
17
+ FEATURES = SCHEMA["features"]
18
+
19
+
20
+ # ------------------------------------------------------------------
21
+ # Helper mappings
22
+ # ------------------------------------------------------------------
23
+
24
+ PNV_MAP = {
25
+ "positive": 1.0,
26
+ "negative": -1.0,
27
+ "variable": 0.5,
28
+ "unknown": 0.0,
29
+ None: 0.0
30
+ }
31
+
32
+ SHAPE_MAP = {
33
+ "cocci": 1.0,
34
+ "rods": 2.0,
35
+ "short rods": 2.5,
36
+ "spiral": 3.0,
37
+ "yeast": 4.0,
38
+ "variable": 0.5,
39
+ "unknown": 0.0
40
+ }
41
+
42
+ OXYGEN_MAP = {
43
+ "aerobic": 1.0,
44
+ "anaerobic": 2.0,
45
+ "facultative anaerobe": 3.0,
46
+ "microaerophilic": 4.0,
47
+ "capnophilic": 5.0,
48
+ "unknown": 0.0
49
+ }
50
+
51
+
52
+ # ------------------------------------------------------------------
53
+ # Normalisation helpers
54
+ # ------------------------------------------------------------------
55
+
56
+ def _norm(s: Any) -> str:
57
+ if not s:
58
+ return "unknown"
59
+ return str(s).strip().lower()
60
+
61
+
62
+ def _map_pnv(x: Any) -> float:
63
+ return PNV_MAP.get(_norm(x), 0.0)
64
+
65
+
66
+ def _map_shape(x: Any) -> float:
67
+ return SHAPE_MAP.get(_norm(x), 0.0)
68
+
69
+
70
+ def _map_oxygen(x: Any) -> float:
71
+ return OXYGEN_MAP.get(_norm(x), 0.0)
72
+
73
+
74
+ def _growth_minmax(v: str):
75
+ """
76
+ Convert '30//37' → (30, 37)
77
+ If missing, return (0, 0)
78
+ """
79
+ if not v:
80
+ return (0.0, 0.0)
81
+ m = re.match(r"^\s*(\d+)\s*//\s*(\d+)\s*$", v)
82
+ if not m:
83
+ return (0.0, 0.0)
84
+ return (float(m.group(1)), float(m.group(2)))
85
+
86
+
87
+ def _media_flag(media_field: str, medium: str) -> float:
88
+ """
89
+ Return 1.0 if medium appears in media list, else 0.0.
90
+ """
91
+ if not media_field:
92
+ return 0.0
93
+ mf = media_field.lower()
94
+ return 1.0 if medium.lower() in mf else 0.0
95
+
96
+
97
+ # ------------------------------------------------------------------
98
+ # Main public function
99
+ # ------------------------------------------------------------------
100
+
101
+ def extract_feature_vector(fused_fields: Dict[str, Any]) -> np.ndarray:
102
+ """
103
+ Convert fused tri-fusion fields into a fixed-length numeric vector.
104
+ Ordered exactly according to feature_schema.json.
105
+ Unknowns → 0.0.
106
+ """
107
+ vec: List[float] = []
108
+
109
+ for f in FEATURES:
110
+ name = f["name"]
111
+ kind = f["kind"]
112
+
113
+ value = fused_fields.get(name, "Unknown")
114
+
115
+ if kind == "pnv":
116
+ vec.append(_map_pnv(value))
117
+
118
+ elif kind == "shape":
119
+ vec.append(_map_shape(value))
120
+
121
+ elif kind == "oxygen":
122
+ vec.append(_map_oxygen(value))
123
+
124
+ elif kind == "numeric_from_growth_temp":
125
+ low, high = _growth_minmax(value)
126
+ vec.append(low)
127
+ vec.append(high)
128
+ # IMPORTANT: skip the next schema feature
129
+ # (schema should include two entries but model expects two values)
130
+ continue
131
+
132
+ elif kind == "media_flag":
133
+ # Each media entry in schema specifies the medium name
134
+ # e.g. "MacConkey Growth"
135
+ medium = name.replace("Growth", "").strip()
136
+ vec.append(_media_flag(fused_fields.get("Media Grown On"), medium))
137
+
138
+ else:
139
+ # Unknown kind → default numeric 0
140
+ vec.append(0.0)
141
+
142
+ return np.array(vec, dtype=float)