|
|
--- |
|
|
license: mit |
|
|
--- |
|
|
 |
|
|
# Usage |
|
|
|
|
|
**Instantiate the Base Model** |
|
|
```python |
|
|
from braindecode.models import SignalJEPA |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
weights_path = hf_hub_download(repo_id="braindecode/SignalJEPA", filename="signal-jepa_16s-60_adeuwv4s.pth") |
|
|
model_state_dict = torch.load(weights_path) |
|
|
|
|
|
# Signal-related arguments |
|
|
# raw: mne.io.BaseRaw |
|
|
chs_info = raw.info["chs"] |
|
|
sfreq = raw.info["sfreq"] |
|
|
|
|
|
model = SignalJEPA( |
|
|
sfreq=sfreq, |
|
|
input_window_seconds=2, |
|
|
chs_info=chs_info, |
|
|
) |
|
|
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False) |
|
|
assert unexpected_keys == [] |
|
|
# The spatial positional encoder is initialized using the `chs_info`: |
|
|
assert set(missing_keys) == {"pos_encoder.pos_encoder_spat.weight"} |
|
|
``` |
|
|
|
|
|
**Instantiate the Downstream Architectures** |
|
|
|
|
|
Contrary to the base model, the downstream architectures are equipped with a classification head which is not pre-trained. |
|
|
Guetschel et al. (2024) [arXiv:2403.11772](https://arxiv.org/abs/2403.11772) introduce three downstream architectures: |
|
|
- a) Contextual downstream architecture |
|
|
- b) Post-local downstream architecture |
|
|
- c) Pre-local architecture |
|
|
|
|
|
```python |
|
|
from braindecode.models import ( |
|
|
SignalJEPA_Contextual, |
|
|
SignalJEPA_PreLocal, |
|
|
SignalJEPA_PostLocal, |
|
|
) |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
weights_path = hf_hub_download(repo_id="braindecode/SignalJEPA", filename="signal-jepa_16s-60_adeuwv4s.pth") |
|
|
model_state_dict = torch.load(weights_path) |
|
|
|
|
|
# Signal-related arguments |
|
|
# raw: mne.io.BaseRaw |
|
|
chs_info = raw.info["chs"] |
|
|
sfreq = raw.info["sfreq"] |
|
|
|
|
|
# The downstream architectures are equipped with an additional classification head |
|
|
# which was not pre-trained. It has the following new parameters: |
|
|
final_layer_keys = { |
|
|
"final_layer.spat_conv.weight", |
|
|
"final_layer.spat_conv.bias", |
|
|
"final_layer.linear.weight", |
|
|
"final_layer.linear.bias", |
|
|
} |
|
|
|
|
|
|
|
|
# a) Contextual downstream architecture |
|
|
# ---------------------------------- |
|
|
model = SignalJEPA_Contextual( |
|
|
sfreq=sfreq, |
|
|
input_window_seconds=2, |
|
|
chs_info=chs_info, |
|
|
n_outputs=1, |
|
|
) |
|
|
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False) |
|
|
assert unexpected_keys == [] |
|
|
# The spatial positional encoder is initialized using the `chs_info`: |
|
|
assert set(missing_keys) == final_layer_keys | {"pos_encoder.pos_encoder_spat.weight"} |
|
|
|
|
|
# In the post-local (b) and pre-local (c) architectures, the transformer is discarded: |
|
|
FILTERED_model_state_dict = { |
|
|
k: v for k, v in model_state_dict.items() if not any(k.startswith(pre) for pre in ["transformer.", "pos_encoder."]) |
|
|
} |
|
|
|
|
|
|
|
|
# b) Post-local downstream architecture |
|
|
# ---------------------------------- |
|
|
model = SignalJEPA_PostLocal( |
|
|
sfreq=sfreq, |
|
|
input_window_seconds=2, |
|
|
n_chans=len(chs_info), # detailed channel info is not needed for this model |
|
|
n_outputs=1, |
|
|
) |
|
|
missing_keys, unexpected_keys = model.load_state_dict(FILTERED_model_state_dict, strict=False) |
|
|
assert unexpected_keys == [] |
|
|
assert set(missing_keys) == final_layer_keys |
|
|
|
|
|
|
|
|
# c) Pre-local architecture |
|
|
# ---------------------- |
|
|
model = SignalJEPA_PreLocal( |
|
|
sfreq=sfreq, |
|
|
input_window_seconds=2, |
|
|
n_chans=len(chs_info), # detailed channel info is not needed for this model |
|
|
n_outputs=1, |
|
|
) |
|
|
missing_keys, unexpected_keys = model.load_state_dict(FILTERED_model_state_dict, strict=False) |
|
|
assert unexpected_keys == [] |
|
|
assert set(missing_keys) == { |
|
|
"spatial_conv.1.weight", |
|
|
"spatial_conv.1.bias", |
|
|
"final_layer.1.weight", |
|
|
"final_layer.1.bias", |
|
|
} |
|
|
``` |