PierreGtch commited on
Commit
71c8b43
Β·
verified Β·
1 Parent(s): 7b53b04

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +114 -86
README.md CHANGED
@@ -1,40 +1,79 @@
1
  ---
2
  license: mit
 
 
 
 
 
 
 
3
  ---
 
4
  ![sjepa](https://cdn-uploads.huggingface.co/production/uploads/646e0135174cc96d509582a6/DS-cXrFyxZ78hK48ft0iU.png)
5
- # Usage
6
 
7
- **Instantiate the Base Model**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  ```python
9
  from braindecode.models import SignalJEPA
10
- from huggingface_hub import hf_hub_download
11
 
12
- weights_path = hf_hub_download(repo_id="braindecode/SignalJEPA", filename="signal-jepa_16s-60_adeuwv4s.pth")
13
- model_state_dict = torch.load(weights_path)
14
 
15
- # Signal-related arguments
16
- # raw: mne.io.BaseRaw
17
- chs_info = raw.info["chs"]
18
- sfreq = raw.info["sfreq"]
 
 
19
 
20
- model = SignalJEPA(
21
- sfreq=sfreq,
22
- input_window_seconds=2,
23
- chs_info=chs_info,
 
24
  )
25
- missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
26
- assert unexpected_keys == []
27
- # The spatial positional encoder is initialized using the `chs_info`:
28
- assert set(missing_keys) == {"pos_encoder.pos_encoder_spat.weight"}
29
  ```
30
 
31
- **Instantiate the Downstream Architectures**
 
 
 
 
 
 
32
 
33
- Contrary to the base model, the downstream architectures are equipped with a classification head which is not pre-trained.
34
- Guetschel et al. (2024) [arXiv:2403.11772](https://arxiv.org/abs/2403.11772) introduce three downstream architectures:
35
- - a) Contextual downstream architecture
36
- - b) Post-local downstream architecture
37
- - c) Pre-local architecture
38
 
39
  ```python
40
  from braindecode.models import (
@@ -42,72 +81,61 @@ from braindecode.models import (
42
  SignalJEPA_PreLocal,
43
  SignalJEPA_PostLocal,
44
  )
45
- from huggingface_hub import hf_hub_download
46
-
47
- weights_path = hf_hub_download(repo_id="braindecode/SignalJEPA", filename="signal-jepa_16s-60_adeuwv4s.pth")
48
- model_state_dict = torch.load(weights_path)
49
-
50
- # Signal-related arguments
51
- # raw: mne.io.BaseRaw
52
- chs_info = raw.info["chs"]
53
- sfreq = raw.info["sfreq"]
54
-
55
- # The downstream architectures are equipped with an additional classification head
56
- # which was not pre-trained. It has the following new parameters:
57
- final_layer_keys = {
58
- "final_layer.spat_conv.weight",
59
- "final_layer.spat_conv.bias",
60
- "final_layer.linear.weight",
61
- "final_layer.linear.bias",
62
- }
63
-
64
 
65
- # a) Contextual downstream architecture
66
- # ----------------------------------
67
- model = SignalJEPA_Contextual(
68
- sfreq=sfreq,
69
- input_window_seconds=2,
70
- chs_info=chs_info,
71
- n_outputs=1,
72
  )
73
- missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
74
- assert unexpected_keys == []
75
- # The spatial positional encoder is initialized using the `chs_info`:
76
- assert set(missing_keys) == final_layer_keys | {"pos_encoder.pos_encoder_spat.weight"}
77
-
78
- # In the post-local (b) and pre-local (c) architectures, the transformer is discarded:
79
- FILTERED_model_state_dict = {
80
- k: v for k, v in model_state_dict.items() if not any(k.startswith(pre) for pre in ["transformer.", "pos_encoder."])
81
- }
82
-
83
 
84
- # b) Post-local downstream architecture
85
- # ----------------------------------
86
- model = SignalJEPA_PostLocal(
87
- sfreq=sfreq,
88
- input_window_seconds=2,
89
- n_chans=len(chs_info), # detailed channel info is not needed for this model
90
- n_outputs=1,
91
  )
92
- missing_keys, unexpected_keys = model.load_state_dict(FILTERED_model_state_dict, strict=False)
93
- assert unexpected_keys == []
94
- assert set(missing_keys) == final_layer_keys
95
-
96
-
97
- # c) Pre-local architecture
98
- # ----------------------
99
- model = SignalJEPA_PreLocal(
100
- sfreq=sfreq,
101
- input_window_seconds=2,
102
- n_chans=len(chs_info), # detailed channel info is not needed for this model
103
- n_outputs=1,
104
  )
105
- missing_keys, unexpected_keys = model.load_state_dict(FILTERED_model_state_dict, strict=False)
106
- assert unexpected_keys == []
107
- assert set(missing_keys) == {
108
- "spatial_conv.1.weight",
109
- "spatial_conv.1.bias",
110
- "final_layer.1.weight",
111
- "final_layer.1.bias",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  }
113
- ```
 
