| |
| from __future__ import annotations |
| from dataclasses import dataclass |
| import numpy as np |
| import mne |
|
|
|
|
| @dataclass(frozen=True) |
| class PreprocessConfig: |
| fs: float |
| f_low: float |
| f_high: float |
|
|
|
|
| def to_time_channel(x: np.ndarray) -> np.ndarray: |
| if x.ndim == 1: |
| return x[:, None] |
| if x.ndim != 2: |
| raise ValueError(f"Expected 1D or 2D array, got {x.shape}") |
| T, C = x.shape |
| if T <= 256 and C > T: |
| x = x.T |
| return x |
|
|
|
|
| def bandpass_tc(x_tc: np.ndarray, cfg: PreprocessConfig) -> np.ndarray: |
| info = mne.create_info( |
| ch_names=[f"ch{i}" for i in range(x_tc.shape[1])], |
| sfreq=cfg.fs, |
| ch_types="eeg", |
| ) |
| raw = mne.io.RawArray(x_tc.T, info, verbose=False) |
| raw_filt = raw.copy().filter(cfg.f_low, cfg.f_high, verbose=False) |
| return raw_filt.get_data().T |
|
|
|
|
| def hilbert_envelope_tc(x_tc: np.ndarray) -> np.ndarray: |
| Xf = np.fft.fft(x_tc, axis=0) |
| N = Xf.shape[0] |
| h = np.zeros(N) |
| if N % 2 == 0: |
| h[0] = h[N // 2] = 1 |
| h[1:N // 2] = 2 |
| else: |
| h[0] = 1 |
| h[1:(N + 1) // 2] = 2 |
| env = np.abs(np.fft.ifft(Xf * h[:, None], axis=0)) |
| return env.astype(np.float32) |
|
|
|
|
| def preprocess_pipeline(x: np.ndarray, cfg: PreprocessConfig): |
| x_tc = to_time_channel(x) |
| x_filt = bandpass_tc(x_tc, cfg) |
| env = hilbert_envelope_tc(x_filt) |
| return { |
| "raw": x_tc, |
| "filtered": x_filt, |
| "envelope": env, |
| } |
|
|