File size: 939 Bytes
e66238c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# save_model.py
import torch
from transformers import AutoTokenizer
from src.hf_custom_transformer.config import CustomTransformerConfig
from src.hf_custom_transformer.modeling import CustomTransformerModel

# 1. Load your checkpoint
state_dict = torch.load("transformer_model.pth", map_location="cpu")

# 2. Load tokenizer (GPT‑2 vocab)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# 3. Build config & model
config = CustomTransformerConfig(
    vocab_size=tokenizer.vocab_size,
    d_model=256,
    n_heads=4,
    n_layers=4,
    max_len=256,
    max_rel=32,
    dropout=0.1,
)
model = CustomTransformerModel(config)

# 4. Load weights & save HF format
model.load_state_dict(state_dict)
model.save_pretrained("hf_custom_transformer")
tokenizer.save_pretrained("hf_custom_transformer")
config.save_pretrained("hf_custom_transformer")

print("✅ Saved HF repo in hf_custom_transformer/")