--- license: mit --- ![sjepa](https://cdn-uploads.huggingface.co/production/uploads/646e0135174cc96d509582a6/DS-cXrFyxZ78hK48ft0iU.png) # Usage **Instantiate the Base Model** ```python from braindecode.models import SignalJEPA from huggingface_hub import hf_hub_download weights_path = hf_hub_download(repo_id="braindecode/SignalJEPA", filename="signal-jepa_16s-60_adeuwv4s.pth") model_state_dict = torch.load(weights_path) # Signal-related arguments # raw: mne.io.BaseRaw chs_info = raw.info["chs"] sfreq = raw.info["sfreq"] model = SignalJEPA( sfreq=sfreq, input_window_seconds=2, chs_info=chs_info, ) missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False) assert unexpected_keys == [] # The spatial positional encoder is initialized using the `chs_info`: assert set(missing_keys) == {"pos_encoder.pos_encoder_spat.weight"} ``` **Instantiate the Downstream Architectures** Contrary to the base model, the downstream architectures are equipped with a classification head which is not pre-trained. Guetschel et al. (2024) [arXiv:2403.11772](https://arxiv.org/abs/2403.11772) introduce three downstream architectures: - a) Contextual downstream architecture - b) Post-local downstream architecture - c) Pre-local architecture ```python from braindecode.models import ( SignalJEPA_Contextual, SignalJEPA_PreLocal, SignalJEPA_PostLocal, ) from huggingface_hub import hf_hub_download weights_path = hf_hub_download(repo_id="braindecode/SignalJEPA", filename="signal-jepa_16s-60_adeuwv4s.pth") model_state_dict = torch.load(weights_path) # Signal-related arguments # raw: mne.io.BaseRaw chs_info = raw.info["chs"] sfreq = raw.info["sfreq"] # The downstream architectures are equipped with an additional classification head # which was not pre-trained. It has the following new parameters: final_layer_keys = { "final_layer.spat_conv.weight", "final_layer.spat_conv.bias", "final_layer.linear.weight", "final_layer.linear.bias", } # a) Contextual downstream architecture # ---------------------------------- model = SignalJEPA_Contextual( sfreq=sfreq, input_window_seconds=2, chs_info=chs_info, n_outputs=1, ) missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False) assert unexpected_keys == [] # The spatial positional encoder is initialized using the `chs_info`: assert set(missing_keys) == final_layer_keys | {"pos_encoder.pos_encoder_spat.weight"} # In the post-local (b) and pre-local (c) architectures, the transformer is discarded: FILTERED_model_state_dict = { k: v for k, v in model_state_dict.items() if not any(k.startswith(pre) for pre in ["transformer.", "pos_encoder."]) } # b) Post-local downstream architecture # ---------------------------------- model = SignalJEPA_PostLocal( sfreq=sfreq, input_window_seconds=2, n_chans=len(chs_info), # detailed channel info is not needed for this model n_outputs=1, ) missing_keys, unexpected_keys = model.load_state_dict(FILTERED_model_state_dict, strict=False) assert unexpected_keys == [] assert set(missing_keys) == final_layer_keys # c) Pre-local architecture # ---------------------- model = SignalJEPA_PreLocal( sfreq=sfreq, input_window_seconds=2, n_chans=len(chs_info), # detailed channel info is not needed for this model n_outputs=1, ) missing_keys, unexpected_keys = model.load_state_dict(FILTERED_model_state_dict, strict=False) assert unexpected_keys == [] assert set(missing_keys) == { "spatial_conv.1.weight", "spatial_conv.1.bias", "final_layer.1.weight", "final_layer.1.bias", } ```