1
  ---
2
  license: mit
3
+ library_name: braindecode
4
+ tags:
5
+ - eeg
6
+ - foundation-model
7
+ - self-supervised
8
+ - signal-jepa
9
+ pipeline_tag: feature-extraction
10
  ---
11
+
12
  ![sjepa](https://cdn-uploads.huggingface.co/production/uploads/646e0135174cc96d509582a6/DS-cXrFyxZ78hK48ft0iU.png)
 
13
 
14
+ # Signal-JEPA
15
+
16
+ Self-supervised pre-trained weights for the Signal-JEPA foundation model from
17
+ [Guetschel et al. (2024)](https://arxiv.org/abs/2403.11772), packaged for use
18
+ with [braindecode](https://braindecode.org/).
19
+
20
+ The model was pre-trained on the Lee2019 dataset (62 EEG channels in the
21
+ 10-10 layout, sampled at 128 Hz). The repo ships the weights together with a
22
+ `config.json` so they can be loaded in one line with
23
+ `YourModelClass.from_pretrained(repo_id, ...)`.
24
+
25
+ ## Available checkpoints
26
+
27
+ Two variants are published:
28
+
29
+ | repo ID | channel embedding included | when to use |
30
+ | --- | --- | --- |
31
+ | [`braindecode/signal-jepa`](https://huggingface.co/braindecode/signal-jepa) | βœ“ 62-row `_ChannelEmbedding` aligned with the pre-training layout | your recording channels are a **subset** (by name, case-insensitive) of the 62 pre-training channels β€” you want to reuse the learned spatial embeddings |
32
+ | [`braindecode/signal-jepa_without-chans`](https://huggingface.co/braindecode/signal-jepa_without-chans) | βœ— only the SSL backbone (feature encoder + transformer) | your channels are **not** a subset of the pre-training set, or you prefer to train channel embeddings from scratch |
33
+
34
+ If you are unsure, start with `braindecode/signal-jepa_without-chans`: it
35
+ always works, regardless of your electrode layout.
36
+
37
+ ## Quick start
38
+
39
+ ### Base model (pre-training architecture)
40
+
41
+ The base model outputs contextual features, not class predictions. Use it
42
+ for downstream feature extraction or further SSL.
43
+
44
  ```python
45
  from braindecode.models import SignalJEPA
 
46
 
47
+ # With the pre-trained channel embeddings (recording channels βŠ‚ pre-train set):
48
+ model = SignalJEPA.from_pretrained("braindecode/signal-jepa")
49
 
50
+ # Or: with your own channels, kept aligned to the pre-training embedding table
51
+ model = SignalJEPA.from_pretrained(
52
+ "braindecode/signal-jepa",
53
+ chs_info=raw.info["chs"], # subset of the 62 pre-training channels
54
+ channel_embedding="pretrain_aligned",
55
+ )
56
 
57
+ # Or: without pre-trained channel embeddings (any electrode layout):
58
+ model = SignalJEPA.from_pretrained(
59
+ "braindecode/signal-jepa_without-chans",
60
+ chs_info=raw.info["chs"],
61
+ strict=False, # the channel-embedding weight is intentionally missing
62
  )
 
 
 
 
63
  ```
64
 
65
+ ### Downstream architectures
66
+
67
+ Three classification architectures are introduced in the paper:
68
+
69
+ - **a) Contextual** β€” uses the full transformer encoder
70
+ - **b) Post-local** β€” discards the transformer; spatial convolution after local features
71
+ - **c) Pre-local** β€” discards the transformer; spatial convolution before local features
72
 
73
+ All three add a freshly-initialized classification head on top of the SSL
74
+ backbone. The head is **not** part of the checkpoint and will be trained from
75
+ scratch during fine-tuning; pass `strict=False` so `from_pretrained` does not
76
+ complain about those missing keys.
 
77
 
78
  ```python
79
  from braindecode.models import (
 
81
  SignalJEPA_PreLocal,
82
  SignalJEPA_PostLocal,
83
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ # a) Contextual β€” keeps the transformer
86
+ model = SignalJEPA_Contextual.from_pretrained(
87
+ "braindecode/signal-jepa", # or "signal-jepa_without-chans"
88
+ n_times=256, # e.g. 2 s at 128 Hz
89
+ n_outputs=4,
90
+ strict=False, # ignore un-trained classification head
 
91
  )
 
 
 
 
 
 
 
 
 
 
92
 
93
+ # b) Post-local β€” transformer discarded
94
+ model = SignalJEPA_PostLocal.from_pretrained(
95
+ "braindecode/signal-jepa_without-chans",
96
+ n_chans=19,
97
+ n_times=256,
98
+ n_outputs=4,
99
+ strict=False,
100
  )
101
+
102
+ # c) Pre-local β€” transformer discarded
103
+ model = SignalJEPA_PreLocal.from_pretrained(
104
+ "braindecode/signal-jepa_without-chans",
105
+ n_chans=19,
106
+ n_times=256,
107
+ n_outputs=4,
108
+ strict=False,
 
 
 
 
109
  )
110
+ ```
111
+
112
+ See the braindecode tutorial
113
+ [Fine-tuning a Foundation Model (Signal-JEPA)](https://braindecode.org/stable/auto_examples/advanced_training/plot_finetune_foundation_model.html)
114
+ for a complete example including layer freezing and training with
115
+ `skorch.EEGClassifier`.
116
+
117
+ ## Channel embedding modes
118
+
119
+ `SignalJEPA` and `SignalJEPA_Contextual` accept a `channel_embedding` kwarg:
120
+
121
+ - `"scratch"` (default): the `_ChannelEmbedding` table has one row per user
122
+ channel, initialized from `chs_info`. Compatible with the
123
+ `without-chans` checkpoint.
124
+ - `"pretrain_aligned"`: the table has 62 rows in the pre-training order,
125
+ `forward` indexes into the subset matching your `chs_info` (matched by
126
+ channel name, case-insensitive). Compatible with the full checkpoint.
127
+
128
+ `from_pretrained` picks the right mode automatically based on the checkpoint's
129
+ `config.json`; override with the `channel_embedding=` kwarg if needed.
130
+
131
+ ## Citation
132
+
133
+ ```bibtex
134
+ @article{guetschel2024sjepa,
135
+ title = {S-JEPA: towards seamless cross-dataset transfer
136
+ through dynamic spatial attention},
137
+ author = {Guetschel, Pierre and Moreau, Thomas and Tangermann, Michael},
138
+ journal = {arXiv preprint arXiv:2403.11772},
139
+ year = {2024},
140
  }
141
+ ```