bruAristimunha commited on
Commit
253541b
·
verified ·
1 Parent(s): 7de012f

Add ST-EEGFormer conversion script

Browse files
Files changed (1) hide show
  1. convert_steegformer_checkpoints.py +228 -0
convert_steegformer_checkpoints.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Convert and re-host ST-EEGFormer checkpoints on the Hugging Face Hub.
2
+
3
+ The official ST-EEGFormer checkpoints are MAE pre-training checkpoints from
4
+ ``LiuyinYang1101/STEEGFormer``. They are stored under timm-style encoder keys,
5
+ include decoder-only tensors for the pre-training objective, and fuse attention
6
+ queries/keys/values in ``attn.qkv``. This script converts those files into the
7
+ braindecode model format used by :class:`~braindecode.models.STEEGFormer`.
8
+
9
+ The conversion is intentionally kept out of ``STEEGFormer.load_state_dict``:
10
+ runtime loading expects braindecode-format checkpoints, while this one-off
11
+ remapping script is archived with the re-hosted weights.
12
+
13
+ Run from the repo root with the official checkpoint available locally::
14
+
15
+ python hf_assets/model_cards/convert_steegformer_checkpoints.py \
16
+ --src STEEGFormer_small.pth --variant small
17
+
18
+ Add ``--push`` (requires write access to the ``braindecode`` HF org) to upload
19
+ the converted folder, including this script, to the corresponding Hub repo.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import argparse
25
+ import shutil
26
+ import sys
27
+ from collections import OrderedDict
28
+ from pathlib import Path
29
+ from typing import Any
30
+
31
+ import torch
32
+
33
+ _REPO_ROOT = Path(__file__).resolve().parents[2]
34
+ if (_REPO_ROOT / "braindecode" / "models").exists():
35
+ sys.path.insert(0, str(_REPO_ROOT))
36
+
37
+ VARIANTS: dict[str, dict[str, Any]] = {
38
+ "small": {
39
+ "embed_dim": 512,
40
+ "depth": 8,
41
+ "num_heads": 8,
42
+ "n_chans_pos": 145,
43
+ },
44
+ "base": {
45
+ "embed_dim": 768,
46
+ "depth": 12,
47
+ "num_heads": 12,
48
+ "n_chans_pos": 145,
49
+ },
50
+ "large": {
51
+ "embed_dim": 1024,
52
+ "depth": 24,
53
+ "num_heads": 16,
54
+ "n_chans_pos": 145,
55
+ },
56
+ "largeV2": {
57
+ "embed_dim": 1024,
58
+ "depth": 24,
59
+ "num_heads": 16,
60
+ "n_chans_pos": 256,
61
+ },
62
+ }
63
+
64
+ HUB_REPOS = {
65
+ "small": "braindecode/STEEGFormer-small",
66
+ "base": "braindecode/STEEGFormer-base",
67
+ "large": "braindecode/STEEGFormer-large",
68
+ "largeV2": "braindecode/STEEGFormer-largeV2",
69
+ }
70
+
71
+ # timm key inside one block -> this module's encoder-block key. The fused
72
+ # ``attn.qkv`` tensor is split separately.
73
+ _TIMM_BLOCK_RENAMES = {
74
+ "norm1": "0.fn.0",
75
+ "attn.proj": "0.fn.1.projection",
76
+ "norm2": "1.fn.0",
77
+ "mlp.fc1": "1.fn.1.0",
78
+ "mlp.fc2": "1.fn.1.3",
79
+ }
80
+
81
+ _DROP_PREFIXES = ("decoder_", "dec_", "mask_token", "enc_temporal_emd")
82
+ _DROP_EXACT = frozenset(
83
+ {"pos_embed", "fc_norm.weight", "fc_norm.bias", "head.weight", "head.bias"}
84
+ )
85
+ _EXPECTED_MISSING = {"final_layer.weight", "final_layer.bias"}
86
+
87
+
88
+ def _load_checkpoint(src: Path) -> dict[str, Any]:
89
+ try:
90
+ return torch.load(src, map_location="cpu", weights_only=True)
91
+ except Exception:
92
+ return torch.load(src, map_location="cpu", weights_only=False)
93
+
94
+
95
+ def _unwrap_state_dict(checkpoint: dict[str, Any]) -> dict[str, torch.Tensor]:
96
+ for key in ("model", "state_dict"):
97
+ nested = checkpoint.get(key)
98
+ if isinstance(nested, dict):
99
+ return nested
100
+ return checkpoint
101
+
102
+
103
+ def remap_official_state_dict(
104
+ checkpoint: dict[str, Any], embed_dim: int
105
+ ) -> OrderedDict[str, torch.Tensor]:
106
+ """Return a braindecode-format state dict from an official timm checkpoint."""
107
+ state_dict = _unwrap_state_dict(checkpoint)
108
+ if not any("attn.qkv" in key for key in state_dict):
109
+ raise ValueError("Expected an official timm checkpoint with attn.qkv keys.")
110
+
111
+ remapped: OrderedDict[str, torch.Tensor] = OrderedDict()
112
+ unknown_block_keys: list[str] = []
113
+
114
+ for key, value in state_dict.items():
115
+ if key.startswith(_DROP_PREFIXES) or key in _DROP_EXACT:
116
+ continue
117
+
118
+ if key == "enc_channel_emd.channel_transformation.weight":
119
+ remapped["channel_pos.embedding.weight"] = value
120
+ elif key.startswith("blocks."):
121
+ _, idx, rest = key.split(".", 2)
122
+ dst = f"encoder.{idx}."
123
+ if rest in ("attn.qkv.weight", "attn.qkv.bias"):
124
+ if value.shape[0] != 3 * embed_dim:
125
+ raise ValueError(
126
+ f"{key} has first dimension {value.shape[0]}, expected "
127
+ f"{3 * embed_dim} for embed_dim={embed_dim}."
128
+ )
129
+ suffix = rest.rsplit(".", 1)[1]
130
+ remapped[f"{dst}0.fn.1.queries.{suffix}"] = value[:embed_dim]
131
+ remapped[f"{dst}0.fn.1.keys.{suffix}"] = value[
132
+ embed_dim : 2 * embed_dim
133
+ ]
134
+ remapped[f"{dst}0.fn.1.values.{suffix}"] = value[2 * embed_dim :]
135
+ else:
136
+ for src, new in _TIMM_BLOCK_RENAMES.items():
137
+ if rest.startswith(src + "."):
138
+ remapped[dst + new + rest[len(src) :]] = value
139
+ break
140
+ else:
141
+ unknown_block_keys.append(key)
142
+ else:
143
+ # cls_token, patch_embed.proj.*, and norm.* share braindecode names.
144
+ remapped[key] = value
145
+
146
+ if unknown_block_keys:
147
+ shown = ", ".join(unknown_block_keys[:8])
148
+ if len(unknown_block_keys) > 8:
149
+ shown += ", ..."
150
+ raise RuntimeError(f"Unmapped timm block keys: {shown}")
151
+
152
+ return remapped
153
+
154
+
155
+ def convert(src: Path, variant: str, out: Path, n_outputs: int = 4) -> None:
156
+ """Convert one official checkpoint into a braindecode pretrained folder."""
157
+ from braindecode.models import STEEGFormer
158
+
159
+ config = dict(VARIANTS[variant])
160
+ checkpoint = _load_checkpoint(src)
161
+ state_dict = remap_official_state_dict(checkpoint, embed_dim=config["embed_dim"])
162
+
163
+ model = STEEGFormer(
164
+ n_chans=22,
165
+ n_times=1000,
166
+ n_outputs=n_outputs,
167
+ patch_size=16,
168
+ mlp_ratio=4,
169
+ **config,
170
+ )
171
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
172
+ extra_missing = sorted(set(missing) - _EXPECTED_MISSING)
173
+ if extra_missing or unexpected:
174
+ raise RuntimeError(
175
+ f"Converted checkpoint mismatch: missing={extra_missing}, "
176
+ f"unexpected={list(unexpected)}"
177
+ )
178
+
179
+ out.mkdir(parents=True, exist_ok=True)
180
+ model.save_pretrained(out)
181
+ shutil.copy2(Path(__file__).resolve(), out / Path(__file__).name)
182
+ print(f"Saved {variant} to {out} ({len(state_dict)} tensors loaded)")
183
+
184
+
185
+ def push(out: Path, variant: str) -> None:
186
+ """Upload a converted folder to the matching braindecode Hub repo."""
187
+ from huggingface_hub import HfApi
188
+
189
+ repo_id = HUB_REPOS[variant]
190
+ api = HfApi()
191
+ api.create_repo(repo_id, repo_type="model", exist_ok=True)
192
+ api.upload_folder(
193
+ repo_id=repo_id,
194
+ folder_path=str(out),
195
+ commit_message="Add ST-EEGFormer braindecode checkpoint conversion",
196
+ )
197
+ print(f"Pushed {out} -> {repo_id}")
198
+
199
+
200
+ def main() -> None:
201
+ parser = argparse.ArgumentParser(
202
+ description=__doc__,
203
+ formatter_class=argparse.RawDescriptionHelpFormatter,
204
+ )
205
+ parser.add_argument("--src", type=Path, required=True)
206
+ parser.add_argument("--variant", choices=sorted(VARIANTS), required=True)
207
+ parser.add_argument(
208
+ "--out",
209
+ type=Path,
210
+ default=None,
211
+ help="output folder (default: hf_export/STEEGFormer-<variant>)",
212
+ )
213
+ parser.add_argument("--n-outputs", type=int, default=4)
214
+ parser.add_argument(
215
+ "--push",
216
+ action="store_true",
217
+ help="upload to the matching braindecode/STEEGFormer-* repo",
218
+ )
219
+ args = parser.parse_args()
220
+
221
+ out = args.out or Path("hf_export") / f"STEEGFormer-{args.variant}"
222
+ convert(args.src, args.variant, out, n_outputs=args.n_outputs)
223
+ if args.push:
224
+ push(out, args.variant)
225
+
226
+
227
+ if __name__ == "__main__":
228
+ main()