NeuroMusicLab / enhanced_utils.py
sofieff's picture
Initial commit: EEG Motor Imagery Music Composer
fa96cf5
raw
history blame
7.86 kB
"""
Enhanced utilities incorporating useful functions from the original src/ templates
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
from typing import Dict, List, Tuple, Optional
import logging
def evaluate_model_performance(model, dataloader, device, class_names: List[str]) -> Dict:
"""
Comprehensive model evaluation with metrics and visualizations.
Enhanced version of src/evaluate.py
"""
model.eval()
all_preds = []
all_labels = []
all_probs = []
with torch.no_grad():
for inputs, labels in dataloader:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
probs = torch.softmax(outputs, dim=1)
_, preds = torch.max(outputs, 1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
all_probs.extend(probs.cpu().numpy())
# Calculate metrics
accuracy = accuracy_score(all_labels, all_preds)
cm = confusion_matrix(all_labels, all_preds)
report = classification_report(all_labels, all_preds, target_names=class_names, output_dict=True)
return {
'accuracy': accuracy,
'predictions': all_preds,
'labels': all_labels,
'probabilities': all_probs,
'confusion_matrix': cm,
'classification_report': report
}
def plot_confusion_matrix(cm: np.ndarray, class_names: List[str], title: str = "Confusion Matrix") -> plt.Figure:
"""
Plot confusion matrix with proper formatting.
Enhanced version from src/visualize.py
"""
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names, ax=ax)
ax.set_title(title)
ax.set_ylabel('True Label')
ax.set_xlabel('Predicted Label')
plt.tight_layout()
return fig
def plot_classification_probabilities(probabilities: np.ndarray, class_names: List[str],
sample_indices: Optional[List[int]] = None) -> plt.Figure:
"""
Plot classification probabilities for selected samples.
"""
if sample_indices is None:
sample_indices = list(range(min(10, len(probabilities))))
fig, ax = plt.subplots(figsize=(12, 6))
x = np.arange(len(class_names))
width = 0.8 / len(sample_indices)
for i, sample_idx in enumerate(sample_indices):
offset = (i - len(sample_indices)/2) * width
ax.bar(x + offset, probabilities[sample_idx], width,
label=f'Sample {sample_idx}', alpha=0.8)
ax.set_xlabel('Motor Imagery Classes')
ax.set_ylabel('Probability')
ax.set_title('Classification Probabilities')
ax.set_xticks(x)
ax.set_xticklabels(class_names, rotation=45)
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
return fig
def plot_training_history(history: Dict[str, List[float]]) -> plt.Figure:
"""
Plot training history (loss and accuracy).
Enhanced version from src/visualize.py
"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
# Plot accuracy
if 'train_accuracy' in history and 'val_accuracy' in history:
ax1.plot(history['train_accuracy'], label='Train', linewidth=2)
ax1.plot(history['val_accuracy'], label='Validation', linewidth=2)
ax1.set_title('Model Accuracy')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Accuracy')
ax1.legend()
ax1.grid(True, alpha=0.3)
# Plot loss
if 'train_loss' in history and 'val_loss' in history:
ax2.plot(history['train_loss'], label='Train', linewidth=2)
ax2.plot(history['val_loss'], label='Validation', linewidth=2)
ax2.set_title('Model Loss')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Loss')
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
return fig
def plot_eeg_channels(eeg_data: np.ndarray, channel_names: Optional[List[str]] = None,
sample_rate: int = 256, title: str = "EEG Channels") -> plt.Figure:
"""
Plot multiple EEG channels.
Enhanced visualization for EEG data.
"""
n_channels, n_samples = eeg_data.shape
time_axis = np.arange(n_samples) / sample_rate
# Determine subplot layout
n_rows = int(np.ceil(np.sqrt(n_channels)))
n_cols = int(np.ceil(n_channels / n_rows))
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 10))
if n_channels == 1:
axes = [axes]
else:
axes = axes.flatten()
for i in range(n_channels):
ax = axes[i]
ax.plot(time_axis, eeg_data[i], 'b-', linewidth=1)
channel_name = channel_names[i] if channel_names else f'Channel {i+1}'
ax.set_title(channel_name)
ax.set_xlabel('Time (s)')
ax.set_ylabel('Amplitude')
ax.grid(True, alpha=0.3)
# Hide unused subplots
for i in range(n_channels, len(axes)):
axes[i].set_visible(False)
plt.suptitle(title)
plt.tight_layout()
return fig
class EarlyStopping:
"""
Early stopping utility from src/types/index.py
"""
def __init__(self, patience=7, min_delta=0, restore_best_weights=True, verbose=False):
self.patience = patience
self.min_delta = min_delta
self.restore_best_weights = restore_best_weights
self.verbose = verbose
self.best_loss = None
self.counter = 0
self.best_weights = None
def __call__(self, val_loss, model):
if self.best_loss is None:
self.best_loss = val_loss
self.save_checkpoint(model)
elif val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0
self.save_checkpoint(model)
else:
self.counter += 1
if self.counter >= self.patience:
if self.verbose:
print(f'Early stopping triggered after {self.counter} epochs of no improvement')
if self.restore_best_weights:
model.load_state_dict(self.best_weights)
return True
return False
def save_checkpoint(self, model):
"""Save model when validation loss decreases."""
if self.restore_best_weights:
self.best_weights = model.state_dict().copy()
def create_enhanced_evaluation_report(model, test_loader, class_names: List[str],
device, save_plots: bool = True) -> Dict:
"""
Create a comprehensive evaluation report with plots and metrics.
"""
# Get evaluation results
results = evaluate_model_performance(model, test_loader, device, class_names)
# Create visualizations
plots = {}
# Confusion Matrix
plots['confusion_matrix'] = plot_confusion_matrix(
results['confusion_matrix'], class_names,
title="Motor Imagery Classification - Confusion Matrix"
)
# Classification Probabilities (sample)
plots['probabilities'] = plot_classification_probabilities(
np.array(results['probabilities']), class_names,
sample_indices=list(range(min(5, len(results['probabilities']))))
)
if save_plots:
for plot_name, fig in plots.items():
fig.savefig(f'{plot_name}.png', dpi=300, bbox_inches='tight')
return {
'metrics': results,
'plots': plots,
'summary': {
'accuracy': results['accuracy'],
'n_samples': len(results['labels']),
'n_classes': len(class_names),
'class_names': class_names
}
}