Spaces:
Sleeping
Sleeping
| """ | |
| Enhanced utilities incorporating useful functions from the original src/ templates | |
| """ | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from sklearn.metrics import confusion_matrix, classification_report, accuracy_score | |
| from typing import Dict, List, Tuple, Optional | |
| import logging | |
| def evaluate_model_performance(model, dataloader, device, class_names: List[str]) -> Dict: | |
| """ | |
| Comprehensive model evaluation with metrics and visualizations. | |
| Enhanced version of src/evaluate.py | |
| """ | |
| model.eval() | |
| all_preds = [] | |
| all_labels = [] | |
| all_probs = [] | |
| with torch.no_grad(): | |
| for inputs, labels in dataloader: | |
| inputs = inputs.to(device) | |
| labels = labels.to(device) | |
| outputs = model(inputs) | |
| probs = torch.softmax(outputs, dim=1) | |
| _, preds = torch.max(outputs, 1) | |
| all_preds.extend(preds.cpu().numpy()) | |
| all_labels.extend(labels.cpu().numpy()) | |
| all_probs.extend(probs.cpu().numpy()) | |
| # Calculate metrics | |
| accuracy = accuracy_score(all_labels, all_preds) | |
| cm = confusion_matrix(all_labels, all_preds) | |
| report = classification_report(all_labels, all_preds, target_names=class_names, output_dict=True) | |
| return { | |
| 'accuracy': accuracy, | |
| 'predictions': all_preds, | |
| 'labels': all_labels, | |
| 'probabilities': all_probs, | |
| 'confusion_matrix': cm, | |
| 'classification_report': report | |
| } | |
| def plot_confusion_matrix(cm: np.ndarray, class_names: List[str], title: str = "Confusion Matrix") -> plt.Figure: | |
| """ | |
| Plot confusion matrix with proper formatting. | |
| Enhanced version from src/visualize.py | |
| """ | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', | |
| xticklabels=class_names, yticklabels=class_names, ax=ax) | |
| ax.set_title(title) | |
| ax.set_ylabel('True Label') | |
| ax.set_xlabel('Predicted Label') | |
| plt.tight_layout() | |
| return fig | |
| def plot_classification_probabilities(probabilities: np.ndarray, class_names: List[str], | |
| sample_indices: Optional[List[int]] = None) -> plt.Figure: | |
| """ | |
| Plot classification probabilities for selected samples. | |
| """ | |
| if sample_indices is None: | |
| sample_indices = list(range(min(10, len(probabilities)))) | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| x = np.arange(len(class_names)) | |
| width = 0.8 / len(sample_indices) | |
| for i, sample_idx in enumerate(sample_indices): | |
| offset = (i - len(sample_indices)/2) * width | |
| ax.bar(x + offset, probabilities[sample_idx], width, | |
| label=f'Sample {sample_idx}', alpha=0.8) | |
| ax.set_xlabel('Motor Imagery Classes') | |
| ax.set_ylabel('Probability') | |
| ax.set_title('Classification Probabilities') | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(class_names, rotation=45) | |
| ax.legend() | |
| ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| return fig | |
| def plot_training_history(history: Dict[str, List[float]]) -> plt.Figure: | |
| """ | |
| Plot training history (loss and accuracy). | |
| Enhanced version from src/visualize.py | |
| """ | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) | |
| # Plot accuracy | |
| if 'train_accuracy' in history and 'val_accuracy' in history: | |
| ax1.plot(history['train_accuracy'], label='Train', linewidth=2) | |
| ax1.plot(history['val_accuracy'], label='Validation', linewidth=2) | |
| ax1.set_title('Model Accuracy') | |
| ax1.set_xlabel('Epoch') | |
| ax1.set_ylabel('Accuracy') | |
| ax1.legend() | |
| ax1.grid(True, alpha=0.3) | |
| # Plot loss | |
| if 'train_loss' in history and 'val_loss' in history: | |
| ax2.plot(history['train_loss'], label='Train', linewidth=2) | |
| ax2.plot(history['val_loss'], label='Validation', linewidth=2) | |
| ax2.set_title('Model Loss') | |
| ax2.set_xlabel('Epoch') | |
| ax2.set_ylabel('Loss') | |
| ax2.legend() | |
| ax2.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| return fig | |
| def plot_eeg_channels(eeg_data: np.ndarray, channel_names: Optional[List[str]] = None, | |
| sample_rate: int = 256, title: str = "EEG Channels") -> plt.Figure: | |
| """ | |
| Plot multiple EEG channels. | |
| Enhanced visualization for EEG data. | |
| """ | |
| n_channels, n_samples = eeg_data.shape | |
| time_axis = np.arange(n_samples) / sample_rate | |
| # Determine subplot layout | |
| n_rows = int(np.ceil(np.sqrt(n_channels))) | |
| n_cols = int(np.ceil(n_channels / n_rows)) | |
| fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 10)) | |
| if n_channels == 1: | |
| axes = [axes] | |
| else: | |
| axes = axes.flatten() | |
| for i in range(n_channels): | |
| ax = axes[i] | |
| ax.plot(time_axis, eeg_data[i], 'b-', linewidth=1) | |
| channel_name = channel_names[i] if channel_names else f'Channel {i+1}' | |
| ax.set_title(channel_name) | |
| ax.set_xlabel('Time (s)') | |
| ax.set_ylabel('Amplitude') | |
| ax.grid(True, alpha=0.3) | |
| # Hide unused subplots | |
| for i in range(n_channels, len(axes)): | |
| axes[i].set_visible(False) | |
| plt.suptitle(title) | |
| plt.tight_layout() | |
| return fig | |
| class EarlyStopping: | |
| """ | |
| Early stopping utility from src/types/index.py | |
| """ | |
| def __init__(self, patience=7, min_delta=0, restore_best_weights=True, verbose=False): | |
| self.patience = patience | |
| self.min_delta = min_delta | |
| self.restore_best_weights = restore_best_weights | |
| self.verbose = verbose | |
| self.best_loss = None | |
| self.counter = 0 | |
| self.best_weights = None | |
| def __call__(self, val_loss, model): | |
| if self.best_loss is None: | |
| self.best_loss = val_loss | |
| self.save_checkpoint(model) | |
| elif val_loss < self.best_loss - self.min_delta: | |
| self.best_loss = val_loss | |
| self.counter = 0 | |
| self.save_checkpoint(model) | |
| else: | |
| self.counter += 1 | |
| if self.counter >= self.patience: | |
| if self.verbose: | |
| print(f'Early stopping triggered after {self.counter} epochs of no improvement') | |
| if self.restore_best_weights: | |
| model.load_state_dict(self.best_weights) | |
| return True | |
| return False | |
| def save_checkpoint(self, model): | |
| """Save model when validation loss decreases.""" | |
| if self.restore_best_weights: | |
| self.best_weights = model.state_dict().copy() | |
| def create_enhanced_evaluation_report(model, test_loader, class_names: List[str], | |
| device, save_plots: bool = True) -> Dict: | |
| """ | |
| Create a comprehensive evaluation report with plots and metrics. | |
| """ | |
| # Get evaluation results | |
| results = evaluate_model_performance(model, test_loader, device, class_names) | |
| # Create visualizations | |
| plots = {} | |
| # Confusion Matrix | |
| plots['confusion_matrix'] = plot_confusion_matrix( | |
| results['confusion_matrix'], class_names, | |
| title="Motor Imagery Classification - Confusion Matrix" | |
| ) | |
| # Classification Probabilities (sample) | |
| plots['probabilities'] = plot_classification_probabilities( | |
| np.array(results['probabilities']), class_names, | |
| sample_indices=list(range(min(5, len(results['probabilities'])))) | |
| ) | |
| if save_plots: | |
| for plot_name, fig in plots.items(): | |
| fig.savefig(f'{plot_name}.png', dpi=300, bbox_inches='tight') | |
| return { | |
| 'metrics': results, | |
| 'plots': plots, | |
| 'summary': { | |
| 'accuracy': results['accuracy'], | |
| 'n_samples': len(results['labels']), | |
| 'n_classes': len(class_names), | |
| 'class_names': class_names | |
| } | |
| } |