Jellyfish042 commited on
Commit
01ede16
·
1 Parent(s): ef158b3
Files changed (2) hide show
  1. core/helpers.py +28 -0
  2. visualization/html_generator.py +7 -6
core/helpers.py CHANGED
@@ -233,6 +233,34 @@ class TokenizerBytesConverter:
233
 
234
  return result
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  def encode_to_flat_bytes(
237
  self,
238
  text: str,
 
233
 
234
  return result
235
 
236
+ def encode_to_ids_and_bytes(
237
+ self,
238
+ text: str,
239
+ add_special_tokens: bool = False,
240
+ strip_leading_space: bool = True,
241
+ ) -> List[tuple]:
242
+ """
243
+ Encode text to (token_id, token_bytes) pairs.
244
+
245
+ This is useful when the caller needs both the vocab token id and the exact
246
+ byte sequence used by the tokenizer for alignment/visualization.
247
+ """
248
+ token_ids = self._tokenizer.encode(text, add_special_tokens=add_special_tokens)
249
+
250
+ result = []
251
+ for idx, token_id in enumerate(token_ids):
252
+ token_bytes = self.token_to_bytes(token_id)
253
+ if token_bytes is None:
254
+ continue
255
+
256
+ # Match encode_to_bytes() behavior for SentencePiece ByteFallback tokenizers.
257
+ if idx == 0 and self._decoder_type == "sentencepiece" and strip_leading_space and token_bytes and token_bytes[0] == 0x20:
258
+ token_bytes = token_bytes[1:]
259
+
260
+ result.append((token_id, token_bytes))
261
+
262
+ return result
263
+
264
  def encode_to_flat_bytes(
265
  self,
266
  text: str,
visualization/html_generator.py CHANGED
@@ -78,16 +78,17 @@ def get_token_info_for_text(text: str) -> dict:
78
  # Get Qwen tokens with positions
79
  qwen_tokens = []
80
  byte_to_qwen = {}
81
- qwen_bytes_list = qwen_tokenizer.encode_to_bytes(text)
 
82
  byte_pos = 0
83
- for idx, token_bytes in enumerate(qwen_bytes_list):
84
  start = byte_pos
85
  end = byte_pos + len(token_bytes)
86
  try:
87
  token_str = bytes(token_bytes).decode("utf-8")
88
  except UnicodeDecodeError:
89
  token_str = repr(bytes(token_bytes))
90
- qwen_tokens.append((start, end, token_str))
91
  byte_to_qwen[start] = idx
92
  byte_pos = end
93
 
@@ -109,7 +110,7 @@ def get_token_info_for_text(text: str) -> dict:
109
  token_str = token_bytes.decode("utf-8")
110
  except UnicodeDecodeError:
111
  token_str = repr(token_bytes)
112
- rwkv_tokens.append((start, end, token_str))
113
  byte_to_rwkv[start] = idx
114
  byte_pos = end
115
 
@@ -249,9 +250,9 @@ def generate_comparison_html(
249
 
250
  def get_tokens_for_range(byte_start, byte_end, token_list):
251
  result = []
252
- for idx, (t_start, t_end, t_str) in enumerate(token_list):
253
  if t_start < byte_end and t_end > byte_start:
254
- result.append((idx, t_str))
255
  return result
256
 
257
  # Build tokens based on common boundaries
 
78
  # Get Qwen tokens with positions
79
  qwen_tokens = []
80
  byte_to_qwen = {}
81
+ # Keep both token id (vocab id) and decoded bytes so the tooltip can show true token ids.
82
+ qwen_id_and_bytes = qwen_tokenizer.encode_to_ids_and_bytes(text)
83
  byte_pos = 0
84
+ for idx, (token_id, token_bytes) in enumerate(qwen_id_and_bytes):
85
  start = byte_pos
86
  end = byte_pos + len(token_bytes)
87
  try:
88
  token_str = bytes(token_bytes).decode("utf-8")
89
  except UnicodeDecodeError:
90
  token_str = repr(bytes(token_bytes))
91
+ qwen_tokens.append((start, end, token_id, token_str))
92
  byte_to_qwen[start] = idx
93
  byte_pos = end
94
 
 
110
  token_str = token_bytes.decode("utf-8")
111
  except UnicodeDecodeError:
112
  token_str = repr(token_bytes)
113
+ rwkv_tokens.append((start, end, token_id, token_str))
114
  byte_to_rwkv[start] = idx
115
  byte_pos = end
116
 
 
250
 
251
  def get_tokens_for_range(byte_start, byte_end, token_list):
252
  result = []
253
+ for t_start, t_end, token_id, t_str in token_list:
254
  if t_start < byte_end and t_end > byte_start:
255
+ result.append((token_id, t_str))
256
  return result
257
 
258
  # Build tokens based on common boundaries