exllamav2_patch / hf.py.patch
tokutsu
Update README & Add patch, script
a0f7b9d
diff --git a/exllamav2/tokenizer/hf.py b/exllamav2/tokenizer/hf.py
index 56134d0..9fde261 100644
--- a/exllamav2/tokenizer/hf.py
+++ b/exllamav2/tokenizer/hf.py
@@ -1,4 +1,5 @@
from __future__ import annotations
+import json
from typing import List, Union
from exllamav2.tokenizer.base import ExLlamaV2TokenizerBase
from tokenizers import Tokenizer
@@ -10,6 +11,7 @@ class ExLlamaV2TokenizerHF(ExLlamaV2TokenizerBase):
space_char_: str
newline_char_: str
+ unk_token_: str | None
vocab: list[str] | None
def __init__(self, tokenizer_json: str) -> None:
@@ -18,6 +20,7 @@ class ExLlamaV2TokenizerHF(ExLlamaV2TokenizerBase):
self.vocab = None
self.space_char_ = " "
self.newline_char_ = "\n"
+ self.unk_token_ = None
self.hf_tokenizer = Tokenizer.from_file(tokenizer_json)
@@ -26,11 +29,18 @@ class ExLlamaV2TokenizerHF(ExLlamaV2TokenizerBase):
self.space_char_ = self.deduce_char_map(" ") # "Ġ"
self.newline_char_ = self.deduce_char_map("\n") # "Ċ"
+ if isinstance(m, models.Unigram):
+ unk_id = self._get_unk_id_from_tokenizer_json(tokenizer_json)
+ self.unk_token_ = self.id_to_piece(unk_id)
+ else:
+ self.unk_token_ = getattr(m, "unk_token", None)
+
def unk_id(self) -> int or None: return None if self.unk_token() is None else self.piece_to_id(self.unk_token())
def pad_id(self) -> int or None: return None
def bos_id(self) -> int or None: return None
def eos_id(self) -> int or None: return None
- def unk_token(self) -> str or None: return self.hf_tokenizer.model.unk_token
+ # def unk_token(self) -> str or None: return self.hf_tokenizer.model.unk_token
+ def unk_token(self) -> str or None: return self.unk_token_
def pad_token(self) -> str or None: return None
def bos_token(self) -> str or None: return None
def eos_token(self) -> str or None: return None
@@ -84,3 +94,9 @@ class ExLlamaV2TokenizerHF(ExLlamaV2TokenizerBase):
def encode(self, text: list or str) -> list:
encoding = self.hf_tokenizer.encode(text, add_special_tokens = False)
return encoding.ids
+
+ @staticmethod
+ def _get_unk_id_from_tokenizer_json(tokenizer_json: str) -> str | None:
+ with open(tokenizer_json, "r", encoding="utf-8") as f:
+ tokenizer_json = json.load(f)
+ return tokenizer_json.get("model", {}).get("unk_id", None)