File size: 2,468 Bytes
a0f7b9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
diff --git a/exllamav2/tokenizer/hf.py b/exllamav2/tokenizer/hf.py
index 56134d0..9fde261 100644
--- a/exllamav2/tokenizer/hf.py
+++ b/exllamav2/tokenizer/hf.py
@@ -1,4 +1,5 @@
 from __future__ import annotations
+import json
 from typing import List, Union
 from exllamav2.tokenizer.base import ExLlamaV2TokenizerBase
 from tokenizers import Tokenizer
@@ -10,6 +11,7 @@ class ExLlamaV2TokenizerHF(ExLlamaV2TokenizerBase):
 
     space_char_: str
     newline_char_: str
+    unk_token_: str | None
     vocab: list[str] | None
 
     def __init__(self, tokenizer_json: str) -> None:
@@ -18,6 +20,7 @@ class ExLlamaV2TokenizerHF(ExLlamaV2TokenizerBase):
         self.vocab = None
         self.space_char_ = " "
         self.newline_char_ = "\n"
+        self.unk_token_ = None
 
         self.hf_tokenizer = Tokenizer.from_file(tokenizer_json)
 
@@ -26,11 +29,18 @@ class ExLlamaV2TokenizerHF(ExLlamaV2TokenizerBase):
             self.space_char_ = self.deduce_char_map(" ")  # "Ġ"
             self.newline_char_ = self.deduce_char_map("\n")  # "Ċ"
 
+        if isinstance(m, models.Unigram):
+            unk_id = self._get_unk_id_from_tokenizer_json(tokenizer_json)
+            self.unk_token_ = self.id_to_piece(unk_id)
+        else:
+            self.unk_token_ = getattr(m, "unk_token", None)
+
     def unk_id(self) -> int or None: return None if self.unk_token() is None else self.piece_to_id(self.unk_token())
     def pad_id(self) -> int or None: return None
     def bos_id(self) -> int or None: return None
     def eos_id(self) -> int or None: return None
-    def unk_token(self) -> str or None: return self.hf_tokenizer.model.unk_token
+    # def unk_token(self) -> str or None: return self.hf_tokenizer.model.unk_token
+    def unk_token(self) -> str or None: return self.unk_token_
     def pad_token(self) -> str or None: return None
     def bos_token(self) -> str or None: return None
     def eos_token(self) -> str or None: return None
@@ -84,3 +94,9 @@ class ExLlamaV2TokenizerHF(ExLlamaV2TokenizerBase):
     def encode(self, text: list or str) -> list:
         encoding = self.hf_tokenizer.encode(text, add_special_tokens = False)
         return encoding.ids
+
+    @staticmethod
+    def _get_unk_id_from_tokenizer_json(tokenizer_json: str) -> str | None:
+        with open(tokenizer_json, "r", encoding="utf-8") as f:
+            tokenizer_json = json.load(f)
+        return tokenizer_json.get("model", {}).get("unk_id", None)