File size: 5,337 Bytes
9e20126
50c4bd8
9e20126
 
50c4bd8
 
 
 
 
9e20126
 
50c4bd8
 
 
51232ee
50c4bd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
---
license: mit
library_name: braindecode
tags:
  - eeg
  - foundation-model
  - self-supervised
  - signal-jepa
pipeline_tag: feature-extraction
---

![sjepa](https://cdn-uploads.huggingface.co/production/uploads/646e0135174cc96d509582a6/DS-cXrFyxZ78hK48ft0iU.png)

# Signal-JEPA
 
Self-supervised pre-trained weights for the Signal-JEPA foundation model from
[Guetschel et al. (2024)](https://arxiv.org/abs/2403.11772), packaged for use
with [braindecode](https://braindecode.org/). See the full API reference in
the docs: [`braindecode.models.SignalJEPA`](https://braindecode.org/stable/generated/braindecode.models.SignalJEPA.html).

The model was pre-trained on the Lee2019 dataset (62 EEG channels in the
10-10 layout, sampled at 128 Hz). The repo ships the weights together with a
`config.json` so they can be loaded in one line with
`YourModelClass.from_pretrained(repo_id, ...)`.

## Available checkpoints

Two variants are published:

| repo ID | channel embedding included | when to use |
| --- | --- | --- |
| [`braindecode/signal-jepa`](https://huggingface.co/braindecode/signal-jepa) | βœ“ 62-row `_ChannelEmbedding` aligned with the pre-training layout | your recording channels are a **subset** (by name, case-insensitive) of the 62 pre-training channels β€” you want to reuse the learned spatial embeddings |
| [`braindecode/signal-jepa_without-chans`](https://huggingface.co/braindecode/signal-jepa_without-chans) | βœ— only the SSL backbone (feature encoder + transformer) | your channels are **not** a subset of the pre-training set, or you prefer to train channel embeddings from scratch |

If you are unsure, start with `braindecode/signal-jepa_without-chans`: it
always works, regardless of your electrode layout.

## Quick start

### Base model (pre-training architecture)

The base model outputs contextual features, not class predictions. Use it
for downstream feature extraction or further SSL.

```python
from braindecode.models import SignalJEPA

# With the pre-trained channel embeddings (recording channels βŠ‚ pre-train set):
model = SignalJEPA.from_pretrained("braindecode/signal-jepa")

# Or: with your own channels, kept aligned to the pre-training embedding table
model = SignalJEPA.from_pretrained(
    "braindecode/signal-jepa",
    chs_info=raw.info["chs"],           # subset of the 62 pre-training channels
    channel_embedding="pretrain_aligned",
)

# Or: without pre-trained channel embeddings (any electrode layout):
model = SignalJEPA.from_pretrained(
    "braindecode/signal-jepa_without-chans",
    chs_info=raw.info["chs"],
    strict=False,  # the channel-embedding weight is intentionally missing
)
```

### Downstream architectures

Three classification architectures are introduced in the paper:

- **a) Contextual** β€” uses the full transformer encoder
- **b) Post-local** β€” discards the transformer; spatial convolution after local features
- **c) Pre-local** β€” discards the transformer; spatial convolution before local features

All three add a freshly-initialized classification head on top of the SSL
backbone. The head is **not** part of the checkpoint and will be trained from
scratch during fine-tuning; pass `strict=False` so `from_pretrained` does not
complain about those missing keys.

```python
from braindecode.models import (
    SignalJEPA_Contextual,
    SignalJEPA_PreLocal,
    SignalJEPA_PostLocal,
)

# a) Contextual β€” keeps the transformer
model = SignalJEPA_Contextual.from_pretrained(
    "braindecode/signal-jepa",          # or "signal-jepa_without-chans"
    n_times=256,                         # e.g. 2 s at 128 Hz
    n_outputs=4,
    strict=False,                        # ignore un-trained classification head
)

# b) Post-local β€” transformer discarded
model = SignalJEPA_PostLocal.from_pretrained(
    "braindecode/signal-jepa_without-chans",
    n_chans=19,
    n_times=256,
    n_outputs=4,
    strict=False,
)

# c) Pre-local β€” transformer discarded
model = SignalJEPA_PreLocal.from_pretrained(
    "braindecode/signal-jepa_without-chans",
    n_chans=19,
    n_times=256,
    n_outputs=4,
    strict=False,
)
```

See the braindecode tutorial
[Fine-tuning a Foundation Model (Signal-JEPA)](https://braindecode.org/stable/auto_examples/advanced_training/plot_finetune_foundation_model.html)
for a complete example including layer freezing and training with
`skorch.EEGClassifier`.

## Channel embedding modes

`SignalJEPA` and `SignalJEPA_Contextual` accept a `channel_embedding` kwarg:

- `"scratch"` (default): the `_ChannelEmbedding` table has one row per user
  channel, initialized from `chs_info`. Compatible with the
  `without-chans` checkpoint.
- `"pretrain_aligned"`: the table has 62 rows in the pre-training order,
  `forward` indexes into the subset matching your `chs_info` (matched by
  channel name, case-insensitive). Compatible with the full checkpoint.

`from_pretrained` picks the right mode automatically based on the checkpoint's
`config.json`; override with the `channel_embedding=` kwarg if needed.

## Citation

```bibtex
@article{guetschel2024sjepa,
  title   = {S-JEPA: towards seamless cross-dataset transfer
             through dynamic spatial attention},
  author  = {Guetschel, Pierre and Moreau, Thomas and Tangermann, Michael},
  journal = {arXiv preprint arXiv:2403.11772},
  year    = {2024},
}
```