GPT2 / src /convert_ckpt_to_hf.py
svshrithik12's picture
Upload folder using huggingface_hub
c96ac34 verified
import os
import torch
from transformers import GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast
TRANSPOSED_SUFFIXES = (
"attn.c_attn.weight",
"attn.c_proj.weight",
"mlp.c_fc.weight",
"mlp.c_proj.weight",
)
def main(ckpt_path: str, out_dir: str):
ckpt = torch.load(ckpt_path, map_location="cpu")
sd = ckpt["model"]
cfg = ckpt["config"]
hf_cfg = GPT2Config(
vocab_size=cfg.vocab_size,
n_positions=cfg.block_size,
n_ctx=cfg.block_size,
n_embd=cfg.n_embd,
n_layer=cfg.n_layer,
n_head=cfg.n_head,
bos_token_id=50256,
eos_token_id=50256,
)
model = GPT2LMHeadModel(hf_cfg)
sd_hf = model.state_dict()
for k in sd_hf.keys():
if k.endswith(".attn.bias") or k.endswith(".attn.masked_bias"):
continue
if any(k.endswith(suf) for suf in TRANSPOSED_SUFFIXES):
sd_hf[k].copy_(sd[k].t())
else:
sd_hf[k].copy_(sd[k])
model.load_state_dict(sd_hf, strict=False)
os.makedirs(out_dir, exist_ok=True)
model.save_pretrained(out_dir, safe_serialization=True)
tok = GPT2TokenizerFast.from_pretrained("gpt2")
tok.save_pretrained(out_dir)
if __name__ == "__main__":
import argparse
p = argparse.ArgumentParser()
p.add_argument("--ckpt", required=True)
p.add_argument("--out", required=True)
args = p.parse_args()
main(args.ckpt, args.out)