Shu-vi commited on
Commit
073c284
·
verified ·
1 Parent(s): 0f20446

Upload streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +437 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,439 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ # Запуск: streamlit run streamlit_app.py
 
 
2
  import streamlit as st
3
+ from gensim.models import Word2Vec, FastText, Doc2Vec
4
+ from gensim.utils import simple_preprocess
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+ from sklearn.decomposition import PCA
7
+ import umap
8
+ import os
9
+ import pandas as pd
10
+ import numpy as np
11
+ import networkx as nx
12
+ import plotly.graph_objs as go
13
+ import plotly.express as px
14
+
15
+ #Загрузка обученной модели
16
+ st.set_page_config(layout="wide", page_title="Исследование векторов")
17
+
18
+ st.title("Интерактивное изучение векторных представлений")
19
+
20
+ #sidebar: загрузка модели
21
+ st.sidebar.header("Выберите модель и затем загрузите обученную модель")
22
+ model_type = st.sidebar.selectbox("Тип модели", ["Word2Vec", "FastText", "Doc2Vec"])
23
+ model_file = st.sidebar.file_uploader("Загрузить обученную модель")
24
+
25
+
26
+ #инициализация/загрузка модели
27
+ model_w2v = None
28
+ model_fasttext = None
29
+ model_doc2vec = None
30
+ df_steps = None
31
+ if "df_steps" in st.session_state and st.session_state["df_steps"] is not None:
32
+ df_steps = st.session_state["df_steps"]
33
+ df_proj = None
34
+ if "df_proj" in st.session_state and st.session_state["df_proj"] is not None:
35
+ df_proj = st.session_state["df_proj"]
36
+ df = None
37
+ if "df" in st.session_state and st.session_state["df"] is not None:
38
+ df = st.session_state["df"]
39
+ if model_type == "Word2Vec":
40
+ if model_file and st.session_state.get("model_w2v") is None:
41
+ with open("temp_model.model", "wb") as f:
42
+ f.write(model_file.getbuffer())
43
+ model_w2v = Word2Vec.load("temp_model.model")
44
+ try:
45
+ os.remove("temp_model.model")
46
+ except OSError:
47
+ pass
48
+ st.session_state["model_w2v"] = model_w2v
49
+ else:
50
+ model_w2v = st.session_state.get("model_w2v")
51
+ elif model_type == "FastText":
52
+ if model_file and st.session_state.get("model_fasttext") is None:
53
+ with open("temp_model.model", "wb") as f:
54
+ f.write(model_file.getbuffer())
55
+ model_fasttext = FastText.load("temp_model.model")
56
+ try:
57
+ os.remove("temp_model.model")
58
+ except OSError:
59
+ pass
60
+ st.session_state["model_fasttext"] = model_fasttext
61
+ else:
62
+ model_fasttext = st.session_state.get("model_fasttext")
63
+ else:#Doc2Vec
64
+ if model_file and st.session_state.get("model_doc2vec") is None:
65
+ with open("temp_model.model", "wb") as f:
66
+ f.write(model_file.getbuffer())
67
+ model_fasttext = Doc2Vec.load("temp_model.model")
68
+ try:
69
+ os.remove("temp_model.model")
70
+ except OSError:
71
+ pass
72
+ st.session_state["model_doc2vec"] = model_doc2vec
73
+ else:
74
+ model_doc2vec = st.session_state.get("model_doc2vec")
75
+
76
+ #вспомогательные функции
77
+ def in_vocab(model, word):
78
+ """
79
+ проверка слова на наличие в словаре
80
+ """
81
+ if model is None:
82
+ return False
83
+ try:
84
+ return word in model.wv
85
+ except Exception:
86
+ return False
87
+
88
+ def most_similar(model, positive=None, negative=None, topn=10):
89
+ """
90
+ возвращает результат из выражения вида король - мужчина + женщина (= королева)
91
+ """
92
+ try:
93
+ return model.wv.most_similar(positive=positive or [], negative=negative or [], topn=topn)
94
+ except Exception as e:
95
+ return []
96
+
97
+ def build_html_report(title: str,
98
+ df_steps: pd.DataFrame | None = None,
99
+ df_proj: pd.DataFrame | None = None,
100
+ df_matrix: pd.DataFrame | None = None,
101
+ figs: list = None) -> str:
102
+ """
103
+ Формирует HTML отчёт: таблицы и графики.
104
+ """
105
+ figs = figs or []
106
+ html_parts = [f"<h1>{title}</h1>",
107
+ "<p>Отчёт сформирован автоматически из последних доступных данных.</p>"]
108
+
109
+ if df_steps is not None and not df_steps.empty:
110
+ html_parts.append("<h2>Промежуточные шаги выражения</h2>")
111
+ html_parts.append(df_steps.to_html(index=False))
112
+ else:
113
+ html_parts.append("<p><em>Нет данных о промежуточных шагах</em></p>")
114
+
115
+ if df_proj is not None and not df_proj.empty:
116
+ html_parts.append("<h2>Проекции слов на ось</h2>")
117
+ html_parts.append(df_proj.to_html(index=True))
118
+ else:
119
+ html_parts.append("<p><em>Нет данных о проекциях</em></p>")
120
+
121
+ if df_matrix is not None and not df_matrix.empty:
122
+ html_parts.append("<h2>Матрица сходств</h2>")
123
+ html_parts.append(df_matrix.to_html(index=True))
124
+ else:
125
+ html_parts.append("<p><em>Нет матрицы сходств</em></p>")
126
+
127
+ # вставляем графики Plotly: первый с include_plotlyjs="cdn"
128
+ for i, f in enumerate(figs):
129
+ html_parts.append(f"<h3>График {i+1}</h3>")
130
+ html_parts.append(f.to_html(full_html=False, include_plotlyjs=("cdn" if i == 0 else False)))
131
+
132
+ return "\n".join(html_parts)
133
+
134
+ def cosine_between_vecs(a, b):
135
+ """
136
+ угол косинуса между векторами
137
+ """
138
+ if a is None or b is None:
139
+ return None
140
+ val = cosine_similarity([a], [b])[0][0]
141
+ return float(val)
142
+
143
+ def infer_docvec(model, text):
144
+ """
145
+ возвращает вектор документа
146
+ """
147
+ if model is None:
148
+ return None
149
+ try:
150
+ return model.infer_vector(simple_preprocess(text))
151
+ except Exception:
152
+ return None
153
+
154
+ def word_vector(model, word):
155
+ """
156
+ возвращает вектор слова
157
+ """
158
+ try:
159
+ return model.wv[word]
160
+ except Exception:
161
+ return None
162
+
163
+ #UI: векторная арифметика
164
+ st.header("Интерактивная векторная арифметика")
165
+ col1, col2 = st.columns([2,1])
166
+
167
+ with col1:
168
+ expr = st.text_input("Введите выражение (пример: сша - трамп + путин)", value="сша - трамп + путин")
169
+ topn = st.number_input("Количество ближайших соседей (topn)", min_value=1, max_value=15, value=3)
170
+ run_expr = st.button("Вычислить выражение")
171
+
172
+ with col2:
173
+ st.write(f"Тип модели: {model_type}")
174
+
175
+ def parse_expression(expr_str):
176
+ """
177
+ парсинг выражений вида: w1 - w2 + w3 - w4
178
+ """
179
+ # Простая лексическая парсировка: слова и +/-
180
+ tokens = expr_str.replace("+", " + ").replace("-", " - ").split()
181
+ ops = []
182
+ current = None
183
+ # схема: первый токен может быть +/- или словом
184
+ sign = 1
185
+ vec_ops = []
186
+ for t in tokens:
187
+ if t == "+":
188
+ sign = 1
189
+ elif t == "-":
190
+ sign = -1
191
+ else:
192
+ vec_ops.append((t, sign))
193
+ sign = 1
194
+ return vec_ops
195
+
196
+ def compute_intermediate_vectors(model, expr_ops):
197
+ #статистика
198
+ intermediate = []
199
+ #результирующий вектор со всеми вычислениями, здесь будет храниться вычисления вида сша-трамп+путин
200
+ result = np.zeros(model.wv.vector_size)
201
+ for word, sign in expr_ops:
202
+ if not in_vocab(model, word):
203
+ intermediate.append({"word": word, "present": False, "vec": None, "result_after": None})
204
+ continue
205
+ vec = word_vector(model, word) * sign
206
+ result = result + vec
207
+ intermediate.append({"word": word, "present": True, "vec": vec.copy(), "result_after": result.copy()})
208
+ return intermediate, result
209
+
210
+ #подсчёт векторной арифметики
211
+ if run_expr:
212
+ #выбрать активную модель
213
+ active_model = model_w2v if model_type=="Word2Vec" else (model_fasttext if model_type=="FastText" else model_doc2vec)
214
+ if active_model is None:
215
+ st.error("Модель не загружена")
216
+ else:
217
+ ops = parse_expression(expr)
218
+ intermediate, final_vec = compute_intermediate_vectors(active_model, ops)
219
+
220
+ # показываем таблицу промежуточных шагов
221
+ rows = []
222
+ for i, s in enumerate(intermediate):
223
+ if not s["present"]:
224
+ rows.append({"шаг": i+1, "слово": s["word"], "в словаре": False, "наиболее похожие": None})
225
+ else:
226
+ ms = most_similar(active_model, positive=[s["vec"]], topn=topn)
227
+ rows.append({
228
+ "шаг": i+1,
229
+ "слово": s["word"],
230
+ "в словаре": True,
231
+ "наиболее похожие": ", ".join([f"{w} ({float(sim):.3f})" for w, sim in ms])
232
+ })
233
+ df_steps = pd.DataFrame(rows)
234
+ st.session_state["df_steps"] = df_steps
235
+ st.subheader("Промежуточные шаги")
236
+ st.dataframe(df_steps)
237
+
238
+ #ближайшие соседи для финального вектора
239
+ st.subheader("Результат выражения — ближайшие слова")
240
+ try:
241
+ final_neighbors = active_model.wv.similar_by_vector(final_vec, topn=topn)
242
+ except Exception:
243
+ final_neighbors = []
244
+ st.write(final_neighbors)
245
+
246
+ #визуализация финального вектора в 2D
247
+ st.subheader("2D проекция: промежуточные и итоговый векторы")
248
+ #соберём векторы для рисования: все оригинальные слов-векторов и результат
249
+ vis_vectors = []
250
+ vis_labels = []
251
+ for s in intermediate:
252
+ if s["present"]:
253
+ vis_vectors.append(s["vec"])
254
+ vis_labels.append(f"{s['word']} (шаг)")
255
+ vis_vectors.append(final_vec)
256
+ vis_labels.append("финальный вектор")
257
+ vis_vectors_np = np.array(vis_vectors)
258
+ reducer = UMAP_OR_PCA = None
259
+ try:
260
+ reducer = umap.UMAP(n_components=2, random_state=42)
261
+ proj = reducer.fit_transform(vis_vectors_np)
262
+ except Exception:
263
+ reducer = PCA(n_components=2)
264
+ proj = reducer.fit_transform(vis_vectors_np)
265
+ fig = px.scatter(x=proj[:,0], y=proj[:,1], text=vis_labels, title="2D проекция")
266
+ st.plotly_chart(fig, use_container_width=True)
267
+
268
+ #UI: косинусное расстояние и матрица сходств
269
+ st.header("Калькулятор косинусного сходства и матрица близостей")
270
+ col1, col2 = st.columns(2)
271
+ with col1:
272
+ word_a = st.text_input("Слово A", value="путин", key="cos_a")
273
+ word_b = st.text_input("Слово B", value="президент", key="cos_b")
274
+ calc_cos = st.button("Посчитать косинусное сходство")
275
+ with col2:
276
+ words_for_matrix = st.text_area("Список слов для матрицы (через запятую)", value="россия,трамп,китай,спорт")
277
+ calc_matrix = st.button("Построить матрицу сходств")
278
+
279
+ if calc_cos:
280
+ active_model = model_w2v if model_type=="Word2Vec" else (model_fasttext if model_type=="FastText" else model_doc2vec)
281
+ if active_model is None:
282
+ st.error("Модель не загружена")
283
+ else:
284
+ if in_vocab(active_model, word_a) and in_vocab(active_model, word_b):
285
+ va = word_vector(active_model, word_a)
286
+ vb = word_vector(active_model, word_b)
287
+ cosv = cosine_between_vecs(va, vb)
288
+ st.metric("Косинусное сходство", f"{cosv:.4f}")
289
+ else:
290
+ st.error("Одно из слов отсутствует в словаре модели")
291
+
292
+ if calc_matrix:
293
+ active_model = model_w2v if model_type=="Word2Vec" else (model_fasttext if model_type=="FastText" else model_doc2vec)
294
+ words = [w.strip() for w in words_for_matrix.split(",") if w.strip()]
295
+ present = [w for w in words if in_vocab(active_model, w)]
296
+ if not present:
297
+ st.error("Нет слов из списка в словаре модели")
298
+ else:
299
+ mat = np.array([word_vector(active_model, w) for w in present])
300
+ simm = cosine_similarity(mat)
301
+ df = pd.DataFrame(simm, index=present, columns=present)
302
+ st.session_state["df"] = df
303
+ st.subheader("Heatmap семантической близости")
304
+ fig = px.imshow(df.values, x=present, y=present, color_continuous_scale='RdBu_r', zmin=-1, zmax=1)
305
+ st.plotly_chart(fig, use_container_width=True)
306
+ st.dataframe(df.style.background_gradient(cmap='RdBu_r', axis=None))
307
+
308
+ #UI: семантическая ось и проекция
309
+ st.header("Семантические оси и проекция")
310
+ axis_left = st.text_input("Слово A (лево оси)", value="мужчина", key="axis_a")
311
+ axis_right = st.text_input("Слово B (право оси)", value="женщина", key="axis_b")
312
+ words_for_proj = st.text_area("Слова для проекции (через запятую)", value="король,королева,президент,работник,няня")
313
+ do_proj = st.button("Произвести проекцию на ось")
314
+
315
+ def project_on_axis(model, left, right, targets):
316
+ axis = word_vector(model, left) - word_vector(model, right)
317
+ scores = {}
318
+ for w in targets:
319
+ if in_vocab(model, w):
320
+ vec = word_vector(model, w)
321
+ #если score > 0 то относится к левому, иначе к правому
322
+ score = cosine_similarity([vec], [axis])[0][0]
323
+ scores[w] = float(score)
324
+ else:
325
+ scores[w] = None
326
+ return scores, axis
327
+
328
+ if do_proj:
329
+ active_model = model_w2v if model_type=="Word2Vec" else (model_fasttext if model_type=="FastText" else model_doc2vec)
330
+ targets = [w.strip() for w in words_for_proj.split(",") if w.strip()]
331
+ if not in_vocab(active_model, axis_left) or not in_vocab(active_model, axis_right):
332
+ st.error("Одна из опорных слов отсутствует в модели")
333
+ else:
334
+ scores, axis_vec = project_on_axis(active_model, axis_left, axis_right, targets)
335
+ df_proj = pd.DataFrame.from_dict(scores, orient='index', columns=['projection']).sort_values('projection', ascending=False)
336
+ st.session_state["df_proj"] = df_proj
337
+ st.dataframe(df_proj)
338
+ st.subheader("График проекций")
339
+ fig = px.bar(df_proj.reset_index().rename(columns={'index':'word'}), x='word', y='projection', color='projection', color_continuous_scale='RdBu')
340
+ st.plotly_chart(fig, use_container_width=True)
341
+
342
+ #UI: граф семантических связей
343
+ st.header("Граф семантических связей")
344
+ graph_seed = st.text_input("Слово (центр графа)", value="россия", key="graph_seed")
345
+ graph_depth = st.slider("Глубина (уровней соседей)", 1, 3, 2)
346
+ graph_topn = st.slider("TopN соседей на уровень", 1, 8, 5)
347
+
348
+ def build_similarity_graph(model, seed, depth=2, topn=5):
349
+ G = nx.Graph()
350
+ visited = set()
351
+ def expand(node, d):
352
+ if d>depth:
353
+ return
354
+ visited.add(node)
355
+ if not in_vocab(model, node):
356
+ return
357
+ try:
358
+ neighbors = model.wv.most_similar(node, topn=topn)
359
+ except Exception:
360
+ neighbors = []
361
+ for nb, sim in neighbors:
362
+ G.add_node(node)
363
+ G.add_node(nb)
364
+ G.add_edge(node, nb, weight=float(sim))
365
+ if nb not in visited:
366
+ expand(nb, d+1)
367
+ expand(seed, 1)
368
+ return G
369
+
370
+ if st.button("Построить граф"):
371
+ active_model = model_w2v if model_type=="Word2Vec" else (model_fasttext if model_type=="FastText" else model_doc2vec)
372
+ if not in_vocab(active_model, graph_seed):
373
+ st.error("Корневое слово отсутствует в модели")
374
+ else:
375
+ G = build_similarity_graph(active_model, graph_seed, depth=graph_depth, topn=graph_topn)
376
+ st.write(f"Узлы: {len(G.nodes())}, Рёбра: {len(G.edges())}")
377
+ #визуализация через plotly
378
+ pos = nx.spring_layout(G, seed=42)
379
+ edge_x = []
380
+ edge_y = []
381
+ for e in G.edges():
382
+ x0, y0 = pos[e[0]]
383
+ x1, y1 = pos[e[1]]
384
+ edge_x += [x0, x1, None]
385
+ edge_y += [y0, y1, None]
386
+ node_x = []
387
+ node_y = []
388
+ texts = []
389
+ for n in G.nodes():
390
+ x, y = pos[n]
391
+ node_x.append(x)
392
+ node_y.append(y)
393
+ texts.append(n)
394
+ edge_trace = go.Scatter(x=edge_x, y=edge_y, mode='lines', line=dict(width=0.5, color='#888'), hoverinfo='none')
395
+ node_trace = go.Scatter(
396
+ x=node_x, y=node_y, mode='markers+text', text=texts, textposition="top center",
397
+ hoverinfo='text', marker=dict(showscale=False, size=10, color='skyblue', line_width=2)
398
+ )
399
+ fig = go.Figure(data=[edge_trace, node_trace])
400
+ fig.update_layout(showlegend=False, margin=dict(b=20,l=5,r=5,t=40))
401
+ st.plotly_chart(fig, use_container_width=True)
402
+
403
+
404
+ #UI: генерация отчёта
405
+ st.header("Генерация отчёта")
406
+ report_title = st.text_input("Заголовок отчёта", value="Отчёт")
407
+ report_btn = st.button("Сгенерировать отчёт")
408
+
409
+
410
+ if report_btn:
411
+ try:
412
+ last_steps = df_steps
413
+ except Exception:
414
+ last_steps = pd.DataFrame()
415
+ try:
416
+ last_proj = df_proj
417
+ except Exception:
418
+ last_proj = pd.DataFrame()
419
+ try:
420
+ last_mat = df
421
+ except Exception:
422
+ last_mat = pd.DataFrame()
423
+
424
+ # добавляем последние графики, если есть
425
+ figs_to_add = []
426
+ if "fig" in globals() and fig is not None:
427
+ figs_to_add.append(fig)
428
+
429
+ html_report = build_html_report(report_title, last_steps, last_proj, last_mat, figs_to_add)
430
+
431
+ st.download_button(
432
+ label="Скачать HTML отчёт",
433
+ data=html_report.encode("utf-8"),
434
+ file_name="report.html",
435
+ mime="text/html",
436
+ )
437
+
438
 
439
+ st.sidebar.header("Для doc2vec только схожести предложений")