diff --git a/exllamav2/tokenizer/hf.py b/exllamav2/tokenizer/hf.py index 56134d0..9fde261 100644 --- a/exllamav2/tokenizer/hf.py +++ b/exllamav2/tokenizer/hf.py @@ -1,4 +1,5 @@ from __future__ import annotations +import json from typing import List, Union from exllamav2.tokenizer.base import ExLlamaV2TokenizerBase from tokenizers import Tokenizer @@ -10,6 +11,7 @@ class ExLlamaV2TokenizerHF(ExLlamaV2TokenizerBase): space_char_: str newline_char_: str + unk_token_: str | None vocab: list[str] | None def __init__(self, tokenizer_json: str) -> None: @@ -18,6 +20,7 @@ class ExLlamaV2TokenizerHF(ExLlamaV2TokenizerBase): self.vocab = None self.space_char_ = " " self.newline_char_ = "\n" + self.unk_token_ = None self.hf_tokenizer = Tokenizer.from_file(tokenizer_json) @@ -26,11 +29,18 @@ class ExLlamaV2TokenizerHF(ExLlamaV2TokenizerBase): self.space_char_ = self.deduce_char_map(" ") # "Ġ" self.newline_char_ = self.deduce_char_map("\n") # "Ċ" + if isinstance(m, models.Unigram): + unk_id = self._get_unk_id_from_tokenizer_json(tokenizer_json) + self.unk_token_ = self.id_to_piece(unk_id) + else: + self.unk_token_ = getattr(m, "unk_token", None) + def unk_id(self) -> int or None: return None if self.unk_token() is None else self.piece_to_id(self.unk_token()) def pad_id(self) -> int or None: return None def bos_id(self) -> int or None: return None def eos_id(self) -> int or None: return None - def unk_token(self) -> str or None: return self.hf_tokenizer.model.unk_token + # def unk_token(self) -> str or None: return self.hf_tokenizer.model.unk_token + def unk_token(self) -> str or None: return self.unk_token_ def pad_token(self) -> str or None: return None def bos_token(self) -> str or None: return None def eos_token(self) -> str or None: return None @@ -84,3 +94,9 @@ class ExLlamaV2TokenizerHF(ExLlamaV2TokenizerBase): def encode(self, text: list or str) -> list: encoding = self.hf_tokenizer.encode(text, add_special_tokens = False) return encoding.ids + + @staticmethod + def _get_unk_id_from_tokenizer_json(tokenizer_json: str) -> str | None: + with open(tokenizer_json, "r", encoding="utf-8") as f: + tokenizer_json = json.load(f) + return tokenizer_json.get("model", {}).get("unk_id", None)