| --- |
| license: apache-2.0 |
| tags: |
| - eeg |
| - neuroscience |
| - foundation-model |
| - embeddings |
| - matryoshka |
| pipeline_tag: feature-extraction |
| library_name: neuroencoder |
| extra_gated_prompt: |- |
| The MRL model is currently gated. Access is granted to verified researchers. |
| Please briefly describe your institution, role, and intended use. |
| If you have a private invitation code, paste it in the "Intended use" field. |
| extra_gated_fields: |
| Institution: text |
| Role: text |
| Intended use: text |
| I agree to use this model for research purposes only: checkbox |
| --- |
| |
| # EPI Embedding |
|
|
|  |
|
|
| EEG model embeddings, distilled from EPI-250k (trained on ~250,000 hours of clinical EEG). |
|
|
| The model produces a 768-dimensional embedding that you can truncate to **768, 384, 192, 48, or 16** dimensions via [Matryoshka Representation Learning](https://arxiv.org/abs/2205.13147). |
|
|
| ## Usage |
|
|
| Install: |
|
|
| ```bash |
| pip install neuroencoder |
| ``` |
|
|
| Then: |
|
|
| ```python |
| import mne, neuroencoder as ne |
| from neuroencoder import MRL |
| |
| raw = mne.io.read_raw_edf("recording.edf", preload=True) |
| model = MRL.from_pretrained() # auto-downloads on first use |
| |
| embeddings = model.embed( |
| raw.get_data(), |
| sfreq=raw.info["sfreq"], |
| channel_names=raw.ch_names, |
| dim=192, |
| ) |
| # -> numpy array, shape [N, 192], L2-normalized |
| |
| ne.explore(embeddings) # interactive Apple Embedding Atlas |
| ``` |
|
|
| `model.embed` runs the full pipeline (filter -> resample -> 8-region average -> 30s sliding window -> embed) and returns numpy. For more control, split into: |
|
|
| ```python |
| images = ne.preprocess(eeg, sfreq=256, channel_names=ch_names) # [N, 8, 224, 224] |
| embeddings = model.predict(images, dim=192) # torch tensor on model device |
| ``` |
|
|
| ## Loading directly from a checkpoint |
|
|
| ```python |
| model = MRL.from_checkpoint("path/to/last.ckpt") |
| ``` |
|
|
| Handles both raw state dicts and PyTorch Lightning checkpoint formats. |
|
|
| ## Benchmarks |
|
|
| Frozen linear probes, 5-fold subject-level cross-validation. Balanced accuracy (%). The first column is **EPI-250k**, our base foundation model (not publicly released) — the upper bound on what the distilled MRL model can preserve. The remaining columns are the MRL model at each truncation dimension. |
|
|
| ### Private clinical tasks |
|
|
| 40,909 annotated 30-second epochs from the Swiss Epilepsy Center. |
|
|
| | Task | EPI-250k | 768 | 384 | 192 | 48 | 16 | |
| |------|:-------:|:---:|:---:|:---:|:--:|:--:| |
| | Seizure / Wake | **93.4** | 93.1 | 92.7 | 92.5 | 91.5 | 84.1 | |
| | Sleep (5-class) | **85.1** | 77.0 | 77.4 | 76.9 | 76.5 | 73.2 | |
| | Artifact / Wake | **90.2** | 90.5 | 90.3 | 90.5 | 90.7 | 65.9 | |
| | Seizure / Sleep | **88.8** | 85.2 | 84.9 | 84.0 | 82.1 | 79.4 | |
| | Spike / Seizure | **81.5** | 76.2 | 75.9 | 74.7 | 71.0 | 65.5 | |
| | Spike / Wake | **97.0** | 94.8 | 94.7 | 94.6 | 92.9 | 87.2 | |
| | Artifact / Spike | **78.8** | 76.0 | 75.6 | 75.3 | 74.4 | 70.4 | |
| | Category (6-cls) | **36.3** | 33.6 | 33.3 | 32.8 | 31.7 | 27.4 | |
| | Clinical Sub (7-cls) | **42.7** | 31.4 | 31.4 | 31.4 | 27.0 | 23.7 | |
| | All Sublabels (49-cls) | **22.1** | 14.8 | 14.4 | 13.7 | 12.3 | 10.6 | |
|
|
| ### Public benchmarks |
|
|
| 10 standard public EEG datasets, evaluated under identical conditions. |
|
|
| | Task | EPI-250k | 768 | 384 | 192 | 48 | 16 | |
| |------|:-------:|:---:|:---:|:---:|:--:|:--:| |
| | TUAB | **73.1** | 72.4 | 72.5 | 72.9 | 72.2 | 70.4 | |
| | TUEV | **54.5** | 45.9 | 47.2 | 46.7 | 42.8 | 32.1 | |
| | TUAR | **45.2** | 43.0 | 42.9 | 42.2 | 39.5 | 36.5 | |
| | TUSL | **73.3** | 71.5 | 75.1 | 77.1 | 71.3 | 69.7 | |
| | Mumtaz | **82.1** | 80.7 | 81.8 | 82.6 | 83.2 | 83.1 | |
| | Schizo | **71.1** | 70.1 | 69.4 | 69.5 | 69.4 | 66.7 | |
| | MentArith | **60.9** | 60.2 | 59.9 | 58.6 | 55.6 | 52.2 | |
| | ADFTD | **43.2** | 40.0 | 40.0 | 41.0 | 38.6 | 35.9 | |
| | PhysioMI | **30.3** | 28.3 | 28.4 | 27.3 | 27.7 | 25.2 | |
| | Parkinsons | **62.9** | 58.9 | 58.6 | 58.2 | 55.9 | 53.2 | |
|
|
| Numeric column headers (`768`, `384`, ...) are the MRL truncation dimensions. |
|
|
| ## Documentation |
|
|
| - Docs: [docs.neuroencoder.com](https://docs.neuroencoder.com) |
| - GitHub: [github.com/avocardio/neuroencoder](https://github.com/avocardio/neuroencoder) |
|
|
| ## Citation |
|
|
| Paper in preparation. A citation will be added once published. |
|
|