""" Evaluation Script for Trained Classifier Head Evaluates the trained classifier on test/validation data and generates comprehensive metrics including per-class accuracy, confusion matrix, etc. Usage: python training/evaluate_classifier.py --checkpoint models/checkpoints/classifier_head_best.pt """ import logging import argparse import json import yaml from pathlib import Path from typing import Dict, List, Any import sys import torch import numpy as np from sklearn.metrics import ( accuracy_score, f1_score, precision_score, recall_score, confusion_matrix, classification_report ) import matplotlib.pyplot as plt import seaborn as sns # Add project root to path sys.path.insert(0, str(Path(__file__).parent.parent)) from training.train_classifier_head import PhonemeDataset, collate_fn from torch.utils.data import DataLoader from models.speech_pathology_model import SpeechPathologyClassifier from models.phoneme_mapper import PhonemeMapper from inference.inference_pipeline import InferencePipeline from config import default_audio_config, default_model_config, default_inference_config logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def load_trained_model(checkpoint_path: Path, config_path: Path = Path("training/config.yaml")) -> torch.nn.Module: """Load trained classifier head from checkpoint.""" # Load config with open(config_path, 'r') as f: config = yaml.safe_load(f) # Initialize inference pipeline inference_pipeline = InferencePipeline( audio_config=default_audio_config, model_config=default_model_config, inference_config=default_inference_config ) model = inference_pipeline.model # Load checkpoint checkpoint = torch.load(checkpoint_path, map_location='cpu') model.classifier_head.load_state_dict(checkpoint['model_state_dict']) logger.info(f"āœ… Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}") logger.info(f" Validation loss: {checkpoint.get('val_loss', 'unknown'):.4f}") logger.info(f" Validation accuracy: {checkpoint.get('val_accuracy', 'unknown'):.4f}") return model def evaluate_model( model: torch.nn.Module, dataloader: DataLoader, device: torch.device, class_names: List[str] ) -> Dict[str, Any]: """Evaluate model and return comprehensive metrics.""" model.eval() all_preds = [] all_labels = [] all_probs = [] with torch.no_grad(): for batch in dataloader: features = batch['features'].to(device) labels = batch['labels'].to(device) batch_size, seq_len, feat_dim = features.shape features_flat = features.view(-1, feat_dim) labels_flat = labels.view(-1) # Forward pass shared_features = model.classifier_head.shared_layers(features_flat) logits = model.classifier_head.full_head(shared_features) probs = torch.softmax(logits, dim=-1) preds = torch.argmax(logits, dim=-1).cpu().numpy() all_preds.extend(preds) all_labels.extend(labels_flat.cpu().numpy()) all_probs.extend(probs.cpu().numpy()) # Calculate metrics accuracy = accuracy_score(all_labels, all_preds) f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0) f1_weighted = f1_score(all_labels, all_preds, average='weighted', zero_division=0) precision_macro = precision_score(all_labels, all_preds, average='macro', zero_division=0) recall_macro = recall_score(all_labels, all_preds, average='macro', zero_division=0) # Per-class metrics cm = confusion_matrix(all_labels, all_preds, labels=list(range(len(class_names)))) # Per-class accuracy per_class_accuracy = cm.diagonal() / cm.sum(axis=1) per_class_accuracy = np.nan_to_num(per_class_accuracy) # Handle division by zero # Classification report report = classification_report( all_labels, all_preds, target_names=class_names, output_dict=True, zero_division=0 ) # Confidence analysis all_probs = np.array(all_probs) max_probs = np.max(all_probs, axis=1) correct_mask = np.array(all_preds) == np.array(all_labels) avg_confidence_correct = np.mean(max_probs[correct_mask]) if np.any(correct_mask) else 0.0 avg_confidence_incorrect = np.mean(max_probs[~correct_mask]) if np.any(~correct_mask) else 0.0 return { 'overall_accuracy': float(accuracy), 'f1_macro': float(f1_macro), 'f1_weighted': float(f1_weighted), 'precision_macro': float(precision_macro), 'recall_macro': float(recall_macro), 'confusion_matrix': cm.tolist(), 'per_class_accuracy': per_class_accuracy.tolist(), 'classification_report': report, 'confidence': { 'avg_correct': float(avg_confidence_correct), 'avg_incorrect': float(avg_confidence_incorrect), 'confidence_distribution': { 'mean': float(np.mean(max_probs)), 'std': float(np.std(max_probs)), 'min': float(np.min(max_probs)), 'max': float(np.max(max_probs)) } }, 'num_samples': len(all_labels) } def plot_confusion_matrix(cm: np.ndarray, class_names: List[str], output_path: Path): """Plot and save confusion matrix.""" plt.figure(figsize=(10, 8)) sns.heatmap( cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names ) plt.title('Confusion Matrix') plt.ylabel('True Label') plt.xlabel('Predicted Label') plt.tight_layout() plt.savefig(output_path) logger.info(f"āœ… Saved confusion matrix to {output_path}") def main(): parser = argparse.ArgumentParser(description="Evaluate trained classifier") parser.add_argument('--checkpoint', type=str, required=True, help='Path to checkpoint file') parser.add_argument('--config', type=str, default='training/config.yaml', help='Path to config file') parser.add_argument('--dataset', type=str, default='data/training_dataset.json', help='Path to evaluation dataset') parser.add_argument('--output', type=str, default='training/evaluation_results.json', help='Path to save evaluation results') parser.add_argument('--plot', type=str, default='training/confusion_matrix.png', help='Path to save confusion matrix plot') args = parser.parse_args() # Load config with open(args.config, 'r') as f: config = yaml.safe_load(f) # Set device device = torch.device('cuda' if torch.cuda.is_available() and config['device']['use_cuda'] else 'cpu') logger.info(f"Using device: {device}") # Load model checkpoint_path = Path(args.checkpoint) if not checkpoint_path.exists(): logger.error(f"Checkpoint not found: {checkpoint_path}") return model = load_trained_model(checkpoint_path, Path(args.config)) model = model.to(device) # Load evaluation dataset dataset_path = Path(args.dataset) if not dataset_path.exists(): logger.error(f"Dataset not found: {dataset_path}") return with open(dataset_path, 'r') as f: eval_data = json.load(f) logger.info(f"Loaded {len(eval_data)} evaluation samples") # Create dataset and dataloader inference_pipeline = InferencePipeline( audio_config=default_audio_config, model_config=default_model_config, inference_config=default_inference_config ) phoneme_mapper = PhonemeMapper(frame_duration_ms=20, sample_rate=16000) from training.train_classifier_head import PhonemeDataset dataset = PhonemeDataset(eval_data, inference_pipeline, phoneme_mapper) dataloader = DataLoader( dataset, batch_size=config['training']['batch_size'], shuffle=False, collate_fn=collate_fn ) # Class names class_names = [ "Normal", "Substitution", "Omission", "Distortion", "Normal+Stutter", "Substitution+Stutter", "Omission+Stutter", "Distortion+Stutter" ] # Evaluate logger.info("Evaluating model...") metrics = evaluate_model(model, dataloader, device, class_names) # Print results logger.info("\n" + "="*50) logger.info("EVALUATION RESULTS") logger.info("="*50) logger.info(f"Overall Accuracy: {metrics['overall_accuracy']:.4f}") logger.info(f"F1 Score (macro): {metrics['f1_macro']:.4f}") logger.info(f"F1 Score (weighted): {metrics['f1_weighted']:.4f}") logger.info(f"Precision (macro): {metrics['precision_macro']:.4f}") logger.info(f"Recall (macro): {metrics['recall_macro']:.4f}") logger.info(f"\nPer-Class Accuracy:") for i, (name, acc) in enumerate(zip(class_names, metrics['per_class_accuracy'])): logger.info(f" {name}: {acc:.4f}") logger.info(f"\nConfidence Analysis:") logger.info(f" Avg confidence (correct): {metrics['confidence']['avg_correct']:.4f}") logger.info(f" Avg confidence (incorrect): {metrics['confidence']['avg_incorrect']:.4f}") # Save results output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, 'w') as f: json.dump(metrics, f, indent=2) logger.info(f"\nāœ… Saved evaluation results to {output_path}") # Plot confusion matrix if args.plot: plot_path = Path(args.plot) plot_path.parent.mkdir(parents=True, exist_ok=True) cm = np.array(metrics['confusion_matrix']) plot_confusion_matrix(cm, class_names, plot_path) if __name__ == "__main__": main()