bruAristimunha's picture
Add architecture-only model card
56a1cbc verified
|
raw
history blame
6.53 kB
metadata
license: bsd-3-clause
library_name: braindecode
pipeline_tag: feature-extraction
tags:
  - eeg
  - biosignal
  - pytorch
  - neuroscience
  - braindecode
  - foundation-model
  - convolutional

SignalJEPA_PostLocal

Post-local downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) .

Architecture-only repository. This repo documents the braindecode.models.SignalJEPA_PostLocal class. No pretrained weights are distributed here — instantiate the model and train it on your own data, or fine-tune from a published foundation-model checkpoint separately.

Quick start

pip install braindecode
from braindecode.models import SignalJEPA_PostLocal

model = SignalJEPA_PostLocal(
    n_chans=22,
    sfreq=250,
    input_window_seconds=4.0,
    n_outputs=4,
)

The signal-shape arguments above are example defaults — adjust them to match your recording.

Documentation

Architecture description

The block below is the rendered class docstring (parameters, references, architecture figure where available).

Post-local downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.

Convolution

:bdg-dark-line:ChannelFoundation Model

This architecture is one of the variants of :class:SignalJEPA that can be used for classification purposes.

.. figure:: https://braindecode.org/dev/_static/model/sjepa_post-local.jpg :align: center :alt: sJEPA Pre-Local.

.. versionadded:: 0.9

.. rubric:: Pretrained Weights

Only the feature encoder weights are reused from the shared SSL checkpoints. This model has no channel embedding nor transformer, so strict=False is required at load time to skip the unused keys. Either hub variant works; the _without-chans one is slightly smaller.

.. important:: Pre-trained Weights Available

.. code:: python
    from braindecode.models import SignalJEPA_PostLocal

    model = SignalJEPA_PostLocal.from_pretrained(
        "braindecode/signal-jepa_without-chans",
        n_chans=22,
        input_window_seconds=16.0,
        n_outputs=4,
        strict=False,
    )

Requires installing ``braindecode[hub]`` for Hub integration.

.. rubric:: Usage

.. code:: python from braindecode.models import SignalJEPA_PostLocal

 model = SignalJEPA_PostLocal(
     n_chans=22,
     input_window_seconds=16.0,
     sfreq=128,
     n_outputs=4,  # e.g., 4-class classification
 )

 # Forward: (batch, n_chans, n_times) -> (batch, n_outputs)
 output = model(eeg_data)

.. warning::

 Pre-trained at **128 Hz** on EEG bandpass-filtered between
 **0.5 and 40 Hz** and rescaled by a factor of :math:`10^{6}`
 (volts to microvolts). Apply the same preprocessing to your
 data to match the pre-training distribution.

Parameters

n_spat_filters : int Number of spatial filters.

References

.. [1] Guetschel, P., Moreau, T., & Tangermann, M. (2024). S-JEPA: towards seamless cross-dataset transfer through dynamic spatial attention. In 9th Graz Brain-Computer Interface Conference, https://www.doi.org/10.3217/978-3-99161-014-4-003

.. rubric:: Hugging Face Hub integration

When the optional huggingface_hub package is installed, all models automatically gain the ability to be pushed to and loaded from the Hugging Face Hub. Install with::

 pip install braindecode[hub]

Pushing a model to the Hub:

.. code:: from braindecode.models import SignalJEPA_PostLocal

 # Train your model
 model = SignalJEPA_PostLocal(n_chans=22, n_outputs=4, n_times=1000)
 # ... training code ...

 # Push to the Hub
 model.push_to_hub(
     repo_id="username/my-signaljepa_postlocal-model",
     commit_message="Initial model upload",
 )

Loading a model from the Hub:

.. code:: from braindecode.models import SignalJEPA_PostLocal

 # Load pretrained model
 model = SignalJEPA_PostLocal.from_pretrained("username/my-signaljepa_postlocal-model")

 # Load with a different number of outputs (head is rebuilt automatically)
 model = SignalJEPA_PostLocal.from_pretrained("username/my-signaljepa_postlocal-model", n_outputs=4)

Extracting features and replacing the head:

.. code:: import torch

 x = torch.randn(1, model.n_chans, model.n_times)
 # Extract encoder features (consistent dict across all models)
 out = model(x, return_features=True)
 features = out["features"]

 # Replace the classification head
 model.reset_head(n_outputs=10)

Saving and restoring full configuration:

.. code:: import json

 config = model.get_config()            # all __init__ params
 with open("config.json", "w") as f:
     json.dump(config, f)

 model2 = SignalJEPA_PostLocal.from_config(config)    # reconstruct (no weights)

All model parameters (both EEG-specific and model-specific such as dropout rates, activation functions, number of filters) are automatically saved to the Hub and restored when loading.

See :ref:load-pretrained-models for a complete tutorial.

Citation

Please cite both the original paper for this architecture (see the References section above) and braindecode:

@article{aristimunha2025braindecode,
  title   = {Braindecode: a deep learning library for raw electrophysiological data},
  author  = {Aristimunha, Bruno and others},
  journal = {Zenodo},
  year    = {2025},
  doi     = {10.5281/zenodo.17699192},
}

License

BSD-3-Clause for the model code (matching braindecode). Pretraining-derived weights, if you fine-tune from a checkpoint, inherit the licence of that checkpoint and its training corpus.