custom-transformer / project /save_model.py
abeimam's picture
Upload 13 files
e66238c verified
raw
history blame contribute delete
939 Bytes
# save_model.py
import torch
from transformers import AutoTokenizer
from src.hf_custom_transformer.config import CustomTransformerConfig
from src.hf_custom_transformer.modeling import CustomTransformerModel
# 1. Load your checkpoint
state_dict = torch.load("transformer_model.pth", map_location="cpu")
# 2. Load tokenizer (GPT‑2 vocab)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
# 3. Build config & model
config = CustomTransformerConfig(
vocab_size=tokenizer.vocab_size,
d_model=256,
n_heads=4,
n_layers=4,
max_len=256,
max_rel=32,
dropout=0.1,
)
model = CustomTransformerModel(config)
# 4. Load weights & save HF format
model.load_state_dict(state_dict)
model.save_pretrained("hf_custom_transformer")
tokenizer.save_pretrained("hf_custom_transformer")
config.save_pretrained("hf_custom_transformer")
print("✅ Saved HF repo in hf_custom_transformer/")