zlaqa-version-c-ai-enginee / training /evaluate_classifier.py
anfastech's picture
New: Phoneme-level speech pathology diagnosis MVP with real-time streaming
1cd6149
"""
Evaluation Script for Trained Classifier Head
Evaluates the trained classifier on test/validation data and generates
comprehensive metrics including per-class accuracy, confusion matrix, etc.
Usage:
python training/evaluate_classifier.py --checkpoint models/checkpoints/classifier_head_best.pt
"""
import logging
import argparse
import json
import yaml
from pathlib import Path
from typing import Dict, List, Any
import sys
import torch
import numpy as np
from sklearn.metrics import (
accuracy_score, f1_score, precision_score, recall_score,
confusion_matrix, classification_report
)
import matplotlib.pyplot as plt
import seaborn as sns
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from training.train_classifier_head import PhonemeDataset, collate_fn
from torch.utils.data import DataLoader
from models.speech_pathology_model import SpeechPathologyClassifier
from models.phoneme_mapper import PhonemeMapper
from inference.inference_pipeline import InferencePipeline
from config import default_audio_config, default_model_config, default_inference_config
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def load_trained_model(checkpoint_path: Path, config_path: Path = Path("training/config.yaml")) -> torch.nn.Module:
"""Load trained classifier head from checkpoint."""
# Load config
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
# Initialize inference pipeline
inference_pipeline = InferencePipeline(
audio_config=default_audio_config,
model_config=default_model_config,
inference_config=default_inference_config
)
model = inference_pipeline.model
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.classifier_head.load_state_dict(checkpoint['model_state_dict'])
logger.info(f"✅ Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
logger.info(f" Validation loss: {checkpoint.get('val_loss', 'unknown'):.4f}")
logger.info(f" Validation accuracy: {checkpoint.get('val_accuracy', 'unknown'):.4f}")
return model
def evaluate_model(
model: torch.nn.Module,
dataloader: DataLoader,
device: torch.device,
class_names: List[str]
) -> Dict[str, Any]:
"""Evaluate model and return comprehensive metrics."""
model.eval()
all_preds = []
all_labels = []
all_probs = []
with torch.no_grad():
for batch in dataloader:
features = batch['features'].to(device)
labels = batch['labels'].to(device)
batch_size, seq_len, feat_dim = features.shape
features_flat = features.view(-1, feat_dim)
labels_flat = labels.view(-1)
# Forward pass
shared_features = model.classifier_head.shared_layers(features_flat)
logits = model.classifier_head.full_head(shared_features)
probs = torch.softmax(logits, dim=-1)
preds = torch.argmax(logits, dim=-1).cpu().numpy()
all_preds.extend(preds)
all_labels.extend(labels_flat.cpu().numpy())
all_probs.extend(probs.cpu().numpy())
# Calculate metrics
accuracy = accuracy_score(all_labels, all_preds)
f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0)
f1_weighted = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
precision_macro = precision_score(all_labels, all_preds, average='macro', zero_division=0)
recall_macro = recall_score(all_labels, all_preds, average='macro', zero_division=0)
# Per-class metrics
cm = confusion_matrix(all_labels, all_preds, labels=list(range(len(class_names))))
# Per-class accuracy
per_class_accuracy = cm.diagonal() / cm.sum(axis=1)
per_class_accuracy = np.nan_to_num(per_class_accuracy) # Handle division by zero
# Classification report
report = classification_report(
all_labels, all_preds,
target_names=class_names,
output_dict=True,
zero_division=0
)
# Confidence analysis
all_probs = np.array(all_probs)
max_probs = np.max(all_probs, axis=1)
correct_mask = np.array(all_preds) == np.array(all_labels)
avg_confidence_correct = np.mean(max_probs[correct_mask]) if np.any(correct_mask) else 0.0
avg_confidence_incorrect = np.mean(max_probs[~correct_mask]) if np.any(~correct_mask) else 0.0
return {
'overall_accuracy': float(accuracy),
'f1_macro': float(f1_macro),
'f1_weighted': float(f1_weighted),
'precision_macro': float(precision_macro),
'recall_macro': float(recall_macro),
'confusion_matrix': cm.tolist(),
'per_class_accuracy': per_class_accuracy.tolist(),
'classification_report': report,
'confidence': {
'avg_correct': float(avg_confidence_correct),
'avg_incorrect': float(avg_confidence_incorrect),
'confidence_distribution': {
'mean': float(np.mean(max_probs)),
'std': float(np.std(max_probs)),
'min': float(np.min(max_probs)),
'max': float(np.max(max_probs))
}
},
'num_samples': len(all_labels)
}
def plot_confusion_matrix(cm: np.ndarray, class_names: List[str], output_path: Path):
"""Plot and save confusion matrix."""
plt.figure(figsize=(10, 8))
sns.heatmap(
cm,
annot=True,
fmt='d',
cmap='Blues',
xticklabels=class_names,
yticklabels=class_names
)
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.savefig(output_path)
logger.info(f"✅ Saved confusion matrix to {output_path}")
def main():
parser = argparse.ArgumentParser(description="Evaluate trained classifier")
parser.add_argument('--checkpoint', type=str, required=True,
help='Path to checkpoint file')
parser.add_argument('--config', type=str, default='training/config.yaml',
help='Path to config file')
parser.add_argument('--dataset', type=str, default='data/training_dataset.json',
help='Path to evaluation dataset')
parser.add_argument('--output', type=str, default='training/evaluation_results.json',
help='Path to save evaluation results')
parser.add_argument('--plot', type=str, default='training/confusion_matrix.png',
help='Path to save confusion matrix plot')
args = parser.parse_args()
# Load config
with open(args.config, 'r') as f:
config = yaml.safe_load(f)
# Set device
device = torch.device('cuda' if torch.cuda.is_available() and config['device']['use_cuda'] else 'cpu')
logger.info(f"Using device: {device}")
# Load model
checkpoint_path = Path(args.checkpoint)
if not checkpoint_path.exists():
logger.error(f"Checkpoint not found: {checkpoint_path}")
return
model = load_trained_model(checkpoint_path, Path(args.config))
model = model.to(device)
# Load evaluation dataset
dataset_path = Path(args.dataset)
if not dataset_path.exists():
logger.error(f"Dataset not found: {dataset_path}")
return
with open(dataset_path, 'r') as f:
eval_data = json.load(f)
logger.info(f"Loaded {len(eval_data)} evaluation samples")
# Create dataset and dataloader
inference_pipeline = InferencePipeline(
audio_config=default_audio_config,
model_config=default_model_config,
inference_config=default_inference_config
)
phoneme_mapper = PhonemeMapper(frame_duration_ms=20, sample_rate=16000)
from training.train_classifier_head import PhonemeDataset
dataset = PhonemeDataset(eval_data, inference_pipeline, phoneme_mapper)
dataloader = DataLoader(
dataset,
batch_size=config['training']['batch_size'],
shuffle=False,
collate_fn=collate_fn
)
# Class names
class_names = [
"Normal",
"Substitution",
"Omission",
"Distortion",
"Normal+Stutter",
"Substitution+Stutter",
"Omission+Stutter",
"Distortion+Stutter"
]
# Evaluate
logger.info("Evaluating model...")
metrics = evaluate_model(model, dataloader, device, class_names)
# Print results
logger.info("\n" + "="*50)
logger.info("EVALUATION RESULTS")
logger.info("="*50)
logger.info(f"Overall Accuracy: {metrics['overall_accuracy']:.4f}")
logger.info(f"F1 Score (macro): {metrics['f1_macro']:.4f}")
logger.info(f"F1 Score (weighted): {metrics['f1_weighted']:.4f}")
logger.info(f"Precision (macro): {metrics['precision_macro']:.4f}")
logger.info(f"Recall (macro): {metrics['recall_macro']:.4f}")
logger.info(f"\nPer-Class Accuracy:")
for i, (name, acc) in enumerate(zip(class_names, metrics['per_class_accuracy'])):
logger.info(f" {name}: {acc:.4f}")
logger.info(f"\nConfidence Analysis:")
logger.info(f" Avg confidence (correct): {metrics['confidence']['avg_correct']:.4f}")
logger.info(f" Avg confidence (incorrect): {metrics['confidence']['avg_incorrect']:.4f}")
# Save results
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w') as f:
json.dump(metrics, f, indent=2)
logger.info(f"\n✅ Saved evaluation results to {output_path}")
# Plot confusion matrix
if args.plot:
plot_path = Path(args.plot)
plot_path.parent.mkdir(parents=True, exist_ok=True)
cm = np.array(metrics['confusion_matrix'])
plot_confusion_matrix(cm, class_names, plot_path)
if __name__ == "__main__":
main()