eeg-cognitive-load / src /dataset_loader.py
dodo-2100's picture
Upload folder using huggingface_hub
2afe0cd verified
import os
import numpy as np
import pandas as pd
from typing import Tuple
class CognitiveLoadDataset:
"""
Data Loader for the 'Cognitive Load Assessment Through EEG' Dataset (Mendeley Data).
Expects .txt files in the data folder.
Format: Text files with EEG data.
"""
def __init__(self, data_path: str, sfreq: int = 250):
self.data_path = data_path
self.sfreq = sfreq
self.X = []
self.y = []
def load_data(self):
if not os.path.exists(self.data_path):
raise FileNotFoundError(f"Data path {self.data_path} not found.")
# Walk through directories because the dataset might have subfolders like "Arithmetic_Data"
files = []
for root, _, filenames in os.walk(self.data_path):
for f in filenames:
if f.endswith('.txt') and not f.startswith('._'):
files.append(os.path.join(root, f))
if not files:
print(f"No .txt files found in {self.data_path}. Please run download_data.py and unzip.")
return
print(f"Found {len(files)} txt files. Loading...")
for file_path in files:
# Heuristic for labels based on file naming in this specific dataset
# e.g., "lowlevel-1.txt", "highlevel-1.txt", "natural-1.txt"
filename = os.path.basename(file_path).lower()
label = 0 # Default (Natural/Rest)
if "low" in filename:
label = 0
elif "mid" in filename:
label = 1
elif "high" in filename:
label = 2
elif "natural" in filename:
label = 0
try:
# Read text file. Assuming space or comma separated values.
# Inspecting reported errors: '2022-06-29' -> First column is Timestamp.
# We need to drop non-numeric columns.
try:
df = pd.read_csv(file_path, header=None, sep=r'\s+|,\s*', engine='python')
# Convert all to numeric, coercing errors (timestamps becomes NaN)
df = df.apply(pd.to_numeric, errors='coerce')
# Drop columns that are all NaN (likely timestamp/date strings)
df = df.dropna(axis=1, how='all')
# Fill any remaining NaNs (rare) with 0 or drop rows
df = df.fillna(0)
except Exception as e:
print(f"Failed to read {filename} with pandas: {e}")
continue
data = df.values
data = df.values
# Check Expected Shape.
# We expect (Channels, Time).
# If we have (Time, Channels), the first dimension will be large (>100 typically).
if data.shape[0] > data.shape[1]:
# Likely (Time, Channels), so transpose to (Channels, Time)
data = data.T
# Now data should be (Channels, Time)
# Filter to 8 channels.
# OpenBCI raw txt usually has: [Index, Ch1, Ch2, ..., Ch8, Accel1, Accel2, Accel3, ...]
# So we want channels 1-8 (indices 1 to 9).
# But let's be careful. If we only have 8 channels, take them all.
if data.shape[0] >= 9:
# Assume first column is Index, take next 8
# We verify if row 0 looks like an index (monotonic) - simplified heuristic:
# Just take 1:9.
data = data[1:9, :]
elif data.shape[0] == 8:
pass # Already 8 channels
else:
# If we found less than 8 channels
print(f"Skipping {filename}: Not enough channels ({data.shape[0]}) found after transpose.")
continue
# Sliding Window Augmentation
# Window = 1000 samples (4 seconds @ 250Hz)
# Stride = 125 samples (0.5 seconds -> 87.5% overlap)
window_size = 1000
stride = 125
n_timepoints = data.shape[1]
if n_timepoints < window_size:
# Pad if shorter than 4s (unlikely for this dataset but good for safety)
pad_width = window_size - n_timepoints
data = np.pad(data, ((0,0), (0, pad_width)), mode='constant')
self.X.append(data.astype(np.float32))
self.y.append(label)
else:
# Slice into windows
for start in range(0, n_timepoints - window_size + 1, stride):
end = start + window_size
window = data[:, start:end]
self.X.append(window.astype(np.float32))
self.y.append(label)
except Exception as e:
print(f"Skipping {filename}: {e}")
# Convert to numpy array
if self.X:
self.X = np.array(self.X, dtype=np.float32)
self.y = np.array(self.y, dtype=np.longlong)
print(f"Loaded {len(self.X)} windows. Shape: {self.X.shape}")
else:
print("No valid data loaded.")
def get_data(self) -> Tuple[np.ndarray, np.ndarray]:
return self.X, self.y