| |
| import torch |
| from transformers import AutoTokenizer |
| from src.hf_custom_transformer.config import CustomTransformerConfig |
| from src.hf_custom_transformer.modeling import CustomTransformerModel |
|
|
| |
| state_dict = torch.load("transformer_model.pth", map_location="cpu") |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| |
| config = CustomTransformerConfig( |
| vocab_size=tokenizer.vocab_size, |
| d_model=256, |
| n_heads=4, |
| n_layers=4, |
| max_len=256, |
| max_rel=32, |
| dropout=0.1, |
| ) |
| model = CustomTransformerModel(config) |
|
|
| |
| model.load_state_dict(state_dict) |
| model.save_pretrained("hf_custom_transformer") |
| tokenizer.save_pretrained("hf_custom_transformer") |
| config.save_pretrained("hf_custom_transformer") |
|
|
| print("✅ Saved HF repo in hf_custom_transformer/") |
|
|