File size: 5,337 Bytes
9e20126 50c4bd8 9e20126 50c4bd8 9e20126 50c4bd8 51232ee 50c4bd8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | ---
license: mit
library_name: braindecode
tags:
- eeg
- foundation-model
- self-supervised
- signal-jepa
pipeline_tag: feature-extraction
---

# Signal-JEPA
Self-supervised pre-trained weights for the Signal-JEPA foundation model from
[Guetschel et al. (2024)](https://arxiv.org/abs/2403.11772), packaged for use
with [braindecode](https://braindecode.org/). See the full API reference in
the docs: [`braindecode.models.SignalJEPA`](https://braindecode.org/stable/generated/braindecode.models.SignalJEPA.html).
The model was pre-trained on the Lee2019 dataset (62 EEG channels in the
10-10 layout, sampled at 128 Hz). The repo ships the weights together with a
`config.json` so they can be loaded in one line with
`YourModelClass.from_pretrained(repo_id, ...)`.
## Available checkpoints
Two variants are published:
| repo ID | channel embedding included | when to use |
| --- | --- | --- |
| [`braindecode/signal-jepa`](https://huggingface.co/braindecode/signal-jepa) | β 62-row `_ChannelEmbedding` aligned with the pre-training layout | your recording channels are a **subset** (by name, case-insensitive) of the 62 pre-training channels β you want to reuse the learned spatial embeddings |
| [`braindecode/signal-jepa_without-chans`](https://huggingface.co/braindecode/signal-jepa_without-chans) | β only the SSL backbone (feature encoder + transformer) | your channels are **not** a subset of the pre-training set, or you prefer to train channel embeddings from scratch |
If you are unsure, start with `braindecode/signal-jepa_without-chans`: it
always works, regardless of your electrode layout.
## Quick start
### Base model (pre-training architecture)
The base model outputs contextual features, not class predictions. Use it
for downstream feature extraction or further SSL.
```python
from braindecode.models import SignalJEPA
# With the pre-trained channel embeddings (recording channels β pre-train set):
model = SignalJEPA.from_pretrained("braindecode/signal-jepa")
# Or: with your own channels, kept aligned to the pre-training embedding table
model = SignalJEPA.from_pretrained(
"braindecode/signal-jepa",
chs_info=raw.info["chs"], # subset of the 62 pre-training channels
channel_embedding="pretrain_aligned",
)
# Or: without pre-trained channel embeddings (any electrode layout):
model = SignalJEPA.from_pretrained(
"braindecode/signal-jepa_without-chans",
chs_info=raw.info["chs"],
strict=False, # the channel-embedding weight is intentionally missing
)
```
### Downstream architectures
Three classification architectures are introduced in the paper:
- **a) Contextual** β uses the full transformer encoder
- **b) Post-local** β discards the transformer; spatial convolution after local features
- **c) Pre-local** β discards the transformer; spatial convolution before local features
All three add a freshly-initialized classification head on top of the SSL
backbone. The head is **not** part of the checkpoint and will be trained from
scratch during fine-tuning; pass `strict=False` so `from_pretrained` does not
complain about those missing keys.
```python
from braindecode.models import (
SignalJEPA_Contextual,
SignalJEPA_PreLocal,
SignalJEPA_PostLocal,
)
# a) Contextual β keeps the transformer
model = SignalJEPA_Contextual.from_pretrained(
"braindecode/signal-jepa", # or "signal-jepa_without-chans"
n_times=256, # e.g. 2 s at 128 Hz
n_outputs=4,
strict=False, # ignore un-trained classification head
)
# b) Post-local β transformer discarded
model = SignalJEPA_PostLocal.from_pretrained(
"braindecode/signal-jepa_without-chans",
n_chans=19,
n_times=256,
n_outputs=4,
strict=False,
)
# c) Pre-local β transformer discarded
model = SignalJEPA_PreLocal.from_pretrained(
"braindecode/signal-jepa_without-chans",
n_chans=19,
n_times=256,
n_outputs=4,
strict=False,
)
```
See the braindecode tutorial
[Fine-tuning a Foundation Model (Signal-JEPA)](https://braindecode.org/stable/auto_examples/advanced_training/plot_finetune_foundation_model.html)
for a complete example including layer freezing and training with
`skorch.EEGClassifier`.
## Channel embedding modes
`SignalJEPA` and `SignalJEPA_Contextual` accept a `channel_embedding` kwarg:
- `"scratch"` (default): the `_ChannelEmbedding` table has one row per user
channel, initialized from `chs_info`. Compatible with the
`without-chans` checkpoint.
- `"pretrain_aligned"`: the table has 62 rows in the pre-training order,
`forward` indexes into the subset matching your `chs_info` (matched by
channel name, case-insensitive). Compatible with the full checkpoint.
`from_pretrained` picks the right mode automatically based on the checkpoint's
`config.json`; override with the `channel_embedding=` kwarg if needed.
## Citation
```bibtex
@article{guetschel2024sjepa,
title = {S-JEPA: towards seamless cross-dataset transfer
through dynamic spatial attention},
author = {Guetschel, Pierre and Moreau, Thomas and Tangermann, Michael},
journal = {arXiv preprint arXiv:2403.11772},
year = {2024},
}
```
|