yuanjunchai
commited on
Commit
·
2596ac1
1
Parent(s):
9d0e21e
add application files
Browse files
app.py
CHANGED
|
@@ -285,7 +285,7 @@ def averaged_glove_embeddings_gdrive(sentence, word_index_dict, embeddings, mode
|
|
| 285 |
|
| 286 |
|
| 287 |
# Task III: Sort the cosine similarity
|
| 288 |
-
def get_sorted_cosine_similarity(
|
| 289 |
"""
|
| 290 |
Get sorted cosine similarity between input sentence and categories
|
| 291 |
Steps:
|
|
@@ -296,14 +296,14 @@ def get_sorted_cosine_similarity(st_session_state, embeddings_metadata):
|
|
| 296 |
5. Return sorted cosine similarity
|
| 297 |
(50 pts)
|
| 298 |
"""
|
| 299 |
-
categories =
|
| 300 |
cosine_sim = {}
|
| 301 |
if embeddings_metadata["embedding_model"] == "glove":
|
| 302 |
word_index_dict = embeddings_metadata["word_index_dict"]
|
| 303 |
embeddings = embeddings_metadata["embeddings"]
|
| 304 |
model_type = embeddings_metadata["model_type"]
|
| 305 |
-
print(f'Debug: {
|
| 306 |
-
input_embedding = averaged_glove_embeddings_gdrive(
|
| 307 |
word_index_dict,
|
| 308 |
embeddings, model_type)
|
| 309 |
|
|
@@ -315,16 +315,16 @@ def get_sorted_cosine_similarity(st_session_state, embeddings_metadata):
|
|
| 315 |
|
| 316 |
else:
|
| 317 |
model_name = embeddings_metadata["model_name"]
|
| 318 |
-
if not "cat_embed_" + model_name in
|
| 319 |
get_category_embeddings(embeddings_metadata)
|
| 320 |
|
| 321 |
-
category_embeddings =
|
| 322 |
|
| 323 |
-
print("text_search = ",
|
| 324 |
if model_name:
|
| 325 |
-
input_embedding = get_sentence_transformer_embeddings(
|
| 326 |
else:
|
| 327 |
-
input_embedding = get_sentence_transformer_embeddings(
|
| 328 |
for index in range(len(categories)):
|
| 329 |
category = categories[index]
|
| 330 |
|
|
@@ -332,7 +332,7 @@ def get_sorted_cosine_similarity(st_session_state, embeddings_metadata):
|
|
| 332 |
category_embedding = category_embeddings[category]
|
| 333 |
else:
|
| 334 |
category_embedding = get_sentence_transformer_embeddings(category, model_name=model_name)
|
| 335 |
-
|
| 336 |
|
| 337 |
cosine_sim[index] = cosine_similarity(input_embedding, category_embedding)
|
| 338 |
|
|
|
|
| 285 |
|
| 286 |
|
| 287 |
# Task III: Sort the cosine similarity
|
| 288 |
+
def get_sorted_cosine_similarity(st.session_state.text_search, embeddings_metadata):
|
| 289 |
"""
|
| 290 |
Get sorted cosine similarity between input sentence and categories
|
| 291 |
Steps:
|
|
|
|
| 296 |
5. Return sorted cosine similarity
|
| 297 |
(50 pts)
|
| 298 |
"""
|
| 299 |
+
categories = st.session_state.categories.split(" ")
|
| 300 |
cosine_sim = {}
|
| 301 |
if embeddings_metadata["embedding_model"] == "glove":
|
| 302 |
word_index_dict = embeddings_metadata["word_index_dict"]
|
| 303 |
embeddings = embeddings_metadata["embeddings"]
|
| 304 |
model_type = embeddings_metadata["model_type"]
|
| 305 |
+
print(f'Debug: {st.session_state.text_search}')
|
| 306 |
+
input_embedding = averaged_glove_embeddings_gdrive(st.session_state.text_search,
|
| 307 |
word_index_dict,
|
| 308 |
embeddings, model_type)
|
| 309 |
|
|
|
|
| 315 |
|
| 316 |
else:
|
| 317 |
model_name = embeddings_metadata["model_name"]
|
| 318 |
+
if not "cat_embed_" + model_name in st.session_state:
|
| 319 |
get_category_embeddings(embeddings_metadata)
|
| 320 |
|
| 321 |
+
category_embeddings = st.session_state["cat_embed_" + model_name]
|
| 322 |
|
| 323 |
+
print("text_search = ", st.session_state.text_search)
|
| 324 |
if model_name:
|
| 325 |
+
input_embedding = get_sentence_transformer_embeddings(st.session_state.text_search, model_name=model_name)
|
| 326 |
else:
|
| 327 |
+
input_embedding = get_sentence_transformer_embeddings(st.session_state.text_search)
|
| 328 |
for index in range(len(categories)):
|
| 329 |
category = categories[index]
|
| 330 |
|
|
|
|
| 332 |
category_embedding = category_embeddings[category]
|
| 333 |
else:
|
| 334 |
category_embedding = get_sentence_transformer_embeddings(category, model_name=model_name)
|
| 335 |
+
st.session_state["cat_embed_" + model_name][category] = category_embedding
|
| 336 |
|
| 337 |
cosine_sim[index] = cosine_similarity(input_embedding, category_embedding)
|
| 338 |
|