STEEGFormer-base / convert_steegformer_checkpoints.py
bruAristimunha's picture
Add ST-EEGFormer conversion script
253541b verified
Raw
History Blame Contribute Delete
7.75 kB
"""Convert and re-host ST-EEGFormer checkpoints on the Hugging Face Hub.
The official ST-EEGFormer checkpoints are MAE pre-training checkpoints from
``LiuyinYang1101/STEEGFormer``. They are stored under timm-style encoder keys,
include decoder-only tensors for the pre-training objective, and fuse attention
queries/keys/values in ``attn.qkv``. This script converts those files into the
braindecode model format used by :class:`~braindecode.models.STEEGFormer`.
The conversion is intentionally kept out of ``STEEGFormer.load_state_dict``:
runtime loading expects braindecode-format checkpoints, while this one-off
remapping script is archived with the re-hosted weights.
Run from the repo root with the official checkpoint available locally::
python hf_assets/model_cards/convert_steegformer_checkpoints.py \
--src STEEGFormer_small.pth --variant small
Add ``--push`` (requires write access to the ``braindecode`` HF org) to upload
the converted folder, including this script, to the corresponding Hub repo.
"""
from __future__ import annotations
import argparse
import shutil
import sys
from collections import OrderedDict
from pathlib import Path
from typing import Any
import torch
_REPO_ROOT = Path(__file__).resolve().parents[2]
if (_REPO_ROOT / "braindecode" / "models").exists():
sys.path.insert(0, str(_REPO_ROOT))
VARIANTS: dict[str, dict[str, Any]] = {
"small": {
"embed_dim": 512,
"depth": 8,
"num_heads": 8,
"n_chans_pos": 145,
},
"base": {
"embed_dim": 768,
"depth": 12,
"num_heads": 12,
"n_chans_pos": 145,
},
"large": {
"embed_dim": 1024,
"depth": 24,
"num_heads": 16,
"n_chans_pos": 145,
},
"largeV2": {
"embed_dim": 1024,
"depth": 24,
"num_heads": 16,
"n_chans_pos": 256,
},
}
HUB_REPOS = {
"small": "braindecode/STEEGFormer-small",
"base": "braindecode/STEEGFormer-base",
"large": "braindecode/STEEGFormer-large",
"largeV2": "braindecode/STEEGFormer-largeV2",
}
# timm key inside one block -> this module's encoder-block key. The fused
# ``attn.qkv`` tensor is split separately.
_TIMM_BLOCK_RENAMES = {
"norm1": "0.fn.0",
"attn.proj": "0.fn.1.projection",
"norm2": "1.fn.0",
"mlp.fc1": "1.fn.1.0",
"mlp.fc2": "1.fn.1.3",
}
_DROP_PREFIXES = ("decoder_", "dec_", "mask_token", "enc_temporal_emd")
_DROP_EXACT = frozenset(
{"pos_embed", "fc_norm.weight", "fc_norm.bias", "head.weight", "head.bias"}
)
_EXPECTED_MISSING = {"final_layer.weight", "final_layer.bias"}
def _load_checkpoint(src: Path) -> dict[str, Any]:
try:
return torch.load(src, map_location="cpu", weights_only=True)
except Exception:
return torch.load(src, map_location="cpu", weights_only=False)
def _unwrap_state_dict(checkpoint: dict[str, Any]) -> dict[str, torch.Tensor]:
for key in ("model", "state_dict"):
nested = checkpoint.get(key)
if isinstance(nested, dict):
return nested
return checkpoint
def remap_official_state_dict(
checkpoint: dict[str, Any], embed_dim: int
) -> OrderedDict[str, torch.Tensor]:
"""Return a braindecode-format state dict from an official timm checkpoint."""
state_dict = _unwrap_state_dict(checkpoint)
if not any("attn.qkv" in key for key in state_dict):
raise ValueError("Expected an official timm checkpoint with attn.qkv keys.")
remapped: OrderedDict[str, torch.Tensor] = OrderedDict()
unknown_block_keys: list[str] = []
for key, value in state_dict.items():
if key.startswith(_DROP_PREFIXES) or key in _DROP_EXACT:
continue
if key == "enc_channel_emd.channel_transformation.weight":
remapped["channel_pos.embedding.weight"] = value
elif key.startswith("blocks."):
_, idx, rest = key.split(".", 2)
dst = f"encoder.{idx}."
if rest in ("attn.qkv.weight", "attn.qkv.bias"):
if value.shape[0] != 3 * embed_dim:
raise ValueError(
f"{key} has first dimension {value.shape[0]}, expected "
f"{3 * embed_dim} for embed_dim={embed_dim}."
)
suffix = rest.rsplit(".", 1)[1]
remapped[f"{dst}0.fn.1.queries.{suffix}"] = value[:embed_dim]
remapped[f"{dst}0.fn.1.keys.{suffix}"] = value[
embed_dim : 2 * embed_dim
]
remapped[f"{dst}0.fn.1.values.{suffix}"] = value[2 * embed_dim :]
else:
for src, new in _TIMM_BLOCK_RENAMES.items():
if rest.startswith(src + "."):
remapped[dst + new + rest[len(src) :]] = value
break
else:
unknown_block_keys.append(key)
else:
# cls_token, patch_embed.proj.*, and norm.* share braindecode names.
remapped[key] = value
if unknown_block_keys:
shown = ", ".join(unknown_block_keys[:8])
if len(unknown_block_keys) > 8:
shown += ", ..."
raise RuntimeError(f"Unmapped timm block keys: {shown}")
return remapped
def convert(src: Path, variant: str, out: Path, n_outputs: int = 4) -> None:
"""Convert one official checkpoint into a braindecode pretrained folder."""
from braindecode.models import STEEGFormer
config = dict(VARIANTS[variant])
checkpoint = _load_checkpoint(src)
state_dict = remap_official_state_dict(checkpoint, embed_dim=config["embed_dim"])
model = STEEGFormer(
n_chans=22,
n_times=1000,
n_outputs=n_outputs,
patch_size=16,
mlp_ratio=4,
**config,
)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
extra_missing = sorted(set(missing) - _EXPECTED_MISSING)
if extra_missing or unexpected:
raise RuntimeError(
f"Converted checkpoint mismatch: missing={extra_missing}, "
f"unexpected={list(unexpected)}"
)
out.mkdir(parents=True, exist_ok=True)
model.save_pretrained(out)
shutil.copy2(Path(__file__).resolve(), out / Path(__file__).name)
print(f"Saved {variant} to {out} ({len(state_dict)} tensors loaded)")
def push(out: Path, variant: str) -> None:
"""Upload a converted folder to the matching braindecode Hub repo."""
from huggingface_hub import HfApi
repo_id = HUB_REPOS[variant]
api = HfApi()
api.create_repo(repo_id, repo_type="model", exist_ok=True)
api.upload_folder(
repo_id=repo_id,
folder_path=str(out),
commit_message="Add ST-EEGFormer braindecode checkpoint conversion",
)
print(f"Pushed {out} -> {repo_id}")
def main() -> None:
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("--src", type=Path, required=True)
parser.add_argument("--variant", choices=sorted(VARIANTS), required=True)
parser.add_argument(
"--out",
type=Path,
default=None,
help="output folder (default: hf_export/STEEGFormer-<variant>)",
)
parser.add_argument("--n-outputs", type=int, default=4)
parser.add_argument(
"--push",
action="store_true",
help="upload to the matching braindecode/STEEGFormer-* repo",
)
args = parser.parse_args()
out = args.out or Path("hf_export") / f"STEEGFormer-{args.variant}"
convert(args.src, args.variant, out, n_outputs=args.n_outputs)
if args.push:
push(out, args.variant)
if __name__ == "__main__":
main()