| """Convert and re-host ST-EEGFormer checkpoints on the Hugging Face Hub. |
| |
| The official ST-EEGFormer checkpoints are MAE pre-training checkpoints from |
| ``LiuyinYang1101/STEEGFormer``. They are stored under timm-style encoder keys, |
| include decoder-only tensors for the pre-training objective, and fuse attention |
| queries/keys/values in ``attn.qkv``. This script converts those files into the |
| braindecode model format used by :class:`~braindecode.models.STEEGFormer`. |
| |
| The conversion is intentionally kept out of ``STEEGFormer.load_state_dict``: |
| runtime loading expects braindecode-format checkpoints, while this one-off |
| remapping script is archived with the re-hosted weights. |
| |
| Run from the repo root with the official checkpoint available locally:: |
| |
| python hf_assets/model_cards/convert_steegformer_checkpoints.py \ |
| --src STEEGFormer_small.pth --variant small |
| |
| Add ``--push`` (requires write access to the ``braindecode`` HF org) to upload |
| the converted folder, including this script, to the corresponding Hub repo. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import shutil |
| import sys |
| from collections import OrderedDict |
| from pathlib import Path |
| from typing import Any |
|
|
| import torch |
|
|
| _REPO_ROOT = Path(__file__).resolve().parents[2] |
| if (_REPO_ROOT / "braindecode" / "models").exists(): |
| sys.path.insert(0, str(_REPO_ROOT)) |
|
|
| VARIANTS: dict[str, dict[str, Any]] = { |
| "small": { |
| "embed_dim": 512, |
| "depth": 8, |
| "num_heads": 8, |
| "n_chans_pos": 145, |
| }, |
| "base": { |
| "embed_dim": 768, |
| "depth": 12, |
| "num_heads": 12, |
| "n_chans_pos": 145, |
| }, |
| "large": { |
| "embed_dim": 1024, |
| "depth": 24, |
| "num_heads": 16, |
| "n_chans_pos": 145, |
| }, |
| "largeV2": { |
| "embed_dim": 1024, |
| "depth": 24, |
| "num_heads": 16, |
| "n_chans_pos": 256, |
| }, |
| } |
|
|
| HUB_REPOS = { |
| "small": "braindecode/STEEGFormer-small", |
| "base": "braindecode/STEEGFormer-base", |
| "large": "braindecode/STEEGFormer-large", |
| "largeV2": "braindecode/STEEGFormer-largeV2", |
| } |
|
|
| |
| |
| _TIMM_BLOCK_RENAMES = { |
| "norm1": "0.fn.0", |
| "attn.proj": "0.fn.1.projection", |
| "norm2": "1.fn.0", |
| "mlp.fc1": "1.fn.1.0", |
| "mlp.fc2": "1.fn.1.3", |
| } |
|
|
| _DROP_PREFIXES = ("decoder_", "dec_", "mask_token", "enc_temporal_emd") |
| _DROP_EXACT = frozenset( |
| {"pos_embed", "fc_norm.weight", "fc_norm.bias", "head.weight", "head.bias"} |
| ) |
| _EXPECTED_MISSING = {"final_layer.weight", "final_layer.bias"} |
|
|
|
|
| def _load_checkpoint(src: Path) -> dict[str, Any]: |
| try: |
| return torch.load(src, map_location="cpu", weights_only=True) |
| except Exception: |
| return torch.load(src, map_location="cpu", weights_only=False) |
|
|
|
|
| def _unwrap_state_dict(checkpoint: dict[str, Any]) -> dict[str, torch.Tensor]: |
| for key in ("model", "state_dict"): |
| nested = checkpoint.get(key) |
| if isinstance(nested, dict): |
| return nested |
| return checkpoint |
|
|
|
|
| def remap_official_state_dict( |
| checkpoint: dict[str, Any], embed_dim: int |
| ) -> OrderedDict[str, torch.Tensor]: |
| """Return a braindecode-format state dict from an official timm checkpoint.""" |
| state_dict = _unwrap_state_dict(checkpoint) |
| if not any("attn.qkv" in key for key in state_dict): |
| raise ValueError("Expected an official timm checkpoint with attn.qkv keys.") |
|
|
| remapped: OrderedDict[str, torch.Tensor] = OrderedDict() |
| unknown_block_keys: list[str] = [] |
|
|
| for key, value in state_dict.items(): |
| if key.startswith(_DROP_PREFIXES) or key in _DROP_EXACT: |
| continue |
|
|
| if key == "enc_channel_emd.channel_transformation.weight": |
| remapped["channel_pos.embedding.weight"] = value |
| elif key.startswith("blocks."): |
| _, idx, rest = key.split(".", 2) |
| dst = f"encoder.{idx}." |
| if rest in ("attn.qkv.weight", "attn.qkv.bias"): |
| if value.shape[0] != 3 * embed_dim: |
| raise ValueError( |
| f"{key} has first dimension {value.shape[0]}, expected " |
| f"{3 * embed_dim} for embed_dim={embed_dim}." |
| ) |
| suffix = rest.rsplit(".", 1)[1] |
| remapped[f"{dst}0.fn.1.queries.{suffix}"] = value[:embed_dim] |
| remapped[f"{dst}0.fn.1.keys.{suffix}"] = value[ |
| embed_dim : 2 * embed_dim |
| ] |
| remapped[f"{dst}0.fn.1.values.{suffix}"] = value[2 * embed_dim :] |
| else: |
| for src, new in _TIMM_BLOCK_RENAMES.items(): |
| if rest.startswith(src + "."): |
| remapped[dst + new + rest[len(src) :]] = value |
| break |
| else: |
| unknown_block_keys.append(key) |
| else: |
| |
| remapped[key] = value |
|
|
| if unknown_block_keys: |
| shown = ", ".join(unknown_block_keys[:8]) |
| if len(unknown_block_keys) > 8: |
| shown += ", ..." |
| raise RuntimeError(f"Unmapped timm block keys: {shown}") |
|
|
| return remapped |
|
|
|
|
| def convert(src: Path, variant: str, out: Path, n_outputs: int = 4) -> None: |
| """Convert one official checkpoint into a braindecode pretrained folder.""" |
| from braindecode.models import STEEGFormer |
|
|
| config = dict(VARIANTS[variant]) |
| checkpoint = _load_checkpoint(src) |
| state_dict = remap_official_state_dict(checkpoint, embed_dim=config["embed_dim"]) |
|
|
| model = STEEGFormer( |
| n_chans=22, |
| n_times=1000, |
| n_outputs=n_outputs, |
| patch_size=16, |
| mlp_ratio=4, |
| **config, |
| ) |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| extra_missing = sorted(set(missing) - _EXPECTED_MISSING) |
| if extra_missing or unexpected: |
| raise RuntimeError( |
| f"Converted checkpoint mismatch: missing={extra_missing}, " |
| f"unexpected={list(unexpected)}" |
| ) |
|
|
| out.mkdir(parents=True, exist_ok=True) |
| model.save_pretrained(out) |
| shutil.copy2(Path(__file__).resolve(), out / Path(__file__).name) |
| print(f"Saved {variant} to {out} ({len(state_dict)} tensors loaded)") |
|
|
|
|
| def push(out: Path, variant: str) -> None: |
| """Upload a converted folder to the matching braindecode Hub repo.""" |
| from huggingface_hub import HfApi |
|
|
| repo_id = HUB_REPOS[variant] |
| api = HfApi() |
| api.create_repo(repo_id, repo_type="model", exist_ok=True) |
| api.upload_folder( |
| repo_id=repo_id, |
| folder_path=str(out), |
| commit_message="Add ST-EEGFormer braindecode checkpoint conversion", |
| ) |
| print(f"Pushed {out} -> {repo_id}") |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser( |
| description=__doc__, |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| ) |
| parser.add_argument("--src", type=Path, required=True) |
| parser.add_argument("--variant", choices=sorted(VARIANTS), required=True) |
| parser.add_argument( |
| "--out", |
| type=Path, |
| default=None, |
| help="output folder (default: hf_export/STEEGFormer-<variant>)", |
| ) |
| parser.add_argument("--n-outputs", type=int, default=4) |
| parser.add_argument( |
| "--push", |
| action="store_true", |
| help="upload to the matching braindecode/STEEGFormer-* repo", |
| ) |
| args = parser.parse_args() |
|
|
| out = args.out or Path("hf_export") / f"STEEGFormer-{args.variant}" |
| convert(args.src, args.variant, out, n_outputs=args.n_outputs) |
| if args.push: |
| push(out, args.variant) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|