File size: 3,646 Bytes
ef750ef
 
 
7b53b04
ef750ef
 
7b53b04
ef750ef
 
 
 
186dceb
ef750ef
 
 
 
 
186dceb
ef750ef
 
 
 
 
 
 
 
 
 
186dceb
 
7b53b04
 
 
 
 
 
 
186dceb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0629bc
186dceb
 
 
 
 
 
 
 
 
 
 
 
ef750ef
186dceb
 
 
 
 
e0629bc
186dceb
 
 
 
 
 
 
 
 
 
 
 
e0629bc
186dceb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef750ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
---
license: mit
---
![sjepa](https://cdn-uploads.huggingface.co/production/uploads/646e0135174cc96d509582a6/DS-cXrFyxZ78hK48ft0iU.png)
# Usage

**Instantiate the Base Model**
```python
from braindecode.models import SignalJEPA
from huggingface_hub import hf_hub_download

weights_path = hf_hub_download(repo_id="braindecode/SignalJEPA", filename="signal-jepa_16s-60_adeuwv4s.pth")
model_state_dict = torch.load(weights_path)

# Signal-related arguments
# raw: mne.io.BaseRaw
chs_info = raw.info["chs"]
sfreq = raw.info["sfreq"]

model = SignalJEPA(
    sfreq=sfreq,
    input_window_seconds=2,
    chs_info=chs_info,
)
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
assert unexpected_keys == []
# The spatial positional encoder is initialized using the `chs_info`:
assert set(missing_keys) == {"pos_encoder.pos_encoder_spat.weight"}
```

**Instantiate the Downstream Architectures**

Contrary to the base model, the downstream architectures are equipped with a classification head which is not pre-trained. 
Guetschel et al. (2024) [arXiv:2403.11772](https://arxiv.org/abs/2403.11772) introduce three downstream architectures:
- a) Contextual downstream architecture
- b) Post-local downstream architecture
- c) Pre-local architecture

```python
from braindecode.models import (
    SignalJEPA_Contextual,
    SignalJEPA_PreLocal,
    SignalJEPA_PostLocal,
)
from huggingface_hub import hf_hub_download

weights_path = hf_hub_download(repo_id="braindecode/SignalJEPA", filename="signal-jepa_16s-60_adeuwv4s.pth")
model_state_dict = torch.load(weights_path)

# Signal-related arguments
# raw: mne.io.BaseRaw
chs_info = raw.info["chs"]
sfreq = raw.info["sfreq"]

# The downstream architectures are equipped with an additional classification head
# which was not pre-trained. It has the following new parameters:
final_layer_keys = {
    "final_layer.spat_conv.weight",
    "final_layer.spat_conv.bias",
    "final_layer.linear.weight",
    "final_layer.linear.bias",
}


# a) Contextual downstream architecture
#    ----------------------------------
model = SignalJEPA_Contextual(
    sfreq=sfreq,
    input_window_seconds=2,
    chs_info=chs_info,
    n_outputs=1,
)
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
assert unexpected_keys == []
# The spatial positional encoder is initialized using the `chs_info`:
assert set(missing_keys) == final_layer_keys | {"pos_encoder.pos_encoder_spat.weight"}

# In the post-local (b) and pre-local (c) architectures, the transformer is discarded:
FILTERED_model_state_dict = {
    k: v for k, v in model_state_dict.items() if not any(k.startswith(pre) for pre in ["transformer.", "pos_encoder."])
}


# b) Post-local downstream architecture
#    ----------------------------------
model = SignalJEPA_PostLocal(
    sfreq=sfreq,
    input_window_seconds=2,
    n_chans=len(chs_info),  # detailed channel info is not needed for this model
    n_outputs=1,
)
missing_keys, unexpected_keys = model.load_state_dict(FILTERED_model_state_dict, strict=False)
assert unexpected_keys == []
assert set(missing_keys) == final_layer_keys


# c) Pre-local architecture
#    ----------------------
model = SignalJEPA_PreLocal(
    sfreq=sfreq,
    input_window_seconds=2,
    n_chans=len(chs_info),  # detailed channel info is not needed for this model
    n_outputs=1,
)
missing_keys, unexpected_keys = model.load_state_dict(FILTERED_model_state_dict, strict=False)
assert unexpected_keys == []
assert set(missing_keys) == {
    "spatial_conv.1.weight",
    "spatial_conv.1.bias",
    "final_layer.1.weight",
    "final_layer.1.bias",
}
```