BactKing / engine /train_genus_model.py
EphAsad's picture
Update engine/train_genus_model.py
02a1f10 verified
# engine/train_genus_model.py
"""
Train a genus-level classifier (XGBoost) from gold tests.
Pipeline:
• Load gold_tests.json
• Extract genus (first token of organism name)
• Convert expected_fields → feature vector (via engine.features.extract_feature_vector)
• Train an XGBoost multi-class classifier
• Save:
models/genus_xgb.json
models/genus_xgb_meta.json
Compatible with FEATURE SCHEMA v2 (category, binary temperature flags, pigment, odor, colony pattern, TSI, etc.)
"""
from __future__ import annotations
import json
import os
import random
from typing import Any, Dict, List, Tuple
import numpy as np
import xgboost as xgb
from .features import extract_feature_vector, FEATURES
# ---------------------------------------------------------------------------
# Paths
# ---------------------------------------------------------------------------
GOLD_TESTS_PATH = "training/gold_tests.json"
MODEL_DIR = "models"
MODEL_PATH = os.path.join(MODEL_DIR, "genus_xgb.json")
META_PATH = os.path.join(MODEL_DIR, "genus_xgb_meta.json")
# ---------------------------------------------------------------------------
# Load gold tests
# ---------------------------------------------------------------------------
def _load_gold_tests(path: str) -> List[Dict[str, Any]]:
if not os.path.exists(path):
raise FileNotFoundError(f"Missing gold test file: {path}")
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("gold_tests.json must contain a list.")
return data
# ---------------------------------------------------------------------------
# Extract genus & expected fields
# ---------------------------------------------------------------------------
def _extract_genus(sample: Dict[str, Any]) -> str | None:
"""
Extract genus from:
name / Name / organism / Organism
(genus = first token before space)
"""
for key in ("name", "Name", "organism", "Organism"):
if key in sample and sample[key]:
val = str(sample[key]).strip()
if val:
return val.split()[0]
return None
def _extract_fields(sample: Dict[str, Any]) -> Dict[str, Any]:
"""
Extract expected field dict from any of:
fields / expected_fields / schema / expected
"""
for key in ("fields", "expected_fields", "schema", "expected"):
if key in sample and isinstance(sample[key], dict):
return sample[key]
return {}
# ---------------------------------------------------------------------------
# Dataset builder
# ---------------------------------------------------------------------------
def _build_dataset(samples: List[Dict[str, Any]]) -> Tuple[np.ndarray, np.ndarray, Dict[str, int]]:
"""
Convert gold tests into:
X → feature matrix
y → integer labels
genus_to_idx → mapping
"""
X_list: List[np.ndarray] = []
y_list: List[int] = []
genus_to_idx: Dict[str, int] = {}
for sample in samples:
genus = _extract_genus(sample)
if not genus:
continue
fields = _extract_fields(sample)
if not fields:
continue
# Generate ML feature vector (schema v2)
vec = extract_feature_vector(fields)
if genus not in genus_to_idx:
genus_to_idx[genus] = len(genus_to_idx)
X_list.append(vec)
y_list.append(genus_to_idx[genus])
if not X_list:
raise ValueError("No usable gold tests found.")
X = np.vstack(X_list)
y = np.array(y_list, dtype=np.int32)
return X, y, genus_to_idx
# ---------------------------------------------------------------------------
# Train XGBoost model
# ---------------------------------------------------------------------------
def _train_xgboost(
X: np.ndarray,
y: np.ndarray,
num_classes: int,
seed: int = 42
) -> Tuple[xgb.Booster, Dict[str, float]]:
"""
Train a multi-class XGBoost classifier.
80/20 split.
"""
n = X.shape[0]
indices = list(range(n))
random.Random(seed).shuffle(indices)
split = int(0.8 * n)
train_idx = indices[:split]
valid_idx = indices[split:]
X_train, y_train = X[train_idx], y[train_idx]
X_valid, y_valid = X[valid_idx], y[valid_idx]
dtrain = xgb.DMatrix(X_train, label=y_train)
dvalid = xgb.DMatrix(X_valid, label=y_valid)
params = {
"objective": "multi:softprob",
"num_class": num_classes,
"eval_metric": "mlogloss",
"max_depth": 6, # Higher depth since schema v2 more complex
"eta": 0.08, # Slightly slower learning
"subsample": 0.9,
"colsample_bytree": 0.9,
"min_child_weight": 1,
"seed": seed,
}
evals = [(dtrain, "train"), (dvalid, "valid")]
model = xgb.train(
params,
dtrain,
evals=evals,
num_boost_round=500, # More rounds since more features
early_stopping_rounds=40, # Allow more patience for complex space
verbose_eval=50,
)
# Accuracy evaluation
train_acc = float(
(np.argmax(model.predict(dtrain), axis=1) == y_train).mean()
)
valid_acc = float(
(np.argmax(model.predict(dvalid), axis=1) == y_valid).mean()
)
return model, {
"train_accuracy": train_acc,
"valid_accuracy": valid_acc,
"best_iteration": int(model.best_iteration),
}
def _ensure_model_dir():
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR, exist_ok=True)
# ---------------------------------------------------------------------------
# Public entry for UI
# ---------------------------------------------------------------------------
def train_genus_model() -> Dict[str, Any]:
try:
print(f"Loading gold tests → {GOLD_TESTS_PATH}")
samples = _load_gold_tests(GOLD_TESTS_PATH)
print("Building ML dataset...")
X, y, genus_to_idx = _build_dataset(samples)
num_classes = len(genus_to_idx)
print(f"Feature dimension: {X.shape[1]}")
print(f"Classes (genera): {num_classes}")
print(f"Samples: {X.shape[0]}")
print("Training XGBoost (schema v2)...")
model, metrics = _train_xgboost(X, y, num_classes)
print("Training complete.")
print(f"Train accuracy: {metrics['train_accuracy']:.3f}")
print(f"Valid accuracy: {metrics['valid_accuracy']:.3f}")
_ensure_model_dir()
model.save_model(MODEL_PATH)
idx_to_genus = {idx: genus for genus, idx in genus_to_idx.items()}
meta = {
"genus_to_idx": genus_to_idx,
"idx_to_genus": idx_to_genus,
"n_features": int(X.shape[1]),
"num_classes": int(num_classes),
"metrics": metrics,
"feature_schema_path": "data/feature_schema.json",
"feature_names": [f["name"] for f in FEATURES],
}
with open(META_PATH, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2, ensure_ascii=False)
return {
"ok": True,
"message": "Genus XGBoost model (schema v2) trained successfully.",
"stats": {
"num_raw_samples": len(samples),
"num_usable_samples": int(X.shape[0]),
"feature_dim": int(X.shape[1]),
"num_classes": int(num_classes),
},
"metrics": metrics,
"paths": {"model_path": MODEL_PATH, "meta_path": META_PATH},
"genus_examples": sorted(genus_to_idx.keys())[:20],
}
except Exception as e:
return {
"ok": False,
"message": f"Training error: {type(e).__name__}: {e}",
}
# ---------------------------------------------------------------------------
# CLI entry
# ---------------------------------------------------------------------------
def main():
print(json.dumps(train_genus_model(), indent=2, ensure_ascii=False))
if __name__ == "__main__":
main()