""" Comprehensive model evaluation on test set. This script: 1. Loads trained ensemble model with calibration 2. Evaluates on held-out test set 3. Computes classification metrics (accuracy, F1, AUC, sensitivity, specificity) 4. Generates confusion matrix, ROC curve, PR curve 5. Performs error analysis Usage: python training/evaluate.py """ import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) import torch import torch.nn as nn from torch.utils.data import DataLoader import mlflow import numpy as np from tqdm import tqdm import yaml import pandas as pd import logging from sklearn.metrics import ( accuracy_score, f1_score, roc_auc_score, confusion_matrix, classification_report, roc_curve, precision_recall_curve, average_precision_score, ) import matplotlib.pyplot as plt import seaborn as sns from training.dataset import DysarthriaDataset from training.train_hubert_salr import HuBERTSALRModel from training.train_cnn_bilstm import CNNBiLSTMTransformer # ══════════════════════════════════════════════════════════════════════════════ # Model Inference # ══════════════════════════════════════════════════════════════════════════════ def evaluate_model(hubert_model, cnn_model, dataloader, alpha, platt_a, platt_b, device): """ Evaluate calibrated ensemble on test set. Returns: predictions, probabilities, labels, file_paths """ all_preds = [] all_probs = [] all_labels = [] all_files = [] hubert_model.eval() cnn_model.eval() with torch.no_grad(): for batch in tqdm(dataloader, desc="Evaluating"): waveform = batch["waveform"].to(device) spectrogram = batch["spectrogram"].to(device) labels = batch["label"] file_paths = batch.get("file_path", [""] * len(labels)) # Ensemble logits hubert_logits = hubert_model(waveform) cnn_logits = cnn_model(spectrogram) ensemble_logits = alpha * hubert_logits + (1 - alpha) * cnn_logits # Apply Platt scaling raw_logits = ensemble_logits[:, 1].cpu().numpy() z = platt_a * raw_logits + platt_b calibrated_probs = 1 / (1 + np.exp(-z)) # Predictions preds = (calibrated_probs > 0.5).astype(int) all_preds.extend(preds) all_probs.extend(calibrated_probs) all_labels.extend(labels.numpy()) all_files.extend(file_paths) return ( np.array(all_preds), np.array(all_probs), np.array(all_labels), all_files, ) # ══════════════════════════════════════════════════════════════════════════════ # Metrics Computation # ══════════════════════════════════════════════════════════════════════════════ def compute_metrics(y_true, y_pred, y_prob): """Compute comprehensive classification metrics.""" # Basic metrics accuracy = accuracy_score(y_true, y_pred) f1 = f1_score(y_true, y_pred, average="binary") auc = roc_auc_score(y_true, y_prob) ap = average_precision_score(y_true, y_prob) # Confusion matrix cm = confusion_matrix(y_true, y_pred) tn, fp, fn, tp = cm.ravel() # Sensitivity and specificity sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0 specificity = tn / (tn + fp) if (tn + fp) > 0 else 0 # Positive and negative predictive value ppv = tp / (tp + fp) if (tp + fp) > 0 else 0 npv = tn / (tn + fn) if (tn + fn) > 0 else 0 return { "accuracy": accuracy, "f1": f1, "auc": auc, "average_precision": ap, "sensitivity": sensitivity, "specificity": specificity, "ppv": ppv, "npv": npv, "tp": int(tp), "tn": int(tn), "fp": int(fp), "fn": int(fn), "confusion_matrix": cm, } # ══════════════════════════════════════════════════════════════════════════════ # Visualization # ══════════════════════════════════════════════════════════════════════════════ def plot_confusion_matrix(cm, output_path: Path): """Plot confusion matrix.""" plt.figure(figsize=(8, 6)) sns.heatmap( cm, annot=True, fmt="d", cmap="Blues", xticklabels=["Healthy", "Dysarthric"], yticklabels=["Healthy", "Dysarthric"], cbar_kws={"label": "Count"}, ) plt.title("Confusion Matrix - Test Set", fontsize=16, fontweight="bold") plt.ylabel("True Label", fontsize=14) plt.xlabel("Predicted Label", fontsize=14) plt.tight_layout() plt.savefig(output_path, dpi=300, bbox_inches="tight") plt.close() def plot_roc_curve(y_true, y_prob, auc_score, output_path: Path): """Plot ROC curve.""" fpr, tpr, thresholds = roc_curve(y_true, y_prob) plt.figure(figsize=(8, 6)) plt.plot(fpr, tpr, linewidth=2, label=f"Model (AUC = {auc_score:.4f})") plt.plot([0, 1], [0, 1], "k--", linewidth=1, label="Random Classifier") plt.xlabel("False Positive Rate", fontsize=14) plt.ylabel("True Positive Rate", fontsize=14) plt.title("ROC Curve - Test Set", fontsize=16, fontweight="bold") plt.legend(fontsize=12) plt.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(output_path, dpi=300, bbox_inches="tight") plt.close() def plot_precision_recall_curve(y_true, y_prob, ap_score, output_path: Path): """Plot Precision-Recall curve.""" precision, recall, thresholds = precision_recall_curve(y_true, y_prob) plt.figure(figsize=(8, 6)) plt.plot(recall, precision, linewidth=2, label=f"Model (AP = {ap_score:.4f})") plt.xlabel("Recall", fontsize=14) plt.ylabel("Precision", fontsize=14) plt.title("Precision-Recall Curve - Test Set", fontsize=16, fontweight="bold") plt.legend(fontsize=12) plt.grid(True, alpha=0.3) plt.xlim([0, 1]) plt.ylim([0, 1]) plt.tight_layout() plt.savefig(output_path, dpi=300, bbox_inches="tight") plt.close() def plot_probability_distribution(y_true, y_prob, output_path: Path): """Plot distribution of predicted probabilities by class.""" plt.figure(figsize=(10, 6)) mask_positive = y_true == 1 mask_negative = y_true == 0 plt.hist( y_prob[mask_negative], bins=30, alpha=0.6, color="blue", label="Healthy", edgecolor="black", ) plt.hist( y_prob[mask_positive], bins=30, alpha=0.6, color="red", label="Dysarthric", edgecolor="black", ) plt.axvline(0.5, color="black", linestyle="--", linewidth=2, label="Decision Threshold") plt.xlabel("Predicted Probability", fontsize=14) plt.ylabel("Count", fontsize=14) plt.title("Predicted Probability Distribution - Test Set", fontsize=16, fontweight="bold") plt.legend(fontsize=12) plt.grid(True, alpha=0.3, axis="y") plt.tight_layout() plt.savefig(output_path, dpi=300, bbox_inches="tight") plt.close() # ══════════════════════════════════════════════════════════════════════════════ # Error Analysis # ══════════════════════════════════════════════════════════════════════════════ def perform_error_analysis(y_true, y_pred, y_prob, file_paths, output_path: Path): """Identify and save misclassified samples.""" errors = [] for i, (true_label, pred_label, prob, file_path) in enumerate( zip(y_true, y_pred, y_prob, file_paths) ): if true_label != pred_label: error_type = "False Positive" if pred_label == 1 else "False Negative" confidence = prob if pred_label == 1 else (1 - prob) errors.append({ "file_path": file_path, "true_label": "Dysarthric" if true_label == 1 else "Healthy", "predicted_label": "Dysarthric" if pred_label == 1 else "Healthy", "probability": prob, "confidence": confidence, "error_type": error_type, }) errors_df = pd.DataFrame(errors) errors_df = errors_df.sort_values("confidence", ascending=False) errors_df.to_csv(output_path, index=False) return errors_df # ══════════════════════════════════════════════════════════════════════════════ # Main # ══════════════════════════════════════════════════════════════════════════════ def main(): logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") # ────────────────────────────────────────────────────────────────────────── # Load Configuration # ────────────────────────────────────────────────────────────────────────── with open("configs/model_config.yaml") as f: config = yaml.safe_load(f) alpha = config.get("ensemble", {}).get("alpha", 0.6) # Load Platt scaling parameters calibration_file = Path("reports/calibration/calibration_params.yaml") if calibration_file.exists(): with open(calibration_file) as f: cal_config = yaml.safe_load(f) platt_a = cal_config["platt_scaling"]["a"] platt_b = cal_config["platt_scaling"]["b"] logger.info(f"Loaded Platt parameters: a={platt_a:.6f}, b={platt_b:.6f}") else: platt_a, platt_b = 1.0, 0.0 logger.warning("Calibration parameters not found, using identity mapping") # ────────────────────────────────────────────────────────────────────────── # Load Models # ────────────────────────────────────────────────────────────────────────── logger.info("Loading models...") hubert_checkpoint = Path("models/hubert_salr_best.pt") cnn_checkpoint = Path("models/cnn_bilstm_best.pt") hubert_model = HuBERTSALRModel() hubert_model.load_state_dict(torch.load(hubert_checkpoint, map_location=device)["model_state_dict"]) hubert_model.to(device) cnn_model = CNNBiLSTMTransformer() cnn_model.load_state_dict(torch.load(cnn_checkpoint, map_location=device)["model_state_dict"]) cnn_model.to(device) # ────────────────────────────────────────────────────────────────────────── # Load Test Data # ────────────────────────────────────────────────────────────────────────── test_manifest = Path("data/manifests/test.csv") test_dataset = DysarthriaDataset(test_manifest, augmentor=None, mode="test") test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4) logger.info(f"Test samples: {len(test_dataset)}") # ────────────────────────────────────────────────────────────────────────── # Evaluate # ────────────────────────────────────────────────────────────────────────── mlflow.set_experiment("model_evaluation") with mlflow.start_run(): logger.info("\nEvaluating on test set...") y_pred, y_prob, y_true, file_paths = evaluate_model( hubert_model, cnn_model, test_loader, alpha, platt_a, platt_b, device ) # Compute metrics metrics = compute_metrics(y_true, y_pred, y_prob) # ────────────────────────────────────────────────────────────────────── # Print Results # ────────────────────────────────────────────────────────────────────── logger.info("\n" + "=" * 80) logger.info("TEST SET EVALUATION RESULTS") logger.info("=" * 80) logger.info(f"Accuracy: {metrics['accuracy']:.4f}") logger.info(f"F1 Score: {metrics['f1']:.4f}") logger.info(f"AUC-ROC: {metrics['auc']:.4f}") logger.info(f"Average Precision: {metrics['average_precision']:.4f}") logger.info(f"Sensitivity: {metrics['sensitivity']:.4f}") logger.info(f"Specificity: {metrics['specificity']:.4f}") logger.info(f"PPV: {metrics['ppv']:.4f}") logger.info(f"NPV: {metrics['npv']:.4f}") logger.info("") logger.info("Confusion Matrix:") logger.info(f" True Negatives: {metrics['tn']}") logger.info(f" False Positives: {metrics['fp']}") logger.info(f" False Negatives: {metrics['fn']}") logger.info(f" True Positives: {metrics['tp']}") logger.info("=" * 80) # Log to MLflow mlflow.log_params({ "ensemble_alpha": alpha, "platt_a": platt_a, "platt_b": platt_b, "test_samples": len(y_true), }) mlflow.log_metrics({ "test_accuracy": metrics["accuracy"], "test_f1": metrics["f1"], "test_auc": metrics["auc"], "test_ap": metrics["average_precision"], "test_sensitivity": metrics["sensitivity"], "test_specificity": metrics["specificity"], "test_ppv": metrics["ppv"], "test_npv": metrics["npv"], }) # ────────────────────────────────────────────────────────────────────── # Save Results # ────────────────────────────────────────────────────────────────────── output_dir = Path("reports/evaluation") output_dir.mkdir(parents=True, exist_ok=True) # Save metrics metrics_file = output_dir / "test_metrics.yaml" with open(metrics_file, "w") as f: # Convert numpy types to Python types metrics_to_save = {k: v for k, v in metrics.items() if k != "confusion_matrix"} yaml.dump(metrics_to_save, f, default_flow_style=False) mlflow.log_artifact(str(metrics_file)) logger.info(f"\n✓ Metrics saved to {metrics_file}") # Classification report report = classification_report( y_true, y_pred, target_names=["Healthy", "Dysarthric"], digits=4, ) report_file = output_dir / "classification_report.txt" with open(report_file, "w") as f: f.write(report) mlflow.log_artifact(str(report_file)) logger.info(f"✓ Classification report saved to {report_file}") # Confusion matrix cm_path = output_dir / "confusion_matrix.png" plot_confusion_matrix(metrics["confusion_matrix"], cm_path) mlflow.log_artifact(str(cm_path)) logger.info(f"✓ Confusion matrix plot saved to {cm_path}") # ROC curve roc_path = output_dir / "roc_curve.png" plot_roc_curve(y_true, y_prob, metrics["auc"], roc_path) mlflow.log_artifact(str(roc_path)) logger.info(f"✓ ROC curve saved to {roc_path}") # Precision-Recall curve pr_path = output_dir / "precision_recall_curve.png" plot_precision_recall_curve(y_true, y_prob, metrics["average_precision"], pr_path) mlflow.log_artifact(str(pr_path)) logger.info(f"✓ Precision-Recall curve saved to {pr_path}") # Probability distribution prob_dist_path = output_dir / "probability_distribution.png" plot_probability_distribution(y_true, y_prob, prob_dist_path) mlflow.log_artifact(str(prob_dist_path)) logger.info(f"✓ Probability distribution saved to {prob_dist_path}") # Error analysis errors_file = output_dir / "misclassified_samples.csv" errors_df = perform_error_analysis(y_true, y_pred, y_prob, file_paths, errors_file) mlflow.log_artifact(str(errors_file)) logger.info(f"✓ Error analysis saved to {errors_file}") logger.info(f" Total errors: {len(errors_df)}") logger.info(f" False Positives: {len(errors_df[errors_df['error_type'] == 'False Positive'])}") logger.info(f" False Negatives: {len(errors_df[errors_df['error_type'] == 'False Negative'])}") logger.info("\n✓ Evaluation complete!") if __name__ == "__main__": main()