File size: 3,646 Bytes
ef750ef 7b53b04 ef750ef 7b53b04 ef750ef 186dceb ef750ef 186dceb ef750ef 186dceb 7b53b04 186dceb e0629bc 186dceb ef750ef 186dceb e0629bc 186dceb e0629bc 186dceb ef750ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
---
license: mit
---

# Usage
**Instantiate the Base Model**
```python
from braindecode.models import SignalJEPA
from huggingface_hub import hf_hub_download
weights_path = hf_hub_download(repo_id="braindecode/SignalJEPA", filename="signal-jepa_16s-60_adeuwv4s.pth")
model_state_dict = torch.load(weights_path)
# Signal-related arguments
# raw: mne.io.BaseRaw
chs_info = raw.info["chs"]
sfreq = raw.info["sfreq"]
model = SignalJEPA(
sfreq=sfreq,
input_window_seconds=2,
chs_info=chs_info,
)
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
assert unexpected_keys == []
# The spatial positional encoder is initialized using the `chs_info`:
assert set(missing_keys) == {"pos_encoder.pos_encoder_spat.weight"}
```
**Instantiate the Downstream Architectures**
Contrary to the base model, the downstream architectures are equipped with a classification head which is not pre-trained.
Guetschel et al. (2024) [arXiv:2403.11772](https://arxiv.org/abs/2403.11772) introduce three downstream architectures:
- a) Contextual downstream architecture
- b) Post-local downstream architecture
- c) Pre-local architecture
```python
from braindecode.models import (
SignalJEPA_Contextual,
SignalJEPA_PreLocal,
SignalJEPA_PostLocal,
)
from huggingface_hub import hf_hub_download
weights_path = hf_hub_download(repo_id="braindecode/SignalJEPA", filename="signal-jepa_16s-60_adeuwv4s.pth")
model_state_dict = torch.load(weights_path)
# Signal-related arguments
# raw: mne.io.BaseRaw
chs_info = raw.info["chs"]
sfreq = raw.info["sfreq"]
# The downstream architectures are equipped with an additional classification head
# which was not pre-trained. It has the following new parameters:
final_layer_keys = {
"final_layer.spat_conv.weight",
"final_layer.spat_conv.bias",
"final_layer.linear.weight",
"final_layer.linear.bias",
}
# a) Contextual downstream architecture
# ----------------------------------
model = SignalJEPA_Contextual(
sfreq=sfreq,
input_window_seconds=2,
chs_info=chs_info,
n_outputs=1,
)
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
assert unexpected_keys == []
# The spatial positional encoder is initialized using the `chs_info`:
assert set(missing_keys) == final_layer_keys | {"pos_encoder.pos_encoder_spat.weight"}
# In the post-local (b) and pre-local (c) architectures, the transformer is discarded:
FILTERED_model_state_dict = {
k: v for k, v in model_state_dict.items() if not any(k.startswith(pre) for pre in ["transformer.", "pos_encoder."])
}
# b) Post-local downstream architecture
# ----------------------------------
model = SignalJEPA_PostLocal(
sfreq=sfreq,
input_window_seconds=2,
n_chans=len(chs_info), # detailed channel info is not needed for this model
n_outputs=1,
)
missing_keys, unexpected_keys = model.load_state_dict(FILTERED_model_state_dict, strict=False)
assert unexpected_keys == []
assert set(missing_keys) == final_layer_keys
# c) Pre-local architecture
# ----------------------
model = SignalJEPA_PreLocal(
sfreq=sfreq,
input_window_seconds=2,
n_chans=len(chs_info), # detailed channel info is not needed for this model
n_outputs=1,
)
missing_keys, unexpected_keys = model.load_state_dict(FILTERED_model_state_dict, strict=False)
assert unexpected_keys == []
assert set(missing_keys) == {
"spatial_conv.1.weight",
"spatial_conv.1.bias",
"final_layer.1.weight",
"final_layer.1.bias",
}
``` |