bruAristimunha commited on
Commit
09addfe
·
verified ·
1 Parent(s): b8d5107

Remove duplicate converter (keep convert_steegformer_checkpoints.py)

Browse files
Files changed (1) hide show
  1. convert_checkpoint.py +0 -79
convert_checkpoint.py DELETED
@@ -1,79 +0,0 @@
1
- """Convert an official ST-EEGFormer MAE checkpoint to braindecode format.
2
-
3
- The braindecode ``STEEGFormer`` loads braindecode-format state dicts directly,
4
- so this one-off converter is what produced the re-hosted ``model.safetensors``
5
- in this repo. It remaps the upstream ``timm`` keys (drops the MAE decoder and
6
- downstream-only keys, renames the blocks, and splits the fused ``attn.qkv``
7
- into the separate queries/keys/values of braindecode's ``MultiHeadAttention``).
8
-
9
- Usage:
10
- python convert_checkpoint.py checkpoint-300.pth ./out \\
11
- --embed-dim 512 --depth 8 --num-heads 8 --n-chans-pos 145
12
- """
13
- import argparse
14
- from collections import OrderedDict
15
-
16
- import torch
17
-
18
- from braindecode.models import STEEGFormer
19
-
20
- _BLOCK = {
21
- "norm1": "0.fn.0",
22
- "attn.proj": "0.fn.1.projection",
23
- "norm2": "1.fn.0",
24
- "mlp.fc1": "1.fn.1.0",
25
- "mlp.fc2": "1.fn.1.3",
26
- }
27
- _DROP_PREFIX = ("decoder_", "dec_", "mask_token", "enc_temporal_emd")
28
- _DROP_EXACT = {"pos_embed", "fc_norm.weight", "fc_norm.bias", "head.weight", "head.bias"}
29
-
30
-
31
- def remap(state_dict, embed_dim):
32
- if isinstance(state_dict.get("model"), dict):
33
- state_dict = state_dict["model"]
34
- e, out = embed_dim, OrderedDict()
35
- for k, v in state_dict.items():
36
- if k.startswith(_DROP_PREFIX) or k in _DROP_EXACT:
37
- continue
38
- if k == "enc_channel_emd.channel_transformation.weight":
39
- out["channel_pos.embedding.weight"] = v
40
- elif k.startswith("blocks."):
41
- _, i, rest = k.split(".", 2)
42
- d = f"encoder.{i}."
43
- if rest in ("attn.qkv.weight", "attn.qkv.bias"):
44
- s = rest.rsplit(".", 1)[1]
45
- out[f"{d}0.fn.1.queries.{s}"] = v[:e]
46
- out[f"{d}0.fn.1.keys.{s}"] = v[e : 2 * e]
47
- out[f"{d}0.fn.1.values.{s}"] = v[2 * e :]
48
- else:
49
- for a, b in _BLOCK.items():
50
- if rest.startswith(a + "."):
51
- out[d + b + rest[len(a) :]] = v
52
- break
53
- else: # cls_token, patch_embed.proj.*, norm.* keep their names
54
- out[k] = v
55
- return out
56
-
57
-
58
- def main():
59
- p = argparse.ArgumentParser(description=__doc__)
60
- p.add_argument("checkpoint")
61
- p.add_argument("out_dir")
62
- p.add_argument("--embed-dim", type=int, required=True)
63
- p.add_argument("--depth", type=int, required=True)
64
- p.add_argument("--num-heads", type=int, required=True)
65
- p.add_argument("--n-chans-pos", type=int, default=145)
66
- a = p.parse_args()
67
- ck = torch.load(a.checkpoint, map_location="cpu", weights_only=False)
68
- model = STEEGFormer(
69
- n_chans=22, n_outputs=4, n_times=1000, embed_dim=a.embed_dim,
70
- depth=a.depth, num_heads=a.num_heads, n_chans_pos=a.n_chans_pos,
71
- )
72
- res = model.load_state_dict(remap(ck, a.embed_dim), strict=False)
73
- assert not res.unexpected_keys, f"unexpected: {res.unexpected_keys[:5]}"
74
- model.save_pretrained(a.out_dir)
75
- print(f"saved braindecode-format checkpoint to {a.out_dir}")
76
-
77
-
78
- if __name__ == "__main__":
79
- main()