cmpatino HF Staff commited on
Commit
2597b43
·
1 Parent(s): 40112a6

Add tokenization details

Browse files
Files changed (1) hide show
  1. app.py +110 -39
app.py CHANGED
@@ -87,9 +87,24 @@ def build_alignment_groups_from_ids(student_tokenizer, teacher_tokenizer, studen
87
  return s_groups, t_groups
88
 
89
 
90
- def highlight_groups(student_tokenizer, student_token_ids, s_groups, t_groups):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  """Build an HTML string with highlighted misalignment regions."""
92
  parts = []
 
93
  for k in range(len(s_groups)):
94
  s_ids = [student_token_ids[idx] for idx in s_groups[k]]
95
  text = student_tokenizer.decode(s_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
@@ -99,64 +114,94 @@ def highlight_groups(student_tokenizer, student_token_ids, s_groups, t_groups):
99
  t_multi = len(t_groups[k]) > 1
100
 
101
  if s_multi and t_multi:
102
- parts.append(f'<span style="background-color: #b388ff;">{escaped}</span>')
 
 
 
 
 
 
 
103
  elif s_multi:
104
- parts.append(f'<span style="background-color: #ffcc80;">{escaped}</span>')
 
 
105
  elif t_multi:
106
- parts.append(f'<span style="background-color: #90caf9;">{escaped}</span>')
 
 
107
  else:
108
  parts.append(escaped)
109
 
110
  return "".join(parts)
111
 
112
 
113
- def process_texts(student_model_id, teacher_model_id, dataset_id, progress=gr.Progress()):
114
- """Load tokenizers and dataset, compute alignment, return highlighted HTML."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  progress(0, desc="Loading tokenizers...")
116
  student_tokenizer = AutoTokenizer.from_pretrained(student_model_id)
117
  teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_id)
118
 
119
- progress(0.1, desc="Loading dataset...")
120
- ds = load_dataset(dataset_id, split="train")
 
121
  rows = ds.select(range(min(10, len(ds))))
 
 
 
 
 
122
 
123
- def make_html_block(row, idx):
124
- text = "".join(msg["content"] for msg in row["messages"])
125
 
126
- s_ids = student_tokenizer.encode(text, add_special_tokens=False)
127
- t_ids = teacher_tokenizer.encode(text, add_special_tokens=False)
128
 
129
- s_groups, t_groups = build_alignment_groups_from_ids(
130
- student_tokenizer, teacher_tokenizer, s_ids, t_ids
131
- )
 
 
 
 
 
132
 
133
- highlighted = highlight_groups(student_tokenizer, s_ids, s_groups, t_groups)
134
- return {
135
- "html_block": (
136
- f'<div style="border:1px solid #ccc; padding:10px; margin:10px 0; '
137
- f'border-radius:5px; white-space:pre-wrap; font-family:monospace; font-size:13px;">'
138
- f"<strong>Text {idx + 1}</strong> "
139
- f"(student tokens: {len(s_ids)}, teacher tokens: {len(t_ids)})<br><br>"
140
- f"{highlighted}</div>"
141
- )
142
- }
143
 
144
- progress(0.2, desc="Processing texts...")
145
- rows = rows.map(make_html_block, num_proc=4, with_indices=True)
146
- html_blocks = rows["html_block"]
 
 
147
 
148
- progress(1, desc="Done!")
149
 
150
- legend = (
151
- '<div style="margin-bottom:15px; font-family:sans-serif;">'
152
- "<strong>Legend:</strong> "
153
- '<span style="background-color:#ffcc80; padding:2px 8px; margin-right:8px;">Student misalignment (orange)</span>'
154
- '<span style="background-color:#90caf9; padding:2px 8px; margin-right:8px;">Teacher misalignment (blue)</span>'
155
- '<span style="background-color:#b388ff; padding:2px 8px;">Both (purple)</span>'
156
- "</div>"
157
- )
158
 
159
- return legend + "\n".join(html_blocks)
 
 
 
 
 
160
 
161
 
162
  with gr.Blocks(title="Tokenization Diff") as demo:
@@ -166,11 +211,37 @@ with gr.Blocks(title="Tokenization Diff") as demo:
166
  student_model = gr.Textbox(label="Student Model", value="Qwen/Qwen3-8B")
167
  teacher_model = gr.Textbox(label="Teacher Model", value="deepseek-ai/DeepSeek-Math-V2")
168
  dataset_id = gr.Textbox(label="Dataset ID", value="lm-provers/FineProofs-SFT")
 
169
 
170
  submit_btn = gr.Button("Submit", variant="primary")
 
 
 
 
 
 
 
171
  output = gr.HTML(label="Tokenization Diff Output")
172
 
173
- submit_btn.click(fn=process_texts, inputs=[student_model, teacher_model, dataset_id], outputs=output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
  if __name__ == "__main__":
176
  demo.launch()
 
87
  return s_groups, t_groups
88
 
89
 
90
+ def _decode_pieces(tokenizer, token_ids, indices):
91
+ """Decode individual token pieces for a group of token indices."""
92
+ return [
93
+ tokenizer.decode([token_ids[idx]], skip_special_tokens=False, clean_up_tokenization_spaces=False)
94
+ for idx in indices
95
+ ]
96
+
97
+
98
+ def _format_pieces(pieces):
99
+ """Format token pieces as a list, e.g. '["hel", "lo"]'."""
100
+ inner = ", ".join(f'"{p}"' for p in pieces)
101
+ return f"[{inner}]"
102
+
103
+
104
+ def highlight_groups(student_tokenizer, teacher_tokenizer, student_token_ids, teacher_token_ids, s_groups, t_groups):
105
  """Build an HTML string with highlighted misalignment regions."""
106
  parts = []
107
+ first_purple = True
108
  for k in range(len(s_groups)):
109
  s_ids = [student_token_ids[idx] for idx in s_groups[k]]
110
  text = student_tokenizer.decode(s_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
 
114
  t_multi = len(t_groups[k]) > 1
115
 
116
  if s_multi and t_multi:
117
+ if first_purple:
118
+ s_pieces = _decode_pieces(student_tokenizer, student_token_ids, s_groups[k])
119
+ t_pieces = _decode_pieces(teacher_tokenizer, teacher_token_ids, t_groups[k])
120
+ tooltip = html.escape(f'Student: {_format_pieces(s_pieces)} / Teacher: {_format_pieces(t_pieces)}')
121
+ parts.append(f'<span style="background-color: #b388ff;" title="{tooltip}">{escaped}</span>')
122
+ first_purple = False
123
+ else:
124
+ parts.append(f'<span style="background-color: #b388ff;">{escaped}</span>')
125
  elif s_multi:
126
+ s_pieces = _decode_pieces(student_tokenizer, student_token_ids, s_groups[k])
127
+ tooltip = html.escape(f'Student: {_format_pieces(s_pieces)}')
128
+ parts.append(f'<span style="background-color: #ffcc80;" title="{tooltip}">{escaped}</span>')
129
  elif t_multi:
130
+ t_pieces = _decode_pieces(teacher_tokenizer, teacher_token_ids, t_groups[k])
131
+ tooltip = html.escape(f'Teacher: {_format_pieces(t_pieces)}')
132
+ parts.append(f'<span style="background-color: #90caf9;" title="{tooltip}">{escaped}</span>')
133
  else:
134
  parts.append(escaped)
135
 
136
  return "".join(parts)
137
 
138
 
139
+ def make_html_block(student_tokenizer, teacher_tokenizer, text, idx):
140
+ """Process a single text and return its highlighted HTML block."""
141
+ s_ids = student_tokenizer.encode(text, add_special_tokens=False)
142
+ t_ids = teacher_tokenizer.encode(text, add_special_tokens=False)
143
+
144
+ s_groups, t_groups = build_alignment_groups_from_ids(
145
+ student_tokenizer, teacher_tokenizer, s_ids, t_ids
146
+ )
147
+
148
+ highlighted = highlight_groups(student_tokenizer, teacher_tokenizer, s_ids, t_ids, s_groups, t_groups)
149
+ return (
150
+ f'<div style="border:1px solid #ccc; padding:10px; margin:10px 0; '
151
+ f'border-radius:5px; white-space:pre-wrap; font-family:monospace; font-size:13px;">'
152
+ f"<strong>Text {idx + 1}</strong> "
153
+ f"(student tokens: {len(s_ids)}, teacher tokens: {len(t_ids)})<br><br>"
154
+ f"{highlighted}</div>"
155
+ )
156
+
157
+
158
+ def process_texts(student_model_id, teacher_model_id, dataset_id, dataset_config, progress=gr.Progress()):
159
+ """Load tokenizers and dataset, compute first row only."""
160
  progress(0, desc="Loading tokenizers...")
161
  student_tokenizer = AutoTokenizer.from_pretrained(student_model_id)
162
  teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_id)
163
 
164
+ progress(0.5, desc="Loading dataset...")
165
+ config = dataset_config.strip() if dataset_config and dataset_config.strip() else None
166
+ ds = load_dataset(dataset_id, name=config, split="train")
167
  rows = ds.select(range(min(10, len(ds))))
168
+ texts = ["".join(msg["content"] for msg in row["messages"]) for row in rows]
169
+
170
+ progress(0.8, desc="Processing first text...")
171
+ first_block = make_html_block(student_tokenizer, teacher_tokenizer, texts[0], 0)
172
+ cache = {0: first_block}
173
 
174
+ progress(1, desc="Done!")
175
+ return student_tokenizer, teacher_tokenizer, texts, cache, 0, render_page(cache, 0, len(texts))
176
 
 
 
177
 
178
+ LEGEND = (
179
+ '<div style="margin-bottom:15px; font-family:sans-serif;">'
180
+ "<strong>Legend:</strong> "
181
+ '<span style="background-color:#ffcc80; padding:2px 8px; margin-right:8px;">Student token split (orange)</span>'
182
+ '<span style="background-color:#90caf9; padding:2px 8px; margin-right:8px;">Teacher token split (blue)</span>'
183
+ '<span style="background-color:#b388ff; padding:2px 8px;">Both (purple)</span>'
184
+ "</div>"
185
+ )
186
 
 
 
 
 
 
 
 
 
 
 
187
 
188
+ def render_page(cache, idx, total):
189
+ if not cache:
190
+ return ""
191
+ counter = f'<div style="font-family:sans-serif; margin-bottom:10px;">Text {idx + 1} of {total}</div>'
192
+ return LEGEND + counter + cache[idx]
193
 
 
194
 
195
+ def go_prev(cache, idx, texts):
196
+ idx = max(0, idx - 1)
197
+ return cache, idx, render_page(cache, idx, len(texts))
 
 
 
 
 
198
 
199
+
200
+ def go_next(student_tokenizer, teacher_tokenizer, texts, cache, idx):
201
+ idx = min(len(texts) - 1, idx + 1)
202
+ if idx not in cache:
203
+ cache[idx] = make_html_block(student_tokenizer, teacher_tokenizer, texts[idx], idx)
204
+ return cache, idx, render_page(cache, idx, len(texts))
205
 
206
 
207
  with gr.Blocks(title="Tokenization Diff") as demo:
 
211
  student_model = gr.Textbox(label="Student Model", value="Qwen/Qwen3-8B")
212
  teacher_model = gr.Textbox(label="Teacher Model", value="deepseek-ai/DeepSeek-Math-V2")
213
  dataset_id = gr.Textbox(label="Dataset ID", value="lm-provers/FineProofs-SFT")
214
+ dataset_config = gr.Textbox(label="Dataset Config", value="default")
215
 
216
  submit_btn = gr.Button("Submit", variant="primary")
217
+
218
+ student_tok_state = gr.State(None)
219
+ teacher_tok_state = gr.State(None)
220
+ texts_state = gr.State([])
221
+ cache_state = gr.State({})
222
+ idx_state = gr.State(0)
223
+
224
  output = gr.HTML(label="Tokenization Diff Output")
225
 
226
+ with gr.Row():
227
+ prev_btn = gr.Button("Previous")
228
+ next_btn = gr.Button("Next")
229
+
230
+ submit_btn.click(
231
+ fn=process_texts,
232
+ inputs=[student_model, teacher_model, dataset_id, dataset_config],
233
+ outputs=[student_tok_state, teacher_tok_state, texts_state, cache_state, idx_state, output],
234
+ )
235
+ prev_btn.click(
236
+ fn=go_prev,
237
+ inputs=[cache_state, idx_state, texts_state],
238
+ outputs=[cache_state, idx_state, output],
239
+ )
240
+ next_btn.click(
241
+ fn=go_next,
242
+ inputs=[student_tok_state, teacher_tok_state, texts_state, cache_state, idx_state],
243
+ outputs=[cache_state, idx_state, output],
244
+ )
245
 
246
  if __name__ == "__main__":
247
  demo.launch()