Spaces:
Runtime error
Runtime error
| """Phase 2: Preprocessing pipeline -- filtering, segmentation, windowing.""" | |
| import os | |
| import numpy as np | |
| from pathlib import Path | |
| from scipy.signal import butter, filtfilt, welch | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent | |
| DATA_DIR = PROJECT_ROOT / "data" | |
| RESULTS_DIR = PROJECT_ROOT / "results" | |
| # Channel names in order | |
| CHANNEL_NAMES = ["AFF6", "AFp2", "AFp1", "AFF5", "FCz", "CPz"] | |
| FS = 500.0 # Sampling rate | |
| def bandpass_filter(data, low=8.0, high=30.0, fs=500.0, order=4): | |
| """Bandpass filter EEG data. data: (n_samples, n_channels).""" | |
| nyq = fs / 2.0 | |
| b, a = butter(order, [low / nyq, high / nyq], btype="band") | |
| filtered = filtfilt(b, a, data, axis=0) | |
| return filtered | |
| def extract_active_segment(eeg, duration, fs=500.0, stim_onset_s=3.0): | |
| """Extract the stimulus-active portion of EEG.""" | |
| start_sample = int(stim_onset_s * fs) | |
| end_sample = start_sample + int(duration * fs) | |
| end_sample = min(end_sample, eeg.shape[0]) | |
| return eeg[start_sample:end_sample] | |
| def normalize_channels(data): | |
| """Zero-mean, unit-variance per channel.""" | |
| mean = data.mean(axis=0, keepdims=True) | |
| std = data.std(axis=0, keepdims=True) | |
| std[std < 1e-8] = 1.0 | |
| return (data - mean) / std | |
| def segment_windows(data, window_size=500, overlap=250): | |
| """Segment data into overlapping windows.""" | |
| step = window_size - overlap | |
| windows = [] | |
| for start in range(0, data.shape[0] - window_size + 1, step): | |
| windows.append(data[start:start + window_size]) | |
| return windows | |
| def preprocess_file(fpath, window_size=500, overlap=250): | |
| """ | |
| Full preprocessing for one .npz file. | |
| Returns: (windows_list, label_str, subject_id) or None if file is bad. | |
| """ | |
| arr = np.load(str(fpath), allow_pickle=True) | |
| eeg_raw = arr["feature_eeg"] # (7499, 6) | |
| label_info = arr["label"].item() | |
| label_str = label_info["label"] | |
| subject_id = label_info["subject_id"] | |
| duration = label_info["duration"] | |
| # Step 1: Bandpass filter | |
| eeg_filtered = bandpass_filter(eeg_raw, low=8.0, high=30.0, fs=FS) | |
| # Check for NaN/Inf | |
| if np.any(np.isnan(eeg_filtered)) or np.any(np.isinf(eeg_filtered)): | |
| print(f" WARNING: NaN/Inf in {fpath.name}, skipping.") | |
| return None | |
| # Step 2: Extract active segment | |
| eeg_active = extract_active_segment(eeg_filtered, duration, fs=FS) | |
| # Step 3: Normalize | |
| eeg_norm = normalize_channels(eeg_active) | |
| # Step 4: Segment into windows | |
| windows = segment_windows(eeg_norm, window_size, overlap) | |
| # Edge case: very short recordings | |
| if len(windows) == 0: | |
| if eeg_norm.shape[0] > 0: | |
| # Pad to window_size | |
| padded = np.zeros((window_size, eeg_norm.shape[1])) | |
| padded[:eeg_norm.shape[0]] = eeg_norm | |
| windows = [padded] | |
| print(f" WARNING: Short recording in {fpath.name}, zero-padded.") | |
| else: | |
| print(f" WARNING: Empty active segment in {fpath.name}, skipping.") | |
| return None | |
| return windows, label_str, subject_id | |
| def preprocess_all(window_size=500, overlap=250): | |
| """Process all .npz files and save preprocessed data.""" | |
| RESULTS_DIR.mkdir(parents=True, exist_ok=True) | |
| files = sorted(DATA_DIR.glob("*.npz")) | |
| print(f"Processing {len(files)} files...") | |
| all_windows = [] | |
| all_labels = [] | |
| all_subjects = [] | |
| skipped = 0 | |
| for i, fpath in enumerate(files): | |
| if (i + 1) % 100 == 0: | |
| print(f" [{i+1}/{len(files)}]...") | |
| result = preprocess_file(fpath, window_size, overlap) | |
| if result is None: | |
| skipped += 1 | |
| continue | |
| windows, label_str, subject_id = result | |
| for w in windows: | |
| all_windows.append(w) | |
| all_labels.append(label_str) | |
| all_subjects.append(subject_id) | |
| X = np.array(all_windows, dtype=np.float32) | |
| y = np.array(all_labels) | |
| subjects = np.array(all_subjects) | |
| print(f"\nPreprocessing complete:") | |
| print(f" Total windows: {X.shape[0]}") | |
| print(f" Window shape: {X.shape[1:]}") | |
| print(f" Skipped files: {skipped}") | |
| print(f" Unique labels: {np.unique(y)}") | |
| print(f" Unique subjects: {np.unique(subjects)}") | |
| # Save | |
| out_path = PROJECT_ROOT / "preprocessed_data.npz" | |
| np.savez_compressed(str(out_path), X=X, y=y, subjects=subjects) | |
| print(f" Saved to {out_path} ({out_path.stat().st_size / 1e6:.1f} MB)") | |
| return X, y, subjects | |
| def verify_psd(sample_file=None): | |
| """Generate PSD verification plot: raw vs filtered.""" | |
| RESULTS_DIR.mkdir(parents=True, exist_ok=True) | |
| if sample_file is None: | |
| files = sorted(DATA_DIR.glob("*.npz")) | |
| sample_file = files[0] | |
| arr = np.load(str(sample_file), allow_pickle=True) | |
| eeg_raw = arr["feature_eeg"] | |
| eeg_filtered = bandpass_filter(eeg_raw, low=8.0, high=30.0, fs=FS) | |
| fig, axes = plt.subplots(2, 3, figsize=(15, 8)) | |
| fig.suptitle(f"PSD Verification: {sample_file.name}\nRaw (blue) vs Filtered 8-30Hz (orange)") | |
| for ch in range(6): | |
| ax = axes[ch // 3, ch % 3] | |
| freqs_raw, psd_raw = welch(eeg_raw[:, ch], fs=FS, nperseg=1024) | |
| freqs_filt, psd_filt = welch(eeg_filtered[:, ch], fs=FS, nperseg=1024) | |
| ax.semilogy(freqs_raw, psd_raw, label="Raw", alpha=0.7) | |
| ax.semilogy(freqs_filt, psd_filt, label="Filtered 8-30Hz", alpha=0.7) | |
| ax.axvline(8, color="gray", linestyle="--", alpha=0.5, label="8 Hz") | |
| ax.axvline(30, color="gray", linestyle="--", alpha=0.5, label="30 Hz") | |
| ax.set_title(f"Ch {ch}: {CHANNEL_NAMES[ch]}") | |
| ax.set_xlabel("Frequency (Hz)") | |
| ax.set_ylabel("PSD (uV^2/Hz)") | |
| ax.set_xlim(0, 60) | |
| ax.legend(fontsize=7) | |
| plt.tight_layout() | |
| out_path = RESULTS_DIR / "psd_verification.png" | |
| plt.savefig(str(out_path), dpi=150) | |
| plt.close() | |
| print(f"PSD verification saved to {out_path}") | |
| if __name__ == "__main__": | |
| verify_psd() | |
| X, y, subjects = preprocess_all() | |