SignalJEPA / README.md
PierreGtch's picture
Update README.md
ef750ef verified
|
raw
history blame
748 Bytes
metadata
license: mit

Usage

from braindecode.models import SignalJEPA
from huggingface_hub import hf_hub_download

weights_path = hf_hub_download(repo_id='braindecode/SignalJEPA', filename='signal-jepa_16s-60_adeuwv4s.pth')
model_state_dict = torch.load(weights_path)

# Signal-related arguments
# raw: mne.io.BaseRaw
chs_info = raw.info["chs"]
sfreq = raw.info['sfreq']

model = SignalJEPA(
    sfreq=sfreq,
    input_window_seconds=2,
    chs_info=chs_info,
    n_outputs=1,
)
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
assert unexpected_keys == []
# The spatial positional encoder is initialized using the `chs_info`:
assert set(missing_keys) == {"pos_encoder.pos_encoder_spat.weight"}