File size: 7,402 Bytes
fa96cf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""

EEG Motor Imagery Classifier Module

----------------------------------

Handles model loading, inference, and real-time prediction for motor imagery classification.

Based on the ShallowFBCSPNet architecture from the original eeg_motor_imagery.py script.

"""

import torch
import torch.nn as nn
import numpy as np
from braindecode.models import ShallowFBCSPNet
from typing import Dict, Tuple, Optional
import os
from sklearn.metrics import accuracy_score
from data_processor import EEGDataProcessor
from config import DEMO_DATA_PATHS

class MotorImageryClassifier:
    """

    Motor imagery classifier using ShallowFBCSPNet model.

    """
    
    def __init__(self, model_path: str = "shallow_weights_all.pth"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = None
        self.model_path = model_path
        self.class_names = {
            0: "left_hand",
            1: "right_hand", 
            2: "neutral",
            3: "left_leg",
            4: "tongue",
            5: "right_leg"
        }
        self.is_loaded = False
        
    def load_model(self, n_chans: int, n_times: int, n_outputs: int = 6):
        """Load the pre-trained ShallowFBCSPNet model."""
        try:
            self.model = ShallowFBCSPNet(
                n_chans=n_chans,
                n_outputs=n_outputs,
                n_times=n_times,
                final_conv_length="auto"
            ).to(self.device)
            
            if os.path.exists(self.model_path):
                try:
                    state_dict = torch.load(self.model_path, map_location=self.device)
                    self.model.load_state_dict(state_dict)
                    self.model.eval()
                    self.is_loaded = True
                    print(f"βœ… Pre-trained model loaded successfully from {self.model_path}")
                except Exception as model_error:
                    print(f"⚠️  Pre-trained model found but incompatible: {model_error}")
                    print("πŸ”„ Starting LOSO training with available EEG data...")
                    self.is_loaded = False
            else:
                print(f"❌ Pre-trained model weights not found at {self.model_path}")
                print("πŸ”„ Starting LOSO training with available EEG data...")
                self.is_loaded = False
                
        except Exception as e:
            print(f"❌ Error loading model: {e}")
            print("πŸ”„ Starting LOSO training with available EEG data...")
            self.is_loaded = False
    
    def get_model_status(self) -> str:
        """Get current model status for user interface."""
        if self.is_loaded:
            return "βœ… Pre-trained model loaded and ready"
        else:
            return "πŸ”„ Using LOSO training (training new model from EEG data)"
            
    def predict(self, eeg_data: np.ndarray) -> Tuple[int, float, Dict[str, float]]:
        """

        Predict motor imagery class from EEG data.

        

        Args:

            eeg_data: EEG data array of shape (n_channels, n_times)

            

        Returns:

            predicted_class: Predicted class index

            confidence: Confidence score

            probabilities: Dictionary of class probabilities

        """
        if not self.is_loaded:
            return self._fallback_loso_classification(eeg_data)
            
        # Ensure input is the right shape: (batch, channels, time)
        if eeg_data.ndim == 2:
            eeg_data = eeg_data[np.newaxis, ...]
        
        # Convert to tensor
        x = torch.from_numpy(eeg_data.astype(np.float32)).to(self.device)
        
        with torch.no_grad():
            output = self.model(x)
            probabilities = torch.softmax(output, dim=1)
            predicted_class = torch.argmax(probabilities, dim=1).cpu().numpy()[0]
            confidence = probabilities.max().cpu().numpy()
            
            # Convert to dictionary
            prob_dict = {
                self.class_names[i]: probabilities[0, i].cpu().numpy()
                for i in range(len(self.class_names))
            }
            
        return predicted_class, confidence, prob_dict

    def _fallback_loso_classification(self, eeg_data: np.ndarray) -> Tuple[int, float, Dict[str, float]]:
        """

        Fallback classification using LOSO (Leave-One-Session-Out) training.

        Trains a model on available data when pre-trained model isn't available.

        """
        try:
            print("πŸ”„ No pre-trained model available. Training new model using LOSO method...")
            print("⏳ This may take a moment - training on real EEG data...")
            
            # Initialize data processor
            processor = EEGDataProcessor()
            
            # Check if demo data files exist
            available_files = [f for f in DEMO_DATA_PATHS if os.path.exists(f)]
            if len(available_files) < 2:
                raise ValueError(f"Not enough data files for LOSO training. Need at least 2 files, found {len(available_files)}. "
                               f"Available files: {available_files}")
            
            # Perform LOSO split (using first session as test)
            X_train, y_train, X_test, y_test, session_info = processor.prepare_loso_split(
                available_files, test_session_idx=0
            )
            
            # Get data dimensions
            n_chans = X_train.shape[1]
            n_times = X_train.shape[2]
            
            # Create and train model
            self.model = ShallowFBCSPNet(
                n_chans=n_chans,
                n_outputs=6,
                n_times=n_times,
                final_conv_length="auto"
            ).to(self.device)
            
            # Simple training loop
            optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
            criterion = nn.CrossEntropyLoss()
            
            # Convert training data to tensors
            X_train_tensor = torch.from_numpy(X_train).float().to(self.device)
            y_train_tensor = torch.from_numpy(y_train).long().to(self.device)
            
            # Quick training (just a few epochs for demo)
            self.model.train()
            for epoch in range(50):
                optimizer.zero_grad()
                outputs = self.model(X_train_tensor)
                loss = criterion(outputs, y_train_tensor)
                loss.backward()
                optimizer.step()
                
                if epoch % 5 == 0:
                    print(f"LOSO Training - Epoch {epoch}, Loss: {loss.item():.4f}")
            
            # Switch to evaluation mode
            self.model.eval()
            self.is_loaded = True
            
            print("βœ… LOSO model trained successfully! Ready for classification.")
            
            # Now make prediction with the trained model
            return self.predict(eeg_data)
            
        except Exception as e:
            print(f"Error in LOSO training: {e}")
            raise RuntimeError(f"Failed to initialize classifier. Neither pre-trained model nor LOSO training succeeded: {e}")