import os import numpy as np import pandas as pd from typing import Tuple class CognitiveLoadDataset: """ Data Loader for the 'Cognitive Load Assessment Through EEG' Dataset (Mendeley Data). Expects .txt files in the data folder. Format: Text files with EEG data. """ def __init__(self, data_path: str, sfreq: int = 250): self.data_path = data_path self.sfreq = sfreq self.X = [] self.y = [] def load_data(self): if not os.path.exists(self.data_path): raise FileNotFoundError(f"Data path {self.data_path} not found.") # Walk through directories because the dataset might have subfolders like "Arithmetic_Data" files = [] for root, _, filenames in os.walk(self.data_path): for f in filenames: if f.endswith('.txt') and not f.startswith('._'): files.append(os.path.join(root, f)) if not files: print(f"No .txt files found in {self.data_path}. Please run download_data.py and unzip.") return print(f"Found {len(files)} txt files. Loading...") for file_path in files: # Heuristic for labels based on file naming in this specific dataset # e.g., "lowlevel-1.txt", "highlevel-1.txt", "natural-1.txt" filename = os.path.basename(file_path).lower() label = 0 # Default (Natural/Rest) if "low" in filename: label = 0 elif "mid" in filename: label = 1 elif "high" in filename: label = 2 elif "natural" in filename: label = 0 try: # Read text file. Assuming space or comma separated values. # Inspecting reported errors: '2022-06-29' -> First column is Timestamp. # We need to drop non-numeric columns. try: df = pd.read_csv(file_path, header=None, sep=r'\s+|,\s*', engine='python') # Convert all to numeric, coercing errors (timestamps becomes NaN) df = df.apply(pd.to_numeric, errors='coerce') # Drop columns that are all NaN (likely timestamp/date strings) df = df.dropna(axis=1, how='all') # Fill any remaining NaNs (rare) with 0 or drop rows df = df.fillna(0) except Exception as e: print(f"Failed to read {filename} with pandas: {e}") continue data = df.values data = df.values # Check Expected Shape. # We expect (Channels, Time). # If we have (Time, Channels), the first dimension will be large (>100 typically). if data.shape[0] > data.shape[1]: # Likely (Time, Channels), so transpose to (Channels, Time) data = data.T # Now data should be (Channels, Time) # Filter to 8 channels. # OpenBCI raw txt usually has: [Index, Ch1, Ch2, ..., Ch8, Accel1, Accel2, Accel3, ...] # So we want channels 1-8 (indices 1 to 9). # But let's be careful. If we only have 8 channels, take them all. if data.shape[0] >= 9: # Assume first column is Index, take next 8 # We verify if row 0 looks like an index (monotonic) - simplified heuristic: # Just take 1:9. data = data[1:9, :] elif data.shape[0] == 8: pass # Already 8 channels else: # If we found less than 8 channels print(f"Skipping {filename}: Not enough channels ({data.shape[0]}) found after transpose.") continue # Sliding Window Augmentation # Window = 1000 samples (4 seconds @ 250Hz) # Stride = 125 samples (0.5 seconds -> 87.5% overlap) window_size = 1000 stride = 125 n_timepoints = data.shape[1] if n_timepoints < window_size: # Pad if shorter than 4s (unlikely for this dataset but good for safety) pad_width = window_size - n_timepoints data = np.pad(data, ((0,0), (0, pad_width)), mode='constant') self.X.append(data.astype(np.float32)) self.y.append(label) else: # Slice into windows for start in range(0, n_timepoints - window_size + 1, stride): end = start + window_size window = data[:, start:end] self.X.append(window.astype(np.float32)) self.y.append(label) except Exception as e: print(f"Skipping {filename}: {e}") # Convert to numpy array if self.X: self.X = np.array(self.X, dtype=np.float32) self.y = np.array(self.y, dtype=np.longlong) print(f"Loaded {len(self.X)} windows. Shape: {self.X.shape}") else: print("No valid data loaded.") def get_data(self) -> Tuple[np.ndarray, np.ndarray]: return self.X, self.y