EphAsad commited on
Commit
02a1f10
·
verified ·
1 Parent(s): 4811a82

Update engine/train_genus_model.py

Browse files
Files changed (1) hide show
  1. engine/train_genus_model.py +90 -126
engine/train_genus_model.py CHANGED
@@ -2,27 +2,16 @@
2
  """
3
  Train a genus-level classifier (XGBoost) from gold tests.
4
 
5
- Inputs:
6
- - training/gold_tests.json
7
- Each item should have:
8
- - "name" (e.g. "Salmonella enterica")
9
- - and a dict of expected fields:
10
- "fields" or "expected_fields" or "schema" or "expected"
11
-
12
- We:
13
- 1) Extract the genus from "name" (first token).
14
- 2) Turn expected fields into an ML feature vector via engine.features.extract_feature_vector.
15
- 3) Train an XGBoost multi-class classifier (one class per genus).
16
- 4) Save:
17
- models/genus_xgb.json (the model)
18
- models/genus_xgb_meta.json (label map + feature meta)
19
-
20
- This module exposes:
21
-
22
- train_genus_model() -> dict summary
23
-
24
- so the Gradio UI can call it and display the JSON summary, and also keeps a
25
- CLI entry via `python -m engine.train_genus_model` or direct execution.
26
  """
27
 
28
  from __future__ import annotations
@@ -37,7 +26,10 @@ import xgboost as xgb
37
 
38
  from .features import extract_feature_vector, FEATURES
39
 
 
40
  # Paths
 
 
41
  GOLD_TESTS_PATH = "training/gold_tests.json"
42
  MODEL_DIR = "models"
43
  MODEL_PATH = os.path.join(MODEL_DIR, "genus_xgb.json")
@@ -45,35 +37,44 @@ META_PATH = os.path.join(MODEL_DIR, "genus_xgb_meta.json")
45
 
46
 
47
  # ---------------------------------------------------------------------------
48
- # Helpers
49
  # ---------------------------------------------------------------------------
50
 
51
-
52
  def _load_gold_tests(path: str) -> List[Dict[str, Any]]:
 
 
 
53
  with open(path, "r", encoding="utf-8") as f:
54
  data = json.load(f)
 
55
  if not isinstance(data, list):
56
- raise ValueError("gold_tests.json should contain a list of samples.")
 
57
  return data
58
 
59
 
 
 
 
 
60
  def _extract_genus(sample: Dict[str, Any]) -> str | None:
61
  """
62
- Get genus from sample["name"] / ["Name"] / ["organism"] etc.
63
- We just take the first word.
 
64
  """
65
  for key in ("name", "Name", "organism", "Organism"):
66
  if key in sample and sample[key]:
67
- text = str(sample[key]).strip()
68
- if not text:
69
- continue
70
- return text.split()[0]
71
  return None
72
 
73
 
74
  def _extract_fields(sample: Dict[str, Any]) -> Dict[str, Any]:
75
  """
76
- Try several possible keys for the expected fields in gold_tests.json.
 
77
  """
78
  for key in ("fields", "expected_fields", "schema", "expected"):
79
  if key in sample and isinstance(sample[key], dict):
@@ -81,19 +82,19 @@ def _extract_fields(sample: Dict[str, Any]) -> Dict[str, Any]:
81
  return {}
82
 
83
 
84
- def _build_dataset(
85
- samples: List[Dict[str, Any]]
86
- ) -> Tuple[np.ndarray, np.ndarray, Dict[str, int]]:
87
- """
88
- Build X (features) and y (integer genus labels) from gold tests.
89
 
90
- Returns:
91
- X: (N, D) feature matrix
92
- y: (N,) integer labels
93
- genus_to_idx: mapping from genus string → class index
94
  """
95
- X: List[np.ndarray] = []
96
- y: List[int] = []
 
 
 
 
 
97
  genus_to_idx: Dict[str, int] = {}
