PierreGtch commited on
Commit
186dceb
·
verified ·
1 Parent(s): ef750ef

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +79 -3
README.md CHANGED
@@ -3,27 +3,103 @@ license: mit
3
  ---
4
  # Usage
5
 
 
6
  ```python
7
  from braindecode.models import SignalJEPA
8
  from huggingface_hub import hf_hub_download
9
 
10
- weights_path = hf_hub_download(repo_id='braindecode/SignalJEPA', filename='signal-jepa_16s-60_adeuwv4s.pth')
11
  model_state_dict = torch.load(weights_path)
12
 
13
  # Signal-related arguments
14
  # raw: mne.io.BaseRaw
15
  chs_info = raw.info["chs"]
16
- sfreq = raw.info['sfreq']
17
 
18
  model = SignalJEPA(
19
  sfreq=sfreq,
20
  input_window_seconds=2,
21
  chs_info=chs_info,
22
- n_outputs=1,
23
  )
24
  missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
25
  assert unexpected_keys == []
26
  # The spatial positional encoder is initialized using the `chs_info`:
27
  assert set(missing_keys) == {"pos_encoder.pos_encoder_spat.weight"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  ```
 
3
  ---
4
  # Usage
5
 
6
+ **Instantiate the base model**
7
  ```python
8
  from braindecode.models import SignalJEPA
9
  from huggingface_hub import hf_hub_download
10
 
11
+ weights_path = hf_hub_download(repo_id="braindecode/SignalJEPA", filename="signal-jepa_16s-60_adeuwv4s.pth")
12
  model_state_dict = torch.load(weights_path)
13
 
14
  # Signal-related arguments
15
  # raw: mne.io.BaseRaw
16
  chs_info = raw.info["chs"]
17
+ sfreq = raw.info["sfreq"]
18
 
19
  model = SignalJEPA(
20
  sfreq=sfreq,
21
  input_window_seconds=2,
22
  chs_info=chs_info,
 
23
  )
24
  missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
25
  assert unexpected_keys == []
26
  # The spatial positional encoder is initialized using the `chs_info`:
27
  assert set(missing_keys) == {"pos_encoder.pos_encoder_spat.weight"}
28
+ ```
29
+
30
+ **Instantiate the downstream architectures**
31
+
32
+ The downstream architectures are equipped with a classification head.
33
+ See the article [arXiv:2403.11772](https://arxiv.org/abs/2403.11772) for more details.
34
+ ```python
35
+ from braindecode.models import (
36
+ SignalJEPA_Contextual,
37
+ SignalJEPA_PreLocal,
38
+ SignalJEPA_PostLocal,
39
+ )
40
+ from huggingface_hub import hf_hub_download
41
+
42
+ weights_path = hf_hub_download(repo_id="braindecode/SignalJEPA", filename="signal-jepa_16s-60_adeuwv4s.pth")
43
+ model_state_dict = torch.load(weights_path)
44
+
45
+ # Signal-related arguments
46
+ # raw: mne.io.BaseRaw
47
+ chs_info = raw.info["chs"]
48
+ sfreq = raw.info["sfreq"]
49
+
50
+ # The downstream architectures are equipped with an additional classification head
51
+ # which was not pre-trained. It has the following new parameters:
52
+ final_layer_keys = {
53
+ "final_layer.spat_conv.weight",
54
+ "final_layer.spat_conv.bias",
55
+ "final_layer.linear.weight",
56
+ "final_layer.linear.bias",
57
+ }
58
+
59
+ # a) Contextual downstream architecture
60
+ # ----------------------------------
61
+ model = SignalJEPA_Contextual(
62
+ sfreq=sfreq,
63
+ input_window_seconds=2,
64
+ chs_info=chs_info,
65
+ n_outputs=1,
66
+ )
67
+ missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
68
+ assert unexpected_keys == []
69
+ # The spatial positional encoder is initialized using the `chs_info`:
70
+ assert set(missing_keys) == final_layer_keys | {"pos_encoder.pos_encoder_spat.weight"}
71
 
72
+ # In the post-local (b) and pre-local (c) architectures, the transformer is discarded:
73
+ FILTERED_model_state_dict = {
74
+ k: v for k, v in model_state_dict.items() if not any(k.startswith(pre) for pre in ["transformer.", "pos_encoder."])
75
+ }
76
+
77
+ # b) Post-local downstream architecture
78
+ # ----------------------------------
79
+ model = SignalJEPA_PostLocal(
80
+ sfreq=sfreq,
81
+ input_window_seconds=2,
82
+ n_chans=len(chs_info), # detailed channel info is not needed for this model
83
+ n_outputs=1,
84
+ )
85
+ missing_keys, unexpected_keys = model.load_state_dict(FILTERED_model_state_dict, strict=False)
86
+ assert unexpected_keys == []
87
+ assert set(missing_keys) == final_layer_keys
88
+
89
+ # c) Pre-local architecture
90
+ # ----------------------
91
+ model = SignalJEPA_PreLocal(
92
+ sfreq=sfreq,
93
+ input_window_seconds=2,
94
+ n_chans=len(chs_info), # detailed channel info is not needed for this model
95
+ n_outputs=1,
96
+ )
97
+ missing_keys, unexpected_keys = model.load_state_dict(FILTERED_model_state_dict, strict=False)
98
+ assert unexpected_keys == []
99
+ assert set(missing_keys) == {
100
+ "spatial_conv.1.weight",
101
+ "spatial_conv.1.bias",
102
+ "final_layer.1.weight",
103
+ "final_layer.1.bias",
104
+ }
105
  ```