AveMujica commited on
Commit
59199b1
·
verified ·
1 Parent(s): 36f4962

Upload 6 files

Browse files
app.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import numpy.linalg as la
4
+ import pickle
5
+ import os
6
+ import gdown
7
+ from sentence_transformers import SentenceTransformer
8
+ import matplotlib.pyplot as plt
9
+ import math
10
+
11
+ def load_glove_embeddings(glove_path="Data/embeddings.pkl"):
12
+ with open(glove_path, "rb") as f:
13
+ embeddings_dict = pickle.load(f, encoding="latin1")
14
+
15
+ return embeddings_dict
16
+
17
+
18
+ def get_model_id_gdrive(model_type):
19
+ if model_type == "25d":
20
+ word_index_id = "13qMXs3-oB9C6kfSRMwbAtzda9xuAUtt8"
21
+ embeddings_id = "1-RXcfBvWyE-Av3ZHLcyJVsps0RYRRr_2"
22
+ elif model_type == "50d":
23
+ embeddings_id = "1DBaVpJsitQ1qxtUvV1Kz7ThDc3az16kZ"
24
+ word_index_id = "1rB4ksHyHZ9skes-fJHMa2Z8J1Qa7awQ9"
25
+ elif model_type == "100d":
26
+ word_index_id = "1-oWV0LqG3fmrozRZ7WB1jzeTJHRUI3mq"
27
+ embeddings_id = "1SRHfX130_6Znz7zbdfqboKosz-PfNvNp"
28
+
29
+ return word_index_id, embeddings_id
30
+
31
+
32
+ def download_glove_embeddings_gdrive(model_type):
33
+ # Get glove embeddings from Google Drive
34
+ word_index_id, embeddings_id = get_model_id_gdrive(model_type)
35
+
36
+ # Use gdown to download files from Google Drive
37
+ embeddings_temp = "embeddings_" + str(model_type) + "_temp.npy"
38
+ word_index_temp = "word_index_dict_" + str(model_type) + "_temp.pkl"
39
+
40
+ # Download word_index pickle file
41
+ print("Downloading word index dictionary....\n")
42
+ gdown.download(id=word_index_id, output=word_index_temp, quiet=False)
43
+
44
+ # Download embeddings numpy file
45
+ print("Downloading embeddings...\n\n")
46
+ gdown.download(id=embeddings_id, output=embeddings_temp, quiet=False)
47
+
48
+
49
+ # @st.cache_data()
50
+ def load_glove_embeddings_gdrive(model_type):
51
+ word_index_temp = "word_index_dict_" + str(model_type) + "_temp.pkl"
52
+ embeddings_temp = "embeddings_" + str(model_type) + "_temp.npy"
53
+
54
+ # Load word index dictionary
55
+ word_index_dict = pickle.load(open(word_index_temp, "rb"), encoding="latin")
56
+
57
+ # Load embeddings numpy array
58
+ embeddings = np.load(embeddings_temp)
59
+
60
+ return word_index_dict, embeddings
61
+
62
+
63
+ @st.cache_resource()
64
+ def load_sentence_transformer_model(model_name):
65
+ sentenceTransformer = SentenceTransformer(model_name)
66
+ return sentenceTransformer
67
+
68
+
69
+ def get_sentence_transformer_embeddings(sentence, model_name="all-MiniLM-L6-v2"):
70
+ """
71
+ Get sentence transformer embeddings for a sentence
72
+ """
73
+ # 384-dimensional embedding
74
+ # Default model: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
75
+
76
+ sentenceTransformer = load_sentence_transformer_model(model_name)
77
+
78
+ try:
79
+ return sentenceTransformer.encode(sentence)
80
+ except:
81
+ if model_name == "all-MiniLM-L6-v2":
82
+ return np.zeros(384)
83
+ else:
84
+ return np.zeros(512)
85
+
86
+
87
+ def get_glove_embeddings(word, word_index_dict, embeddings, model_type):
88
+ """
89
+ Get GloVe embedding for a single word
90
+ """
91
+ if word.lower() in word_index_dict:
92
+ return embeddings[word_index_dict[word.lower()]]
93
+ else:
94
+ return np.zeros(int(model_type.split("d")[0]))
95
+
96
+
97
+ def get_category_embeddings(embeddings_metadata):
98
+ """
99
+ Get embeddings for each category
100
+ 1. Split categories into words
101
+ 2. Get embeddings for each word
102
+ """
103
+ model_name = embeddings_metadata["model_name"]
104
+ st.session_state["cat_embed_" + model_name] = {}
105
+ for category in st.session_state.categories.split(" "):
106
+ if model_name:
107
+ if not category in st.session_state["cat_embed_" + model_name]:
108
+ st.session_state["cat_embed_" + model_name][category] = get_sentence_transformer_embeddings(category, model_name=model_name)
109
+ else:
110
+ if not category in st.session_state["cat_embed_" + model_name]:
111
+ st.session_state["cat_embed_" + model_name][category] = get_sentence_transformer_embeddings(category)
112
+
113
+
114
+ def update_category_embeddings(embeddings_metadata):
115
+ """
116
+ Update embeddings for each category
117
+ """
118
+ get_category_embeddings(embeddings_metadata)
119
+
120
+
121
+ ### Plotting utility functions
122
+
123
+ def plot_piechart(sorted_cosine_scores_items):
124
+ sorted_cosine_scores = np.array([
125
+ sorted_cosine_scores_items[index][1]
126
+ for index in range(len(sorted_cosine_scores_items))
127
+ ]
128
+ )
129
+ categories = st.session_state.categories.split(" ")
130
+ categories_sorted = [
131
+ categories[sorted_cosine_scores_items[index][0]]
132
+ for index in range(len(sorted_cosine_scores_items))
133
+ ]
134
+ fig, ax = plt.subplots()
135
+ ax.pie(sorted_cosine_scores, labels=categories_sorted, autopct="%1.1f%%")
136
+ st.pyplot(fig) # Display figure
137
+
138
+
139
+ def plot_piechart_helper(sorted_cosine_scores_items):
140
+ # 使用seaborn的pastel调色板
141
+ colors = plt.cm.Pastel1.colors
142
+ categories = st.session_state.categories.split(" ")
143
+
144
+ # 创建子图并设置大小
145
+ fig, ax = plt.subplots(figsize=(6, 6))
146
+
147
+ # 准备数据
148
+ labels = [categories[i] for i, _ in sorted_cosine_scores_items]
149
+ sizes = [score for _, score in sorted_cosine_scores_items]
150
+ explode = np.zeros(len(labels))
151
+ explode[0] = 0.1 # 突出显示最高分项
152
+
153
+ # 绘制饼图
154
+ wedges, texts, autotexts = ax.pie(
155
+ sizes,
156
+ explode=explode,
157
+ labels=labels,
158
+ colors=colors,
159
+ autopct=lambda p: f'{p:.1f}%',
160
+ startangle=90,
161
+ shadow=True, # 添加阴影
162
+ pctdistance=0.85, # 调整百分比位置
163
+ wedgeprops={'edgecolor': 'white', 'linewidth': 1}, # 添加白色边界
164
+ textprops={'fontsize': 10}
165
+ )
166
+
167
+ # 设置百分比文本样式
168
+ for autotext in autotexts:
169
+ autotext.set_color('black')
170
+ autotext.set_fontsize(10)
171
+ autotext.set_fontweight('bold')
172
+
173
+ # 添加中心空白实现类3D效果
174
+ centre_circle = plt.Circle((0,0),0.42,fc='white')
175
+ ax.add_artist(centre_circle)
176
+
177
+ # 设置标题
178
+ ax.set_title('Category Distribution', fontsize=14, pad=20)
179
+
180
+ # 保证圆形比例
181
+ ax.axis('equal')
182
+
183
+ return fig
184
+
185
+
186
+ def plot_piecharts(sorted_cosine_scores_models):
187
+ scores_list = []
188
+ categories = st.session_state.categories.split(" ")
189
+ index = 0
190
+ for model in sorted_cosine_scores_models:
191
+ scores_list.append(sorted_cosine_scores_models[model])
192
+ index += 1
193
+
194
+ if len(sorted_cosine_scores_models) == 2:
195
+ fig, (ax1, ax2) = plt.subplots(2)
196
+
197
+ categories_sorted = [
198
+ categories[scores_list[0][index][0]] for index in range(len(scores_list[0]))
199
+ ]
200
+ sorted_scores = np.array(
201
+ [scores_list[0][index][1] for index in range(len(scores_list[0]))]
202
+ )
203
+ ax1.pie(sorted_scores, labels=categories_sorted, autopct="%1.1f%%")
204
+
205
+ categories_sorted = [
206
+ categories[scores_list[1][index][0]] for index in range(len(scores_list[1]))
207
+ ]
208
+ sorted_scores = np.array(
209
+ [scores_list[1][index][1] for index in range(len(scores_list[1]))]
210
+ )
211
+ ax2.pie(sorted_scores, labels=categories_sorted, autopct="%1.1f%%")
212
+
213
+ st.pyplot(fig)
214
+
215
+
216
+ def plot_alatirchart(sorted_cosine_scores_models):
217
+ models = list(sorted_cosine_scores_models.keys())
218
+ tabs = st.tabs(models)
219
+ figs = {}
220
+ for model in models:
221
+ figs[model] = plot_piechart_helper(sorted_cosine_scores_models[model])
222
+
223
+ for index in range(len(tabs)):
224
+ with tabs[index]:
225
+ st.pyplot(figs[models[index]])
226
+
227
+ # Task I: Compute Cosine Similarity
228
+ def cosine_similarity(x, y):
229
+ """
230
+ Exponentiated cosine similarity
231
+ 1. Compute cosine similarity
232
+ 2. Exponentiate cosine similarity
233
+ 3. Return exponentiated cosine similarity
234
+ (20 pts)
235
+ """
236
+ dot_product = np.dot(x, y)
237
+ norm_x = la.norm(x)
238
+ norm_y = la.norm(y)
239
+ if norm_x == 0 or norm_y == 0:
240
+ return 0.0 # Handle zero vectors to avoid division by zero
241
+ cos_sim = dot_product / (norm_x * norm_y)
242
+ return np.exp(cos_sim)
243
+
244
+ # Task II: Average Glove Embedding Calculation
245
+ def averaged_glove_embeddings_gdrive(sentence, word_index_dict, embeddings, model_type):
246
+ """
247
+ Get averaged glove embeddings for a sentence
248
+ 1. Split sentence into words
249
+ 2. Get embeddings for each word
250
+ 3. Sum embeddings for each word
251
+ 4. Divide by number of words
252
+ 5. Return averaged embeddings
253
+ (30 pts)
254
+ """
255
+ model_dim = int(model_type.split('d')[0])
256
+ words = sentence.split()
257
+ avg_embedding = np.zeros(model_dim)
258
+ if not words:
259
+ return avg_embedding
260
+ for word in words:
261
+ word_embed = get_glove_embeddings(word, word_index_dict, embeddings, model_type)
262
+ avg_embedding += word_embed
263
+ avg_embedding /= len(words)
264
+ return avg_embedding
265
+
266
+ # Task III: Sort the cosine similarity
267
+ def get_sorted_cosine_similarity(embeddings_metadata):
268
+ """
269
+ Get sorted cosine similarity between input sentence and categories
270
+ Steps:
271
+ 1. Get embeddings for input sentence
272
+ 2. Get embeddings for categories (update if not found)
273
+ 3. Compute cosine similarity between input and categories
274
+ 4. Sort cosine similarities
275
+ 5. Return sorted cosine similarities
276
+ (50 pts)
277
+ """
278
+ categories = st.session_state.categories.split(" ")
279
+ cosine_sim = {}
280
+ if embeddings_metadata["embedding_model"] == "glove":
281
+ word_index_dict = embeddings_metadata["word_index_dict"]
282
+ embeddings = embeddings_metadata["embeddings"]
283
+ model_type = embeddings_metadata["model_type"]
284
+
285
+ input_embedding = averaged_glove_embeddings_gdrive(
286
+ st.session_state.text_search,
287
+ word_index_dict,
288
+ embeddings, model_type
289
+ )
290
+
291
+ # Compute cosine similarity for each category
292
+ for idx, category in enumerate(categories):
293
+ cat_embed = get_glove_embeddings(category, word_index_dict, embeddings, model_type)
294
+ sim = cosine_similarity(input_embedding, cat_embed)
295
+ cosine_sim[idx] = sim
296
+
297
+ else:
298
+ model_name = embeddings_metadata.get("model_name", "")
299
+ if f"cat_embed_{model_name}" not in st.session_state:
300
+ get_category_embeddings(embeddings_metadata)
301
+
302
+ category_embeddings = st.session_state[f"cat_embed_{model_name}"]
303
+
304
+ input_embedding = get_sentence_transformer_embeddings(
305
+ st.session_state.text_search, model_name=model_name
306
+ )
307
+
308
+ for idx, category in enumerate(categories):
309
+ if category not in category_embeddings:
310
+ # Update missing category embedding
311
+ category_embeddings[category] = get_sentence_transformer_embeddings(category, model_name=model_name)
312
+ cat_embed = category_embeddings[category]
313
+ sim = cosine_similarity(input_embedding, cat_embed)
314
+ cosine_sim[idx] = sim
315
+
316
+ # Sort scores in descending order
317
+ sorted_scores = sorted(cosine_sim.items(), key=lambda x: x[1], reverse=True)
318
+ return sorted_scores
319
+
320
+
321
+ if __name__ == "__main__":
322
+ # 侧边栏设置
323
+ st.sidebar.title("Model Configuration")
324
+ st.sidebar.markdown(
325
+ """
326
+ GloVe is an unsupervised learning algorithm for obtaining vector representations for words. Pretrained on
327
+ 2 billion tweets with vocabulary size of 1.2 million. Download from [Stanford NLP](http://nlp.stanford.edu/data/glove.twitter.27B.zip).
328
+
329
+ Jeffrey Pennington, Richard Socher, and Christopher D. Manning. 2014. *GloVe: Global Vectors for Word Representation*.
330
+ """
331
+ )
332
+ # Sentence Transformer模型选择
333
+ st_model = st.sidebar.selectbox(
334
+ "Sentence Transformer Model",
335
+ options=[
336
+ "all-MiniLM-L6-v2",
337
+ "all-mpnet-base-v2",
338
+ "multi-qa-mpnet-base-dot-v1",
339
+ "paraphrase-multilingual-mpnet-base-v2"
340
+ ],
341
+ index=0,
342
+ help="Select pretrained sentence transformer model"
343
+ )
344
+
345
+ # GloVe模型选择
346
+ model_type = st.sidebar.selectbox(
347
+ "GloVe Dimension",
348
+ ("25d", "50d", "100d"),
349
+ index=1,
350
+ help="Select dimension for GloVe embeddings"
351
+ )
352
+
353
+ # 主界面设置
354
+ st.title("Semantic Search Demo")
355
+
356
+ # 初始化session状态
357
+ if "categories" not in st.session_state:
358
+ st.session_state.categories = "Flowers Colors Cars Weather Food"
359
+ if "text_search" not in st.session_state:
360
+ st.session_state.text_search = "Roses are red, trucks are blue, and Seattle is grey right now"
361
+
362
+ # 输入组件
363
+ st.subheader("Categories (space-separated)")
364
+ st.text_input(
365
+ label="Categories",
366
+ key="categories",
367
+ value=st.session_state.categories
368
+ )
369
+
370
+ st.subheader("Input Sentence")
371
+ st.text_input(
372
+ label="Your input",
373
+ key="text_search",
374
+ value=st.session_state.text_search
375
+ )
376
+
377
+ # 下载GloVe嵌入
378
+ embeddings_path = f"embeddings_{model_type}_temp.npy"
379
+ word_index_dict_path = f"word_index_dict_{model_type}_temp.pkl"
380
+ if not os.path.isfile(embeddings_path) or not os.path.isfile(word_index_dict_path):
381
+ with st.spinner(f"Downloading GloVe-{model_type} embeddings..."):
382
+ download_glove_embeddings_gdrive(model_type)
383
+
384
+ # 加载嵌入模型
385
+ word_index_dict, embeddings = load_glove_embeddings_gdrive(model_type)
386
+
387
+ # 处理输入
388
+ if st.session_state.text_search.strip():
389
+ # GloVe处理流程
390
+ glove_metadata = {
391
+ "embedding_model": "glove",
392
+ "word_index_dict": word_index_dict,
393
+ "embeddings": embeddings,
394
+ "model_type": model_type,
395
+ }
396
+
397
+ # Transformer处理流程
398
+ transformer_metadata = {
399
+ "embedding_model": "transformers",
400
+ "model_name": st_model
401
+ }
402
+
403
+ # 并行处理
404
+ col1, col2 = st.columns(2)
405
+
406
+ with col1:
407
+ with st.spinner(f"Processing GloVe-{model_type}..."):
408
+ sorted_glove = get_sorted_cosine_similarity(glove_metadata)
409
+
410
+ with col2:
411
+ with st.spinner(f"Processing {st_model}..."):
412
+ sorted_transformer = get_sorted_cosine_similarity(transformer_metadata)
413
+
414
+ # 可视化结果
415
+ st.subheader(f"Results for: '{st.session_state.text_search}'")
416
+ plot_alatirchart({
417
+ f"Sentence Transformer ({st_model})": sorted_transformer,
418
+ f"GloVe-{model_type}": sorted_glove
419
+ })
420
+
421
+ # 开发者信息
422
+ st.markdown("---")
423
+ st.caption("Developed by [Xinghao Chen](https://www.linkedin.com/in/cxh42/) | "
424
+ "Model credits: [Sentence Transformers](https://www.sbert.net/) |"
425
+ "[GloVe](https://nlp.stanford.edu/projects/glove/)")
embeddings_25d_temp.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5eec0acf13b5c7d7c3bd178c1c84332347b9c0d55a474e37f4313e5289aacde3
3
+ size 238702880
embeddings_50d_temp.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e74f88cde3ff2e36c815d13955c67983cf6f81829d2582cb6789c10786e5ef66
3
+ size 477405680
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ streamlit
2
+ numpy
3
+ gdown
4
+ sentence_transformers
5
+ matplotlib
word_index_dict_25d_temp.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:674af352f703098ef122f6a8db7c5e08c5081829d49daea32e5aeac1fe582900
3
+ size 60284151
word_index_dict_50d_temp.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:674af352f703098ef122f6a8db7c5e08c5081829d49daea32e5aeac1fe582900
3
+ size 60284151