Spaces:
Running
Running
Commit
·
01ede16
1
Parent(s):
ef158b3
fix
Browse files- core/helpers.py +28 -0
- visualization/html_generator.py +7 -6
core/helpers.py
CHANGED
|
@@ -233,6 +233,34 @@ class TokenizerBytesConverter:
|
|
| 233 |
|
| 234 |
return result
|
| 235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
def encode_to_flat_bytes(
|
| 237 |
self,
|
| 238 |
text: str,
|
|
|
|
| 233 |
|
| 234 |
return result
|
| 235 |
|
| 236 |
+
def encode_to_ids_and_bytes(
|
| 237 |
+
self,
|
| 238 |
+
text: str,
|
| 239 |
+
add_special_tokens: bool = False,
|
| 240 |
+
strip_leading_space: bool = True,
|
| 241 |
+
) -> List[tuple]:
|
| 242 |
+
"""
|
| 243 |
+
Encode text to (token_id, token_bytes) pairs.
|
| 244 |
+
|
| 245 |
+
This is useful when the caller needs both the vocab token id and the exact
|
| 246 |
+
byte sequence used by the tokenizer for alignment/visualization.
|
| 247 |
+
"""
|
| 248 |
+
token_ids = self._tokenizer.encode(text, add_special_tokens=add_special_tokens)
|
| 249 |
+
|
| 250 |
+
result = []
|
| 251 |
+
for idx, token_id in enumerate(token_ids):
|
| 252 |
+
token_bytes = self.token_to_bytes(token_id)
|
| 253 |
+
if token_bytes is None:
|
| 254 |
+
continue
|
| 255 |
+
|
| 256 |
+
# Match encode_to_bytes() behavior for SentencePiece ByteFallback tokenizers.
|
| 257 |
+
if idx == 0 and self._decoder_type == "sentencepiece" and strip_leading_space and token_bytes and token_bytes[0] == 0x20:
|
| 258 |
+
token_bytes = token_bytes[1:]
|
| 259 |
+
|
| 260 |
+
result.append((token_id, token_bytes))
|
| 261 |
+
|
| 262 |
+
return result
|
| 263 |
+
|
| 264 |
def encode_to_flat_bytes(
|
| 265 |
self,
|
| 266 |
text: str,
|
visualization/html_generator.py
CHANGED
|
@@ -78,16 +78,17 @@ def get_token_info_for_text(text: str) -> dict:
|
|
| 78 |
# Get Qwen tokens with positions
|
| 79 |
qwen_tokens = []
|
| 80 |
byte_to_qwen = {}
|
| 81 |
-
|
|
|
|
| 82 |
byte_pos = 0
|
| 83 |
-
for idx, token_bytes in enumerate(
|
| 84 |
start = byte_pos
|
| 85 |
end = byte_pos + len(token_bytes)
|
| 86 |
try:
|
| 87 |
token_str = bytes(token_bytes).decode("utf-8")
|
| 88 |
except UnicodeDecodeError:
|
| 89 |
token_str = repr(bytes(token_bytes))
|
| 90 |
-
qwen_tokens.append((start, end, token_str))
|
| 91 |
byte_to_qwen[start] = idx
|
| 92 |
byte_pos = end
|
| 93 |
|
|
@@ -109,7 +110,7 @@ def get_token_info_for_text(text: str) -> dict:
|
|
| 109 |
token_str = token_bytes.decode("utf-8")
|
| 110 |
except UnicodeDecodeError:
|
| 111 |
token_str = repr(token_bytes)
|
| 112 |
-
rwkv_tokens.append((start, end, token_str))
|
| 113 |
byte_to_rwkv[start] = idx
|
| 114 |
byte_pos = end
|
| 115 |
|
|
@@ -249,9 +250,9 @@ def generate_comparison_html(
|
|
| 249 |
|
| 250 |
def get_tokens_for_range(byte_start, byte_end, token_list):
|
| 251 |
result = []
|
| 252 |
-
for
|
| 253 |
if t_start < byte_end and t_end > byte_start:
|
| 254 |
-
result.append((
|
| 255 |
return result
|
| 256 |
|
| 257 |
# Build tokens based on common boundaries
|
|
|
|
| 78 |
# Get Qwen tokens with positions
|
| 79 |
qwen_tokens = []
|
| 80 |
byte_to_qwen = {}
|
| 81 |
+
# Keep both token id (vocab id) and decoded bytes so the tooltip can show true token ids.
|
| 82 |
+
qwen_id_and_bytes = qwen_tokenizer.encode_to_ids_and_bytes(text)
|
| 83 |
byte_pos = 0
|
| 84 |
+
for idx, (token_id, token_bytes) in enumerate(qwen_id_and_bytes):
|
| 85 |
start = byte_pos
|
| 86 |
end = byte_pos + len(token_bytes)
|
| 87 |
try:
|
| 88 |
token_str = bytes(token_bytes).decode("utf-8")
|
| 89 |
except UnicodeDecodeError:
|
| 90 |
token_str = repr(bytes(token_bytes))
|
| 91 |
+
qwen_tokens.append((start, end, token_id, token_str))
|
| 92 |
byte_to_qwen[start] = idx
|
| 93 |
byte_pos = end
|
| 94 |
|
|
|
|
| 110 |
token_str = token_bytes.decode("utf-8")
|
| 111 |
except UnicodeDecodeError:
|
| 112 |
token_str = repr(token_bytes)
|
| 113 |
+
rwkv_tokens.append((start, end, token_id, token_str))
|
| 114 |
byte_to_rwkv[start] = idx
|
| 115 |
byte_pos = end
|
| 116 |
|
|
|
|
| 250 |
|
| 251 |
def get_tokens_for_range(byte_start, byte_end, token_list):
|
| 252 |
result = []
|
| 253 |
+
for t_start, t_end, token_id, t_str in token_list:
|
| 254 |
if t_start < byte_end and t_end > byte_start:
|
| 255 |
+
result.append((token_id, t_str))
|
| 256 |
return result
|
| 257 |
|
| 258 |
# Build tokens based on common boundaries
|