Spaces:
Sleeping
Sleeping
Add tokenization details
Browse files
app.py
CHANGED
|
@@ -87,9 +87,24 @@ def build_alignment_groups_from_ids(student_tokenizer, teacher_tokenizer, studen
|
|
| 87 |
return s_groups, t_groups
|
| 88 |
|
| 89 |
|
| 90 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
"""Build an HTML string with highlighted misalignment regions."""
|
| 92 |
parts = []
|
|
|
|
| 93 |
for k in range(len(s_groups)):
|
| 94 |
s_ids = [student_token_ids[idx] for idx in s_groups[k]]
|
| 95 |
text = student_tokenizer.decode(s_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
|
|
@@ -99,64 +114,94 @@ def highlight_groups(student_tokenizer, student_token_ids, s_groups, t_groups):
|
|
| 99 |
t_multi = len(t_groups[k]) > 1
|
| 100 |
|
| 101 |
if s_multi and t_multi:
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
elif s_multi:
|
| 104 |
-
|
|
|
|
|
|
|
| 105 |
elif t_multi:
|
| 106 |
-
|
|
|
|
|
|
|
| 107 |
else:
|
| 108 |
parts.append(escaped)
|
| 109 |
|
| 110 |
return "".join(parts)
|
| 111 |
|
| 112 |
|
| 113 |
-
def
|
| 114 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
progress(0, desc="Loading tokenizers...")
|
| 116 |
student_tokenizer = AutoTokenizer.from_pretrained(student_model_id)
|
| 117 |
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_id)
|
| 118 |
|
| 119 |
-
progress(0.
|
| 120 |
-
|
|
|
|
| 121 |
rows = ds.select(range(min(10, len(ds))))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
-
|
| 124 |
-
|
| 125 |
|
| 126 |
-
s_ids = student_tokenizer.encode(text, add_special_tokens=False)
|
| 127 |
-
t_ids = teacher_tokenizer.encode(text, add_special_tokens=False)
|
| 128 |
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
-
highlighted = highlight_groups(student_tokenizer, s_ids, s_groups, t_groups)
|
| 134 |
-
return {
|
| 135 |
-
"html_block": (
|
| 136 |
-
f'<div style="border:1px solid #ccc; padding:10px; margin:10px 0; '
|
| 137 |
-
f'border-radius:5px; white-space:pre-wrap; font-family:monospace; font-size:13px;">'
|
| 138 |
-
f"<strong>Text {idx + 1}</strong> "
|
| 139 |
-
f"(student tokens: {len(s_ids)}, teacher tokens: {len(t_ids)})<br><br>"
|
| 140 |
-
f"{highlighted}</div>"
|
| 141 |
-
)
|
| 142 |
-
}
|
| 143 |
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
| 147 |
|
| 148 |
-
progress(1, desc="Done!")
|
| 149 |
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
'<span style="background-color:#ffcc80; padding:2px 8px; margin-right:8px;">Student misalignment (orange)</span>'
|
| 154 |
-
'<span style="background-color:#90caf9; padding:2px 8px; margin-right:8px;">Teacher misalignment (blue)</span>'
|
| 155 |
-
'<span style="background-color:#b388ff; padding:2px 8px;">Both (purple)</span>'
|
| 156 |
-
"</div>"
|
| 157 |
-
)
|
| 158 |
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
|
| 162 |
with gr.Blocks(title="Tokenization Diff") as demo:
|
|
@@ -166,11 +211,37 @@ with gr.Blocks(title="Tokenization Diff") as demo:
|
|
| 166 |
student_model = gr.Textbox(label="Student Model", value="Qwen/Qwen3-8B")
|
| 167 |
teacher_model = gr.Textbox(label="Teacher Model", value="deepseek-ai/DeepSeek-Math-V2")
|
| 168 |
dataset_id = gr.Textbox(label="Dataset ID", value="lm-provers/FineProofs-SFT")
|
|
|
|
| 169 |
|
| 170 |
submit_btn = gr.Button("Submit", variant="primary")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
output = gr.HTML(label="Tokenization Diff Output")
|
| 172 |
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
if __name__ == "__main__":
|
| 176 |
demo.launch()
|
|
|
|
| 87 |
return s_groups, t_groups
|
| 88 |
|
| 89 |
|
| 90 |
+
def _decode_pieces(tokenizer, token_ids, indices):
|
| 91 |
+
"""Decode individual token pieces for a group of token indices."""
|
| 92 |
+
return [
|
| 93 |
+
tokenizer.decode([token_ids[idx]], skip_special_tokens=False, clean_up_tokenization_spaces=False)
|
| 94 |
+
for idx in indices
|
| 95 |
+
]
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _format_pieces(pieces):
|
| 99 |
+
"""Format token pieces as a list, e.g. '["hel", "lo"]'."""
|
| 100 |
+
inner = ", ".join(f'"{p}"' for p in pieces)
|
| 101 |
+
return f"[{inner}]"
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def highlight_groups(student_tokenizer, teacher_tokenizer, student_token_ids, teacher_token_ids, s_groups, t_groups):
|
| 105 |
"""Build an HTML string with highlighted misalignment regions."""
|
| 106 |
parts = []
|
| 107 |
+
first_purple = True
|
| 108 |
for k in range(len(s_groups)):
|
| 109 |
s_ids = [student_token_ids[idx] for idx in s_groups[k]]
|
| 110 |
text = student_tokenizer.decode(s_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
|
|
|
|
| 114 |
t_multi = len(t_groups[k]) > 1
|
| 115 |
|
| 116 |
if s_multi and t_multi:
|
| 117 |
+
if first_purple:
|
| 118 |
+
s_pieces = _decode_pieces(student_tokenizer, student_token_ids, s_groups[k])
|
| 119 |
+
t_pieces = _decode_pieces(teacher_tokenizer, teacher_token_ids, t_groups[k])
|
| 120 |
+
tooltip = html.escape(f'Student: {_format_pieces(s_pieces)} / Teacher: {_format_pieces(t_pieces)}')
|
| 121 |
+
parts.append(f'<span style="background-color: #b388ff;" title="{tooltip}">{escaped}</span>')
|
| 122 |
+
first_purple = False
|
| 123 |
+
else:
|
| 124 |
+
parts.append(f'<span style="background-color: #b388ff;">{escaped}</span>')
|
| 125 |
elif s_multi:
|
| 126 |
+
s_pieces = _decode_pieces(student_tokenizer, student_token_ids, s_groups[k])
|
| 127 |
+
tooltip = html.escape(f'Student: {_format_pieces(s_pieces)}')
|
| 128 |
+
parts.append(f'<span style="background-color: #ffcc80;" title="{tooltip}">{escaped}</span>')
|
| 129 |
elif t_multi:
|
| 130 |
+
t_pieces = _decode_pieces(teacher_tokenizer, teacher_token_ids, t_groups[k])
|
| 131 |
+
tooltip = html.escape(f'Teacher: {_format_pieces(t_pieces)}')
|
| 132 |
+
parts.append(f'<span style="background-color: #90caf9;" title="{tooltip}">{escaped}</span>')
|
| 133 |
else:
|
| 134 |
parts.append(escaped)
|
| 135 |
|
| 136 |
return "".join(parts)
|
| 137 |
|
| 138 |
|
| 139 |
+
def make_html_block(student_tokenizer, teacher_tokenizer, text, idx):
|
| 140 |
+
"""Process a single text and return its highlighted HTML block."""
|
| 141 |
+
s_ids = student_tokenizer.encode(text, add_special_tokens=False)
|
| 142 |
+
t_ids = teacher_tokenizer.encode(text, add_special_tokens=False)
|
| 143 |
+
|
| 144 |
+
s_groups, t_groups = build_alignment_groups_from_ids(
|
| 145 |
+
student_tokenizer, teacher_tokenizer, s_ids, t_ids
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
highlighted = highlight_groups(student_tokenizer, teacher_tokenizer, s_ids, t_ids, s_groups, t_groups)
|
| 149 |
+
return (
|
| 150 |
+
f'<div style="border:1px solid #ccc; padding:10px; margin:10px 0; '
|
| 151 |
+
f'border-radius:5px; white-space:pre-wrap; font-family:monospace; font-size:13px;">'
|
| 152 |
+
f"<strong>Text {idx + 1}</strong> "
|
| 153 |
+
f"(student tokens: {len(s_ids)}, teacher tokens: {len(t_ids)})<br><br>"
|
| 154 |
+
f"{highlighted}</div>"
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def process_texts(student_model_id, teacher_model_id, dataset_id, dataset_config, progress=gr.Progress()):
|
| 159 |
+
"""Load tokenizers and dataset, compute first row only."""
|
| 160 |
progress(0, desc="Loading tokenizers...")
|
| 161 |
student_tokenizer = AutoTokenizer.from_pretrained(student_model_id)
|
| 162 |
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_id)
|
| 163 |
|
| 164 |
+
progress(0.5, desc="Loading dataset...")
|
| 165 |
+
config = dataset_config.strip() if dataset_config and dataset_config.strip() else None
|
| 166 |
+
ds = load_dataset(dataset_id, name=config, split="train")
|
| 167 |
rows = ds.select(range(min(10, len(ds))))
|
| 168 |
+
texts = ["".join(msg["content"] for msg in row["messages"]) for row in rows]
|
| 169 |
+
|
| 170 |
+
progress(0.8, desc="Processing first text...")
|
| 171 |
+
first_block = make_html_block(student_tokenizer, teacher_tokenizer, texts[0], 0)
|
| 172 |
+
cache = {0: first_block}
|
| 173 |
|
| 174 |
+
progress(1, desc="Done!")
|
| 175 |
+
return student_tokenizer, teacher_tokenizer, texts, cache, 0, render_page(cache, 0, len(texts))
|
| 176 |
|
|
|
|
|
|
|
| 177 |
|
| 178 |
+
LEGEND = (
|
| 179 |
+
'<div style="margin-bottom:15px; font-family:sans-serif;">'
|
| 180 |
+
"<strong>Legend:</strong> "
|
| 181 |
+
'<span style="background-color:#ffcc80; padding:2px 8px; margin-right:8px;">Student token split (orange)</span>'
|
| 182 |
+
'<span style="background-color:#90caf9; padding:2px 8px; margin-right:8px;">Teacher token split (blue)</span>'
|
| 183 |
+
'<span style="background-color:#b388ff; padding:2px 8px;">Both (purple)</span>'
|
| 184 |
+
"</div>"
|
| 185 |
+
)
|
| 186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
+
def render_page(cache, idx, total):
|
| 189 |
+
if not cache:
|
| 190 |
+
return ""
|
| 191 |
+
counter = f'<div style="font-family:sans-serif; margin-bottom:10px;">Text {idx + 1} of {total}</div>'
|
| 192 |
+
return LEGEND + counter + cache[idx]
|
| 193 |
|
|
|
|
| 194 |
|
| 195 |
+
def go_prev(cache, idx, texts):
|
| 196 |
+
idx = max(0, idx - 1)
|
| 197 |
+
return cache, idx, render_page(cache, idx, len(texts))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
+
|
| 200 |
+
def go_next(student_tokenizer, teacher_tokenizer, texts, cache, idx):
|
| 201 |
+
idx = min(len(texts) - 1, idx + 1)
|
| 202 |
+
if idx not in cache:
|
| 203 |
+
cache[idx] = make_html_block(student_tokenizer, teacher_tokenizer, texts[idx], idx)
|
| 204 |
+
return cache, idx, render_page(cache, idx, len(texts))
|
| 205 |
|
| 206 |
|
| 207 |
with gr.Blocks(title="Tokenization Diff") as demo:
|
|
|
|
| 211 |
student_model = gr.Textbox(label="Student Model", value="Qwen/Qwen3-8B")
|
| 212 |
teacher_model = gr.Textbox(label="Teacher Model", value="deepseek-ai/DeepSeek-Math-V2")
|
| 213 |
dataset_id = gr.Textbox(label="Dataset ID", value="lm-provers/FineProofs-SFT")
|
| 214 |
+
dataset_config = gr.Textbox(label="Dataset Config", value="default")
|
| 215 |
|
| 216 |
submit_btn = gr.Button("Submit", variant="primary")
|
| 217 |
+
|
| 218 |
+
student_tok_state = gr.State(None)
|
| 219 |
+
teacher_tok_state = gr.State(None)
|
| 220 |
+
texts_state = gr.State([])
|
| 221 |
+
cache_state = gr.State({})
|
| 222 |
+
idx_state = gr.State(0)
|
| 223 |
+
|
| 224 |
output = gr.HTML(label="Tokenization Diff Output")
|
| 225 |
|
| 226 |
+
with gr.Row():
|
| 227 |
+
prev_btn = gr.Button("Previous")
|
| 228 |
+
next_btn = gr.Button("Next")
|
| 229 |
+
|
| 230 |
+
submit_btn.click(
|
| 231 |
+
fn=process_texts,
|
| 232 |
+
inputs=[student_model, teacher_model, dataset_id, dataset_config],
|
| 233 |
+
outputs=[student_tok_state, teacher_tok_state, texts_state, cache_state, idx_state, output],
|
| 234 |
+
)
|
| 235 |
+
prev_btn.click(
|
| 236 |
+
fn=go_prev,
|
| 237 |
+
inputs=[cache_state, idx_state, texts_state],
|
| 238 |
+
outputs=[cache_state, idx_state, output],
|
| 239 |
+
)
|
| 240 |
+
next_btn.click(
|
| 241 |
+
fn=go_next,
|
| 242 |
+
inputs=[student_tok_state, teacher_tok_state, texts_state, cache_state, idx_state],
|
| 243 |
+
outputs=[cache_state, idx_state, output],
|
| 244 |
+
)
|
| 245 |
|
| 246 |
if __name__ == "__main__":
|
| 247 |
demo.launch()
|