| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import torch |
| from JiRackTernary_new import JiRackConfig, JiRackTernary1B |
| from transformers import AutoTokenizer |
| from safetensors.torch import load_file, save_file |
| import os |
|
|
| print("π Copying embeddings and lm_head...") |
|
|
| old_model_path = "." |
| new_tokenizer_path = "./jirack_code_tokenizer" |
| save_path = "./JiRack_init_model_with_new_vocab" |
|
|
| os.makedirs(save_path, exist_ok=True) |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(new_tokenizer_path) |
| new_vocab_size = len(tokenizer) |
|
|
| print(f"New vocab size: {new_vocab_size}") |
|
|
| |
| config = JiRackConfig() |
| model = JiRackTernary1B(config) |
|
|
| |
| old_state = load_file(f"{old_model_path}/model.safetensors") |
|
|
| old_vocab_size = 128256 |
|
|
| with torch.no_grad(): |
| |
| model.token_emb.weight[:old_vocab_size] = old_state['token_emb.weight'][:old_vocab_size].clone() |
| model.lm_head.weight[:old_vocab_size] = old_state['lm_head.weight'][:old_vocab_size].clone() |
| |
| |
| mean_emb = old_state['token_emb.weight'].mean(dim=0) |
| model.token_emb.weight[old_vocab_size:] = mean_emb |
| model.lm_head.weight[old_vocab_size:] = mean_emb |
|
|
| print(f"β
Copied {old_vocab_size} tokens") |
| print(f"β
Initialized {new_vocab_size - old_vocab_size} new tokens") |
|
|
| |
| save_file(model.state_dict(), f"{save_path}/model.safetensors") |
| tokenizer.save_pretrained(save_path) |
|
|
| print(f"\nπ Done! New model saved to: {save_path}") |
| print("Use this folder as the starting weights for training from scratch.") |