NeuroMusicLab / data_processor.py
sofieff's picture
Initial commit: EEG Motor Imagery Music Composer
fa96cf5
raw
history blame
9.81 kB
"""
EEG Data Processing Module
-------------------------
Handles EEG data loading, preprocessing, and epoching for real-time classification.
Adapted from the original eeg_motor_imagery.py script.
"""
import scipy.io
import numpy as np
import mne
import torch
import pandas as pd
from typing import List, Tuple, Dict, Optional
from pathlib import Path
from scipy.signal import butter, lfilter
class EEGDataProcessor:
"""
Processes EEG data from .mat files for motor imagery classification.
"""
def __init__(self):
self.fs = None
self.ch_names = None
self.event_id = {
"left_hand": 1,
"right_hand": 2,
"neutral": 3,
"left_leg": 4,
"tongue": 5,
"right_leg": 6,
}
def load_mat_file(self, file_path: str) -> Tuple[np.ndarray, np.ndarray, List[str], int]:
"""Load and parse a single .mat EEG file."""
mat = scipy.io.loadmat(file_path)
content = mat['o'][0, 0]
labels = content[4].flatten()
signals = content[5]
chan_names_raw = content[6]
channels = [ch[0][0] for ch in chan_names_raw]
fs = int(content[2][0, 0])
return signals, labels, channels, fs
def create_raw_object(self, signals: np.ndarray, channels: List[str], fs: int,
drop_ground_electrodes: bool = True) -> mne.io.RawArray:
"""Create MNE Raw object from signal data."""
df = pd.DataFrame(signals, columns=channels)
if drop_ground_electrodes:
# Drop auxiliary channels that should be excluded
aux_exclude = ('X3', 'X5')
columns_to_drop = [ch for ch in channels if ch in aux_exclude]
df = df.drop(columns=columns_to_drop, errors="ignore")
print(f"Dropped auxiliary channels {columns_to_drop}. Remaining channels: {len(df.columns)}")
eeg = df.values.T
ch_names = df.columns.tolist()
self.ch_names = ch_names
self.fs = fs
info = mne.create_info(ch_names=ch_names, sfreq=fs, ch_types="eeg")
raw = mne.io.RawArray(eeg, info)
return raw
def extract_events(self, labels: np.ndarray) -> np.ndarray:
"""Extract events from label array."""
onsets = np.where((labels[1:] != 0) & (labels[:-1] == 0))[0] + 1
event_codes = labels[onsets].astype(int)
events = np.c_[onsets, np.zeros_like(onsets), event_codes]
# Keep only relevant events
mask = np.isin(events[:, 2], np.arange(1, 7))
events = events[mask]
return events
def create_epochs(self, raw: mne.io.RawArray, events: np.ndarray,
tmin: float = 0, tmax: float = 1.5) -> mne.Epochs:
"""Create epochs from raw data and events."""
epochs = mne.Epochs(
raw,
events=events,
event_id=self.event_id,
tmin=tmin,
tmax=tmax,
baseline=None,
preload=True,
)
return epochs
def process_files(self, file_paths: List[str]) -> Tuple[np.ndarray, np.ndarray]:
"""Process multiple EEG files and return combined data."""
all_epochs = []
for file_path in file_paths:
signals, labels, channels, fs = self.load_mat_file(file_path)
raw = self.create_raw_object(signals, channels, fs, drop_ground_electrodes=True)
events = self.extract_events(labels)
epochs = self.create_epochs(raw, events)
all_epochs.append(epochs)
if len(all_epochs) > 1:
epochs_combined = mne.concatenate_epochs(all_epochs)
else:
epochs_combined = all_epochs[0]
# Convert to arrays for model input
X = epochs_combined.get_data().astype("float32")
y = (epochs_combined.events[:, -1] - 1).astype("int64") # classes 0..5
return X, y
def load_continuous_data(self, file_paths: List[str]) -> Tuple[np.ndarray, int]:
"""
Load continuous raw EEG data without epoching.
Args:
file_paths: List of .mat file paths
Returns:
raw_data: Continuous EEG data [n_channels, n_timepoints]
fs: Sampling frequency
"""
all_raw_data = []
for file_path in file_paths:
signals, labels, channels, fs = self.load_mat_file(file_path)
raw = self.create_raw_object(signals, channels, fs, drop_ground_electrodes=True)
# Extract continuous data (no epoching)
continuous_data = raw.get_data() # [n_channels, n_timepoints]
all_raw_data.append(continuous_data)
# Concatenate all continuous data along time axis
if len(all_raw_data) > 1:
combined_raw = np.concatenate(all_raw_data, axis=1)
else:
combined_raw = all_raw_data[0]
return combined_raw, fs
def prepare_loso_split(self, file_paths: List[str], test_subject_idx: int = 0) -> Tuple:
"""
Prepare Leave-One-Subject-Out (LOSO) split for EEG data.
Args:
file_paths: List of .mat file paths (one per subject)
test_subject_idx: Index of subject to use for testing
Returns:
X_train, y_train, X_test, y_test, subject_info
"""
all_subjects_data = []
subject_info = []
# Load each subject separately
for i, file_path in enumerate(file_paths):
signals, labels, channels, fs = self.load_mat_file(file_path)
raw = self.create_raw_object(signals, channels, fs, drop_ground_electrodes=True)
events = self.extract_events(labels)
epochs = self.create_epochs(raw, events)
# Convert to arrays
X_subject = epochs.get_data().astype("float32")
y_subject = (epochs.events[:, -1] - 1).astype("int64")
all_subjects_data.append((X_subject, y_subject))
subject_info.append({
'file_path': file_path,
'subject_id': f"Subject_{i+1}",
'n_epochs': len(X_subject),
'channels': channels,
'fs': fs
})
# LOSO split: one subject for test, others for train
test_subject = all_subjects_data[test_subject_idx]
train_subjects = [all_subjects_data[i] for i in range(len(all_subjects_data)) if i != test_subject_idx]
# Combine training subjects
if len(train_subjects) > 1:
X_train = np.concatenate([subj[0] for subj in train_subjects], axis=0)
y_train = np.concatenate([subj[1] for subj in train_subjects], axis=0)
else:
X_train, y_train = train_subjects[0]
X_test, y_test = test_subject
print("LOSO Split:")
print(f" Test Subject: {subject_info[test_subject_idx]['subject_id']} ({len(X_test)} epochs)")
print(f" Train Subjects: {len(train_subjects)} subjects ({len(X_train)} epochs)")
return X_train, y_train, X_test, y_test, subject_info
def simulate_real_time_data(self, X: np.ndarray, y: np.ndarray, mode: str = "random") -> Tuple[np.ndarray, int]:
"""
Simulate real-time EEG data for demo purposes.
Args:
X: EEG data array (currently epoched data)
y: Labels array
mode: "random", "sequential", or "class_balanced"
Returns:
Single epoch and its true label
"""
if mode == "random":
idx = np.random.randint(0, len(X))
elif mode == "sequential":
# Use a counter for sequential sampling (would need to store state)
idx = np.random.randint(0, len(X)) # Simplified for now
elif mode == "class_balanced":
# Sample ensuring we get different classes
available_classes = np.unique(y)
target_class = np.random.choice(available_classes)
class_indices = np.where(y == target_class)[0]
idx = np.random.choice(class_indices)
else:
idx = np.random.randint(0, len(X))
return X[idx], y[idx]
def simulate_continuous_stream(self, raw_data: np.ndarray, fs: int, window_size: float = 1.5) -> np.ndarray:
"""
Simulate continuous EEG stream by extracting sliding windows from raw data.
Args:
raw_data: Continuous EEG data [n_channels, n_timepoints]
fs: Sampling frequency
window_size: Window size in seconds
Returns:
Single window of EEG data [n_channels, window_samples]
"""
window_samples = int(window_size * fs) # e.g., 1.5s * 200Hz = 300 samples
# Ensure we don't go beyond the data
max_start = raw_data.shape[1] - window_samples
if max_start <= 0:
return raw_data # Return full data if too short
# Random starting point in the continuous stream
start_idx = np.random.randint(0, max_start)
end_idx = start_idx + window_samples
# Extract window
window = raw_data[:, start_idx:end_idx]
return window