| import argparse |
| import json |
| import math |
| import os |
| from pathlib import Path |
|
|
| import torch |
| from mistral_common.tokens.tokenizers.mistral import MistralTokenizer |
| from mistral_common.tokens.tokenizers.sentencepiece import SentencePieceTokenizer |
|
|
| from model.args import ModelArgs |
|
|
| FIRST_PIECE_ID = 3 |
| OLD_VOCAB_SIZE = 32000 |
| NEW_VOCAB_SIZE = 32768 |
|
|
|
|
| def extend_model(original_model: Path, extended_model: Path): |
| original_ckpt = torch.load(str(original_model / "consolidated.00.pth"), mmap=True) |
| model_args = ModelArgs.load(str(original_model / "params.json")) |
|
|
| original_vocab_size = model_args.vocab_size |
| assert ( |
| original_vocab_size == OLD_VOCAB_SIZE |
| ), f"Original vocab size {original_vocab_size} is not equal to 32000. Can only extend models with vocab_size of 32000" |
|
|
| if not extended_model.exists(): |
| os.makedirs(extended_model, exist_ok=True) |
| print(f"Created empty directory {extended_model}.") |
|
|
| assert not list( |
| extended_model.iterdir() |
| ), f"Make sure {extended_model} is empty" |
|
|
| |
| mistral_tokenizer = MistralTokenizer.v3() |
| tokenizer: SentencePieceTokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer |
|
|
| new_vocab_size = tokenizer.n_words |
| assert ( |
| new_vocab_size == 32768 |
| ), f"New Tokenzier has vocab_size: {new_vocab_size} but has to be equal to 32768. Make sure to pass a v2 or v3 tokenizer file" |
|
|
| vocabulary_delta = new_vocab_size - original_vocab_size |
|
|
| |
| assert tokenizer._model.id_to_piece(vocabulary_delta + FIRST_PIECE_ID) == "<0x00>" |
| assert tokenizer._model.id_to_piece(FIRST_PIECE_ID - 1) == "</s>" |
|
|
| assert isinstance(tokenizer, SentencePieceTokenizer) |
|
|
| original_embeddings = original_ckpt["tok_embeddings.weight"] |
|
|
| assert ( |
| original_vocab_size == original_embeddings.shape[0] |
| ), f"Original vocab size {original_vocab_size} is not equal to original embeddings shape {original_embeddings.shape[0]}." |
|
|
| dim = original_embeddings.shape[1] |
|
|
| |
| extended_embeddings = torch.zeros( |
| tokenizer.n_words, dim, dtype=original_embeddings.dtype |
| ) |
| extended_embeddings[:original_vocab_size] = original_embeddings |
| extended_embeddings[:FIRST_PIECE_ID] = original_embeddings[:FIRST_PIECE_ID] |
| extended_embeddings[FIRST_PIECE_ID + vocabulary_delta :] = original_embeddings[ |
| FIRST_PIECE_ID: |
| ] |
|
|
| |
| extended_tokens = torch.empty( |
| vocabulary_delta, dim, dtype=original_embeddings.dtype |
| ) |
| torch.nn.init.normal_(extended_tokens, std=1 / math.sqrt(dim)) |
|
|
| extended_embeddings[FIRST_PIECE_ID : FIRST_PIECE_ID + vocabulary_delta] = ( |
| extended_tokens |
| ) |
|
|
| |
| original_output = original_ckpt["output.weight"] |
| assert ( |
| original_output.shape[0] == original_vocab_size |
| ), f"Original output shape {original_output.shape[0]} is not equal to {original_vocab_size}." |
| assert ( |
| original_output.shape[1] == dim |
| ), f"Original output dim {original_output.shape[1]} is not equal to embedding dim {dim}." |
|
|
| assert ( |
| original_output.dtype == original_embeddings.dtype |
| ), f"Original output and embeddings have different dtypes: {original_output.dtype} vs {original_embeddings.dtype}." |
|
|
| extended_output = torch.zeros(tokenizer.n_words, dim, dtype=original_output.dtype) |
| extended_output[:FIRST_PIECE_ID] = original_output[:FIRST_PIECE_ID] |
| extended_output[FIRST_PIECE_ID + vocabulary_delta :] = original_output[ |
| FIRST_PIECE_ID: |
| ] |
|
|
| |
| extended_tokens = torch.empty(vocabulary_delta, dim, dtype=original_output.dtype) |
| torch.nn.init.normal_(extended_tokens, std=1 / math.sqrt(dim)) |
|
|
| extended_output[FIRST_PIECE_ID : FIRST_PIECE_ID + vocabulary_delta] = ( |
| extended_tokens |
| ) |
|
|
| original_ckpt["tok_embeddings.weight"] = extended_embeddings |
| original_ckpt["output.weight"] = extended_output |
|
|
| new_ckpt_path = extended_model / "consolidated.00.pth" |
| print(f"Exporting extended model to {extended_model} ...") |
| torch.save(original_ckpt, new_ckpt_path) |
|
|
| params_path = extended_model / "params.json" |
| with open(params_path, "w") as f: |
| model_dict = model_args.to_dict() |
| del model_dict["lora"] |
| if model_dict["moe"] is None: |
| del model_dict["moe"] |
| model_dict["vocab_size"] = new_vocab_size |
|
|
| f.write(json.dumps(model_dict, indent=4)) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Extend a model using the specified original model, extended model, and tokenizer paths." |
| ) |
| parser.add_argument( |
| "--original_model_ckpt", type=Path, help="Path to the original model folder." |
| ) |
| parser.add_argument( |
| "--extended_model_ckpt", type=Path, help="Path to the extended model file." |
| ) |
| args = parser.parse_args() |
|
|
| extend_model( |
| original_model=args.original_model_ckpt, |
| extended_model=args.extended_model_ckpt, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|