Spaces:
Running
Running
Commit
·
52ba00f
1
Parent(s):
a1e2fc4
Fix Top 10 predictions display error
Browse files- Improve decode_token function with better error handling
- Add detailed logging for debugging token decoding issues
- Enhance build_byte_to_token_map with more informative error messages
- Add docstring to build_byte_to_token_map function
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
visualization/html_generator.py
CHANGED
|
@@ -182,17 +182,25 @@ def generate_comparison_html(
|
|
| 182 |
"""
|
| 183 |
|
| 184 |
def decode_token(token_id: int, tokenizer, model_type: str) -> str:
|
|
|
|
| 185 |
if tokenizer is None:
|
| 186 |
return f"[{token_id}]"
|
| 187 |
try:
|
| 188 |
if model_type in ["rwkv", "rwkv7"]:
|
| 189 |
-
|
|
|
|
|
|
|
| 190 |
else:
|
| 191 |
-
|
| 192 |
-
|
|
|
|
|
|
|
|
|
|
| 193 |
return f"[{token_id}]"
|
| 194 |
|
| 195 |
def build_byte_to_token_map(text: str, tokenizer, model_type: str):
|
|
|
|
|
|
|
| 196 |
if tokenizer is None:
|
| 197 |
return []
|
| 198 |
|
|
@@ -200,6 +208,7 @@ def generate_comparison_html(
|
|
| 200 |
|
| 201 |
try:
|
| 202 |
if model_type in ["rwkv", "rwkv7"]:
|
|
|
|
| 203 |
tokenized = tokenizer.encode(text)
|
| 204 |
if hasattr(tokenized, "ids"):
|
| 205 |
token_ids = tokenized.ids
|
|
@@ -212,9 +221,11 @@ def generate_comparison_html(
|
|
| 212 |
token_bytes = tokenizer.decodeBytes([token_id])
|
| 213 |
token_ranges.append((byte_pos, byte_pos + len(token_bytes), idx))
|
| 214 |
byte_pos += len(token_bytes)
|
| 215 |
-
except:
|
|
|
|
| 216 |
pass
|
| 217 |
else:
|
|
|
|
| 218 |
tokenizer_name = getattr(tokenizer, "name_or_path", None)
|
| 219 |
if tokenizer_name:
|
| 220 |
converter = TokenizerBytesConverter(tokenizer_name, trust_remote_code=True)
|
|
@@ -223,6 +234,8 @@ def generate_comparison_html(
|
|
| 223 |
for idx, token_bytes in enumerate(token_bytes_list):
|
| 224 |
token_ranges.append((byte_pos, byte_pos + len(token_bytes), idx))
|
| 225 |
byte_pos += len(token_bytes)
|
|
|
|
|
|
|
| 226 |
except Exception as e:
|
| 227 |
print(f"Warning: Could not build byte-to-token map ({model_type}): {e}")
|
| 228 |
return []
|
|
|
|
| 182 |
"""
|
| 183 |
|
| 184 |
def decode_token(token_id: int, tokenizer, model_type: str) -> str:
|
| 185 |
+
"""Decode a single token ID to text using the appropriate tokenizer."""
|
| 186 |
if tokenizer is None:
|
| 187 |
return f"[{token_id}]"
|
| 188 |
try:
|
| 189 |
if model_type in ["rwkv", "rwkv7"]:
|
| 190 |
+
# RWKV tokenizer uses decode method
|
| 191 |
+
decoded = tokenizer.decode([token_id])
|
| 192 |
+
return decoded if decoded else f"[{token_id}]"
|
| 193 |
else:
|
| 194 |
+
# HuggingFace tokenizer
|
| 195 |
+
decoded = tokenizer.decode([token_id])
|
| 196 |
+
return decoded if decoded else f"[{token_id}]"
|
| 197 |
+
except Exception as e:
|
| 198 |
+
print(f"Warning: Failed to decode token {token_id} ({model_type}): {e}")
|
| 199 |
return f"[{token_id}]"
|
| 200 |
|
| 201 |
def build_byte_to_token_map(text: str, tokenizer, model_type: str):
|
| 202 |
+
"""Build mapping from byte position to token index using the correct tokenizer.
|
| 203 |
+
Returns a list of (start, end, token_idx) tuples for range-based lookup."""
|
| 204 |
if tokenizer is None:
|
| 205 |
return []
|
| 206 |
|
|
|
|
| 208 |
|
| 209 |
try:
|
| 210 |
if model_type in ["rwkv", "rwkv7"]:
|
| 211 |
+
# RWKV tokenizer
|
| 212 |
tokenized = tokenizer.encode(text)
|
| 213 |
if hasattr(tokenized, "ids"):
|
| 214 |
token_ids = tokenized.ids
|
|
|
|
| 221 |
token_bytes = tokenizer.decodeBytes([token_id])
|
| 222 |
token_ranges.append((byte_pos, byte_pos + len(token_bytes), idx))
|
| 223 |
byte_pos += len(token_bytes)
|
| 224 |
+
except Exception as e:
|
| 225 |
+
print(f"Warning: Failed to decode RWKV token {token_id}: {e}")
|
| 226 |
pass
|
| 227 |
else:
|
| 228 |
+
# HuggingFace tokenizer - use TokenizerBytesConverter
|
| 229 |
tokenizer_name = getattr(tokenizer, "name_or_path", None)
|
| 230 |
if tokenizer_name:
|
| 231 |
converter = TokenizerBytesConverter(tokenizer_name, trust_remote_code=True)
|
|
|
|
| 234 |
for idx, token_bytes in enumerate(token_bytes_list):
|
| 235 |
token_ranges.append((byte_pos, byte_pos + len(token_bytes), idx))
|
| 236 |
byte_pos += len(token_bytes)
|
| 237 |
+
else:
|
| 238 |
+
print(f"Warning: Could not get tokenizer name for HF model")
|
| 239 |
except Exception as e:
|
| 240 |
print(f"Warning: Could not build byte-to-token map ({model_type}): {e}")
|
| 241 |
return []
|