| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import json |
| | import logging |
| | import tempfile |
| | from typing import Dict, List, Optional, Tuple |
| |
|
| | import tokenizers |
| | import tokenizers.models |
| | import torch |
| | import tqdm |
| | import transformers |
| | from pydantic import BaseModel |
| |
|
| | from mergekit.common import ModelPath, ModelReference |
| | from mergekit.graph import Task |
| |
|
| |
|
| | def get_vocab_size(model_path: ModelPath, trust_remote_code: bool) -> Optional[int]: |
| | try: |
| | cfg = transformers.AutoConfig.from_pretrained( |
| | model_path.path, |
| | revision=model_path.revision, |
| | trust_remote_code=trust_remote_code, |
| | ) |
| | return cfg.vocab_size |
| | except Exception as e: |
| | logging.warning(f"Unable to get vocab size for {model_path}", exc_info=e) |
| |
|
| | return None |
| |
|
| |
|
| | def get_stripped_tokenizer( |
| | path: ModelPath, trust_remote_code: bool = False |
| | ) -> transformers.PreTrainedTokenizerFast: |
| | """ |
| | Return a tokenizer for a model that only contains used tokens. |
| | |
| | Strips any tokens with indices >= model.vocab_size. |
| | """ |
| | tokenizer = transformers.AutoTokenizer.from_pretrained( |
| | path.path, |
| | revision=path.revision, |
| | trust_remote_code=trust_remote_code, |
| | use_fast=True, |
| | ) |
| | vocab_size = get_vocab_size(path, trust_remote_code=trust_remote_code) or len( |
| | tokenizer.get_vocab() |
| | ) |
| |
|
| | unused_toks = [ |
| | tok for tok, idx in tokenizer.get_vocab().items() if idx >= vocab_size |
| | ] |
| | if not unused_toks: |
| | |
| | return tokenizer |
| |
|
| | if not tokenizer.is_fast: |
| | raise RuntimeError( |
| | f"Model {path} has unused tokens and does not support fast " |
| | "tokenizer - can not be used in tokenizer merge" |
| | ) |
| |
|
| | tok_dict = json.loads(tokenizer._tokenizer.to_str()) |
| | if tok_dict["model"]["type"] != "BPE": |
| | raise RuntimeError( |
| | f"Tokenizer for {path} has type {tok_dict['model']['type']}, " |
| | "but only BPE is currently supported for tokenizer merge" |
| | ) |
| |
|
| | tok_dict["added_tokens"] = [ |
| | e for e in tok_dict["added_tokens"] if e["id"] < vocab_size |
| | ] |
| |
|
| | for tok in unused_toks: |
| | if tok in tok_dict["model"]["vocab"]: |
| | del tok_dict["model"]["vocab"][tok] |
| |
|
| | def _keep_merge(m): |
| | toks = m.split(" ") |
| | for tok in toks: |
| | if tok in unused_toks: |
| | return False |
| | return True |
| |
|
| | tok_dict["model"]["merges"] = [ |
| | e for e in tok_dict["model"]["merges"] if _keep_merge(e) |
| | ] |
| | tokenizer._tokenizer = tokenizers.Tokenizer.from_str(json.dumps(tok_dict)) |
| | return tokenizer |
| |
|
| |
|
| | def build_union_tokenizer( |
| | base_tok: transformers.PreTrainedTokenizerBase, |
| | tokenizers: Dict[ModelReference, transformers.PreTrainedTokenizerBase], |
| | trust_remote_code: bool = False, |
| | ) -> transformers.PreTrainedTokenizerBase: |
| | out_added_tokens = {} |
| | out_vocab = {} |
| |
|
| | warned_added_tokens = set() |
| |
|
| | for model, tokenizer in tokenizers.items(): |
| | vocab_size = ( |
| | get_vocab_size(model.model, trust_remote_code=trust_remote_code) |
| | or tokenizer.vocab_size |
| | ) |
| | added_tokens = tokenizer.added_tokens_decoder |
| |
|
| | vocab = tokenizer.get_vocab() |
| | for tok, idx in vocab.items(): |
| | if idx >= vocab_size: |
| | logging.warning( |
| | f"Token {repr(tok)} present in {str(model)} tokenizer but >= vocab_size" |
| | ) |
| | continue |
| | if tok in added_tokens: |
| | |
| | continue |
| |
|
| | if tok not in out_vocab: |
| | out_vocab[tok] = len(out_vocab) |
| |
|
| | for tok_idx, info in tokenizer.added_tokens_decoder.items(): |
| | tok = info.content |
| | if tok_idx >= vocab_size: |
| | continue |
| |
|
| | if tok in out_added_tokens: |
| | if (out_added_tokens[tok] != info) and tok not in warned_added_tokens: |
| | logging.warning( |
| | f"Token '{tok}' added with multiple different settings, using first" |
| | ) |
| | warned_added_tokens.add(tok) |
| |
|
| | continue |
| | out_added_tokens[tok] = info |
| |
|
| | |
| | with tempfile.TemporaryDirectory() as p: |
| | base_tok.save_pretrained(p, legacy_format=False, safe_serialization=True) |
| | res = transformers.AutoTokenizer.from_pretrained( |
| | p, use_fast=True, trust_remote_code=trust_remote_code |
| | ) |
| |
|
| | orig_base_vocab = base_tok.get_vocab() |
| | for tok in out_vocab: |
| | if tok in out_added_tokens: |
| | continue |
| |
|
| | if tok not in orig_base_vocab: |
| | res.add_tokens(tok) |
| |
|
| | for info in out_added_tokens.values(): |
| | res.add_tokens(info) |
| | return res |
| |
|
| |
|
| | def build_tokenizer( |
| | base_model: Optional[ModelReference], |
| | referenced_models: List[ModelReference], |
| | tokenizer_source: str, |
| | trust_remote_code: bool, |
| | ) -> Tuple[transformers.PreTrainedTokenizer, Dict[ModelReference, torch.IntTensor]]: |
| | if base_model is None: |
| | base_model = referenced_models[0] |
| | if base_model is None: |
| | raise RuntimeError("No models referenced") |
| |
|
| | |
| | tokenizer_base = get_stripped_tokenizer( |
| | base_model.model, trust_remote_code=trust_remote_code |
| | ) |
| |
|
| | |
| | logging.info("Loading tokenizers") |
| | tokenizers = {base_model: tokenizer_base} |
| | for model in referenced_models: |
| | if model == base_model: |
| | continue |
| |
|
| | try: |
| | model_tok = transformers.AutoTokenizer.from_pretrained( |
| | model.model.path, |
| | revision=model.model.revision, |
| | trust_remote_code=trust_remote_code, |
| | ) |
| | except Exception as e: |
| | logging.error(e) |
| | logging.warning( |
| | f"Unable to load tokenizer for {model}. Assuming same as {base_model}." |
| | ) |
| | continue |
| | tokenizers[model] = model_tok |
| |
|
| | logging.info("Building output tokenizer") |
| | |
| | if tokenizer_source == "base": |
| | |
| | tokenizer_out = tokenizer_base |
| | elif tokenizer_source == "union": |
| | tokenizer_out = build_union_tokenizer( |
| | tokenizer_base, tokenizers, trust_remote_code=trust_remote_code |
| | ) |
| | elif tokenizer_source.startswith("model:"): |
| | tokenizer_out = transformers.AutoTokenizer.from_pretrained( |
| | tokenizer_source[len("model:") :], |
| | trust_remote_code=trust_remote_code, |
| | ) |
| | else: |
| | raise RuntimeError(f"Unimplemented tokenizer source: {tokenizer_source}") |
| |
|
| | vocab_out = tokenizer_out.get_vocab() |
| |
|
| | logging.info("Building permutations") |
| | permutations = {} |
| | for model in tqdm.tqdm(referenced_models): |
| | if model in tokenizers: |
| | model_vocab = tokenizers[model].get_vocab() |
| | else: |
| | model_vocab = tokenizers[base_model].get_vocab() |
| |
|
| | vocab_size = get_vocab_size(model.model, trust_remote_code=trust_remote_code) |
| | if vocab_size is None: |
| | vocab_size = len(model_vocab) |
| |
|
| | p = {} |
| | for tok in vocab_out: |
| | new_idx = vocab_out[tok] |
| | if tok not in model_vocab: |
| | p[new_idx] = -1 |
| | continue |
| |
|
| | orig_idx = model_vocab[tok] |
| | if orig_idx >= vocab_size: |
| | logging.warning( |
| | f"{model} token {repr(tok)} has index {orig_idx}>{vocab_size-1} (padding?)" |
| | ) |
| | continue |
| |
|
| | p[new_idx] = orig_idx |
| |
|
| | permutations[model] = p |
| |
|
| | return tokenizer_out, permutations |
| |
|
| |
|
| | class TokenizerInfo(BaseModel, arbitrary_types_allowed=True): |
| | tokenizer: transformers.PreTrainedTokenizerBase |
| | permutations: Optional[Dict[ModelReference, Dict[int, int]]] |
| |
|
| |
|
| | class BuildTokenizer(Task[TokenizerInfo]): |
| | base_model: Optional[ModelReference] |
| | referenced_models: Tuple[ModelReference, ...] |
| | tokenizer_source: str |
| | trust_remote_code: bool = False |
| |
|
| | def arguments(self) -> Dict[str, Task]: |
| | return {} |
| |
|
| | def execute(self, **_kwargs) -> TokenizerInfo: |
| | tokenizer, permutations = build_tokenizer( |
| | self.base_model, |
| | self.referenced_models, |
| | self.tokenizer_source, |
| | self.trust_remote_code, |
| | ) |
| | return TokenizerInfo(tokenizer=tokenizer, permutations=permutations) |
| |
|