cmpatino HF Staff commited on
Commit
35f1842
·
1 Parent(s): f8df45d

Init version

Browse files
Files changed (2) hide show
  1. app.py +172 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import html
2
+
3
+ import gradio as gr
4
+ from datasets import load_dataset
5
+ from transformers import AutoTokenizer
6
+
7
+
8
+ def build_alignment_groups_from_ids(student_tokenizer, teacher_tokenizer, student_token_ids, teacher_token_ids):
9
+ """
10
+ Build alignment groups using a greedy substring-equality algorithm on decoded token pieces.
11
+ Adapted from TRL's GoldTrainer._build_alignment_groups_from_ids.
12
+ """
13
+
14
+ def to_canonical_pieces(tok, ids):
15
+ pieces = []
16
+ prev = ""
17
+ for k in range(len(ids)):
18
+ cur = tok.decode(ids[: k + 1], skip_special_tokens=False, clean_up_tokenization_spaces=False)
19
+ pieces.append(cur[len(prev):])
20
+ prev = cur
21
+ return pieces
22
+
23
+ s_pieces = to_canonical_pieces(student_tokenizer, student_token_ids)
24
+ t_pieces = to_canonical_pieces(teacher_tokenizer, teacher_token_ids)
25
+
26
+ i = j = 0
27
+ s_buf = t_buf = ""
28
+ s_group = []
29
+ t_group = []
30
+ s_groups = []
31
+ t_groups = []
32
+
33
+ def flush():
34
+ if s_group and t_group:
35
+ s_groups.append(s_group.copy())
36
+ t_groups.append(t_group.copy())
37
+
38
+ while i < len(s_pieces) or j < len(t_pieces):
39
+ if s_buf == t_buf and s_buf != "":
40
+ flush()
41
+ s_buf = t_buf = ""
42
+ s_group = []
43
+ t_group = []
44
+ continue
45
+
46
+ if s_buf == "" and i < len(s_pieces):
47
+ s_buf += s_pieces[i]
48
+ s_group.append(i)
49
+ i += 1
50
+ continue
51
+ if t_buf == "" and j < len(t_pieces):
52
+ t_buf += t_pieces[j]
53
+ t_group.append(j)
54
+ j += 1
55
+ continue
56
+
57
+ if len(s_buf) <= len(t_buf):
58
+ if i < len(s_pieces):
59
+ s_buf += s_pieces[i]
60
+ s_group.append(i)
61
+ i += 1
62
+ elif j < len(t_pieces):
63
+ t_buf += t_pieces[j]
64
+ t_group.append(j)
65
+ j += 1
66
+ else:
67
+ if j < len(t_pieces):
68
+ t_buf += t_pieces[j]
69
+ t_group.append(j)
70
+ j += 1
71
+ elif i < len(s_pieces):
72
+ s_buf += s_pieces[i]
73
+ s_group.append(i)
74
+ i += 1
75
+
76
+ if s_buf == t_buf and s_group and t_group:
77
+ flush()
78
+ elif s_group or t_group:
79
+ if not s_group:
80
+ s_group = []
81
+ if not t_group:
82
+ t_group = []
83
+ if s_group or t_group:
84
+ s_groups.append(s_group.copy() if s_group else [])
85
+ t_groups.append(t_group.copy() if t_group else [])
86
+
87
+ return s_groups, t_groups
88
+
89
+
90
+ def highlight_groups(student_tokenizer, student_token_ids, s_groups, t_groups):
91
+ """Build an HTML string with highlighted misalignment regions."""
92
+ parts = []
93
+ for k in range(len(s_groups)):
94
+ s_ids = [student_token_ids[idx] for idx in s_groups[k]]
95
+ text = student_tokenizer.decode(s_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
96
+ escaped = html.escape(text)
97
+
98
+ s_multi = len(s_groups[k]) > 1
99
+ t_multi = len(t_groups[k]) > 1
100
+
101
+ if s_multi and t_multi:
102
+ parts.append(f'<span style="background-color: #b388ff;">{escaped}</span>')
103
+ elif s_multi:
104
+ parts.append(f'<span style="background-color: #ffcc80;">{escaped}</span>')
105
+ elif t_multi:
106
+ parts.append(f'<span style="background-color: #90caf9;">{escaped}</span>')
107
+ else:
108
+ parts.append(escaped)
109
+
110
+ return "".join(parts)
111
+
112
+
113
+ def process_texts(student_model_id, teacher_model_id, dataset_id, progress=gr.Progress()):
114
+ """Load tokenizers and dataset, compute alignment, return highlighted HTML."""
115
+ progress(0, desc="Loading tokenizers...")
116
+ student_tokenizer = AutoTokenizer.from_pretrained(student_model_id)
117
+ teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_id)
118
+
119
+ progress(0.1, desc="Loading dataset...")
120
+ ds = load_dataset(dataset_id, split="train")
121
+ rows = ds.select(range(min(10, len(ds))))
122
+
123
+ html_blocks = []
124
+ for row_idx, row in enumerate(rows):
125
+ progress((row_idx + 1) / 12, desc=f"Processing text {row_idx + 1}/10...")
126
+ text = "".join(msg["content"] for msg in row["messages"])
127
+
128
+ s_ids = student_tokenizer.encode(text, add_special_tokens=False)
129
+ t_ids = teacher_tokenizer.encode(text, add_special_tokens=False)
130
+
131
+ s_groups, t_groups = build_alignment_groups_from_ids(
132
+ student_tokenizer, teacher_tokenizer, s_ids, t_ids
133
+ )
134
+
135
+ highlighted = highlight_groups(student_tokenizer, s_ids, s_groups, t_groups)
136
+ html_blocks.append(
137
+ f'<div style="border:1px solid #ccc; padding:10px; margin:10px 0; '
138
+ f'border-radius:5px; white-space:pre-wrap; font-family:monospace; font-size:13px;">'
139
+ f"<strong>Text {row_idx + 1}</strong> "
140
+ f"(student tokens: {len(s_ids)}, teacher tokens: {len(t_ids)})<br><br>"
141
+ f"{highlighted}</div>"
142
+ )
143
+
144
+ progress(1, desc="Done!")
145
+
146
+ legend = (
147
+ '<div style="margin-bottom:15px; font-family:sans-serif;">'
148
+ "<strong>Legend:</strong> "
149
+ '<span style="background-color:#ffcc80; padding:2px 8px; margin-right:8px;">Student misalignment (orange)</span>'
150
+ '<span style="background-color:#90caf9; padding:2px 8px; margin-right:8px;">Teacher misalignment (blue)</span>'
151
+ '<span style="background-color:#b388ff; padding:2px 8px;">Both (purple)</span>'
152
+ "</div>"
153
+ )
154
+
155
+ return legend + "\n".join(html_blocks)
156
+
157
+
158
+ with gr.Blocks(title="Tokenization Diff") as demo:
159
+ gr.Markdown("# Tokenization Diff\nVisualize where two tokenizers differ in how they tokenize text.")
160
+
161
+ with gr.Row():
162
+ student_model = gr.Textbox(label="Student Model", value="Qwen/Qwen3-8B")
163
+ teacher_model = gr.Textbox(label="Teacher Model", value="deepseek-ai/DeepSeek-Math-V2")
164
+ dataset_id = gr.Textbox(label="Dataset ID", value="lm-provers/FineProofs-SFT")
165
+
166
+ submit_btn = gr.Button("Submit", variant="primary")
167
+ output = gr.HTML(label="Tokenization Diff Output")
168
+
169
+ submit_btn.click(fn=process_texts, inputs=[student_model, teacher_model, dataset_id], outputs=output)
170
+
171
+ if __name__ == "__main__":
172
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio>=6.6.0
2
+ transformers
3
+ datasets