NeuroMusicLab / classifier.py
sofieff's picture
Initial commit: EEG Motor Imagery Music Composer
fa96cf5
raw
history blame
7.4 kB
"""
EEG Motor Imagery Classifier Module
----------------------------------
Handles model loading, inference, and real-time prediction for motor imagery classification.
Based on the ShallowFBCSPNet architecture from the original eeg_motor_imagery.py script.
"""
import torch
import torch.nn as nn
import numpy as np
from braindecode.models import ShallowFBCSPNet
from typing import Dict, Tuple, Optional
import os
from sklearn.metrics import accuracy_score
from data_processor import EEGDataProcessor
from config import DEMO_DATA_PATHS
class MotorImageryClassifier:
"""
Motor imagery classifier using ShallowFBCSPNet model.
"""
def __init__(self, model_path: str = "shallow_weights_all.pth"):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = None
self.model_path = model_path
self.class_names = {
0: "left_hand",
1: "right_hand",
2: "neutral",
3: "left_leg",
4: "tongue",
5: "right_leg"
}
self.is_loaded = False
def load_model(self, n_chans: int, n_times: int, n_outputs: int = 6):
"""Load the pre-trained ShallowFBCSPNet model."""
try:
self.model = ShallowFBCSPNet(
n_chans=n_chans,
n_outputs=n_outputs,
n_times=n_times,
final_conv_length="auto"
).to(self.device)
if os.path.exists(self.model_path):
try:
state_dict = torch.load(self.model_path, map_location=self.device)
self.model.load_state_dict(state_dict)
self.model.eval()
self.is_loaded = True
print(f"βœ… Pre-trained model loaded successfully from {self.model_path}")
except Exception as model_error:
print(f"⚠️ Pre-trained model found but incompatible: {model_error}")
print("πŸ”„ Starting LOSO training with available EEG data...")
self.is_loaded = False
else:
print(f"❌ Pre-trained model weights not found at {self.model_path}")
print("πŸ”„ Starting LOSO training with available EEG data...")
self.is_loaded = False
except Exception as e:
print(f"❌ Error loading model: {e}")
print("πŸ”„ Starting LOSO training with available EEG data...")
self.is_loaded = False
def get_model_status(self) -> str:
"""Get current model status for user interface."""
if self.is_loaded:
return "βœ… Pre-trained model loaded and ready"
else:
return "πŸ”„ Using LOSO training (training new model from EEG data)"
def predict(self, eeg_data: np.ndarray) -> Tuple[int, float, Dict[str, float]]:
"""
Predict motor imagery class from EEG data.
Args:
eeg_data: EEG data array of shape (n_channels, n_times)
Returns:
predicted_class: Predicted class index
confidence: Confidence score
probabilities: Dictionary of class probabilities
"""
if not self.is_loaded:
return self._fallback_loso_classification(eeg_data)
# Ensure input is the right shape: (batch, channels, time)
if eeg_data.ndim == 2:
eeg_data = eeg_data[np.newaxis, ...]
# Convert to tensor
x = torch.from_numpy(eeg_data.astype(np.float32)).to(self.device)
with torch.no_grad():
output = self.model(x)
probabilities = torch.softmax(output, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).cpu().numpy()[0]
confidence = probabilities.max().cpu().numpy()
# Convert to dictionary
prob_dict = {
self.class_names[i]: probabilities[0, i].cpu().numpy()
for i in range(len(self.class_names))
}
return predicted_class, confidence, prob_dict
def _fallback_loso_classification(self, eeg_data: np.ndarray) -> Tuple[int, float, Dict[str, float]]:
"""
Fallback classification using LOSO (Leave-One-Session-Out) training.
Trains a model on available data when pre-trained model isn't available.
"""
try:
print("πŸ”„ No pre-trained model available. Training new model using LOSO method...")
print("⏳ This may take a moment - training on real EEG data...")
# Initialize data processor
processor = EEGDataProcessor()
# Check if demo data files exist
available_files = [f for f in DEMO_DATA_PATHS if os.path.exists(f)]
if len(available_files) < 2:
raise ValueError(f"Not enough data files for LOSO training. Need at least 2 files, found {len(available_files)}. "
f"Available files: {available_files}")
# Perform LOSO split (using first session as test)
X_train, y_train, X_test, y_test, session_info = processor.prepare_loso_split(
available_files, test_session_idx=0
)
# Get data dimensions
n_chans = X_train.shape[1]
n_times = X_train.shape[2]
# Create and train model
self.model = ShallowFBCSPNet(
n_chans=n_chans,
n_outputs=6,
n_times=n_times,
final_conv_length="auto"
).to(self.device)
# Simple training loop
optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# Convert training data to tensors
X_train_tensor = torch.from_numpy(X_train).float().to(self.device)
y_train_tensor = torch.from_numpy(y_train).long().to(self.device)
# Quick training (just a few epochs for demo)
self.model.train()
for epoch in range(50):
optimizer.zero_grad()
outputs = self.model(X_train_tensor)
loss = criterion(outputs, y_train_tensor)
loss.backward()
optimizer.step()
if epoch % 5 == 0:
print(f"LOSO Training - Epoch {epoch}, Loss: {loss.item():.4f}")
# Switch to evaluation mode
self.model.eval()
self.is_loaded = True
print("βœ… LOSO model trained successfully! Ready for classification.")
# Now make prediction with the trained model
return self.predict(eeg_data)
except Exception as e:
print(f"Error in LOSO training: {e}")
raise RuntimeError(f"Failed to initialize classifier. Neither pre-trained model nor LOSO training succeeded: {e}")