Jellyfish042 Claude Sonnet 4.5 commited on
Commit
52ba00f
·
1 Parent(s): a1e2fc4

Fix Top 10 predictions display error

Browse files

- Improve decode_token function with better error handling
- Add detailed logging for debugging token decoding issues
- Enhance build_byte_to_token_map with more informative error messages
- Add docstring to build_byte_to_token_map function

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

Files changed (1) hide show
  1. visualization/html_generator.py +17 -4
visualization/html_generator.py CHANGED
@@ -182,17 +182,25 @@ def generate_comparison_html(
182
  """
183
 
184
  def decode_token(token_id: int, tokenizer, model_type: str) -> str:
 
185
  if tokenizer is None:
186
  return f"[{token_id}]"
187
  try:
188
  if model_type in ["rwkv", "rwkv7"]:
189
- return tokenizer.decode([token_id])
 
 
190
  else:
191
- return tokenizer.decode([token_id])
192
- except:
 
 
 
193
  return f"[{token_id}]"
194
 
195
  def build_byte_to_token_map(text: str, tokenizer, model_type: str):
 
 
196
  if tokenizer is None:
197
  return []
198
 
@@ -200,6 +208,7 @@ def generate_comparison_html(
200
 
201
  try:
202
  if model_type in ["rwkv", "rwkv7"]:
 
203
  tokenized = tokenizer.encode(text)
204
  if hasattr(tokenized, "ids"):
205
  token_ids = tokenized.ids
@@ -212,9 +221,11 @@ def generate_comparison_html(
212
  token_bytes = tokenizer.decodeBytes([token_id])
213
  token_ranges.append((byte_pos, byte_pos + len(token_bytes), idx))
214
  byte_pos += len(token_bytes)
215
- except:
 
216
  pass
217
  else:
 
218
  tokenizer_name = getattr(tokenizer, "name_or_path", None)
219
  if tokenizer_name:
220
  converter = TokenizerBytesConverter(tokenizer_name, trust_remote_code=True)
@@ -223,6 +234,8 @@ def generate_comparison_html(
223
  for idx, token_bytes in enumerate(token_bytes_list):
224
  token_ranges.append((byte_pos, byte_pos + len(token_bytes), idx))
225
  byte_pos += len(token_bytes)
 
 
226
  except Exception as e:
227
  print(f"Warning: Could not build byte-to-token map ({model_type}): {e}")
228
  return []
 
182
  """
183
 
184
  def decode_token(token_id: int, tokenizer, model_type: str) -> str:
185
+ """Decode a single token ID to text using the appropriate tokenizer."""
186
  if tokenizer is None:
187
  return f"[{token_id}]"
188
  try:
189
  if model_type in ["rwkv", "rwkv7"]:
190
+ # RWKV tokenizer uses decode method
191
+ decoded = tokenizer.decode([token_id])
192
+ return decoded if decoded else f"[{token_id}]"
193
  else:
194
+ # HuggingFace tokenizer
195
+ decoded = tokenizer.decode([token_id])
196
+ return decoded if decoded else f"[{token_id}]"
197
+ except Exception as e:
198
+ print(f"Warning: Failed to decode token {token_id} ({model_type}): {e}")
199
  return f"[{token_id}]"
200
 
201
  def build_byte_to_token_map(text: str, tokenizer, model_type: str):
202
+ """Build mapping from byte position to token index using the correct tokenizer.
203
+ Returns a list of (start, end, token_idx) tuples for range-based lookup."""
204
  if tokenizer is None:
205
  return []
206
 
 
208
 
209
  try:
210
  if model_type in ["rwkv", "rwkv7"]:
211
+ # RWKV tokenizer
212
  tokenized = tokenizer.encode(text)
213
  if hasattr(tokenized, "ids"):
214
  token_ids = tokenized.ids
 
221
  token_bytes = tokenizer.decodeBytes([token_id])
222
  token_ranges.append((byte_pos, byte_pos + len(token_bytes), idx))
223
  byte_pos += len(token_bytes)
224
+ except Exception as e:
225
+ print(f"Warning: Failed to decode RWKV token {token_id}: {e}")
226
  pass
227
  else:
228
+ # HuggingFace tokenizer - use TokenizerBytesConverter
229
  tokenizer_name = getattr(tokenizer, "name_or_path", None)
230
  if tokenizer_name:
231
  converter = TokenizerBytesConverter(tokenizer_name, trust_remote_code=True)
 
234
  for idx, token_bytes in enumerate(token_bytes_list):
235
  token_ranges.append((byte_pos, byte_pos + len(token_bytes), idx))
236
  byte_pos += len(token_bytes)
237
+ else:
238
+ print(f"Warning: Could not get tokenizer name for HF model")
239
  except Exception as e:
240
  print(f"Warning: Could not build byte-to-token map ({model_type}): {e}")
241
  return []