Byte-lingua-code / superbpe /tokenizers_superbpe /bindings /python /scripts /sentencepiece_extractor.py
| from argparse import ArgumentParser | |
| from json import dump | |
| from logging import basicConfig, getLogger | |
| from os import linesep, remove | |
| from os.path import exists | |
| from tempfile import NamedTemporaryFile | |
| from typing import Dict, List, Tuple | |
| from requests import get | |
| from sentencepiece import SentencePieceProcessor | |
| from tqdm import trange, tqdm | |
| basicConfig() | |
| logger = getLogger() | |
| class SentencePieceExtractor: | |
| """ | |
| Extractor implementation for SentencePiece trained models. | |
| https://github.com/google/sentencepiece | |
| """ | |
| def __init__(self, model: str): | |
| # Get SentencePiece | |
| self.sp = SentencePieceProcessor() | |
| self.sp.Load(model) | |
| def extract(self) -> Tuple[Dict[str, int], List[Tuple]]: | |
| sp = self.sp | |
| vocab = {sp.id_to_piece(index): index for index in trange(sp.GetPieceSize())} | |
| # Merges | |
| merges = [] | |
| for piece_l in tqdm(vocab.keys(), total=sp.GetPieceSize()): | |
| for piece_r in vocab.keys(): | |
| merge = f"{piece_l}{piece_r}" | |
| piece_id = vocab.get(merge, None) | |
| if piece_id: | |
| merges += [(piece_l, piece_r, piece_id)] | |
| merges = sorted(merges, key=lambda val: val[2]) | |
| merges = [(val[0], val[1]) for val in merges] | |
| return vocab, merges | |
| class YouTokenToMeExtractor: | |
| """ | |
| Extractor implementation for YouTokenToMe trained models format. | |
| Model are as follow: | |
| vocab_size nb_merges | |
| piece piece_id | |
| ...(repeated vocab_size) | |
| piece_id_left piece_id_right piece_id | |
| ...(repeated nb merges) | |
| """ | |
| def __init__(self, model: str): | |
| self._model = model | |
| def extract(self) -> Tuple[Dict[str, int], List[Tuple]]: | |
| with open(self._model, "r") as model_f: | |
| # Retrieve information | |
| nb_pieces, nb_merges = map(int, model_f.readline().split()) | |
| vocab, merges = {}, [] | |
| # Vocab | |
| for _ in trange(nb_pieces): | |
| piece, piece_id = map(int, model_f.readline().split()) | |
| vocab[piece_id] = chr(piece) | |
| # Merges | |
| for _ in trange(nb_merges): | |
| piece_id_l, piece_id_r, piece = map(int, model_f.readline().split()) | |
| piece_l, piece_r = vocab[piece_id_l], vocab[piece_id_r] | |
| vocab[piece] = f"{piece_l}{piece_r}" | |
| merges += [(piece_l, piece_r)] | |
| # Special tokens | |
| unk, pad, bos, eos = map(int, model_f.readline().split()) | |
| vocab[unk] = "<unk>" | |
| vocab[pad] = "<pad>" | |
| vocab[bos] = "<bos>" | |
| vocab[eos] = "<eos>" | |
| # Invert key and value for vocab | |
| vocab = dict(zip(vocab.values(), vocab.keys())) | |
| return vocab, merges | |
| if __name__ == "__main__": | |
| parser = ArgumentParser("SentencePiece vocab extractor") | |
| parser.add_argument( | |
| "--provider", | |
| type=str, | |
| required=True, | |
| choices=["sentencepiece", "youtokentome"], | |
| help="Indicate the format of the file.", | |
| ) | |
| parser.add_argument("--model", type=str, required=True, help="SentencePiece model to extract vocab from.") | |
| parser.add_argument( | |
| "--vocab-output-path", | |
| type=str, | |
| required=True, | |
| help="Path where the vocab.json file will be extracted", | |
| ) | |
| parser.add_argument( | |
| "--merges-output-path", | |
| type=str, | |
| required=True, | |
| help="Path where the merges file will be extracted", | |
| ) | |
| # Parse cli arguments | |
| args = parser.parse_args() | |
| try: | |
| if args.model.startswith("http"): | |
| # Saving model | |
| with NamedTemporaryFile("wb", delete=False) as f: | |
| logger.info("Writing content from {} to {}".format(args.model, f.name)) | |
| response = get(args.model, allow_redirects=True) | |
| f.write(response.content) | |
| args.remote_model = args.model | |
| args.model = f.name | |
| # Allocate extractor | |
| extractor = SentencePieceExtractor if args.provider == "sentencepiece" else YouTokenToMeExtractor | |
| extractor = extractor(args.model) | |
| logger.info(f"Using {type(extractor).__name__}") | |
| # Open output files and let's extract model information | |
| with open(args.vocab_output_path, "w") as vocab_f: | |
| with open(args.merges_output_path, "w") as merges_f: | |
| # Do the extraction | |
| vocab, merges = extractor.extract() | |
| # Save content | |
| dump(vocab, vocab_f) | |
| merges_f.writelines(map(lambda x: f"{x[0]} {x[1]}{linesep}", merges)) | |
| finally: | |
| # If model was downloaded from internet we need to cleanup the tmp folder. | |
| if hasattr(args, "remote_model") and exists(args.model): | |
| remove(args.model) | |