|
|
|
|
|
""" |
|
|
Train a genus-level classifier (XGBoost) from gold tests. |
|
|
|
|
|
Pipeline: |
|
|
• Load gold_tests.json |
|
|
• Extract genus (first token of organism name) |
|
|
• Convert expected_fields → feature vector (via engine.features.extract_feature_vector) |
|
|
• Train an XGBoost multi-class classifier |
|
|
• Save: |
|
|
models/genus_xgb.json |
|
|
models/genus_xgb_meta.json |
|
|
|
|
|
Compatible with FEATURE SCHEMA v2 (category, binary temperature flags, pigment, odor, colony pattern, TSI, etc.) |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import json |
|
|
import os |
|
|
import random |
|
|
from typing import Any, Dict, List, Tuple |
|
|
|
|
|
import numpy as np |
|
|
import xgboost as xgb |
|
|
|
|
|
from .features import extract_feature_vector, FEATURES |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GOLD_TESTS_PATH = "training/gold_tests.json" |
|
|
MODEL_DIR = "models" |
|
|
MODEL_PATH = os.path.join(MODEL_DIR, "genus_xgb.json") |
|
|
META_PATH = os.path.join(MODEL_DIR, "genus_xgb_meta.json") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_gold_tests(path: str) -> List[Dict[str, Any]]: |
|
|
if not os.path.exists(path): |
|
|
raise FileNotFoundError(f"Missing gold test file: {path}") |
|
|
|
|
|
with open(path, "r", encoding="utf-8") as f: |
|
|
data = json.load(f) |
|
|
|
|
|
if not isinstance(data, list): |
|
|
raise ValueError("gold_tests.json must contain a list.") |
|
|
|
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _extract_genus(sample: Dict[str, Any]) -> str | None: |
|
|
""" |
|
|
Extract genus from: |
|
|
name / Name / organism / Organism |
|
|
(genus = first token before space) |
|
|
""" |
|
|
for key in ("name", "Name", "organism", "Organism"): |
|
|
if key in sample and sample[key]: |
|
|
val = str(sample[key]).strip() |
|
|
if val: |
|
|
return val.split()[0] |
|
|
return None |
|
|
|
|
|
|
|
|
def _extract_fields(sample: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Extract expected field dict from any of: |
|
|
fields / expected_fields / schema / expected |
|
|
""" |
|
|
for key in ("fields", "expected_fields", "schema", "expected"): |
|
|
if key in sample and isinstance(sample[key], dict): |
|
|
return sample[key] |
|
|
return {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _build_dataset(samples: List[Dict[str, Any]]) -> Tuple[np.ndarray, np.ndarray, Dict[str, int]]: |
|
|
""" |
|
|
Convert gold tests into: |
|
|
X → feature matrix |
|
|
y → integer labels |
|
|
genus_to_idx → mapping |
|
|
""" |
|
|
X_list: List[np.ndarray] = [] |
|
|
y_list: List[int] = [] |
|
|
genus_to_idx: Dict[str, int] = {} |
|
|
|
|
|
for sample in samples: |
|
|
genus = _extract_genus(sample) |
|
|
if not genus: |
|
|
continue |
|
|
|
|
|
fields = _extract_fields(sample) |
|
|
if not fields: |
|
|
continue |
|
|
|
|
|
|
|
|
vec = extract_feature_vector(fields) |
|
|
|
|
|
if genus not in genus_to_idx: |
|
|
genus_to_idx[genus] = len(genus_to_idx) |
|
|
|
|
|
X_list.append(vec) |
|
|
y_list.append(genus_to_idx[genus]) |
|
|
|
|
|
if not X_list: |
|
|
raise ValueError("No usable gold tests found.") |
|
|
|
|
|
X = np.vstack(X_list) |
|
|
y = np.array(y_list, dtype=np.int32) |
|
|
|
|
|
return X, y, genus_to_idx |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _train_xgboost( |
|
|
X: np.ndarray, |
|
|
y: np.ndarray, |
|
|
num_classes: int, |
|
|
seed: int = 42 |
|
|
) -> Tuple[xgb.Booster, Dict[str, float]]: |
|
|
""" |
|
|
Train a multi-class XGBoost classifier. |
|
|
80/20 split. |
|
|
""" |
|
|
|
|
|
n = X.shape[0] |
|
|
indices = list(range(n)) |
|
|
random.Random(seed).shuffle(indices) |
|
|
|
|
|
split = int(0.8 * n) |
|
|
train_idx = indices[:split] |
|
|
valid_idx = indices[split:] |
|
|
|
|
|
X_train, y_train = X[train_idx], y[train_idx] |
|
|
X_valid, y_valid = X[valid_idx], y[valid_idx] |
|
|
|
|
|
dtrain = xgb.DMatrix(X_train, label=y_train) |
|
|
dvalid = xgb.DMatrix(X_valid, label=y_valid) |
|
|
|
|
|
params = { |
|
|
"objective": "multi:softprob", |
|
|
"num_class": num_classes, |
|
|
"eval_metric": "mlogloss", |
|
|
"max_depth": 6, |
|
|
"eta": 0.08, |
|
|
"subsample": 0.9, |
|
|
"colsample_bytree": 0.9, |
|
|
"min_child_weight": 1, |
|
|
"seed": seed, |
|
|
} |
|
|
|
|
|
evals = [(dtrain, "train"), (dvalid, "valid")] |
|
|
|
|
|
model = xgb.train( |
|
|
params, |
|
|
dtrain, |
|
|
evals=evals, |
|
|
num_boost_round=500, |
|
|
early_stopping_rounds=40, |
|
|
verbose_eval=50, |
|
|
) |
|
|
|
|
|
|
|
|
train_acc = float( |
|
|
(np.argmax(model.predict(dtrain), axis=1) == y_train).mean() |
|
|
) |
|
|
valid_acc = float( |
|
|
(np.argmax(model.predict(dvalid), axis=1) == y_valid).mean() |
|
|
) |
|
|
|
|
|
return model, { |
|
|
"train_accuracy": train_acc, |
|
|
"valid_accuracy": valid_acc, |
|
|
"best_iteration": int(model.best_iteration), |
|
|
} |
|
|
|
|
|
|
|
|
def _ensure_model_dir(): |
|
|
if not os.path.exists(MODEL_DIR): |
|
|
os.makedirs(MODEL_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_genus_model() -> Dict[str, Any]: |
|
|
try: |
|
|
print(f"Loading gold tests → {GOLD_TESTS_PATH}") |
|
|
samples = _load_gold_tests(GOLD_TESTS_PATH) |
|
|
|
|
|
print("Building ML dataset...") |
|
|
X, y, genus_to_idx = _build_dataset(samples) |
|
|
|
|
|
num_classes = len(genus_to_idx) |
|
|
print(f"Feature dimension: {X.shape[1]}") |
|
|
print(f"Classes (genera): {num_classes}") |
|
|
print(f"Samples: {X.shape[0]}") |
|
|
|
|
|
print("Training XGBoost (schema v2)...") |
|
|
model, metrics = _train_xgboost(X, y, num_classes) |
|
|
|
|
|
print("Training complete.") |
|
|
print(f"Train accuracy: {metrics['train_accuracy']:.3f}") |
|
|
print(f"Valid accuracy: {metrics['valid_accuracy']:.3f}") |
|
|
|
|
|
_ensure_model_dir() |
|
|
model.save_model(MODEL_PATH) |
|
|
|
|
|
idx_to_genus = {idx: genus for genus, idx in genus_to_idx.items()} |
|
|
|
|
|
meta = { |
|
|
"genus_to_idx": genus_to_idx, |
|
|
"idx_to_genus": idx_to_genus, |
|
|
"n_features": int(X.shape[1]), |
|
|
"num_classes": int(num_classes), |
|
|
"metrics": metrics, |
|
|
"feature_schema_path": "data/feature_schema.json", |
|
|
"feature_names": [f["name"] for f in FEATURES], |
|
|
} |
|
|
|
|
|
with open(META_PATH, "w", encoding="utf-8") as f: |
|
|
json.dump(meta, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
return { |
|
|
"ok": True, |
|
|
"message": "Genus XGBoost model (schema v2) trained successfully.", |
|
|
"stats": { |
|
|
"num_raw_samples": len(samples), |
|
|
"num_usable_samples": int(X.shape[0]), |
|
|
"feature_dim": int(X.shape[1]), |
|
|
"num_classes": int(num_classes), |
|
|
}, |
|
|
"metrics": metrics, |
|
|
"paths": {"model_path": MODEL_PATH, "meta_path": META_PATH}, |
|
|
"genus_examples": sorted(genus_to_idx.keys())[:20], |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
return { |
|
|
"ok": False, |
|
|
"message": f"Training error: {type(e).__name__}: {e}", |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
print(json.dumps(train_genus_model(), indent=2, ensure_ascii=False)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |