theformatisvalid commited on
Commit
34f32f7
·
verified ·
1 Parent(s): 0463151

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +294 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,296 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ import os
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ from collections import Counter
7
+ import json
8
+ from io import StringIO, BytesIO
9
+ import tempfile
10
+ import re
11
+ import base64
12
+
13
+ from tokenizers_trainer import train_bpe, train_wordpiece, train_unigram
14
+ from tokenizers_analysis import calculate_oov
15
+
16
+ st.set_page_config(page_title='Tokenizer Explorer', layout="wide")
17
+
18
+ st.title('Tokenizer Explorer')
19
+
20
+ UPLOAD_DIR = 'uploads'
21
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
22
+
23
+ SAMPLE_DATA_PATH = 'core/united_core.txt'
24
+
25
+ st.sidebar.header('Data loading')
26
+
27
+ data_source = st.sidebar.radio('Data source', ['Upload your file', 'Use example'])
28
+
29
+ text_lines = []
30
+
31
+ if data_source == 'Upload your file':
32
+ uploaded_file = st.sidebar.file_uploader('Upload file (.txt)', type=['txt'])
33
+ if uploaded_file is not None:
34
+ content = uploaded_file.read().decode('utf-8')
35
+ text_lines = [line.strip() for line in content.splitlines() if line.strip()]
36
+ st.session_state['raw_text'] = content
37
+ else:
38
+ st.info('Upload your file.')
39
+ else:
40
+ if os.path.exists(SAMPLE_DATA_PATH):
41
+ with open(SAMPLE_DATA_PATH, encoding='utf-8') as f:
42
+ content = f.read()
43
+ text_lines = [line.strip() for line in content.splitlines() if line.strip()]
44
+ st.session_state['raw_text'] = content
45
+ st.sidebar.success(f'Example uploaded: {SAMPLE_DATA_PATH}')
46
+ else:
47
+ st.error(f'Example file not found: {SAMPLE_DATA_PATH}')
48
+
49
+ if not text_lines:
50
+ st.stop()
51
+
52
+ st.sidebar.header('Settings')
53
+
54
+ vocab_size = st.sidebar.slider('Vocabulary size', 5000, 50000, 20000, step=1000)
55
+ min_freq = st.sidebar.slider('Minimal token frequency', 1, 10, 2)
56
+ model_type = st.sidebar.selectbox('Tokenizer', ['BPE', 'WordPiece', 'Unigram'])
57
+
58
+ normalize_text = st.sidebar.checkbox('Normalize text', True)
59
+
60
+ def normalize(line):
61
+ if normalize_text:
62
+ line = line.lower()
63
+ line = re.sub(r'[^\w\s]', '', line)
64
+ return line.strip()
65
+
66
+ texts = [normalize(line) for line in text_lines if normalize(line)]
67
+ if not texts:
68
+ st.error('Text is empty after normalization.')
69
+ st.stop()
70
+
71
+ corpus_path = os.path.join(UPLOAD_DIR, 'temp_corpus.txt')
72
+ with open(corpus_path, 'w', encoding='utf-8') as f:
73
+ f.write("\n".join(texts))
74
+
75
+ st.sidebar.header('Training')
76
+
77
+ if st.sidebar.button('Train tokenizer'):
78
+ with st.spinner('training...'):
79
+ try:
80
+ if model_type == 'BPE':
81
+ st.session_state['tokenizer'] = train_bpe(vocab_size, min_freq, corpus_path)
82
+ st.session_state['model_name'] = 'BPE'
83
+
84
+ elif model_type == 'WordPiece':
85
+ st.session_state['tokenizer'] = train_wordpiece(vocab_size, min_freq, corpus_path)
86
+ st.session_state['model_name'] = 'WordPiece'
87
+
88
+ elif model_type == 'Unigram':
89
+ st.session_state['tokenizer'] = train_unigram(vocab_size, min_freq, corpus_path)
90
+ st.session_state['model_name'] = 'Unigram'
91
+
92
+ st.sidebar.success('Training complete')
93
+ except Exception as e:
94
+ st.sidebar.error(f'Training error: {e}')
95
+
96
+ if 'tokenizer' not in st.session_state:
97
+ st.info('Setup parameters and press "Train tokenizer" on left panel')
98
+ st.stop()
99
+
100
+ tokenizer = st.session_state['tokenizer']
101
+ model_name = st.session_state['model_name']
102
+
103
+ def tokenize_text(text):
104
+ if model_name in ['BPE', 'WordPiece']:
105
+ return tokenizer.encode(text).tokens
106
+ else:
107
+ return tokenizer.encode_as_pieces(text)
108
+
109
+
110
+ def get_vocabulary(tokenizer):
111
+ if hasattr(tokenizer, 'get_vocab'):
112
+ return tokenizer.get_vocab()
113
+ else:
114
+ size = tokenizer.get_piece_size()
115
+ return {tokenizer.id_to_piece(i): i for i in range(size)}
116
+
117
+
118
+ all_tokens = []
119
+ for line in texts[:1000]:
120
+ tokens = tokenize_text(line)
121
+ all_tokens.extend(tokens)
122
+
123
+ token_counter = Counter(all_tokens)
124
+ df_tokens = pd.DataFrame(token_counter.items(), columns=['token', 'frequency']).sort_values('frequency', ascending=False)
125
+
126
+ st.header(f'Report: {model_name} (Vocab={vocab_size}, MinFreq={min_freq})')
127
+
128
+ col1, col2 = st.columns(2)
129
+
130
+ with col1:
131
+ st.subheader('Token length distribution')
132
+ token_lengths = [len(t) for t in all_tokens]
133
+ fig1, ax1 = plt.subplots()
134
+ sns.histplot(token_lengths, bins=30, kde=True, ax=ax1)
135
+ ax1.set_xlabel('Token length, chars')
136
+ ax1.set_ylabel('Frequency')
137
+ st.pyplot(fig1)
138
+
139
+ with col2:
140
+ st.subheader('Most frequent tokens')
141
+ top20 = df_tokens.head(20)
142
+ fig2, ax2 = plt.subplots(figsize=(8, 6))
143
+ sns.barplot(data=top20, x='frequency', y='token', ax=ax2)
144
+ ax2.set_xlabel('Frequency')
145
+ ax2.set_ylabel('Token')
146
+ st.pyplot(fig2)
147
+
148
+ st.subheader('Out-of-Vocabulary percentage')
149
+ oov_rate = calculate_oov(' '.join(texts), get_vocabulary(tokenizer))
150
+ st.metric(label='', value=f'{oov_rate:.2%}')
151
+
152
+ st.sidebar.header('Export')
153
+ if st.sidebar.button('Export as HTML'):
154
+
155
+ def fig_to_base64(fig):
156
+ buf = BytesIO()
157
+ fig.savefig(buf, format='png', dpi=150, bbox_inches='tight')
158
+ buf.seek(0)
159
+ img_str = base64.b64encode(buf.read()).decode()
160
+ buf.close()
161
+ return f'<img src="data:image/png;base64,{img_str}" style="max-width:100%;">'
162
+
163
+ token_lengths = [len(t) for t in all_tokens]
164
+ fig1, ax1 = plt.subplots(figsize=(6, 4))
165
+ sns.histplot(token_lengths, bins=30, kde=True, ax=ax1)
166
+ ax1.set_xlabel('Token length, chars')
167
+ ax1.set_ylabel('Frequency')
168
+ ax1.set_title('Token Length Distribution')
169
+ chart1_html = fig_to_base64(fig1)
170
+ plt.close(fig1)
171
+
172
+ top20 = df_tokens.head(20)
173
+ fig2, ax2 = plt.subplots(figsize=(8, 6))
174
+ sns.barplot(data=top20, x='frequency', y='token', ax=ax2)
175
+ ax2.set_xlabel('Frequency')
176
+ ax2.set_ylabel('Token')
177
+ ax2.set_title('Top 20 Most Frequent Tokens')
178
+ chart2_html = fig_to_base64(fig2)
179
+ plt.close(fig2)
180
+
181
+ oov_rate = calculate_oov(' '.join(texts), get_vocabulary(tokenizer))
182
+
183
+ report_html = f'''
184
+ <html>
185
+ <head>
186
+ <meta charset="UTF-8">
187
+ <title>Tokenizer Report: {model_name}</title>
188
+ <style>
189
+ body {{
190
+ font-family: Arial, sans-serif;
191
+ margin: 40px;
192
+ line-height: 1.6;
193
+ color: #333;
194
+ }}
195
+ h1, h2, h3 {{
196
+ color: #2c3e50;
197
+ }}
198
+ table {{
199
+ border-collapse: collapse;
200
+ width: 100%;
201
+ margin: 20px 0;
202
+ }}
203
+ table th, table td {{
204
+ border: 1px solid #bdc3c7;
205
+ padding: 8px;
206
+ text-align: left;
207
+ }}
208
+ table th {{
209
+ background-color: #ecf0f1;
210
+ }}
211
+ .chart {{
212
+ margin: 30px 0;
213
+ }}
214
+ .info-box {{
215
+ background-color: #f9f9f9;
216
+ border-left: 4px solid #3498db;
217
+ padding: 15px;
218
+ margin: 20px 0;
219
+ }}
220
+ footer {{
221
+ margin-top: 50px;
222
+ font-size: 0.9em;
223
+ color: #7f8c8d;
224
+ text-align: center;
225
+ }}
226
+ </style>
227
+ </head>
228
+ <body>
229
+ <h1>Tokenizer Report: {model_name}</h1>
230
+ <p><strong>Generated on:</strong> {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
231
+
232
+ <h2>Model Parameters</h2>
233
+ <ul>
234
+ <li><strong>Vocabulary size:</strong> {vocab_size}</li>
235
+ <li><strong>Minimum frequency:</strong> {min_freq}</li>
236
+ <li><strong>Normalization:</strong> {'Yes' if normalize_text else 'No'}</li>
237
+ <li><strong>Total tokens processed:</strong> {len(all_tokens)}</li>
238
+ <li><strong>Unique tokens found:</strong> {len(token_counter)}</li>
239
+ <li><strong>Out-of-Vocabulary rate:</strong> {oov_rate:.2%}</li>
240
+ </ul>
241
+
242
+ <h2>Token Length Distribution</h2>
243
+ <div class="chart">
244
+ {chart1_html}
245
+ </div>
246
+
247
+ <h2>Most Frequent Tokens (Top 20)</h2>
248
+ <div class="chart">
249
+ {chart2_html}
250
+ </div>
251
+
252
+ <h2>Top 10 Most Frequent Tokens</h2>
253
+ <table>
254
+ <tr><th>Token</th><th>Frequency</th></tr>
255
+ '''
256
+
257
+ for _, row in df_tokens.head(10).iterrows():
258
+ report_html += f'<tr><td>{row["token"]}</td><td>{row["frequency"]:,}</td></tr>'
259
+ report_html += '</table>'
260
+
261
+ report_html += '''
262
+ <h2>Dictionary (First 100 Tokens)</h2>
263
+ <table>
264
+ <tr><th>Rank</th><th>Token</th><th>Frequency</th></tr>
265
+ '''
266
+ for i, (_, row) in enumerate(df_tokens.head(100).iterrows()):
267
+ report_html += f'<tr><td>{i+1}</td><td>{row["token"]}</td><td>{row["frequency"]:,}</td></tr>'
268
+ report_html += '''
269
+ </table>
270
+ </body>
271
+ </html>
272
+ '''
273
+
274
+ html_path = os.path.join(UPLOAD_DIR, 'tokenizer_report.html')
275
+ with open(html_path, 'w', encoding='utf-8') as f:
276
+ f.write(report_html)
277
+
278
+ with open(html_path, encoding='utf-8') as f:
279
+ st.sidebar.download_button(
280
+ 'Download Full Report',
281
+ f.read(),
282
+ file_name='tokenizer_report.html',
283
+ mime='text/html'
284
+ )
285
+
286
+ with st.expander('View dictionary'):
287
+ st.dataframe(df_tokens.head(100))
288
 
289
+ with st.expander('Info'):
290
+ st.write(f'- Method: {model_name}')
291
+ st.write(f'- Vocabulary size: {vocab_size}')
292
+ st.write(f'- Min. frequency: {min_freq}')
293
+ st.write(f'- Normalization: {"Yes" if normalize_text else "No"}')
294
+ st.write(f'- Unique tokens: {len(token_counter)}')
295
+ st.write(f'- Total tokens: {len(all_tokens)}')
296
+