"""Convert and re-host ST-EEGFormer checkpoints on the Hugging Face Hub. The official ST-EEGFormer checkpoints are MAE pre-training checkpoints from ``LiuyinYang1101/STEEGFormer``. They are stored under timm-style encoder keys, include decoder-only tensors for the pre-training objective, and fuse attention queries/keys/values in ``attn.qkv``. This script converts those files into the braindecode model format used by :class:`~braindecode.models.STEEGFormer`. The conversion is intentionally kept out of ``STEEGFormer.load_state_dict``: runtime loading expects braindecode-format checkpoints, while this one-off remapping script is archived with the re-hosted weights. Run from the repo root with the official checkpoint available locally:: python hf_assets/model_cards/convert_steegformer_checkpoints.py \ --src STEEGFormer_small.pth --variant small Add ``--push`` (requires write access to the ``braindecode`` HF org) to upload the converted folder, including this script, to the corresponding Hub repo. """ from __future__ import annotations import argparse import shutil import sys from collections import OrderedDict from pathlib import Path from typing import Any import torch _REPO_ROOT = Path(__file__).resolve().parents[2] if (_REPO_ROOT / "braindecode" / "models").exists(): sys.path.insert(0, str(_REPO_ROOT)) VARIANTS: dict[str, dict[str, Any]] = { "small": { "embed_dim": 512, "depth": 8, "num_heads": 8, "n_chans_pos": 145, }, "base": { "embed_dim": 768, "depth": 12, "num_heads": 12, "n_chans_pos": 145, }, "large": { "embed_dim": 1024, "depth": 24, "num_heads": 16, "n_chans_pos": 145, }, "largeV2": { "embed_dim": 1024, "depth": 24, "num_heads": 16, "n_chans_pos": 256, }, } HUB_REPOS = { "small": "braindecode/STEEGFormer-small", "base": "braindecode/STEEGFormer-base", "large": "braindecode/STEEGFormer-large", "largeV2": "braindecode/STEEGFormer-largeV2", } # timm key inside one block -> this module's encoder-block key. The fused # ``attn.qkv`` tensor is split separately. _TIMM_BLOCK_RENAMES = { "norm1": "0.fn.0", "attn.proj": "0.fn.1.projection", "norm2": "1.fn.0", "mlp.fc1": "1.fn.1.0", "mlp.fc2": "1.fn.1.3", } _DROP_PREFIXES = ("decoder_", "dec_", "mask_token", "enc_temporal_emd") _DROP_EXACT = frozenset( {"pos_embed", "fc_norm.weight", "fc_norm.bias", "head.weight", "head.bias"} ) _EXPECTED_MISSING = {"final_layer.weight", "final_layer.bias"} def _load_checkpoint(src: Path) -> dict[str, Any]: try: return torch.load(src, map_location="cpu", weights_only=True) except Exception: return torch.load(src, map_location="cpu", weights_only=False) def _unwrap_state_dict(checkpoint: dict[str, Any]) -> dict[str, torch.Tensor]: for key in ("model", "state_dict"): nested = checkpoint.get(key) if isinstance(nested, dict): return nested return checkpoint def remap_official_state_dict( checkpoint: dict[str, Any], embed_dim: int ) -> OrderedDict[str, torch.Tensor]: """Return a braindecode-format state dict from an official timm checkpoint.""" state_dict = _unwrap_state_dict(checkpoint) if not any("attn.qkv" in key for key in state_dict): raise ValueError("Expected an official timm checkpoint with attn.qkv keys.") remapped: OrderedDict[str, torch.Tensor] = OrderedDict() unknown_block_keys: list[str] = [] for key, value in state_dict.items(): if key.startswith(_DROP_PREFIXES) or key in _DROP_EXACT: continue if key == "enc_channel_emd.channel_transformation.weight": remapped["channel_pos.embedding.weight"] = value elif key.startswith("blocks."): _, idx, rest = key.split(".", 2) dst = f"encoder.{idx}." if rest in ("attn.qkv.weight", "attn.qkv.bias"): if value.shape[0] != 3 * embed_dim: raise ValueError( f"{key} has first dimension {value.shape[0]}, expected " f"{3 * embed_dim} for embed_dim={embed_dim}." ) suffix = rest.rsplit(".", 1)[1] remapped[f"{dst}0.fn.1.queries.{suffix}"] = value[:embed_dim] remapped[f"{dst}0.fn.1.keys.{suffix}"] = value[ embed_dim : 2 * embed_dim ] remapped[f"{dst}0.fn.1.values.{suffix}"] = value[2 * embed_dim :] else: for src, new in _TIMM_BLOCK_RENAMES.items(): if rest.startswith(src + "."): remapped[dst + new + rest[len(src) :]] = value break else: unknown_block_keys.append(key) else: # cls_token, patch_embed.proj.*, and norm.* share braindecode names. remapped[key] = value if unknown_block_keys: shown = ", ".join(unknown_block_keys[:8]) if len(unknown_block_keys) > 8: shown += ", ..." raise RuntimeError(f"Unmapped timm block keys: {shown}") return remapped def convert(src: Path, variant: str, out: Path, n_outputs: int = 4) -> None: """Convert one official checkpoint into a braindecode pretrained folder.""" from braindecode.models import STEEGFormer config = dict(VARIANTS[variant]) checkpoint = _load_checkpoint(src) state_dict = remap_official_state_dict(checkpoint, embed_dim=config["embed_dim"]) model = STEEGFormer( n_chans=22, n_times=1000, n_outputs=n_outputs, patch_size=16, mlp_ratio=4, **config, ) missing, unexpected = model.load_state_dict(state_dict, strict=False) extra_missing = sorted(set(missing) - _EXPECTED_MISSING) if extra_missing or unexpected: raise RuntimeError( f"Converted checkpoint mismatch: missing={extra_missing}, " f"unexpected={list(unexpected)}" ) out.mkdir(parents=True, exist_ok=True) model.save_pretrained(out) shutil.copy2(Path(__file__).resolve(), out / Path(__file__).name) print(f"Saved {variant} to {out} ({len(state_dict)} tensors loaded)") def push(out: Path, variant: str) -> None: """Upload a converted folder to the matching braindecode Hub repo.""" from huggingface_hub import HfApi repo_id = HUB_REPOS[variant] api = HfApi() api.create_repo(repo_id, repo_type="model", exist_ok=True) api.upload_folder( repo_id=repo_id, folder_path=str(out), commit_message="Add ST-EEGFormer braindecode checkpoint conversion", ) print(f"Pushed {out} -> {repo_id}") def main() -> None: parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument("--src", type=Path, required=True) parser.add_argument("--variant", choices=sorted(VARIANTS), required=True) parser.add_argument( "--out", type=Path, default=None, help="output folder (default: hf_export/STEEGFormer-)", ) parser.add_argument("--n-outputs", type=int, default=4) parser.add_argument( "--push", action="store_true", help="upload to the matching braindecode/STEEGFormer-* repo", ) args = parser.parse_args() out = args.out or Path("hf_export") / f"STEEGFormer-{args.variant}" convert(args.src, args.variant, out, n_outputs=args.n_outputs) if args.push: push(out, args.variant) if __name__ == "__main__": main()