| import os | |
| import torch | |
| from transformers import GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast | |
| TRANSPOSED_SUFFIXES = ( | |
| "attn.c_attn.weight", | |
| "attn.c_proj.weight", | |
| "mlp.c_fc.weight", | |
| "mlp.c_proj.weight", | |
| ) | |
| def main(ckpt_path: str, out_dir: str): | |
| ckpt = torch.load(ckpt_path, map_location="cpu") | |
| sd = ckpt["model"] | |
| cfg = ckpt["config"] | |
| hf_cfg = GPT2Config( | |
| vocab_size=cfg.vocab_size, | |
| n_positions=cfg.block_size, | |
| n_ctx=cfg.block_size, | |
| n_embd=cfg.n_embd, | |
| n_layer=cfg.n_layer, | |
| n_head=cfg.n_head, | |
| bos_token_id=50256, | |
| eos_token_id=50256, | |
| ) | |
| model = GPT2LMHeadModel(hf_cfg) | |
| sd_hf = model.state_dict() | |
| for k in sd_hf.keys(): | |
| if k.endswith(".attn.bias") or k.endswith(".attn.masked_bias"): | |
| continue | |
| if any(k.endswith(suf) for suf in TRANSPOSED_SUFFIXES): | |
| sd_hf[k].copy_(sd[k].t()) | |
| else: | |
| sd_hf[k].copy_(sd[k]) | |
| model.load_state_dict(sd_hf, strict=False) | |
| os.makedirs(out_dir, exist_ok=True) | |
| model.save_pretrained(out_dir, safe_serialization=True) | |
| tok = GPT2TokenizerFast.from_pretrained("gpt2") | |
| tok.save_pretrained(out_dir) | |
| if __name__ == "__main__": | |
| import argparse | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--ckpt", required=True) | |
| p.add_argument("--out", required=True) | |
| args = p.parse_args() | |
| main(args.ckpt, args.out) | |