|
|
""" |
|
|
Evaluation Script for Trained Classifier Head |
|
|
|
|
|
Evaluates the trained classifier on test/validation data and generates |
|
|
comprehensive metrics including per-class accuracy, confusion matrix, etc. |
|
|
|
|
|
Usage: |
|
|
python training/evaluate_classifier.py --checkpoint models/checkpoints/classifier_head_best.pt |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import argparse |
|
|
import json |
|
|
import yaml |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Any |
|
|
import sys |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
from sklearn.metrics import ( |
|
|
accuracy_score, f1_score, precision_score, recall_score, |
|
|
confusion_matrix, classification_report |
|
|
) |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
|
|
|
from training.train_classifier_head import PhonemeDataset, collate_fn |
|
|
from torch.utils.data import DataLoader |
|
|
from models.speech_pathology_model import SpeechPathologyClassifier |
|
|
from models.phoneme_mapper import PhonemeMapper |
|
|
from inference.inference_pipeline import InferencePipeline |
|
|
from config import default_audio_config, default_model_config, default_inference_config |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def load_trained_model(checkpoint_path: Path, config_path: Path = Path("training/config.yaml")) -> torch.nn.Module: |
|
|
"""Load trained classifier head from checkpoint.""" |
|
|
|
|
|
with open(config_path, 'r') as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
|
|
|
inference_pipeline = InferencePipeline( |
|
|
audio_config=default_audio_config, |
|
|
model_config=default_model_config, |
|
|
inference_config=default_inference_config |
|
|
) |
|
|
|
|
|
model = inference_pipeline.model |
|
|
|
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
|
model.classifier_head.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
logger.info(f"✅ Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}") |
|
|
logger.info(f" Validation loss: {checkpoint.get('val_loss', 'unknown'):.4f}") |
|
|
logger.info(f" Validation accuracy: {checkpoint.get('val_accuracy', 'unknown'):.4f}") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def evaluate_model( |
|
|
model: torch.nn.Module, |
|
|
dataloader: DataLoader, |
|
|
device: torch.device, |
|
|
class_names: List[str] |
|
|
) -> Dict[str, Any]: |
|
|
"""Evaluate model and return comprehensive metrics.""" |
|
|
model.eval() |
|
|
|
|
|
all_preds = [] |
|
|
all_labels = [] |
|
|
all_probs = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in dataloader: |
|
|
features = batch['features'].to(device) |
|
|
labels = batch['labels'].to(device) |
|
|
|
|
|
batch_size, seq_len, feat_dim = features.shape |
|
|
features_flat = features.view(-1, feat_dim) |
|
|
labels_flat = labels.view(-1) |
|
|
|
|
|
|
|
|
shared_features = model.classifier_head.shared_layers(features_flat) |
|
|
logits = model.classifier_head.full_head(shared_features) |
|
|
probs = torch.softmax(logits, dim=-1) |
|
|
|
|
|
preds = torch.argmax(logits, dim=-1).cpu().numpy() |
|
|
all_preds.extend(preds) |
|
|
all_labels.extend(labels_flat.cpu().numpy()) |
|
|
all_probs.extend(probs.cpu().numpy()) |
|
|
|
|
|
|
|
|
accuracy = accuracy_score(all_labels, all_preds) |
|
|
f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0) |
|
|
f1_weighted = f1_score(all_labels, all_preds, average='weighted', zero_division=0) |
|
|
precision_macro = precision_score(all_labels, all_preds, average='macro', zero_division=0) |
|
|
recall_macro = recall_score(all_labels, all_preds, average='macro', zero_division=0) |
|
|
|
|
|
|
|
|
cm = confusion_matrix(all_labels, all_preds, labels=list(range(len(class_names)))) |
|
|
|
|
|
|
|
|
per_class_accuracy = cm.diagonal() / cm.sum(axis=1) |
|
|
per_class_accuracy = np.nan_to_num(per_class_accuracy) |
|
|
|
|
|
|
|
|
report = classification_report( |
|
|
all_labels, all_preds, |
|
|
target_names=class_names, |
|
|
output_dict=True, |
|
|
zero_division=0 |
|
|
) |
|
|
|
|
|
|
|
|
all_probs = np.array(all_probs) |
|
|
max_probs = np.max(all_probs, axis=1) |
|
|
correct_mask = np.array(all_preds) == np.array(all_labels) |
|
|
|
|
|
avg_confidence_correct = np.mean(max_probs[correct_mask]) if np.any(correct_mask) else 0.0 |
|
|
avg_confidence_incorrect = np.mean(max_probs[~correct_mask]) if np.any(~correct_mask) else 0.0 |
|
|
|
|
|
return { |
|
|
'overall_accuracy': float(accuracy), |
|
|
'f1_macro': float(f1_macro), |
|
|
'f1_weighted': float(f1_weighted), |
|
|
'precision_macro': float(precision_macro), |
|
|
'recall_macro': float(recall_macro), |
|
|
'confusion_matrix': cm.tolist(), |
|
|
'per_class_accuracy': per_class_accuracy.tolist(), |
|
|
'classification_report': report, |
|
|
'confidence': { |
|
|
'avg_correct': float(avg_confidence_correct), |
|
|
'avg_incorrect': float(avg_confidence_incorrect), |
|
|
'confidence_distribution': { |
|
|
'mean': float(np.mean(max_probs)), |
|
|
'std': float(np.std(max_probs)), |
|
|
'min': float(np.min(max_probs)), |
|
|
'max': float(np.max(max_probs)) |
|
|
} |
|
|
}, |
|
|
'num_samples': len(all_labels) |
|
|
} |
|
|
|
|
|
|
|
|
def plot_confusion_matrix(cm: np.ndarray, class_names: List[str], output_path: Path): |
|
|
"""Plot and save confusion matrix.""" |
|
|
plt.figure(figsize=(10, 8)) |
|
|
sns.heatmap( |
|
|
cm, |
|
|
annot=True, |
|
|
fmt='d', |
|
|
cmap='Blues', |
|
|
xticklabels=class_names, |
|
|
yticklabels=class_names |
|
|
) |
|
|
plt.title('Confusion Matrix') |
|
|
plt.ylabel('True Label') |
|
|
plt.xlabel('Predicted Label') |
|
|
plt.tight_layout() |
|
|
plt.savefig(output_path) |
|
|
logger.info(f"✅ Saved confusion matrix to {output_path}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Evaluate trained classifier") |
|
|
parser.add_argument('--checkpoint', type=str, required=True, |
|
|
help='Path to checkpoint file') |
|
|
parser.add_argument('--config', type=str, default='training/config.yaml', |
|
|
help='Path to config file') |
|
|
parser.add_argument('--dataset', type=str, default='data/training_dataset.json', |
|
|
help='Path to evaluation dataset') |
|
|
parser.add_argument('--output', type=str, default='training/evaluation_results.json', |
|
|
help='Path to save evaluation results') |
|
|
parser.add_argument('--plot', type=str, default='training/confusion_matrix.png', |
|
|
help='Path to save confusion matrix plot') |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
with open(args.config, 'r') as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() and config['device']['use_cuda'] else 'cpu') |
|
|
logger.info(f"Using device: {device}") |
|
|
|
|
|
|
|
|
checkpoint_path = Path(args.checkpoint) |
|
|
if not checkpoint_path.exists(): |
|
|
logger.error(f"Checkpoint not found: {checkpoint_path}") |
|
|
return |
|
|
|
|
|
model = load_trained_model(checkpoint_path, Path(args.config)) |
|
|
model = model.to(device) |
|
|
|
|
|
|
|
|
dataset_path = Path(args.dataset) |
|
|
if not dataset_path.exists(): |
|
|
logger.error(f"Dataset not found: {dataset_path}") |
|
|
return |
|
|
|
|
|
with open(dataset_path, 'r') as f: |
|
|
eval_data = json.load(f) |
|
|
|
|
|
logger.info(f"Loaded {len(eval_data)} evaluation samples") |
|
|
|
|
|
|
|
|
inference_pipeline = InferencePipeline( |
|
|
audio_config=default_audio_config, |
|
|
model_config=default_model_config, |
|
|
inference_config=default_inference_config |
|
|
) |
|
|
|
|
|
phoneme_mapper = PhonemeMapper(frame_duration_ms=20, sample_rate=16000) |
|
|
|
|
|
from training.train_classifier_head import PhonemeDataset |
|
|
dataset = PhonemeDataset(eval_data, inference_pipeline, phoneme_mapper) |
|
|
|
|
|
dataloader = DataLoader( |
|
|
dataset, |
|
|
batch_size=config['training']['batch_size'], |
|
|
shuffle=False, |
|
|
collate_fn=collate_fn |
|
|
) |
|
|
|
|
|
|
|
|
class_names = [ |
|
|
"Normal", |
|
|
"Substitution", |
|
|
"Omission", |
|
|
"Distortion", |
|
|
"Normal+Stutter", |
|
|
"Substitution+Stutter", |
|
|
"Omission+Stutter", |
|
|
"Distortion+Stutter" |
|
|
] |
|
|
|
|
|
|
|
|
logger.info("Evaluating model...") |
|
|
metrics = evaluate_model(model, dataloader, device, class_names) |
|
|
|
|
|
|
|
|
logger.info("\n" + "="*50) |
|
|
logger.info("EVALUATION RESULTS") |
|
|
logger.info("="*50) |
|
|
logger.info(f"Overall Accuracy: {metrics['overall_accuracy']:.4f}") |
|
|
logger.info(f"F1 Score (macro): {metrics['f1_macro']:.4f}") |
|
|
logger.info(f"F1 Score (weighted): {metrics['f1_weighted']:.4f}") |
|
|
logger.info(f"Precision (macro): {metrics['precision_macro']:.4f}") |
|
|
logger.info(f"Recall (macro): {metrics['recall_macro']:.4f}") |
|
|
logger.info(f"\nPer-Class Accuracy:") |
|
|
for i, (name, acc) in enumerate(zip(class_names, metrics['per_class_accuracy'])): |
|
|
logger.info(f" {name}: {acc:.4f}") |
|
|
logger.info(f"\nConfidence Analysis:") |
|
|
logger.info(f" Avg confidence (correct): {metrics['confidence']['avg_correct']:.4f}") |
|
|
logger.info(f" Avg confidence (incorrect): {metrics['confidence']['avg_incorrect']:.4f}") |
|
|
|
|
|
|
|
|
output_path = Path(args.output) |
|
|
output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
with open(output_path, 'w') as f: |
|
|
json.dump(metrics, f, indent=2) |
|
|
logger.info(f"\n✅ Saved evaluation results to {output_path}") |
|
|
|
|
|
|
|
|
if args.plot: |
|
|
plot_path = Path(args.plot) |
|
|
plot_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
cm = np.array(metrics['confusion_matrix']) |
|
|
plot_confusion_matrix(cm, class_names, plot_path) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|