Spaces:
Sleeping
Sleeping
| # engine/train_genus_model.py | |
| """ | |
| Train a genus-level classifier (XGBoost) from gold tests. | |
| Pipeline: | |
| • Load gold_tests.json | |
| • Extract genus (first token of organism name) | |
| • Convert expected_fields → feature vector (via engine.features.extract_feature_vector) | |
| • Train an XGBoost multi-class classifier | |
| • Save: | |
| models/genus_xgb.json | |
| models/genus_xgb_meta.json | |
| Compatible with FEATURE SCHEMA v2 (category, binary temperature flags, pigment, odor, colony pattern, TSI, etc.) | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import random | |
| from typing import Any, Dict, List, Tuple | |
| import numpy as np | |
| import xgboost as xgb | |
| from .features import extract_feature_vector, FEATURES | |
| # --------------------------------------------------------------------------- | |
| # Paths | |
| # --------------------------------------------------------------------------- | |
| GOLD_TESTS_PATH = "training/gold_tests.json" | |
| MODEL_DIR = "models" | |
| MODEL_PATH = os.path.join(MODEL_DIR, "genus_xgb.json") | |
| META_PATH = os.path.join(MODEL_DIR, "genus_xgb_meta.json") | |
| # --------------------------------------------------------------------------- | |
| # Load gold tests | |
| # --------------------------------------------------------------------------- | |
| def _load_gold_tests(path: str) -> List[Dict[str, Any]]: | |
| if not os.path.exists(path): | |
| raise FileNotFoundError(f"Missing gold test file: {path}") | |
| with open(path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| if not isinstance(data, list): | |
| raise ValueError("gold_tests.json must contain a list.") | |
| return data | |
| # --------------------------------------------------------------------------- | |
| # Extract genus & expected fields | |
| # --------------------------------------------------------------------------- | |
| def _extract_genus(sample: Dict[str, Any]) -> str | None: | |
| """ | |
| Extract genus from: | |
| name / Name / organism / Organism | |
| (genus = first token before space) | |
| """ | |
| for key in ("name", "Name", "organism", "Organism"): | |
| if key in sample and sample[key]: | |
| val = str(sample[key]).strip() | |
| if val: | |
| return val.split()[0] | |
| return None | |
| def _extract_fields(sample: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Extract expected field dict from any of: | |
| fields / expected_fields / schema / expected | |
| """ | |
| for key in ("fields", "expected_fields", "schema", "expected"): | |
| if key in sample and isinstance(sample[key], dict): | |
| return sample[key] | |
| return {} | |
| # --------------------------------------------------------------------------- | |
| # Dataset builder | |
| # --------------------------------------------------------------------------- | |
| def _build_dataset(samples: List[Dict[str, Any]]) -> Tuple[np.ndarray, np.ndarray, Dict[str, int]]: | |
| """ | |
| Convert gold tests into: | |
| X → feature matrix | |
| y → integer labels | |
| genus_to_idx → mapping | |
| """ | |
| X_list: List[np.ndarray] = [] | |
| y_list: List[int] = [] | |
| genus_to_idx: Dict[str, int] = {} | |
| for sample in samples: | |
| genus = _extract_genus(sample) | |
| if not genus: | |
| continue | |
| fields = _extract_fields(sample) | |
| if not fields: | |
| continue | |
| # Generate ML feature vector (schema v2) | |
| vec = extract_feature_vector(fields) | |
| if genus not in genus_to_idx: | |
| genus_to_idx[genus] = len(genus_to_idx) | |
| X_list.append(vec) | |
| y_list.append(genus_to_idx[genus]) | |
| if not X_list: | |
| raise ValueError("No usable gold tests found.") | |
| X = np.vstack(X_list) | |
| y = np.array(y_list, dtype=np.int32) | |
| return X, y, genus_to_idx | |
| # --------------------------------------------------------------------------- | |
| # Train XGBoost model | |
| # --------------------------------------------------------------------------- | |
| def _train_xgboost( | |
| X: np.ndarray, | |
| y: np.ndarray, | |
| num_classes: int, | |
| seed: int = 42 | |
| ) -> Tuple[xgb.Booster, Dict[str, float]]: | |
| """ | |
| Train a multi-class XGBoost classifier. | |
| 80/20 split. | |
| """ | |
| n = X.shape[0] | |
| indices = list(range(n)) | |
| random.Random(seed).shuffle(indices) | |
| split = int(0.8 * n) | |
| train_idx = indices[:split] | |
| valid_idx = indices[split:] | |
| X_train, y_train = X[train_idx], y[train_idx] | |
| X_valid, y_valid = X[valid_idx], y[valid_idx] | |
| dtrain = xgb.DMatrix(X_train, label=y_train) | |
| dvalid = xgb.DMatrix(X_valid, label=y_valid) | |
| params = { | |
| "objective": "multi:softprob", | |
| "num_class": num_classes, | |
| "eval_metric": "mlogloss", | |
| "max_depth": 6, # Higher depth since schema v2 more complex | |
| "eta": 0.08, # Slightly slower learning | |
| "subsample": 0.9, | |
| "colsample_bytree": 0.9, | |
| "min_child_weight": 1, | |
| "seed": seed, | |
| } | |
| evals = [(dtrain, "train"), (dvalid, "valid")] | |
| model = xgb.train( | |
| params, | |
| dtrain, | |
| evals=evals, | |
| num_boost_round=500, # More rounds since more features | |
| early_stopping_rounds=40, # Allow more patience for complex space | |
| verbose_eval=50, | |
| ) | |
| # Accuracy evaluation | |
| train_acc = float( | |
| (np.argmax(model.predict(dtrain), axis=1) == y_train).mean() | |
| ) | |
| valid_acc = float( | |
| (np.argmax(model.predict(dvalid), axis=1) == y_valid).mean() | |
| ) | |
| return model, { | |
| "train_accuracy": train_acc, | |
| "valid_accuracy": valid_acc, | |
| "best_iteration": int(model.best_iteration), | |
| } | |
| def _ensure_model_dir(): | |
| if not os.path.exists(MODEL_DIR): | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| # --------------------------------------------------------------------------- | |
| # Public entry for UI | |
| # --------------------------------------------------------------------------- | |
| def train_genus_model() -> Dict[str, Any]: | |
| try: | |
| print(f"Loading gold tests → {GOLD_TESTS_PATH}") | |
| samples = _load_gold_tests(GOLD_TESTS_PATH) | |
| print("Building ML dataset...") | |
| X, y, genus_to_idx = _build_dataset(samples) | |
| num_classes = len(genus_to_idx) | |
| print(f"Feature dimension: {X.shape[1]}") | |
| print(f"Classes (genera): {num_classes}") | |
| print(f"Samples: {X.shape[0]}") | |
| print("Training XGBoost (schema v2)...") | |
| model, metrics = _train_xgboost(X, y, num_classes) | |
| print("Training complete.") | |
| print(f"Train accuracy: {metrics['train_accuracy']:.3f}") | |
| print(f"Valid accuracy: {metrics['valid_accuracy']:.3f}") | |
| _ensure_model_dir() | |
| model.save_model(MODEL_PATH) | |
| idx_to_genus = {idx: genus for genus, idx in genus_to_idx.items()} | |
| meta = { | |
| "genus_to_idx": genus_to_idx, | |
| "idx_to_genus": idx_to_genus, | |
| "n_features": int(X.shape[1]), | |
| "num_classes": int(num_classes), | |
| "metrics": metrics, | |
| "feature_schema_path": "data/feature_schema.json", | |
| "feature_names": [f["name"] for f in FEATURES], | |
| } | |
| with open(META_PATH, "w", encoding="utf-8") as f: | |
| json.dump(meta, f, indent=2, ensure_ascii=False) | |
| return { | |
| "ok": True, | |
| "message": "Genus XGBoost model (schema v2) trained successfully.", | |
| "stats": { | |
| "num_raw_samples": len(samples), | |
| "num_usable_samples": int(X.shape[0]), | |
| "feature_dim": int(X.shape[1]), | |
| "num_classes": int(num_classes), | |
| }, | |
| "metrics": metrics, | |
| "paths": {"model_path": MODEL_PATH, "meta_path": META_PATH}, | |
| "genus_examples": sorted(genus_to_idx.keys())[:20], | |
| } | |
| except Exception as e: | |
| return { | |
| "ok": False, | |
| "message": f"Training error: {type(e).__name__}: {e}", | |
| } | |
| # --------------------------------------------------------------------------- | |
| # CLI entry | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| print(json.dumps(train_genus_model(), indent=2, ensure_ascii=False)) | |
| if __name__ == "__main__": | |
| main() |