Spaces:
Running
Running
| """Comprehensive Model Training Pipeline for DeepAMR. | |
| This module provides a complete training pipeline for AMR prediction models, | |
| supporting both traditional ML (sklearn) and deep learning (PyTorch) approaches. | |
| Works with preprocessed data from any source (NCBI, PATRIC, CARD, ResFinder). | |
| """ | |
| import json | |
| import logging | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import numpy as np | |
| import joblib | |
| # Sklearn imports | |
| from sklearn.ensemble import ( | |
| RandomForestClassifier, | |
| GradientBoostingClassifier, | |
| AdaBoostClassifier, | |
| ExtraTreesClassifier, | |
| ) | |
| from sklearn.linear_model import LogisticRegression, SGDClassifier | |
| from sklearn.svm import SVC, LinearSVC | |
| from sklearn.neighbors import KNeighborsClassifier | |
| from sklearn.neural_network import MLPClassifier | |
| from sklearn.multiclass import OneVsRestClassifier | |
| from sklearn.preprocessing import StandardScaler | |
| from sklearn.metrics import ( | |
| accuracy_score, | |
| precision_score, | |
| recall_score, | |
| f1_score, | |
| roc_auc_score, | |
| classification_report, | |
| confusion_matrix, | |
| hamming_loss, | |
| ) | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================ | |
| # Data Loading Utilities | |
| # ============================================================================ | |
| def load_processed_data(data_dir: str, prefix: str) -> Dict: | |
| """Load preprocessed data from directory. | |
| Args: | |
| data_dir: Directory containing processed files | |
| prefix: File prefix (e.g., 'ncbi_organism', 'ncbi_amr') | |
| Returns: | |
| Dictionary with X_train, X_val, X_test, y_train, y_val, y_test, metadata | |
| """ | |
| data_path = Path(data_dir) | |
| data = { | |
| "X_train": np.load(data_path / f"{prefix}_X_train.npy"), | |
| "X_val": np.load(data_path / f"{prefix}_X_val.npy"), | |
| "X_test": np.load(data_path / f"{prefix}_X_test.npy"), | |
| "y_train": np.load(data_path / f"{prefix}_y_train.npy"), | |
| "y_val": np.load(data_path / f"{prefix}_y_val.npy"), | |
| "y_test": np.load(data_path / f"{prefix}_y_test.npy"), | |
| } | |
| metadata_file = data_path / f"{prefix}_metadata.json" | |
| if metadata_file.exists(): | |
| with open(metadata_file) as f: | |
| data["metadata"] = json.load(f) | |
| else: | |
| data["metadata"] = {} | |
| return data | |
| # ============================================================================ | |
| # Model Factory | |
| # ============================================================================ | |
| class ModelFactory: | |
| """Factory for creating ML models with default configurations.""" | |
| MODELS = { | |
| # Tree-based models | |
| "random_forest": { | |
| "class": RandomForestClassifier, | |
| "params": { | |
| "n_estimators": 200, | |
| "max_depth": 20, | |
| "min_samples_split": 5, | |
| "min_samples_leaf": 2, | |
| "n_jobs": -1, | |
| "random_state": 42, | |
| "class_weight": "balanced", | |
| }, | |
| }, | |
| "extra_trees": { | |
| "class": ExtraTreesClassifier, | |
| "params": { | |
| "n_estimators": 200, | |
| "max_depth": 20, | |
| "n_jobs": -1, | |
| "random_state": 42, | |
| "class_weight": "balanced", | |
| }, | |
| }, | |
| "gradient_boosting": { | |
| "class": GradientBoostingClassifier, | |
| "params": { | |
| "n_estimators": 100, | |
| "max_depth": 5, | |
| "learning_rate": 0.1, | |
| "random_state": 42, | |
| }, | |
| }, | |
| "adaboost": { | |
| "class": AdaBoostClassifier, | |
| "params": { | |
| "n_estimators": 100, | |
| "learning_rate": 0.1, | |
| "random_state": 42, | |
| }, | |
| }, | |
| # Linear models | |
| "logistic_regression": { | |
| "class": LogisticRegression, | |
| "params": { | |
| "max_iter": 1000, | |
| "random_state": 42, | |
| "class_weight": "balanced", | |
| "n_jobs": -1, | |
| }, | |
| }, | |
| "sgd": { | |
| "class": SGDClassifier, | |
| "params": { | |
| "max_iter": 1000, | |
| "random_state": 42, | |
| "class_weight": "balanced", | |
| "n_jobs": -1, | |
| }, | |
| }, | |
| # SVM models | |
| "svm": { | |
| "class": SVC, | |
| "params": { | |
| "kernel": "rbf", | |
| "probability": True, | |
| "random_state": 42, | |
| "class_weight": "balanced", | |
| }, | |
| }, | |
| "linear_svm": { | |
| "class": LinearSVC, | |
| "params": { | |
| "max_iter": 1000, | |
| "random_state": 42, | |
| "class_weight": "balanced", | |
| }, | |
| }, | |
| # Other models | |
| "knn": { | |
| "class": KNeighborsClassifier, | |
| "params": { | |
| "n_neighbors": 5, | |
| "n_jobs": -1, | |
| }, | |
| }, | |
| "mlp": { | |
| "class": MLPClassifier, | |
| "params": { | |
| "hidden_layer_sizes": (256, 128, 64), | |
| "max_iter": 500, | |
| "random_state": 42, | |
| "early_stopping": True, | |
| }, | |
| }, | |
| } | |
| def create( | |
| cls, | |
| model_name: str, | |
| task_type: str = "multiclass", | |
| custom_params: Optional[Dict] = None, | |
| ): | |
| """Create a model instance. | |
| Args: | |
| model_name: Name of the model | |
| task_type: 'binary', 'multiclass', or 'multilabel' | |
| custom_params: Custom parameters to override defaults | |
| Returns: | |
| Model instance | |
| """ | |
| if model_name not in cls.MODELS: | |
| raise ValueError( | |
| f"Unknown model: {model_name}. " | |
| f"Available: {list(cls.MODELS.keys())}" | |
| ) | |
| config = cls.MODELS[model_name] | |
| params = config["params"].copy() | |
| if custom_params: | |
| params.update(custom_params) | |
| model = config["class"](**params) | |
| # Wrap for multi-label | |
| if task_type == "multilabel": | |
| model = OneVsRestClassifier(model, n_jobs=-1) | |
| return model | |
| def list_models(cls) -> List[str]: | |
| """List available models.""" | |
| return list(cls.MODELS.keys()) | |
| # ============================================================================ | |
| # AMR Model Trainer | |
| # ============================================================================ | |
| class AMRModelTrainer: | |
| """Training pipeline for AMR prediction models.""" | |
| def __init__( | |
| self, | |
| model_name: str = "random_forest", | |
| task_type: str = "multiclass", | |
| scale_features: bool = True, | |
| model_params: Optional[Dict] = None, | |
| ): | |
| """Initialize trainer. | |
| Args: | |
| model_name: Name of the model to use | |
| task_type: 'binary', 'multiclass', or 'multilabel' | |
| scale_features: Whether to standardize features | |
| model_params: Custom model parameters | |
| """ | |
| self.model_name = model_name | |
| self.task_type = task_type | |
| self.scale_features = scale_features | |
| self.model = ModelFactory.create(model_name, task_type, model_params) | |
| self.scaler = StandardScaler() if scale_features else None | |
| self.feature_names: Optional[List[str]] = None | |
| self.class_names: Optional[List[str]] = None | |
| self.is_fitted = False | |
| self.training_history: Dict = {} | |
| def fit( | |
| self, | |
| X_train: np.ndarray, | |
| y_train: np.ndarray, | |
| X_val: Optional[np.ndarray] = None, | |
| y_val: Optional[np.ndarray] = None, | |
| feature_names: Optional[List[str]] = None, | |
| class_names: Optional[List[str]] = None, | |
| ) -> "AMRModelTrainer": | |
| """Train the model. | |
| Args: | |
| X_train: Training features | |
| y_train: Training labels | |
| X_val: Validation features (optional) | |
| y_val: Validation labels (optional) | |
| feature_names: Feature names | |
| class_names: Class names | |
| Returns: | |
| self | |
| """ | |
| logger.info(f"Training {self.model_name} ({self.task_type})...") | |
| logger.info(f" Training samples: {X_train.shape[0]}") | |
| logger.info(f" Features: {X_train.shape[1]}") | |
| self.feature_names = feature_names | |
| self.class_names = class_names | |
| # Scale features | |
| if self.scaler: | |
| X_train = self.scaler.fit_transform(X_train) | |
| if X_val is not None: | |
| X_val = self.scaler.transform(X_val) | |
| # Train | |
| self.model.fit(X_train, y_train) | |
| self.is_fitted = True | |
| # Evaluate on validation if provided | |
| if X_val is not None and y_val is not None: | |
| val_metrics = self.evaluate(X_val, y_val) | |
| self.training_history["validation"] = val_metrics | |
| self._log_metrics("Validation", val_metrics) | |
| logger.info("Training complete!") | |
| return self | |
| def predict(self, X: np.ndarray) -> np.ndarray: | |
| """Predict labels.""" | |
| if not self.is_fitted: | |
| raise RuntimeError("Model not fitted.") | |
| if self.scaler: | |
| X = self.scaler.transform(X) | |
| return self.model.predict(X) | |
| def predict_proba(self, X: np.ndarray) -> np.ndarray: | |
| """Predict probabilities.""" | |
| if not self.is_fitted: | |
| raise RuntimeError("Model not fitted.") | |
| if self.scaler: | |
| X = self.scaler.transform(X) | |
| if hasattr(self.model, "predict_proba"): | |
| return self.model.predict_proba(X) | |
| elif hasattr(self.model, "decision_function"): | |
| return self.model.decision_function(X) | |
| else: | |
| raise NotImplementedError("Model does not support probability prediction") | |
| def evaluate(self, X: np.ndarray, y_true: np.ndarray) -> Dict: | |
| """Evaluate model performance.""" | |
| y_pred = self.predict(X) | |
| metrics = {} | |
| if self.task_type == "multilabel": | |
| metrics["hamming_loss"] = float(hamming_loss(y_true, y_pred)) | |
| metrics["micro_f1"] = float(f1_score(y_true, y_pred, average="micro", zero_division=0)) | |
| metrics["macro_f1"] = float(f1_score(y_true, y_pred, average="macro", zero_division=0)) | |
| metrics["weighted_f1"] = float(f1_score(y_true, y_pred, average="weighted", zero_division=0)) | |
| metrics["micro_precision"] = float(precision_score(y_true, y_pred, average="micro", zero_division=0)) | |
| metrics["micro_recall"] = float(recall_score(y_true, y_pred, average="micro", zero_division=0)) | |
| # Per-class metrics | |
| if self.class_names: | |
| metrics["per_class"] = {} | |
| for i, name in enumerate(self.class_names): | |
| metrics["per_class"][name] = { | |
| "precision": float(precision_score(y_true[:, i], y_pred[:, i], zero_division=0)), | |
| "recall": float(recall_score(y_true[:, i], y_pred[:, i], zero_division=0)), | |
| "f1": float(f1_score(y_true[:, i], y_pred[:, i], zero_division=0)), | |
| "support": int(y_true[:, i].sum()), | |
| } | |
| # AUC | |
| try: | |
| y_proba = self.predict_proba(X) | |
| metrics["micro_auc"] = float(roc_auc_score(y_true, y_proba, average="micro")) | |
| metrics["macro_auc"] = float(roc_auc_score(y_true, y_proba, average="macro")) | |
| except Exception: | |
| pass | |
| else: | |
| metrics["accuracy"] = float(accuracy_score(y_true, y_pred)) | |
| metrics["precision"] = float(precision_score(y_true, y_pred, average="weighted", zero_division=0)) | |
| metrics["recall"] = float(recall_score(y_true, y_pred, average="weighted", zero_division=0)) | |
| metrics["f1"] = float(f1_score(y_true, y_pred, average="weighted", zero_division=0)) | |
| metrics["confusion_matrix"] = confusion_matrix(y_true, y_pred).tolist() | |
| # Per-class report | |
| if self.class_names: | |
| report = classification_report( | |
| y_true, y_pred, | |
| target_names=self.class_names, | |
| output_dict=True, | |
| zero_division=0, | |
| ) | |
| metrics["classification_report"] = report | |
| # AUC | |
| try: | |
| y_proba = self.predict_proba(X) | |
| if self.task_type == "binary": | |
| metrics["auc"] = float(roc_auc_score(y_true, y_proba[:, 1])) | |
| else: | |
| metrics["auc"] = float(roc_auc_score(y_true, y_proba, multi_class="ovr", average="weighted")) | |
| except Exception: | |
| pass | |
| return metrics | |
| def _log_metrics(self, prefix: str, metrics: Dict) -> None: | |
| """Log metrics.""" | |
| if self.task_type == "multilabel": | |
| logger.info(f" {prefix} - Hamming Loss: {metrics.get('hamming_loss', 0):.4f}") | |
| logger.info(f" {prefix} - Micro F1: {metrics.get('micro_f1', 0):.4f}") | |
| logger.info(f" {prefix} - Macro F1: {metrics.get('macro_f1', 0):.4f}") | |
| else: | |
| logger.info(f" {prefix} - Accuracy: {metrics.get('accuracy', 0):.4f}") | |
| logger.info(f" {prefix} - F1: {metrics.get('f1', 0):.4f}") | |
| def get_feature_importance(self, top_n: int = 20) -> List[Tuple[str, float]]: | |
| """Get feature importances.""" | |
| if not self.is_fitted: | |
| raise RuntimeError("Model not fitted.") | |
| if hasattr(self.model, "feature_importances_"): | |
| importances = self.model.feature_importances_ | |
| elif hasattr(self.model, "estimators_"): | |
| importances = np.mean([ | |
| est.feature_importances_ for est in self.model.estimators_ | |
| if hasattr(est, "feature_importances_") | |
| ], axis=0) | |
| elif hasattr(self.model, "coef_"): | |
| importances = np.abs(self.model.coef_).mean(axis=0) | |
| else: | |
| return [] | |
| if self.feature_names: | |
| importance_list = list(zip(self.feature_names, importances)) | |
| else: | |
| importance_list = [(f"feature_{i}", imp) for i, imp in enumerate(importances)] | |
| importance_list.sort(key=lambda x: x[1], reverse=True) | |
| return importance_list[:top_n] | |
| def save(self, filepath: str) -> None: | |
| """Save model.""" | |
| Path(filepath).parent.mkdir(parents=True, exist_ok=True) | |
| joblib.dump({ | |
| "model": self.model, | |
| "scaler": self.scaler, | |
| "model_name": self.model_name, | |
| "task_type": self.task_type, | |
| "feature_names": self.feature_names, | |
| "class_names": self.class_names, | |
| "is_fitted": self.is_fitted, | |
| "training_history": self.training_history, | |
| }, filepath) | |
| logger.info(f"Model saved to {filepath}") | |
| def load(cls, filepath: str) -> "AMRModelTrainer": | |
| """Load model.""" | |
| data = joblib.load(filepath) | |
| trainer = cls( | |
| model_name=data["model_name"], | |
| task_type=data["task_type"], | |
| scale_features=data["scaler"] is not None, | |
| ) | |
| trainer.model = data["model"] | |
| trainer.scaler = data["scaler"] | |
| trainer.feature_names = data["feature_names"] | |
| trainer.class_names = data["class_names"] | |
| trainer.is_fitted = data["is_fitted"] | |
| trainer.training_history = data["training_history"] | |
| return trainer | |
| # ============================================================================ | |
| # Training Pipeline | |
| # ============================================================================ | |
| def train_single_model( | |
| data_dir: str, | |
| prefix: str, | |
| model_name: str = "random_forest", | |
| output_dir: str = "models", | |
| ) -> Dict: | |
| """Train a single model on preprocessed data. | |
| Args: | |
| data_dir: Directory with preprocessed data | |
| prefix: Data file prefix | |
| model_name: Model to train | |
| output_dir: Output directory for model and results | |
| Returns: | |
| Dictionary with metrics | |
| """ | |
| logger.info("=" * 60) | |
| logger.info(f"Training {model_name} on {prefix}") | |
| logger.info("=" * 60) | |
| # Load data | |
| data = load_processed_data(data_dir, prefix) | |
| metadata = data.get("metadata", {}) | |
| task_type = metadata.get("task_type", "multiclass") | |
| feature_names = metadata.get("feature_names", []) | |
| class_names = metadata.get("class_names", []) | |
| logger.info(f"Task: {task_type}") | |
| logger.info(f"Train: {data['X_train'].shape[0]}, Val: {data['X_val'].shape[0]}, Test: {data['X_test'].shape[0]}") | |
| logger.info(f"Features: {data['X_train'].shape[1]}, Classes: {len(class_names)}") | |
| # Train | |
| trainer = AMRModelTrainer(model_name=model_name, task_type=task_type) | |
| trainer.fit( | |
| data["X_train"], data["y_train"], | |
| data["X_val"], data["y_val"], | |
| feature_names, class_names, | |
| ) | |
| # Evaluate on test set | |
| logger.info("\nTest Set Results:") | |
| test_metrics = trainer.evaluate(data["X_test"], data["y_test"]) | |
| trainer._log_metrics("Test", test_metrics) | |
| # Feature importance | |
| logger.info("\nTop 10 Important Features:") | |
| for feat, imp in trainer.get_feature_importance(10): | |
| logger.info(f" {feat}: {imp:.4f}") | |
| # Save | |
| output_path = Path(output_dir) | |
| output_path.mkdir(parents=True, exist_ok=True) | |
| model_file = output_path / f"{prefix}_{model_name}.joblib" | |
| trainer.save(str(model_file)) | |
| results = { | |
| "model_name": model_name, | |
| "task_type": task_type, | |
| "class_names": class_names, | |
| "test_metrics": test_metrics, | |
| "feature_importance": trainer.get_feature_importance(20), | |
| } | |
| results_file = output_path / f"{prefix}_{model_name}_results.json" | |
| with open(results_file, "w") as f: | |
| json.dump(results, f, indent=2, default=str) | |
| logger.info(f"\nResults saved to {results_file}") | |
| return results | |
| def train_multiple_models( | |
| data_dir: str, | |
| prefix: str, | |
| models: Optional[List[str]] = None, | |
| output_dir: str = "models", | |
| ) -> Dict[str, Dict]: | |
| """Train multiple models and compare results. | |
| Args: | |
| data_dir: Directory with preprocessed data | |
| prefix: Data file prefix | |
| models: List of models to train (default: all available) | |
| output_dir: Output directory | |
| Returns: | |
| Dictionary mapping model names to results | |
| """ | |
| if models is None: | |
| models = ["random_forest", "extra_trees", "gradient_boosting", "logistic_regression", "mlp"] | |
| all_results = {} | |
| for model_name in models: | |
| try: | |
| results = train_single_model(data_dir, prefix, model_name, output_dir) | |
| all_results[model_name] = results | |
| except Exception as e: | |
| logger.error(f"Error training {model_name}: {e}") | |
| all_results[model_name] = {"error": str(e)} | |
| # Summary comparison | |
| logger.info("\n" + "=" * 60) | |
| logger.info("Model Comparison Summary") | |
| logger.info("=" * 60) | |
| # Load task type to determine which metrics to show | |
| data = load_processed_data(data_dir, prefix) | |
| task_type = data.get("metadata", {}).get("task_type", "multiclass") | |
| if task_type == "multilabel": | |
| metric_key = "micro_f1" | |
| metric_name = "Micro F1" | |
| else: | |
| metric_key = "f1" | |
| metric_name = "F1 Score" | |
| comparison = [] | |
| for model_name, results in all_results.items(): | |
| if "error" in results: | |
| comparison.append((model_name, 0.0)) | |
| else: | |
| score = results.get("test_metrics", {}).get(metric_key, 0.0) | |
| comparison.append((model_name, score)) | |
| comparison.sort(key=lambda x: x[1], reverse=True) | |
| for model_name, score in comparison: | |
| logger.info(f" {model_name}: {metric_name} = {score:.4f}") | |
| # Save comparison | |
| comparison_file = Path(output_dir) / f"{prefix}_model_comparison.json" | |
| with open(comparison_file, "w") as f: | |
| json.dump(all_results, f, indent=2, default=str) | |
| return all_results | |
| # ============================================================================ | |
| # Main Entry Point | |
| # ============================================================================ | |
| def main(): | |
| """Main function.""" | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Train AMR prediction models") | |
| parser.add_argument("--data-dir", default="data/processed/ncbi", help="Data directory") | |
| parser.add_argument("--prefix", default="ncbi_organism", help="Data prefix") | |
| parser.add_argument("--model", default="random_forest", help="Model name (or 'all' for comparison)") | |
| parser.add_argument("--output-dir", default="models", help="Output directory") | |
| parser.add_argument("--list-models", action="store_true", help="List available models") | |
| args = parser.parse_args() | |
| if args.list_models: | |
| print("Available models:") | |
| for model in ModelFactory.list_models(): | |
| print(f" - {model}") | |
| return | |
| if args.model == "all": | |
| train_multiple_models(args.data_dir, args.prefix, output_dir=args.output_dir) | |
| else: | |
| train_single_model(args.data_dir, args.prefix, args.model, args.output_dir) | |
| if __name__ == "__main__": | |
| main() | |