metadata
license: mit
Usage
from braindecode.models import SignalJEPA
from huggingface_hub import hf_hub_download
weights_path = hf_hub_download(repo_id='braindecode/SignalJEPA', filename='signal-jepa_16s-60_adeuwv4s.pth')
model_state_dict = torch.load(weights_path)
# Signal-related arguments
# raw: mne.io.BaseRaw
chs_info = raw.info["chs"]
sfreq = raw.info['sfreq']
model = SignalJEPA(
sfreq=sfreq,
input_window_seconds=2,
chs_info=chs_info,
n_outputs=1,
)
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
assert unexpected_keys == []
# The spatial positional encoder is initialized using the `chs_info`:
assert set(missing_keys) == {"pos_encoder.pos_encoder_spat.weight"}