StrokeMitra-API / training /evaluate.py
DhruvB1906's picture
Upload folder using huggingface_hub
f468cc8 verified
"""
Comprehensive model evaluation on test set.
This script:
1. Loads trained ensemble model with calibration
2. Evaluates on held-out test set
3. Computes classification metrics (accuracy, F1, AUC, sensitivity, specificity)
4. Generates confusion matrix, ROC curve, PR curve
5. Performs error analysis
Usage:
python training/evaluate.py
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import mlflow
import numpy as np
from tqdm import tqdm
import yaml
import pandas as pd
import logging
from sklearn.metrics import (
accuracy_score,
f1_score,
roc_auc_score,
confusion_matrix,
classification_report,
roc_curve,
precision_recall_curve,
average_precision_score,
)
import matplotlib.pyplot as plt
import seaborn as sns
from training.dataset import DysarthriaDataset
from training.train_hubert_salr import HuBERTSALRModel
from training.train_cnn_bilstm import CNNBiLSTMTransformer
# ══════════════════════════════════════════════════════════════════════════════
# Model Inference
# ══════════════════════════════════════════════════════════════════════════════
def evaluate_model(hubert_model, cnn_model, dataloader, alpha, platt_a, platt_b, device):
"""
Evaluate calibrated ensemble on test set.
Returns:
predictions, probabilities, labels, file_paths
"""
all_preds = []
all_probs = []
all_labels = []
all_files = []
hubert_model.eval()
cnn_model.eval()
with torch.no_grad():
for batch in tqdm(dataloader, desc="Evaluating"):
waveform = batch["waveform"].to(device)
spectrogram = batch["spectrogram"].to(device)
labels = batch["label"]
file_paths = batch.get("file_path", [""] * len(labels))
# Ensemble logits
hubert_logits = hubert_model(waveform)
cnn_logits = cnn_model(spectrogram)
ensemble_logits = alpha * hubert_logits + (1 - alpha) * cnn_logits
# Apply Platt scaling
raw_logits = ensemble_logits[:, 1].cpu().numpy()
z = platt_a * raw_logits + platt_b
calibrated_probs = 1 / (1 + np.exp(-z))
# Predictions
preds = (calibrated_probs > 0.5).astype(int)
all_preds.extend(preds)
all_probs.extend(calibrated_probs)
all_labels.extend(labels.numpy())
all_files.extend(file_paths)
return (
np.array(all_preds),
np.array(all_probs),
np.array(all_labels),
all_files,
)
# ══════════════════════════════════════════════════════════════════════════════
# Metrics Computation
# ══════════════════════════════════════════════════════════════════════════════
def compute_metrics(y_true, y_pred, y_prob):
"""Compute comprehensive classification metrics."""
# Basic metrics
accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred, average="binary")
auc = roc_auc_score(y_true, y_prob)
ap = average_precision_score(y_true, y_prob)
# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
tn, fp, fn, tp = cm.ravel()
# Sensitivity and specificity
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
# Positive and negative predictive value
ppv = tp / (tp + fp) if (tp + fp) > 0 else 0
npv = tn / (tn + fn) if (tn + fn) > 0 else 0
return {
"accuracy": accuracy,
"f1": f1,
"auc": auc,
"average_precision": ap,
"sensitivity": sensitivity,
"specificity": specificity,
"ppv": ppv,
"npv": npv,
"tp": int(tp),
"tn": int(tn),
"fp": int(fp),
"fn": int(fn),
"confusion_matrix": cm,
}
# ══════════════════════════════════════════════════════════════════════════════
# Visualization
# ══════════════════════════════════════════════════════════════════════════════
def plot_confusion_matrix(cm, output_path: Path):
"""Plot confusion matrix."""
plt.figure(figsize=(8, 6))
sns.heatmap(
cm,
annot=True,
fmt="d",
cmap="Blues",
xticklabels=["Healthy", "Dysarthric"],
yticklabels=["Healthy", "Dysarthric"],
cbar_kws={"label": "Count"},
)
plt.title("Confusion Matrix - Test Set", fontsize=16, fontweight="bold")
plt.ylabel("True Label", fontsize=14)
plt.xlabel("Predicted Label", fontsize=14)
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches="tight")
plt.close()
def plot_roc_curve(y_true, y_prob, auc_score, output_path: Path):
"""Plot ROC curve."""
fpr, tpr, thresholds = roc_curve(y_true, y_prob)
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, linewidth=2, label=f"Model (AUC = {auc_score:.4f})")
plt.plot([0, 1], [0, 1], "k--", linewidth=1, label="Random Classifier")
plt.xlabel("False Positive Rate", fontsize=14)
plt.ylabel("True Positive Rate", fontsize=14)
plt.title("ROC Curve - Test Set", fontsize=16, fontweight="bold")
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches="tight")
plt.close()
def plot_precision_recall_curve(y_true, y_prob, ap_score, output_path: Path):
"""Plot Precision-Recall curve."""
precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
plt.figure(figsize=(8, 6))
plt.plot(recall, precision, linewidth=2, label=f"Model (AP = {ap_score:.4f})")
plt.xlabel("Recall", fontsize=14)
plt.ylabel("Precision", fontsize=14)
plt.title("Precision-Recall Curve - Test Set", fontsize=16, fontweight="bold")
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches="tight")
plt.close()
def plot_probability_distribution(y_true, y_prob, output_path: Path):
"""Plot distribution of predicted probabilities by class."""
plt.figure(figsize=(10, 6))
mask_positive = y_true == 1
mask_negative = y_true == 0
plt.hist(
y_prob[mask_negative],
bins=30,
alpha=0.6,
color="blue",
label="Healthy",
edgecolor="black",
)
plt.hist(
y_prob[mask_positive],
bins=30,
alpha=0.6,
color="red",
label="Dysarthric",
edgecolor="black",
)
plt.axvline(0.5, color="black", linestyle="--", linewidth=2, label="Decision Threshold")
plt.xlabel("Predicted Probability", fontsize=14)
plt.ylabel("Count", fontsize=14)
plt.title("Predicted Probability Distribution - Test Set", fontsize=16, fontweight="bold")
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3, axis="y")
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches="tight")
plt.close()
# ══════════════════════════════════════════════════════════════════════════════
# Error Analysis
# ══════════════════════════════════════════════════════════════════════════════
def perform_error_analysis(y_true, y_pred, y_prob, file_paths, output_path: Path):
"""Identify and save misclassified samples."""
errors = []
for i, (true_label, pred_label, prob, file_path) in enumerate(
zip(y_true, y_pred, y_prob, file_paths)
):
if true_label != pred_label:
error_type = "False Positive" if pred_label == 1 else "False Negative"
confidence = prob if pred_label == 1 else (1 - prob)
errors.append({
"file_path": file_path,
"true_label": "Dysarthric" if true_label == 1 else "Healthy",
"predicted_label": "Dysarthric" if pred_label == 1 else "Healthy",
"probability": prob,
"confidence": confidence,
"error_type": error_type,
})
errors_df = pd.DataFrame(errors)
errors_df = errors_df.sort_values("confidence", ascending=False)
errors_df.to_csv(output_path, index=False)
return errors_df
# ══════════════════════════════════════════════════════════════════════════════
# Main
# ══════════════════════════════════════════════════════════════════════════════
def main():
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# ──────────────────────────────────────────────────────────────────────────
# Load Configuration
# ──────────────────────────────────────────────────────────────────────────
with open("configs/model_config.yaml") as f:
config = yaml.safe_load(f)
alpha = config.get("ensemble", {}).get("alpha", 0.6)
# Load Platt scaling parameters
calibration_file = Path("reports/calibration/calibration_params.yaml")
if calibration_file.exists():
with open(calibration_file) as f:
cal_config = yaml.safe_load(f)
platt_a = cal_config["platt_scaling"]["a"]
platt_b = cal_config["platt_scaling"]["b"]
logger.info(f"Loaded Platt parameters: a={platt_a:.6f}, b={platt_b:.6f}")
else:
platt_a, platt_b = 1.0, 0.0
logger.warning("Calibration parameters not found, using identity mapping")
# ──────────────────────────────────────────────────────────────────────────
# Load Models
# ──────────────────────────────────────────────────────────────────────────
logger.info("Loading models...")
hubert_checkpoint = Path("models/hubert_salr_best.pt")
cnn_checkpoint = Path("models/cnn_bilstm_best.pt")
hubert_model = HuBERTSALRModel()
hubert_model.load_state_dict(torch.load(hubert_checkpoint, map_location=device)["model_state_dict"])
hubert_model.to(device)
cnn_model = CNNBiLSTMTransformer()
cnn_model.load_state_dict(torch.load(cnn_checkpoint, map_location=device)["model_state_dict"])
cnn_model.to(device)
# ──────────────────────────────────────────────────────────────────────────
# Load Test Data
# ──────────────────────────────────────────────────────────────────────────
test_manifest = Path("data/manifests/test.csv")
test_dataset = DysarthriaDataset(test_manifest, augmentor=None, mode="test")
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)
logger.info(f"Test samples: {len(test_dataset)}")
# ──────────────────────────────────────────────────────────────────────────
# Evaluate
# ──────────────────────────────────────────────────────────────────────────
mlflow.set_experiment("model_evaluation")
with mlflow.start_run():
logger.info("\nEvaluating on test set...")
y_pred, y_prob, y_true, file_paths = evaluate_model(
hubert_model, cnn_model, test_loader, alpha, platt_a, platt_b, device
)
# Compute metrics
metrics = compute_metrics(y_true, y_pred, y_prob)
# ──────────────────────────────────────────────────────────────────────
# Print Results
# ──────────────────────────────────────────────────────────────────────
logger.info("\n" + "=" * 80)
logger.info("TEST SET EVALUATION RESULTS")
logger.info("=" * 80)
logger.info(f"Accuracy: {metrics['accuracy']:.4f}")
logger.info(f"F1 Score: {metrics['f1']:.4f}")
logger.info(f"AUC-ROC: {metrics['auc']:.4f}")
logger.info(f"Average Precision: {metrics['average_precision']:.4f}")
logger.info(f"Sensitivity: {metrics['sensitivity']:.4f}")
logger.info(f"Specificity: {metrics['specificity']:.4f}")
logger.info(f"PPV: {metrics['ppv']:.4f}")
logger.info(f"NPV: {metrics['npv']:.4f}")
logger.info("")
logger.info("Confusion Matrix:")
logger.info(f" True Negatives: {metrics['tn']}")
logger.info(f" False Positives: {metrics['fp']}")
logger.info(f" False Negatives: {metrics['fn']}")
logger.info(f" True Positives: {metrics['tp']}")
logger.info("=" * 80)
# Log to MLflow
mlflow.log_params({
"ensemble_alpha": alpha,
"platt_a": platt_a,
"platt_b": platt_b,
"test_samples": len(y_true),
})
mlflow.log_metrics({
"test_accuracy": metrics["accuracy"],
"test_f1": metrics["f1"],
"test_auc": metrics["auc"],
"test_ap": metrics["average_precision"],
"test_sensitivity": metrics["sensitivity"],
"test_specificity": metrics["specificity"],
"test_ppv": metrics["ppv"],
"test_npv": metrics["npv"],
})
# ──────────────────────────────────────────────────────────────────────
# Save Results
# ──────────────────────────────────────────────────────────────────────
output_dir = Path("reports/evaluation")
output_dir.mkdir(parents=True, exist_ok=True)
# Save metrics
metrics_file = output_dir / "test_metrics.yaml"
with open(metrics_file, "w") as f:
# Convert numpy types to Python types
metrics_to_save = {k: v for k, v in metrics.items() if k != "confusion_matrix"}
yaml.dump(metrics_to_save, f, default_flow_style=False)
mlflow.log_artifact(str(metrics_file))
logger.info(f"\n✓ Metrics saved to {metrics_file}")
# Classification report
report = classification_report(
y_true,
y_pred,
target_names=["Healthy", "Dysarthric"],
digits=4,
)
report_file = output_dir / "classification_report.txt"
with open(report_file, "w") as f:
f.write(report)
mlflow.log_artifact(str(report_file))
logger.info(f"✓ Classification report saved to {report_file}")
# Confusion matrix
cm_path = output_dir / "confusion_matrix.png"
plot_confusion_matrix(metrics["confusion_matrix"], cm_path)
mlflow.log_artifact(str(cm_path))
logger.info(f"✓ Confusion matrix plot saved to {cm_path}")
# ROC curve
roc_path = output_dir / "roc_curve.png"
plot_roc_curve(y_true, y_prob, metrics["auc"], roc_path)
mlflow.log_artifact(str(roc_path))
logger.info(f"✓ ROC curve saved to {roc_path}")
# Precision-Recall curve
pr_path = output_dir / "precision_recall_curve.png"
plot_precision_recall_curve(y_true, y_prob, metrics["average_precision"], pr_path)
mlflow.log_artifact(str(pr_path))
logger.info(f"✓ Precision-Recall curve saved to {pr_path}")
# Probability distribution
prob_dist_path = output_dir / "probability_distribution.png"
plot_probability_distribution(y_true, y_prob, prob_dist_path)
mlflow.log_artifact(str(prob_dist_path))
logger.info(f"✓ Probability distribution saved to {prob_dist_path}")
# Error analysis
errors_file = output_dir / "misclassified_samples.csv"
errors_df = perform_error_analysis(y_true, y_pred, y_prob, file_paths, errors_file)
mlflow.log_artifact(str(errors_file))
logger.info(f"✓ Error analysis saved to {errors_file}")
logger.info(f" Total errors: {len(errors_df)}")
logger.info(f" False Positives: {len(errors_df[errors_df['error_type'] == 'False Positive'])}")
logger.info(f" False Negatives: {len(errors_df[errors_df['error_type'] == 'False Negative'])}")
logger.info("\n✓ Evaluation complete!")
if __name__ == "__main__":
main()