Spaces:
Running
Running
| """Unified Training Pipeline for DeepAMR. | |
| This module provides a comprehensive training system that: | |
| 1. Combines multiple data sources (NCBI, CARD, PATRIC) | |
| 2. Handles class imbalance with multiple strategies | |
| 3. Supports both sklearn and PyTorch models | |
| 4. Implements proper cross-validation | |
| 5. Provides detailed evaluation metrics | |
| """ | |
| import json | |
| import logging | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple, Union | |
| from datetime import datetime | |
| import numpy as np | |
| import joblib | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler | |
| from sklearn.ensemble import ( | |
| RandomForestClassifier, | |
| GradientBoostingClassifier, | |
| ExtraTreesClassifier, | |
| ) | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.neural_network import MLPClassifier | |
| from sklearn.multiclass import OneVsRestClassifier | |
| from sklearn.preprocessing import StandardScaler | |
| from sklearn.model_selection import StratifiedKFold, cross_val_predict | |
| from sklearn.metrics import ( | |
| accuracy_score, precision_score, recall_score, f1_score, | |
| roc_auc_score, classification_report, hamming_loss, | |
| precision_recall_curve, average_precision_score, | |
| ) | |
| from sklearn.utils.class_weight import compute_class_weight, compute_sample_weight | |
| from imblearn.over_sampling import SMOTE, ADASYN, RandomOverSampler | |
| from imblearn.under_sampling import RandomUnderSampler | |
| from imblearn.combine import SMOTETomek | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================= | |
| # Data Loading Utilities | |
| # ============================================================================= | |
| def load_dataset(data_dir: str, prefix: str) -> Dict: | |
| """Load a preprocessed dataset.""" | |
| data_path = Path(data_dir) | |
| data = { | |
| "X_train": np.load(data_path / f"{prefix}_X_train.npy"), | |
| "X_val": np.load(data_path / f"{prefix}_X_val.npy"), | |
| "X_test": np.load(data_path / f"{prefix}_X_test.npy"), | |
| "y_train": np.load(data_path / f"{prefix}_y_train.npy"), | |
| "y_val": np.load(data_path / f"{prefix}_y_val.npy"), | |
| "y_test": np.load(data_path / f"{prefix}_y_test.npy"), | |
| } | |
| metadata_file = data_path / f"{prefix}_metadata.json" | |
| if metadata_file.exists(): | |
| with open(metadata_file) as f: | |
| data["metadata"] = json.load(f) | |
| else: | |
| data["metadata"] = {} | |
| return data | |
| def combine_datasets(datasets: List[Dict], task_type: str = "multilabel") -> Dict: | |
| """Combine multiple datasets for training. | |
| For multilabel, aligns class labels across datasets. | |
| """ | |
| if len(datasets) == 1: | |
| return datasets[0] | |
| # Get all unique class names | |
| all_classes = set() | |
| for ds in datasets: | |
| if "class_names" in ds.get("metadata", {}): | |
| all_classes.update(ds["metadata"]["class_names"]) | |
| all_classes = sorted(all_classes) | |
| combined = { | |
| "X_train": [], | |
| "X_val": [], | |
| "X_test": [], | |
| "y_train": [], | |
| "y_val": [], | |
| "y_test": [], | |
| "metadata": { | |
| "class_names": all_classes, | |
| "task_type": task_type, | |
| "n_classes": len(all_classes), | |
| "source_datasets": [], | |
| } | |
| } | |
| for ds in datasets: | |
| ds_classes = ds.get("metadata", {}).get("class_names", []) | |
| combined["metadata"]["source_datasets"].append({ | |
| "n_samples": len(ds["X_train"]) + len(ds["X_val"]) + len(ds["X_test"]), | |
| "classes": ds_classes, | |
| }) | |
| if task_type == "multilabel" and ds_classes: | |
| # Remap labels to unified class space | |
| class_map = {cls: all_classes.index(cls) for cls in ds_classes if cls in all_classes} | |
| for split in ["train", "val", "test"]: | |
| combined[f"X_{split}"].append(ds[f"X_{split}"]) | |
| # Remap y to new class indices | |
| y_old = ds[f"y_{split}"] | |
| y_new = np.zeros((len(y_old), len(all_classes)), dtype=y_old.dtype) | |
| for old_idx, cls in enumerate(ds_classes): | |
| if cls in class_map: | |
| new_idx = class_map[cls] | |
| y_new[:, new_idx] = y_old[:, old_idx] | |
| combined[f"y_{split}"].append(y_new) | |
| else: | |
| for split in ["train", "val", "test"]: | |
| combined[f"X_{split}"].append(ds[f"X_{split}"]) | |
| combined[f"y_{split}"].append(ds[f"y_{split}"]) | |
| # Concatenate | |
| for split in ["train", "val", "test"]: | |
| combined[f"X_{split}"] = np.vstack(combined[f"X_{split}"]) | |
| combined[f"y_{split}"] = np.vstack(combined[f"y_{split}"]) if task_type == "multilabel" else np.concatenate(combined[f"y_{split}"]) | |
| combined["metadata"]["n_samples"] = len(combined["X_train"]) + len(combined["X_val"]) + len(combined["X_test"]) | |
| combined["metadata"]["n_features"] = combined["X_train"].shape[1] | |
| return combined | |
| # ============================================================================= | |
| # Class Imbalance Handling | |
| # ============================================================================= | |
| class ImbalanceHandler: | |
| """Handle class imbalance with multiple strategies.""" | |
| STRATEGIES = ["none", "class_weight", "smote", "adasyn", "random_over", "random_under", "smote_tomek"] | |
| def __init__(self, strategy: str = "class_weight", random_state: int = 42): | |
| if strategy not in self.STRATEGIES: | |
| raise ValueError(f"Unknown strategy: {strategy}. Choose from {self.STRATEGIES}") | |
| self.strategy = strategy | |
| self.random_state = random_state | |
| def get_sample_weights(self, y: np.ndarray, task_type: str = "multiclass") -> np.ndarray: | |
| """Compute sample weights for imbalanced data.""" | |
| if task_type == "multilabel": | |
| # For multilabel, weight by inverse frequency of each label combination | |
| label_counts = {} | |
| for i, row in enumerate(y): | |
| key = tuple(row) | |
| label_counts[key] = label_counts.get(key, 0) + 1 | |
| weights = np.array([1.0 / label_counts[tuple(row)] for row in y]) | |
| weights = weights / weights.sum() * len(weights) | |
| return weights | |
| else: | |
| return compute_sample_weight("balanced", y) | |
| def get_class_weights(self, y: np.ndarray, task_type: str = "multiclass") -> Union[Dict, np.ndarray]: | |
| """Compute class weights.""" | |
| if task_type == "multilabel": | |
| # For multilabel, compute weight per class based on positive/negative ratio | |
| n_samples = len(y) | |
| weights = [] | |
| for i in range(y.shape[1]): | |
| n_pos = y[:, i].sum() | |
| n_neg = n_samples - n_pos | |
| if n_pos > 0: | |
| weight = n_neg / n_pos | |
| else: | |
| weight = 1.0 | |
| weights.append(min(weight, 10.0)) # Cap at 10x | |
| return np.array(weights) | |
| else: | |
| classes = np.unique(y) | |
| weights = compute_class_weight("balanced", classes=classes, y=y) | |
| return dict(zip(classes, weights)) | |
| def resample(self, X: np.ndarray, y: np.ndarray, task_type: str = "multiclass") -> Tuple[np.ndarray, np.ndarray]: | |
| """Resample data to handle imbalance.""" | |
| if self.strategy == "none" or self.strategy == "class_weight": | |
| return X, y | |
| if task_type == "multilabel": | |
| # For multilabel, use random oversampling of minority label combinations | |
| logger.warning("Resampling for multilabel is experimental. Using random oversampling.") | |
| # Convert multilabel to label combinations for resampling | |
| label_strings = [''.join(map(str, row)) for row in y] | |
| from sklearn.preprocessing import LabelEncoder | |
| le = LabelEncoder() | |
| y_encoded = le.fit_transform(label_strings) | |
| sampler = RandomOverSampler(random_state=self.random_state) | |
| X_res, y_res_encoded = sampler.fit_resample(X, y_encoded) | |
| # Map back to multilabel | |
| y_res_strings = le.inverse_transform(y_res_encoded) | |
| y_res = np.array([[int(c) for c in s] for s in y_res_strings]) | |
| return X_res, y_res | |
| # Multiclass resampling | |
| try: | |
| if self.strategy == "smote": | |
| sampler = SMOTE(random_state=self.random_state) | |
| elif self.strategy == "adasyn": | |
| sampler = ADASYN(random_state=self.random_state) | |
| elif self.strategy == "random_over": | |
| sampler = RandomOverSampler(random_state=self.random_state) | |
| elif self.strategy == "random_under": | |
| sampler = RandomUnderSampler(random_state=self.random_state) | |
| elif self.strategy == "smote_tomek": | |
| sampler = SMOTETomek(random_state=self.random_state) | |
| else: | |
| return X, y | |
| X_res, y_res = sampler.fit_resample(X, y) | |
| logger.info(f"Resampled: {len(X)} -> {len(X_res)} samples") | |
| return X_res, y_res | |
| except Exception as e: | |
| logger.warning(f"Resampling failed: {e}. Using original data.") | |
| return X, y | |
| # ============================================================================= | |
| # PyTorch Models | |
| # ============================================================================= | |
| class AMRNet(nn.Module): | |
| """Deep neural network for AMR prediction.""" | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| output_dim: int, | |
| hidden_dims: List[int] = [512, 256, 128], | |
| dropout: float = 0.3, | |
| task_type: str = "multiclass", | |
| ): | |
| super().__init__() | |
| self.task_type = task_type | |
| layers = [] | |
| prev_dim = input_dim | |
| for hidden_dim in hidden_dims: | |
| layers.extend([ | |
| nn.Linear(prev_dim, hidden_dim), | |
| nn.BatchNorm1d(hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| ]) | |
| prev_dim = hidden_dim | |
| layers.append(nn.Linear(prev_dim, output_dim)) | |
| # Note: No Sigmoid for multilabel - BCEWithLogitsLoss applies it internally | |
| # For multiclass, CrossEntropyLoss applies softmax internally | |
| self.model = nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.model(x) | |
| class PyTorchTrainer: | |
| """Training wrapper for PyTorch models.""" | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| task_type: str = "multiclass", | |
| class_weights: Optional[np.ndarray] = None, | |
| device: str = "auto", | |
| ): | |
| if device == "auto": | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") | |
| else: | |
| self.device = torch.device(device) | |
| self.model = model.to(self.device) | |
| self.task_type = task_type | |
| if task_type == "multilabel": | |
| if class_weights is not None: | |
| # Use pos_weight for class-weighted multilabel loss | |
| pos_weight = torch.FloatTensor(class_weights).to(self.device) | |
| self.criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight) | |
| else: | |
| self.criterion = nn.BCEWithLogitsLoss() | |
| else: | |
| if class_weights is not None: | |
| weight = torch.FloatTensor(list(class_weights.values()) if isinstance(class_weights, dict) else class_weights).to(self.device) | |
| self.criterion = nn.CrossEntropyLoss(weight=weight) | |
| else: | |
| self.criterion = nn.CrossEntropyLoss() | |
| def fit( | |
| self, | |
| X_train: np.ndarray, | |
| y_train: np.ndarray, | |
| X_val: Optional[np.ndarray] = None, | |
| y_val: Optional[np.ndarray] = None, | |
| epochs: int = 100, | |
| batch_size: int = 32, | |
| lr: float = 0.001, | |
| patience: int = 10, | |
| sample_weights: Optional[np.ndarray] = None, | |
| ) -> Dict: | |
| """Train the model.""" | |
| # Prepare data | |
| X_train_t = torch.FloatTensor(X_train) | |
| if self.task_type == "multilabel": | |
| y_train_t = torch.FloatTensor(y_train) | |
| else: | |
| y_train_t = torch.LongTensor(y_train) | |
| train_dataset = TensorDataset(X_train_t, y_train_t) | |
| # Use weighted sampling if sample weights provided | |
| if sample_weights is not None: | |
| sampler = WeightedRandomSampler( | |
| weights=sample_weights, | |
| num_samples=len(sample_weights), | |
| replacement=True, | |
| ) | |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler) | |
| else: | |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | |
| # Validation data | |
| if X_val is not None: | |
| X_val_t = torch.FloatTensor(X_val).to(self.device) | |
| if self.task_type == "multilabel": | |
| y_val_t = torch.FloatTensor(y_val).to(self.device) | |
| else: | |
| y_val_t = torch.LongTensor(y_val).to(self.device) | |
| optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay=1e-5) | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5) | |
| best_val_loss = float('inf') | |
| best_state = None | |
| patience_counter = 0 | |
| history = {"train_loss": [], "val_loss": []} | |
| for epoch in range(epochs): | |
| # Training | |
| self.model.train() | |
| train_loss = 0.0 | |
| for batch_X, batch_y in train_loader: | |
| batch_X, batch_y = batch_X.to(self.device), batch_y.to(self.device) | |
| optimizer.zero_grad() | |
| outputs = self.model(batch_X) | |
| loss = self.criterion(outputs, batch_y) | |
| loss.backward() | |
| optimizer.step() | |
| train_loss += loss.item() | |
| train_loss /= len(train_loader) | |
| history["train_loss"].append(train_loss) | |
| # Validation | |
| if X_val is not None: | |
| self.model.eval() | |
| with torch.no_grad(): | |
| val_outputs = self.model(X_val_t) | |
| val_loss = self.criterion(val_outputs, y_val_t).item() | |
| history["val_loss"].append(val_loss) | |
| scheduler.step(val_loss) | |
| if val_loss < best_val_loss: | |
| best_val_loss = val_loss | |
| best_state = self.model.state_dict().copy() | |
| patience_counter = 0 | |
| else: | |
| patience_counter += 1 | |
| if patience_counter >= patience: | |
| logger.info(f"Early stopping at epoch {epoch+1}") | |
| break | |
| if (epoch + 1) % 10 == 0: | |
| val_str = f", Val Loss: {val_loss:.4f}" if X_val is not None else "" | |
| logger.info(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f}{val_str}") | |
| if best_state is not None: | |
| self.model.load_state_dict(best_state) | |
| return history | |
| def predict(self, X: np.ndarray) -> np.ndarray: | |
| """Predict labels.""" | |
| self.model.eval() | |
| X_t = torch.FloatTensor(X).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(X_t) | |
| if self.task_type == "multilabel": | |
| # Apply sigmoid for inference (BCEWithLogitsLoss uses raw logits) | |
| probs = torch.sigmoid(outputs) | |
| return (probs.cpu().numpy() > 0.5).astype(int) | |
| else: | |
| return outputs.argmax(dim=1).cpu().numpy() | |
| def predict_proba(self, X: np.ndarray) -> np.ndarray: | |
| """Predict probabilities.""" | |
| self.model.eval() | |
| X_t = torch.FloatTensor(X).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(X_t) | |
| if self.task_type == "multilabel": | |
| # Apply sigmoid for probability output | |
| return torch.sigmoid(outputs).cpu().numpy() | |
| else: | |
| return torch.softmax(outputs, dim=1).cpu().numpy() | |
| # ============================================================================= | |
| # Unified Trainer | |
| # ============================================================================= | |
| class UnifiedAMRTrainer: | |
| """Unified training pipeline for AMR prediction.""" | |
| SKLEARN_MODELS = { | |
| "random_forest": lambda: RandomForestClassifier( | |
| n_estimators=200, max_depth=20, min_samples_split=5, | |
| n_jobs=-1, random_state=42, class_weight="balanced" | |
| ), | |
| "extra_trees": lambda: ExtraTreesClassifier( | |
| n_estimators=200, max_depth=20, | |
| n_jobs=-1, random_state=42, class_weight="balanced" | |
| ), | |
| "gradient_boosting": lambda: GradientBoostingClassifier( | |
| n_estimators=100, max_depth=5, learning_rate=0.1, random_state=42 | |
| ), | |
| "logistic_regression": lambda: LogisticRegression( | |
| max_iter=1000, random_state=42, class_weight="balanced", n_jobs=-1 | |
| ), | |
| "mlp": lambda: MLPClassifier( | |
| hidden_layer_sizes=(256, 128, 64), max_iter=500, | |
| random_state=42, early_stopping=True | |
| ), | |
| } | |
| def __init__( | |
| self, | |
| task_type: str = "multilabel", | |
| imbalance_strategy: str = "class_weight", | |
| scale_features: bool = True, | |
| ): | |
| self.task_type = task_type | |
| self.imbalance_handler = ImbalanceHandler(imbalance_strategy) | |
| self.scale_features = scale_features | |
| self.scaler = StandardScaler() if scale_features else None | |
| self.model = None | |
| self.class_names = None | |
| self.feature_names = None | |
| self.results = {} | |
| def _create_sklearn_model(self, model_name: str): | |
| """Create sklearn model with proper wrapper for multilabel.""" | |
| if model_name not in self.SKLEARN_MODELS: | |
| raise ValueError(f"Unknown model: {model_name}") | |
| base_model = self.SKLEARN_MODELS[model_name]() | |
| if self.task_type == "multilabel": | |
| return OneVsRestClassifier(base_model, n_jobs=-1) | |
| return base_model | |
| def _create_pytorch_model(self, input_dim: int, output_dim: int, hidden_dims: List[int] = [512, 256, 128]): | |
| """Create PyTorch model.""" | |
| return AMRNet( | |
| input_dim=input_dim, | |
| output_dim=output_dim, | |
| hidden_dims=hidden_dims, | |
| task_type=self.task_type, | |
| ) | |
| def train( | |
| self, | |
| data: Dict, | |
| model_type: str = "random_forest", | |
| use_pytorch: bool = False, | |
| epochs: int = 100, | |
| batch_size: int = 32, | |
| ) -> Dict: | |
| """Train a model on the data.""" | |
| logger.info(f"Training {model_type} ({'PyTorch' if use_pytorch else 'sklearn'})...") | |
| X_train = data["X_train"].copy() | |
| y_train = data["y_train"].copy() | |
| X_val = data.get("X_val") | |
| y_val = data.get("y_val") | |
| X_test = data["X_test"] | |
| y_test = data["y_test"] | |
| self.class_names = data.get("metadata", {}).get("class_names", []) | |
| self.feature_names = data.get("metadata", {}).get("feature_names", []) | |
| # Scale features | |
| if self.scaler: | |
| X_train = self.scaler.fit_transform(X_train) | |
| if X_val is not None: | |
| X_val = self.scaler.transform(X_val) | |
| X_test = self.scaler.transform(X_test) | |
| # Handle imbalance | |
| if self.imbalance_handler.strategy not in ["none", "class_weight"]: | |
| X_train, y_train = self.imbalance_handler.resample(X_train, y_train, self.task_type) | |
| class_weights = None | |
| sample_weights = None | |
| if self.imbalance_handler.strategy == "class_weight": | |
| class_weights = self.imbalance_handler.get_class_weights(y_train, self.task_type) | |
| sample_weights = self.imbalance_handler.get_sample_weights(y_train, self.task_type) | |
| # Train model | |
| if use_pytorch: | |
| output_dim = y_train.shape[1] if self.task_type == "multilabel" else len(np.unique(y_train)) | |
| model = self._create_pytorch_model(X_train.shape[1], output_dim) | |
| trainer = PyTorchTrainer(model, self.task_type, class_weights) | |
| history = trainer.fit( | |
| X_train, y_train, X_val, y_val, | |
| epochs=epochs, batch_size=batch_size, | |
| sample_weights=sample_weights, | |
| ) | |
| self.model = trainer | |
| self.results["training_history"] = history | |
| else: | |
| model = self._create_sklearn_model(model_type) | |
| # For sklearn, OneVsRestClassifier doesn't support sample_weight directly | |
| # Just fit without sample weights for multilabel (class_weight is already in base estimator) | |
| model.fit(X_train, y_train) | |
| self.model = model | |
| # Evaluate | |
| train_metrics = self._evaluate(X_train, y_train, "train") | |
| if X_val is not None: | |
| val_metrics = self._evaluate(X_val, y_val, "val") | |
| test_metrics = self._evaluate(X_test, y_test, "test") | |
| self.results["model_type"] = model_type | |
| self.results["use_pytorch"] = use_pytorch | |
| self.results["task_type"] = self.task_type | |
| self.results["imbalance_strategy"] = self.imbalance_handler.strategy | |
| self.results["train_metrics"] = train_metrics | |
| if X_val is not None: | |
| self.results["val_metrics"] = val_metrics | |
| self.results["test_metrics"] = test_metrics | |
| self.results["class_names"] = self.class_names | |
| self._log_results(test_metrics) | |
| return self.results | |
| def _evaluate(self, X: np.ndarray, y_true: np.ndarray, split_name: str) -> Dict: | |
| """Evaluate model performance.""" | |
| if hasattr(self.model, "predict"): | |
| y_pred = self.model.predict(X) | |
| else: | |
| y_pred = self.model.predict(X) | |
| metrics = {"split": split_name} | |
| if self.task_type == "multilabel": | |
| metrics["hamming_loss"] = float(hamming_loss(y_true, y_pred)) | |
| metrics["micro_f1"] = float(f1_score(y_true, y_pred, average="micro", zero_division=0)) | |
| metrics["macro_f1"] = float(f1_score(y_true, y_pred, average="macro", zero_division=0)) | |
| metrics["weighted_f1"] = float(f1_score(y_true, y_pred, average="weighted", zero_division=0)) | |
| metrics["micro_precision"] = float(precision_score(y_true, y_pred, average="micro", zero_division=0)) | |
| metrics["micro_recall"] = float(recall_score(y_true, y_pred, average="micro", zero_division=0)) | |
| # Per-class metrics | |
| metrics["per_class"] = {} | |
| for i, cls in enumerate(self.class_names): | |
| metrics["per_class"][cls] = { | |
| "precision": float(precision_score(y_true[:, i], y_pred[:, i], zero_division=0)), | |
| "recall": float(recall_score(y_true[:, i], y_pred[:, i], zero_division=0)), | |
| "f1": float(f1_score(y_true[:, i], y_pred[:, i], zero_division=0)), | |
| "support": int(y_true[:, i].sum()), | |
| } | |
| # AUC | |
| try: | |
| if hasattr(self.model, "predict_proba"): | |
| y_proba = self.model.predict_proba(X) | |
| else: | |
| y_proba = y_pred | |
| metrics["micro_auc"] = float(roc_auc_score(y_true, y_proba, average="micro")) | |
| metrics["macro_auc"] = float(roc_auc_score(y_true, y_proba, average="macro")) | |
| except Exception: | |
| pass | |
| else: | |
| metrics["accuracy"] = float(accuracy_score(y_true, y_pred)) | |
| metrics["precision"] = float(precision_score(y_true, y_pred, average="weighted", zero_division=0)) | |
| metrics["recall"] = float(recall_score(y_true, y_pred, average="weighted", zero_division=0)) | |
| metrics["f1"] = float(f1_score(y_true, y_pred, average="weighted", zero_division=0)) | |
| metrics["f1_macro"] = float(f1_score(y_true, y_pred, average="macro", zero_division=0)) | |
| try: | |
| if hasattr(self.model, "predict_proba"): | |
| y_proba = self.model.predict_proba(X) | |
| metrics["auc"] = float(roc_auc_score(y_true, y_proba, multi_class="ovr", average="weighted")) | |
| except Exception: | |
| pass | |
| return metrics | |
| def _log_results(self, metrics: Dict): | |
| """Log evaluation results.""" | |
| logger.info("=" * 60) | |
| logger.info("Test Results:") | |
| logger.info("=" * 60) | |
| if self.task_type == "multilabel": | |
| logger.info(f" Hamming Loss: {metrics['hamming_loss']:.4f}") | |
| logger.info(f" Micro F1: {metrics['micro_f1']:.4f}") | |
| logger.info(f" Macro F1: {metrics['macro_f1']:.4f}") | |
| logger.info(f" Micro AUC: {metrics.get('micro_auc', 'N/A')}") | |
| logger.info("\nPer-class F1 scores:") | |
| for cls, cls_metrics in metrics.get("per_class", {}).items(): | |
| logger.info(f" {cls}: F1={cls_metrics['f1']:.3f}, Support={cls_metrics['support']}") | |
| else: | |
| logger.info(f" Accuracy: {metrics['accuracy']:.4f}") | |
| logger.info(f" F1 (Weighted): {metrics['f1']:.4f}") | |
| logger.info(f" F1 (Macro): {metrics['f1_macro']:.4f}") | |
| if "auc" in metrics: | |
| logger.info(f" AUC: {metrics['auc']:.4f}") | |
| def save(self, filepath: str): | |
| """Save the trained model.""" | |
| Path(filepath).parent.mkdir(parents=True, exist_ok=True) | |
| save_data = { | |
| "model": self.model.model.state_dict() if isinstance(self.model, PyTorchTrainer) else self.model, | |
| "scaler": self.scaler, | |
| "task_type": self.task_type, | |
| "class_names": self.class_names, | |
| "feature_names": self.feature_names, | |
| "results": self.results, | |
| } | |
| if isinstance(self.model, PyTorchTrainer): | |
| torch.save(save_data, filepath.replace(".joblib", ".pt")) | |
| else: | |
| joblib.dump(save_data, filepath) | |
| logger.info(f"Model saved to {filepath}") | |
| def run_comprehensive_training(output_dir: str = "models/unified") -> Dict: | |
| """Run comprehensive training across all datasets and models.""" | |
| output_path = Path(output_dir) | |
| output_path.mkdir(parents=True, exist_ok=True) | |
| all_results = { | |
| "timestamp": datetime.now().isoformat(), | |
| "experiments": [], | |
| } | |
| # ========================================================================== | |
| # Experiment 1: NCBI AMR with different imbalance strategies | |
| # ========================================================================== | |
| logger.info("\n" + "="*80) | |
| logger.info("EXPERIMENT 1: NCBI AMR - Imbalance Strategy Comparison") | |
| logger.info("="*80) | |
| ncbi_amr = load_dataset("data/processed/ncbi", "ncbi_amr") | |
| best_ncbi_result = None | |
| best_ncbi_f1 = 0 | |
| for strategy in ["class_weight", "smote", "random_over"]: | |
| for model_type in ["random_forest", "extra_trees", "logistic_regression"]: | |
| logger.info(f"\n--- {model_type} with {strategy} ---") | |
| trainer = UnifiedAMRTrainer( | |
| task_type="multilabel", | |
| imbalance_strategy=strategy, | |
| ) | |
| try: | |
| results = trainer.train(ncbi_amr, model_type=model_type) | |
| all_results["experiments"].append({ | |
| "dataset": "ncbi_amr", | |
| "model": model_type, | |
| "strategy": strategy, | |
| "test_micro_f1": results["test_metrics"]["micro_f1"], | |
| "test_macro_f1": results["test_metrics"]["macro_f1"], | |
| "test_micro_auc": results["test_metrics"].get("micro_auc"), | |
| }) | |
| if results["test_metrics"]["micro_f1"] > best_ncbi_f1: | |
| best_ncbi_f1 = results["test_metrics"]["micro_f1"] | |
| best_ncbi_result = (trainer, results, model_type, strategy) | |
| except Exception as e: | |
| logger.error(f"Failed: {e}") | |
| # Save best NCBI model | |
| if best_ncbi_result: | |
| trainer, results, model_type, strategy = best_ncbi_result | |
| trainer.save(str(output_path / f"ncbi_amr_best_{model_type}.joblib")) | |
| logger.info(f"\nBest NCBI AMR model: {model_type} with {strategy} (Micro F1: {best_ncbi_f1:.4f})") | |
| # ========================================================================== | |
| # Experiment 2: CARD Drug Class (larger dataset) | |
| # ========================================================================== | |
| logger.info("\n" + "="*80) | |
| logger.info("EXPERIMENT 2: CARD Drug Class") | |
| logger.info("="*80) | |
| card_data = load_dataset("data/processed/card", "card_drug_class") | |
| best_card_result = None | |
| best_card_f1 = 0 | |
| for model_type in ["random_forest", "extra_trees"]: | |
| logger.info(f"\n--- {model_type} ---") | |
| trainer = UnifiedAMRTrainer( | |
| task_type="multilabel", | |
| imbalance_strategy="class_weight", | |
| ) | |
| try: | |
| results = trainer.train(card_data, model_type=model_type) | |
| all_results["experiments"].append({ | |
| "dataset": "card_drug_class", | |
| "model": model_type, | |
| "strategy": "class_weight", | |
| "test_micro_f1": results["test_metrics"]["micro_f1"], | |
| "test_macro_f1": results["test_metrics"]["macro_f1"], | |
| }) | |
| if results["test_metrics"]["micro_f1"] > best_card_f1: | |
| best_card_f1 = results["test_metrics"]["micro_f1"] | |
| best_card_result = (trainer, results, model_type) | |
| except Exception as e: | |
| logger.error(f"Failed: {e}") | |
| # Save best CARD model | |
| if best_card_result: | |
| trainer, results, model_type = best_card_result | |
| trainer.save(str(output_path / f"card_drug_class_best_{model_type}.joblib")) | |
| logger.info(f"\nBest CARD model: {model_type} (Micro F1: {best_card_f1:.4f})") | |
| # ========================================================================== | |
| # Experiment 3: PyTorch Deep Learning on NCBI AMR | |
| # ========================================================================== | |
| logger.info("\n" + "="*80) | |
| logger.info("EXPERIMENT 3: Deep Learning on NCBI AMR") | |
| logger.info("="*80) | |
| trainer = UnifiedAMRTrainer( | |
| task_type="multilabel", | |
| imbalance_strategy="class_weight", | |
| ) | |
| try: | |
| results = trainer.train( | |
| ncbi_amr, | |
| model_type="deep_learning", | |
| use_pytorch=True, | |
| epochs=100, | |
| batch_size=32, | |
| ) | |
| all_results["experiments"].append({ | |
| "dataset": "ncbi_amr", | |
| "model": "deep_learning", | |
| "strategy": "class_weight", | |
| "test_micro_f1": results["test_metrics"]["micro_f1"], | |
| "test_macro_f1": results["test_metrics"]["macro_f1"], | |
| }) | |
| # Save PyTorch model | |
| trainer.save(str(output_path / "ncbi_amr_deep_learning.pt")) | |
| except Exception as e: | |
| logger.error(f"Deep learning failed: {e}") | |
| # ========================================================================== | |
| # Summary | |
| # ========================================================================== | |
| logger.info("\n" + "="*80) | |
| logger.info("TRAINING SUMMARY") | |
| logger.info("="*80) | |
| # Sort by F1 score | |
| experiments_sorted = sorted( | |
| all_results["experiments"], | |
| key=lambda x: x.get("test_micro_f1", 0), | |
| reverse=True, | |
| ) | |
| logger.info("\nTop 5 Models by Micro F1:") | |
| for i, exp in enumerate(experiments_sorted[:5], 1): | |
| logger.info(f" {i}. {exp['dataset']} / {exp['model']} / {exp['strategy']}: " | |
| f"Micro F1={exp['test_micro_f1']:.4f}, Macro F1={exp['test_macro_f1']:.4f}") | |
| # Save all results | |
| with open(output_path / "training_results.json", "w") as f: | |
| json.dump(all_results, f, indent=2, default=str) | |
| logger.info(f"\nResults saved to {output_path / 'training_results.json'}") | |
| return all_results | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Unified AMR Training Pipeline") | |
| parser.add_argument("--mode", choices=["full", "quick", "ncbi", "card"], default="full") | |
| parser.add_argument("--output-dir", default="models/unified") | |
| args = parser.parse_args() | |
| if args.mode == "full": | |
| run_comprehensive_training(args.output_dir) | |
| else: | |
| # Quick mode for testing | |
| logger.info("Running quick training mode...") | |
| ncbi_amr = load_dataset("data/processed/ncbi", "ncbi_amr") | |
| trainer = UnifiedAMRTrainer(task_type="multilabel", imbalance_strategy="class_weight") | |
| results = trainer.train(ncbi_amr, model_type="random_forest") | |
| trainer.save(f"{args.output_dir}/quick_test.joblib") | |