| import os
|
| import numpy as np
|
| import pandas as pd
|
| from typing import Tuple
|
|
|
| class CognitiveLoadDataset:
|
| """
|
| Data Loader for the 'Cognitive Load Assessment Through EEG' Dataset (Mendeley Data).
|
| Expects .txt files in the data folder.
|
| Format: Text files with EEG data.
|
| """
|
| def __init__(self, data_path: str, sfreq: int = 250):
|
| self.data_path = data_path
|
| self.sfreq = sfreq
|
| self.X = []
|
| self.y = []
|
|
|
| def load_data(self):
|
| if not os.path.exists(self.data_path):
|
| raise FileNotFoundError(f"Data path {self.data_path} not found.")
|
|
|
|
|
| files = []
|
| for root, _, filenames in os.walk(self.data_path):
|
| for f in filenames:
|
| if f.endswith('.txt') and not f.startswith('._'):
|
| files.append(os.path.join(root, f))
|
|
|
| if not files:
|
| print(f"No .txt files found in {self.data_path}. Please run download_data.py and unzip.")
|
| return
|
|
|
| print(f"Found {len(files)} txt files. Loading...")
|
|
|
| for file_path in files:
|
|
|
|
|
| filename = os.path.basename(file_path).lower()
|
|
|
| label = 0
|
| if "low" in filename:
|
| label = 0
|
| elif "mid" in filename:
|
| label = 1
|
| elif "high" in filename:
|
| label = 2
|
| elif "natural" in filename:
|
| label = 0
|
|
|
| try:
|
|
|
|
|
|
|
| try:
|
| df = pd.read_csv(file_path, header=None, sep=r'\s+|,\s*', engine='python')
|
|
|
|
|
| df = df.apply(pd.to_numeric, errors='coerce')
|
|
|
|
|
| df = df.dropna(axis=1, how='all')
|
|
|
|
|
| df = df.fillna(0)
|
|
|
| except Exception as e:
|
| print(f"Failed to read {filename} with pandas: {e}")
|
| continue
|
|
|
| data = df.values
|
|
|
| data = df.values
|
|
|
|
|
|
|
|
|
| if data.shape[0] > data.shape[1]:
|
|
|
| data = data.T
|
|
|
|
|
|
|
|
|
|
|
|
|
| if data.shape[0] >= 9:
|
|
|
|
|
|
|
| data = data[1:9, :]
|
| elif data.shape[0] == 8:
|
| pass
|
| else:
|
|
|
| print(f"Skipping {filename}: Not enough channels ({data.shape[0]}) found after transpose.")
|
| continue
|
|
|
|
|
|
|
|
|
| window_size = 1000
|
| stride = 125
|
| n_timepoints = data.shape[1]
|
|
|
| if n_timepoints < window_size:
|
|
|
| pad_width = window_size - n_timepoints
|
| data = np.pad(data, ((0,0), (0, pad_width)), mode='constant')
|
| self.X.append(data.astype(np.float32))
|
| self.y.append(label)
|
| else:
|
|
|
| for start in range(0, n_timepoints - window_size + 1, stride):
|
| end = start + window_size
|
| window = data[:, start:end]
|
| self.X.append(window.astype(np.float32))
|
| self.y.append(label)
|
|
|
| except Exception as e:
|
| print(f"Skipping {filename}: {e}")
|
|
|
|
|
| if self.X:
|
| self.X = np.array(self.X, dtype=np.float32)
|
| self.y = np.array(self.y, dtype=np.longlong)
|
| print(f"Loaded {len(self.X)} windows. Shape: {self.X.shape}")
|
| else:
|
| print("No valid data loaded.")
|
|
|
| def get_data(self) -> Tuple[np.ndarray, np.ndarray]:
|
| return self.X, self.y
|
|
|