theformatisvalid commited on
Commit
0b9dd2a
·
verified ·
1 Parent(s): 084e1b8

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +449 -37
src/streamlit_app.py CHANGED
@@ -1,40 +1,452 @@
1
- import altair as alt
 
2
  import numpy as np
3
  import pandas as pd
4
- import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import fasttext
2
+ import streamlit as st
3
  import numpy as np
4
  import pandas as pd
5
+ from gensim.models import Word2Vec
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+ import plotly.express as px
8
+ import plotly.graph_objects as go
9
+ from collections import Counter
10
+ import os
11
+ import glob
12
+
13
+
14
+ class UnifiedVectorModel:
15
+ def __init__(self, backend_model, model_type="w2v"):
16
+ self.model = backend_model
17
+ self.model_type = model_type.lower()
18
+
19
+ if self.model_type == "w2v":
20
+ self.wv = backend_model.wv
21
+ self.key_to_index = self.wv.key_to_index
22
+ self.vector_size = self.wv.vector_size
23
+ self._words = set(self.wv.key_to_index.keys())
24
+
25
+ elif self.model_type == "ft":
26
+ # Для fasttext-wheel
27
+ self.key_to_index = {word: i for i, word in enumerate(backend_model.get_words())}
28
+ self.vector_size = backend_model.get_dimension()
29
+ self._words = set(self.key_to_index.keys())
30
+ else:
31
+ raise ValueError("model_type must be 'w2v' or 'ft'")
32
+
33
+ def __contains__(self, word):
34
+ return word in self._words
35
+
36
+ def __getitem__(self, word):
37
+ if self.model_type == "w2v":
38
+ return self.wv[word]
39
+ elif self.model_type == "ft":
40
+ return self.model.get_word_vector(word)
41
+
42
+ def most_similar(self, positive=None, negative=None, topn=10):
43
+ from sklearn.metrics.pairwise import cosine_similarity
44
+
45
+ if not positive:
46
+ positive = []
47
+ if not negative:
48
+ negative = []
49
+
50
+ try:
51
+ if self.model_type == "w2v":
52
+ return self.wv.most_similar(positive=positive, negative=negative, topn=topn)
53
+
54
+ elif self.model_type == "ft":
55
+ vec = np.zeros(self.vector_size)
56
+ for w in positive:
57
+ if w in self:
58
+ vec += self[w]
59
+ else:
60
+ continue
61
+ for w in negative:
62
+ if w in self:
63
+ vec -= self[w]
64
+ else:
65
+ continue
66
+
67
+ if np.allclose(vec, 0):
68
+ return []
69
+
70
+ words = list(self._words)
71
+ vectors = np.array([self[w] for w in words])
72
+
73
+ sims = cosine_similarity([vec], vectors)[0]
74
+ best = np.argsort(sims)[::-1][:topn + len(positive) + len(negative)]
75
+
76
+ result = []
77
+ for i in best:
78
+ word = words[i]
79
+ if word not in positive and word not in negative:
80
+ result.append((word, float(sims[i])))
81
+ if len(result) >= topn:
82
+ break
83
+ return result
84
+
85
+ except Exception as e:
86
+ print(f"Error in most_similar: {e}")
87
+ return []
88
+
89
+ def similar_by_vector(self, vector, topn=10):
90
+ from sklearn.metrics.pairwise import cosine_similarity
91
+
92
+ words = list(self._words)
93
+ vectors = np.array([self[w] for w in words])
94
+ sims = cosine_similarity([vector], vectors)[0]
95
+ best = np.argsort(sims)[::-1][:topn]
96
+
97
+ return [(words[i], float(sims[i])) for i in best]
98
+
99
+ def get_words(self):
100
+ return list(self._words)
101
+
102
+ @property
103
+ def vectors(self):
104
+ if not hasattr(self, '_cached_vectors'):
105
+ words = list(self._words)
106
+ self._cached_words = words
107
+ self._cached_vectors = np.array([self[w] for w in words])
108
+ return self._cached_vectors
109
+
110
+ @property
111
+ def index_to_key(self):
112
+ if not hasattr(self, '_index_to_key'):
113
+ self._index_to_key = list(self._words)
114
+ return self._index_to_key
115
+
116
+
117
+ @st.cache_resource
118
+ def load_model(model_path):
119
+ try:
120
+ if model_path.endswith(".model"):
121
+ raw_model = Word2Vec.load(model_path)
122
+ current_model = UnifiedVectorModel(raw_model, model_type="w2v")
123
+
124
+ elif model_path.endswith(".bin"):
125
+ raw_model = fasttext.load_model(model_path)
126
+ current_model = UnifiedVectorModel(raw_model, model_type="ft")
127
+ else:
128
+ raise ValueError(f"wrong path format")
129
+ return current_model
130
+ except Exception as e:
131
+ st.error(f"error loading model {model_path}: {e}")
132
+ return None
133
+
134
+
135
+ MODELS_DIR = "models"
136
+
137
+ if not os.path.exists(MODELS_DIR):
138
+ st.error(f"Folder `{MODELS_DIR}` not found.")
139
+ st.stop()
140
+
141
+ model_files = []
142
+ for ext in ["*.bin", "*.model", "*.vec"]:
143
+ model_files.extend(glob.glob(os.path.join(MODELS_DIR, ext)))
144
+ model_files = [f for f in model_files if os.path.isfile(f)]
145
+ model_names = [os.path.basename(f) for f in model_files]
146
+
147
+ if len(model_names) == 0:
148
+ st.error(f"No models in folder `{MODELS_DIR}` (.bin, .model, .vec).")
149
+ st.info("Supported formats: Word2Vec (binary/text), FastText.")
150
+ st.stop()
151
+
152
+ selected_model_name = st.sidebar.selectbox(
153
+ "Choose pretrained model",
154
+ model_names
155
+ )
156
+
157
+ selected_model_path = os.path.join(MODELS_DIR, selected_model_name)
158
+
159
+ st.sidebar.info(f"loading: `{selected_model_name}`")
160
+
161
+ model = load_model(selected_model_path)
162
+
163
+ if model is None:
164
+ st.stop()
165
+ else:
166
+ st.sidebar.success(f"Model '{selected_model_name}' loaded")
167
+ st.sidebar.write(f"Voc size: {len(model.key_to_index):,}")
168
+ st.sidebar.write(f"Vector size: {model.vector_size}")
169
+
170
+ def analogy_accuracy(model, file_name):
171
+ right = 0
172
+ count = 0
173
+ results = []
174
+ with open(file_name, encoding='utf-8') as file:
175
+ for line in file:
176
+ words = line.strip().split()
177
+ if len(words) != 4:
178
+ continue
179
+ try:
180
+ most_similar = model.most_similar(positive=[words[0], words[2]], negative=[words[1]], topn=10)
181
+ predicted = [x[0] for x in most_similar]
182
+ correct = words[3]
183
+ if correct in predicted:
184
+ rank = predicted.index(correct) + 1
185
+ right += 1
186
+ else:
187
+ rank = None
188
+ count += 1
189
+ results.append({
190
+ "query": f"{words[0]} - {words[1]} + {words[2]}",
191
+ "target": correct,
192
+ "predicted": predicted[0],
193
+ "rank": rank,
194
+ "in_top10": bool(rank)
195
+ })
196
+ except KeyError as e:
197
+ continue
198
+ accuracy = right / count if count > 0 else 0
199
+ return accuracy, results
200
+
201
+
202
+ def avg_similarity(model, file_name):
203
+ res = []
204
+ with open(file_name, encoding='utf-8') as file:
205
+ for line in file:
206
+ words = line.strip().split()
207
+ try:
208
+ vectors = [model[word] for word in words]
209
+ except KeyError:
210
+ continue
211
+ sims = cosine_similarity(vectors)
212
+ for i in range(len(words) - 1):
213
+ for j in range(i + 1, len(words)):
214
+ res.append(sims[i][j])
215
+ return sum(res) / len(res) if res else 0
216
+
217
+
218
+ def projection(word_vec, axis):
219
+ axis_norm = axis / np.linalg.norm(axis)
220
+ return np.dot(word_vec, axis_norm)
221
+
222
+
223
+ def get_projection_row(model, axis):
224
+ words = list(model.key_to_index.keys())
225
+ projections = [(word, projection(model[word], axis)) for word in words]
226
+ projections = sorted(projections, key=lambda x: x[1])
227
+ return projections
228
+
229
+
230
+ st.title("Vector embeddings")
231
+
232
+ tab1, tab2, tab3, tab4, tab5 = st.tabs([
233
+ "Vector ariphmetics",
234
+ "Semantic consistency",
235
+ "Semantic axis",
236
+ "Distribution analysis",
237
+ "Report"
238
+ ])
239
+
240
+ with tab1:
241
+ st.header("Vector ariphmetics")
242
+ expr = st.text_input("Insert expression", value="рубль - россия + сша")
243
+
244
+ if st.button("Compute"):
245
+ words = expr.replace('+', ' + ').replace('-', ' - ').split()
246
+ positive, negative = [], []
247
+ current = 'pos'
248
+
249
+ for w in words:
250
+ if w == '+':
251
+ current = 'pos'
252
+ elif w == '-':
253
+ current = 'neg'
254
+ else:
255
+ (positive if current == 'pos' else negative).append(w)
256
+
257
+ missing = [w for w in positive + negative if w not in model]
258
+ if missing:
259
+ st.warning(f"Words not found in voc: {', '.join(missing)}")
260
+ st.stop()
261
+
262
+ try:
263
+ similar = model.most_similar(
264
+ positive=positive,
265
+ negative=negative,
266
+ topn=10
267
+ )
268
+
269
+ st.write("### Result:")
270
+ result_words = [f"{w} ({s:.3f})" for w, s in similar]
271
+ st.write("Nearest words: " + ", ".join(result_words))
272
+
273
+ st.write("### In-between steps")
274
+
275
+ cum_vec = np.zeros(model.vector_size)
276
+
277
+ steps_data = []
278
+
279
+ for i in range(len(positive)):
280
+ cum_vec += model[w]
281
+ nearest = model.most_similar(positive=positive[:i + 1], topn=1)
282
+ steps_data.append({
283
+ "step": f"+ {positive[i]}",
284
+ "nearest word": nearest[0][0],
285
+ "similarity": nearest[0][1]
286
+ })
287
+
288
+ for i in range(len(negative)):
289
+ cum_vec -= model[w]
290
+ nearest = model.most_similar(positive=positive, negative=negative[:i + 1], topn=1)
291
+ steps_data.append({
292
+ "step": f"- {negative[i]}",
293
+ "nearest word": nearest[0][0],
294
+ "similarity": nearest[0][1]
295
+ })
296
+
297
+ df_steps = pd.DataFrame(steps_data)
298
+ st.dataframe(df_steps[["step", "nearest word", "similarity"]])
299
+
300
+ result_word = similar[0][0]
301
+ fig = px.scatter(
302
+ x=[cum_vec[0]], y=[cum_vec[1]],
303
+ text=[result_word],
304
+ title="Result (first 2 components)"
305
+ )
306
+ fig.update_traces(textposition='top center', marker=dict(size=12, color='red'))
307
+ st.plotly_chart(fig)
308
+
309
+ except Exception as e:
310
+ st.error(f"Error computing: {e}")
311
+
312
+ with tab2:
313
+ st.header("Similarity calculator")
314
+ col1, col2 = st.columns(2)
315
+ with col1:
316
+ word1 = st.text_input("word 1", value="мужчина")
317
+ with col2:
318
+ word2 = st.text_input("word 2", value="женщина")
319
+
320
+ if st.button("Compute similarity"):
321
+ try:
322
+ v1, v2 = model[word1], model[word2]
323
+ sim = cosine_similarity([v1], [v2])[0][0]
324
+ st.metric("Cosine similarity", f"{sim:.4f}")
325
+
326
+ st.write("### Nearest neighbors graph")
327
+ neighbors = model.most_similar(word1, topn=5) + model.most_similar(word2, topn=5)
328
+ nodes = list(set([word1, word2] + [n[0] for n in neighbors]))
329
+ edges = [(word1, n[0]) for n in model.most_similar(word1, topn=5)] + \
330
+ [(word2, n[0]) for n in model.most_similar(word2, topn=5)]
331
+
332
+ G = go.Figure()
333
+ pos = np.random.rand(len(nodes), 2) * 2 - 1
334
+ node_x = pos[:, 0]
335
+ node_y = pos[:, 1]
336
+
337
+ for edge in edges:
338
+ x0, y0 = pos[nodes.index(edge[0])]
339
+ x1, y1 = pos[nodes.index(edge[1])]
340
+ G.add_trace(go.Scatter(x=[x0, x1], y=[y0, y1], mode='lines', line=dict(width=1, color='gray'), showlegend=False))
341
+
342
+ G.add_trace(go.Scatter(x=node_x, y=node_y, mode='text+markers',
343
+ marker=dict(size=10, color='lightblue'),
344
+ text=nodes, textposition="top center"))
345
+ G.update_layout(title="Semantic links graph", showlegend=False)
346
+ st.plotly_chart(G)
347
+
348
+ except KeyError as e:
349
+ st.error(f"Word not found: {e}")
350
+
351
+ with tab3:
352
+ st.header("Semantic axis projection")
353
+ col1, col2 = st.columns(2)
354
+ with col1:
355
+ pos_axis = st.text_input("positive", value="мужчина")
356
+ with col2:
357
+ neg_axis = st.text_input("negative", value="женщина")
358
+
359
+ if st.button("Build axis"):
360
+ try:
361
+ pos_vec = model[pos_axis]
362
+ neg_vec = model[neg_axis]
363
+ axis = pos_vec - neg_vec
364
+
365
+ projections = get_projection_row(model, axis)
366
+ top_pos = projections[-10:][::-1]
367
+ top_neg = projections[:10]
368
+
369
+ st.write(f"Axis: **{pos_axis} – {neg_axis}**")
370
+ st.write("### Top 10 positive:")
371
+ st.write(", ".join([f"{w} ({p:.3f})" for w, p in top_pos]))
372
+
373
+ st.write("### Top 10 negative:")
374
+ st.write(", ".join([f"{w} ({p:.3f})" for w, p in top_neg]))
375
+
376
+ df_proj = pd.DataFrame(top_pos + top_neg, columns=["word", "projection"])
377
+ fig = px.bar(df_proj, x="projection", y="word", orientation='h', title=f"Projection on axis: {pos_axis}–{neg_axis}")
378
+ st.plotly_chart(fig)
379
+
380
+ except KeyError as e:
381
+ st.error(f"Error: {e}")
382
+
383
+ with tab4:
384
+ st.header("Distance distribution analysis")
385
+ all_vectors = model.vectors
386
+ sample = all_vectors[np.random.choice(all_vectors.shape[0], 1000, replace=False)]
387
+
388
+ dists = cosine_similarity(sample)
389
+ np.fill_diagonal(dists, 0)
390
+ flat_dists = dists.flatten()
391
+ flat_dists = flat_dists[flat_dists > 0]
392
+
393
+ fig = px.histogram(flat_dists, nbins=50, title="Cosine similarity distribution between random words")
394
+ st.plotly_chart(fig)
395
+
396
+ st.metric("Mean similarity", f"{np.mean(flat_dists):.3f}")
397
+ st.metric("Std deviation", f"{np.std(flat_dists):.3f}")
398
+
399
+ with tab5:
400
+ st.header("Report")
401
+
402
+ st.subheader("1. Analogy rate")
403
+ analogies_file = "data/analogy.txt"
404
+ if os.path.exists(analogies_file):
405
+ acc, results = analogy_accuracy(model, analogies_file)
406
+ st.metric("Analogy accuracy (in top 10)", f"{acc:.2%}")
407
+ st.dataframe(pd.DataFrame(results))
408
+ else:
409
+ st.warning("File `analogy.txt` not found.")
410
+
411
+ st.subheader("2. Average synonyms similarity")
412
+ sim_file = "data/synonyms.txt"
413
+ if os.path.exists(sim_file):
414
+ avg_sim = avg_similarity(model, sim_file)
415
+ st.metric("Average similarity", f"{avg_sim:.4f}")
416
+ else:
417
+ st.warning("File `similarity_words.txt` not found.")
418
+
419
+ st.subheader("3. Average antonyms similarity")
420
+ sim_file = "data/antonyms.txt"
421
+ if os.path.exists(sim_file):
422
+ avg_sim = avg_similarity(model, sim_file)
423
+ st.metric("Average similarity", f"{avg_sim:.4f}")
424
+ else:
425
+ st.warning("File `similarity_words.txt` not found.")
426
+
427
+ st.subheader("4. Heatmap for nearest words")
428
+ query_words = st.text_input("Enter words", value="мужчина женщина мальчик девочка").split()
429
+ if st.button("Build heatmap"):
430
+ try:
431
+ vectors = [model[w] for w in query_words]
432
+ sims = cosine_similarity(vectors)
433
+ fig = px.imshow(sims, x=query_words, y=query_words, color_continuous_scale="Blues", title="Similarity heatmap")
434
+ st.plotly_chart(fig)
435
+ except KeyError as e:
436
+ st.error(f"Error: {e}")
437
+
438
+ st.subheader("5. 2D projection")
439
+ sample_words = st.text_input("Input words", value="мужчина женщина мальчик девочка")
440
+ word_list = sample_words.split()
441
+ if st.button("Show clusters"):
442
+ try:
443
+ from sklearn.manifold import TSNE
444
+ vectors = np.array([model[w] for w in word_list])
445
+ tsne = TSNE(n_components=2, perplexity=len(vectors) - 1, random_state=42)
446
+ embedded = tsne.fit_transform(vectors)
447
 
448
+ fig = px.scatter(x=embedded[:, 0], y=embedded[:, 1], text=word_list, title="words projection")
449
+ fig.update_traces(textposition='top center')
450
+ st.plotly_chart(fig)
451
+ except KeyError as e:
452
+ st.error(f"Word not found: {e}")