98
 
99
  for sample in samples:
@@ -103,41 +104,41 @@ def _build_dataset(
103
 
104
  fields = _extract_fields(sample)
105
  if not fields:
106
- # No expected fields for this sample → skip
107
  continue
108
 
109
- # Convert expected fields to feature vector
110
  vec = extract_feature_vector(fields)
111
 
112
- # Map genus to class index
113
  if genus not in genus_to_idx:
114
  genus_to_idx[genus] = len(genus_to_idx)
115
- label = genus_to_idx[genus]
116
 
117
- X.append(vec)
118
- y.append(label)
 
 
 
119
 
120
- if not X:
121
- raise ValueError("No usable samples found in gold_tests.json.")
122
 
123
- X_arr = np.vstack(X)
124
- y_arr = np.array(y, dtype=np.int32)
125
 
126
- return X_arr, y_arr, genus_to_idx
127
 
 
 
 
128
 
129
  def _train_xgboost(
130
  X: np.ndarray,
131
  y: np.ndarray,
132
  num_classes: int,
133
- seed: int = 42,
134
  ) -> Tuple[xgb.Booster, Dict[str, float]]:
135
  """
136
- Train an XGBoost multi-class classifier with a simple train/valid split.
137
- Returns:
138
- model, metrics_dict
139
  """
140
- # Train/valid split (80/20)
141
  n = X.shape[0]
142
  indices = list(range(n))
143
  random.Random(seed).shuffle(indices)
@@ -156,10 +157,10 @@ def _train_xgboost(
156
  "objective": "multi:softprob",
157
  "num_class": num_classes,
158
  "eval_metric": "mlogloss",
159
- "max_depth": 5,
160
- "eta": 0.1,
161
- "subsample": 0.8,
162
- "colsample_bytree": 0.8,
163
  "min_child_weight": 1,
164
  "seed": seed,
165
  }
@@ -169,30 +170,29 @@ def _train_xgboost(
169
  model = xgb.train(
170
  params,
171
  dtrain,
172
- num_boost_round=200,
173
  evals=evals,
174
- early_stopping_rounds=20,
175
- verbose_eval=25,
 
176
  )
177
 
178
- # Simple accuracy on train/valid
179
- train_pred = np.argmax(model.predict(dtrain), axis=1)
180
- valid_pred = np.argmax(model.predict(dvalid), axis=1)
181
-
182
- train_acc = float((train_pred == y_train).mean())
183
- valid_acc = float((valid_pred == y_valid).mean())
 
184
 
185
- metrics = {
186
  "train_accuracy": train_acc,
187
  "valid_accuracy": valid_acc,
188
  "best_iteration": int(model.best_iteration),
189
  }
190
 
191
- return model, metrics
192
-
193
 
194
- def _ensure_model_dir() -> None:
195
- if not os.path.isdir(MODEL_DIR):
196
  os.makedirs(MODEL_DIR, exist_ok=True)
197
 
198
 
@@ -200,49 +200,29 @@ def _ensure_model_dir() -> None:
200
  # Public entry for UI
201
  # ---------------------------------------------------------------------------
202
 
203
-
204
  def train_genus_model() -> Dict[str, Any]:
205
- """
206
- Public function used by the Gradio UI.
207
-
208
- Returns a JSON-serialisable dict, e.g.:
209
-
210
- {
211
- "ok": true,
212
- "message": "...",
213
- "stats": {...},
214
- "metrics": {...},
215
- "paths": {...},
216
- "class_count": 42,
217
- "genus_examples": ["Salmonella", "Staphylococcus", ...]
218
- }
219
- """
220
  try:
221
- print("Loading gold tests from:", GOLD_TESTS_PATH)
222
  samples = _load_gold_tests(GOLD_TESTS_PATH)
223
- print(f"Loaded {len(samples)} gold samples (raw).")
224
 
225
- print("Building dataset...")
226
  X, y, genus_to_idx = _build_dataset(samples)
 
227
  num_classes = len(genus_to_idx)
228
- print(f"Usable samples: {X.shape[0]}")
229
  print(f"Feature dimension: {X.shape[1]}")
230
- print(f"Distinct genera (classes): {num_classes}")
 
231
 
232
- print("Training XGBoost genus classifier...")
233
- model, metrics = _train_xgboost(X, y, num_classes=num_classes)
234
 
235
  print("Training complete.")
236
  print(f"Train accuracy: {metrics['train_accuracy']:.3f}")
237
  print(f"Valid accuracy: {metrics['valid_accuracy']:.3f}")
238
- print(f"Best iteration: {metrics['best_iteration']}")
239
 
240
  _ensure_model_dir()
241
-
242
- print("Saving model to:", MODEL_PATH)
243
  model.save_model(MODEL_PATH)
244
 
245
- # Build index → genus map
246
  idx_to_genus = {idx: genus for genus, idx in genus_to_idx.items()}
247
 
248
  meta = {
@@ -255,18 +235,12 @@ def train_genus_model() -> Dict[str, Any]:
255
  "feature_names": [f["name"] for f in FEATURES],
256
  }
257
 
258
- print("Saving meta to:", META_PATH)
259
  with open(META_PATH, "w", encoding="utf-8") as f:
260
  json.dump(meta, f, indent=2, ensure_ascii=False)
261
 
262
- print("Done.")
263
-
264
- # Compact summary for the UI
265
- genus_examples = sorted(list(genus_to_idx.keys()))[:20]
266
-
267
  return {
268
  "ok": True,
269
- "message": "Genus XGBoost model trained and saved successfully.",
270
  "stats": {
271
  "num_raw_samples": len(samples),
272
  "num_usable_samples": int(X.shape[0]),
@@ -274,19 +248,14 @@ def train_genus_model() -> Dict[str, Any]:
274
  "num_classes": int(num_classes),
275
  },
276
  "metrics": metrics,
277
- "paths": {
278
- "model_path": MODEL_PATH,
279
- "meta_path": META_PATH,
280
- },
281
- "class_count": int(num_classes),
282
- "genus_examples": genus_examples,
283
  }
284
 
285
  except Exception as e:
286
- # If anything blows up, return a clean error for the UI JSON
287
  return {
288
  "ok": False,
289
- "message": f"Error during genus model training: {type(e).__name__}: {e}",
290
  }
291
 
292
 
@@ -294,14 +263,9 @@ def train_genus_model() -> Dict[str, Any]:
294
  # CLI entry
295
  # ---------------------------------------------------------------------------
296
 
297
-
298
- def main() -> None:
299
- """
300
- Keep a CLI entry that prints the same summary.
301
- """
302
- summary = train_genus_model()
303
- print(json.dumps(summary, indent=2, ensure_ascii=False))
304
 
305
 
306
  if __name__ == "__main__":
307
- main()
 
2
  """
3
  Train a genus-level classifier (XGBoost) from gold tests.
4
 
5
+ Pipeline:
6
+ Load gold_tests.json
7
+ Extract genus (first token of organism name)
8
+ Convert expected_fields → feature vector (via engine.features.extract_feature_vector)
9
+ Train an XGBoost multi-class classifier
10
+ Save:
11
+ models/genus_xgb.json
12
+ models/genus_xgb_meta.json
13
+
14
+ Compatible with FEATURE SCHEMA v2 (category, binary temperature flags, pigment, odor, colony pattern, TSI, etc.)
 
 
 
 
 
 
 
 
 
 
 
15
  """
16
 
17
  from __future__ import annotations
 
26
 
27
  from .features import extract_feature_vector, FEATURES
28
 
29
+ # ---------------------------------------------------------------------------
30
  # Paths
31
+ # ---------------------------------------------------------------------------
32
+
33
  GOLD_TESTS_PATH = "training/gold_tests.json"
34
  MODEL_DIR = "models"
35
  MODEL_PATH = os.path.join(MODEL_DIR, "genus_xgb.json")
 
37
 
38
 
39
  # ---------------------------------------------------------------------------
40
+ # Load gold tests
41
  # ---------------------------------------------------------------------------
42
 
 
43
  def _load_gold_tests(path: str) -> List[Dict[str, Any]]:
44
+ if not os.path.exists(path):
45
+ raise FileNotFoundError(f"Missing gold test file: {path}")
46
+
47
  with open(path, "r", encoding="utf-8") as f:
48
  data = json.load(f)
49
+
50
  if not isinstance(data, list):
51
+ raise ValueError("gold_tests.json must contain a list.")
52
+
53
  return data
54
 
55
 
56
+ # ---------------------------------------------------------------------------
57
+ # Extract genus & expected fields
58
+ # ---------------------------------------------------------------------------
59
+
60
  def _extract_genus(sample: Dict[str, Any]) -> str | None:
61
  """
62
+ Extract genus from:
63
+ name / Name / organism / Organism
64
+ (genus = first token before space)
65
  """
66
  for key in ("name", "Name", "organism", "Organism"):
67
  if key in sample and sample[key]:
68
+ val = str(sample[key]).strip()
69
+ if val:
70
+ return val.split()[0]
 
71
  return None
72
 
73
 
74
  def _extract_fields(sample: Dict[str, Any]) -> Dict[str, Any]:
75
  """
76
+ Extract expected field dict from any of:
77
+ fields / expected_fields / schema / expected
78
  """
79
  for key in ("fields", "expected_fields", "schema", "expected"):
80
  if key in sample and isinstance(sample[key], dict):
 
82
  return {}
83
 
84
 
85
+ # ---------------------------------------------------------------------------
86
+ # Dataset builder
87
+ # ---------------------------------------------------------------------------
 
 
88
 
89
+ def _build_dataset(samples: List[Dict[str, Any]]) -> Tuple[np.ndarray, np.ndarray, Dict[str, int]]:
 
 
 
90
  """
91
+ Convert gold tests into:
92
+ X → feature matrix
93
+ y → integer labels
94
+ genus_to_idx → mapping
95
+ """
96
+ X_list: List[np.ndarray] = []
97
+ y_list: List[int] = []
98
  genus_to_idx: Dict[str, int] = {}
99
 
100
  for sample in samples:
 
104
 
105
  fields = _extract_fields(sample)
106
  if not fields:
 
107
  continue
108
 
109
+ # Generate ML feature vector (schema v2)
110
  vec = extract_feature_vector(fields)
111
 
 
112
  if genus not in genus_to_idx:
113
  genus_to_idx[genus] = len(genus_to_idx)
 
114
 
115
+ X_list.append(vec)
116
+ y_list.append(genus_to_idx[genus])
117
+
118
+ if not X_list:
119
+ raise ValueError("No usable gold tests found.")
120
 
121
+ X = np.vstack(X_list)
122
+ y = np.array(y_list, dtype=np.int32)
123
 
124
+ return X, y, genus_to_idx
 
125
 
 
126
 
127
+ # ---------------------------------------------------------------------------
128
+ # Train XGBoost model
129
+ # ---------------------------------------------------------------------------
130
 
131
  def _train_xgboost(
132
  X: np.ndarray,
133
  y: np.ndarray,
134
  num_classes: int,
135
+ seed: int = 42
136
  ) -> Tuple[xgb.Booster, Dict[str, float]]:
137
  """
138
+ Train a multi-class XGBoost classifier.
139
+ 80/20 split.
 
140
  """
141
+
142
  n = X.shape[0]
143
  indices = list(range(n))
144
  random.Random(seed).shuffle(indices)
 
157
  "objective": "multi:softprob",
158
  "num_class": num_classes,
159
  "eval_metric": "mlogloss",
160
+ "max_depth": 6, # Higher depth since schema v2 more complex
161
+ "eta": 0.08, # Slightly slower learning
162
+ "subsample": 0.9,
163
+ "colsample_bytree": 0.9,
164
  "min_child_weight": 1,
165
  "seed": seed,
166
  }
 
170
  model = xgb.train(
171
  params,
172
  dtrain,
 
173
  evals=evals,
174
+ num_boost_round=500, # More rounds since more features
175
+ early_stopping_rounds=40, # Allow more patience for complex space
176
+ verbose_eval=50,
177
  )
178
 
179
+ # Accuracy evaluation
180
+ train_acc = float(
181
+ (np.argmax(model.predict(dtrain), axis=1) == y_train).mean()
182
+ )
183
+ valid_acc = float(
184
+ (np.argmax(model.predict(dvalid), axis=1) == y_valid).mean()
185
+ )
186
 
187
+ return model, {
188
  "train_accuracy": train_acc,
189
  "valid_accuracy": valid_acc,
190
  "best_iteration": int(model.best_iteration),
191
  }
192
 
 
 
193
 
194
+ def _ensure_model_dir():
195
+ if not os.path.exists(MODEL_DIR):
196
  os.makedirs(MODEL_DIR, exist_ok=True)
197
 
198
 
 
200
  # Public entry for UI
201
  # ---------------------------------------------------------------------------
202
 
 
203
  def train_genus_model() -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  try:
205
+ print(f"Loading gold tests {GOLD_TESTS_PATH}")
206
  samples = _load_gold_tests(GOLD_TESTS_PATH)
 
207
 
208
+ print("Building ML dataset...")
209
  X, y, genus_to_idx = _build_dataset(samples)
210
+
211
  num_classes = len(genus_to_idx)
 
212
  print(f"Feature dimension: {X.shape[1]}")
213
+ print(f"Classes (genera): {num_classes}")
214
+ print(f"Samples: {X.shape[0]}")
215
 
216
+ print("Training XGBoost (schema v2)...")
217
+ model, metrics = _train_xgboost(X, y, num_classes)
218
 
219
  print("Training complete.")
220
  print(f"Train accuracy: {metrics['train_accuracy']:.3f}")
221
  print(f"Valid accuracy: {metrics['valid_accuracy']:.3f}")
 
222
 
223
  _ensure_model_dir()
 
 
224
  model.save_model(MODEL_PATH)
225
 
 
226
  idx_to_genus = {idx: genus for genus, idx in genus_to_idx.items()}
227
 
228
  meta = {
 
235
  "feature_names": [f["name"] for f in FEATURES],
236
  }
237
 
 
238
  with open(META_PATH, "w", encoding="utf-8") as f:
239
  json.dump(meta, f, indent=2, ensure_ascii=False)
240
 
 
 
 
 
 
241
  return {
242
  "ok": True,
243
+ "message": "Genus XGBoost model (schema v2) trained successfully.",
244
  "stats": {
245
  "num_raw_samples": len(samples),
246
  "num_usable_samples": int(X.shape[0]),
 
248
  "num_classes": int(num_classes),
249
  },
250
  "metrics": metrics,
251
+ "paths": {"model_path": MODEL_PATH, "meta_path": META_PATH},
252
+ "genus_examples": sorted(genus_to_idx.keys())[:20],
 
 
 
 
253
  }
254
 
255
  except Exception as e:
 
256
  return {
257
  "ok": False,
258
+ "message": f"Training error: {type(e).__name__}: {e}",
259
  }
260
 
261
 
 
263
  # CLI entry
264
  # ---------------------------------------------------------------------------
265
 
266
+ def main():
267
+ print(json.dumps(train_genus_model(), indent=2, ensure_ascii=False))
 
 
 
 
 
268
 
269
 
270
  if __name__ == "__main__":
271
+ main()