yuanjunchai commited on
Commit ·
9d0e21e
1
Parent(s): da456c0
add application files
Browse files
app.py
CHANGED
|
@@ -285,7 +285,7 @@ def averaged_glove_embeddings_gdrive(sentence, word_index_dict, embeddings, mode
|
|
| 285 |
|
| 286 |
|
| 287 |
# Task III: Sort the cosine similarity
|
| 288 |
-
def get_sorted_cosine_similarity(embeddings_metadata):
|
| 289 |
"""
|
| 290 |
Get sorted cosine similarity between input sentence and categories
|
| 291 |
Steps:
|
|
@@ -296,14 +296,14 @@ def get_sorted_cosine_similarity(embeddings_metadata):
|
|
| 296 |
5. Return sorted cosine similarity
|
| 297 |
(50 pts)
|
| 298 |
"""
|
| 299 |
-
categories =
|
| 300 |
cosine_sim = {}
|
| 301 |
if embeddings_metadata["embedding_model"] == "glove":
|
| 302 |
word_index_dict = embeddings_metadata["word_index_dict"]
|
| 303 |
embeddings = embeddings_metadata["embeddings"]
|
| 304 |
model_type = embeddings_metadata["model_type"]
|
| 305 |
-
|
| 306 |
-
input_embedding = averaged_glove_embeddings_gdrive(
|
| 307 |
word_index_dict,
|
| 308 |
embeddings, model_type)
|
| 309 |
|
|
@@ -315,16 +315,16 @@ def get_sorted_cosine_similarity(embeddings_metadata):
|
|
| 315 |
|
| 316 |
else:
|
| 317 |
model_name = embeddings_metadata["model_name"]
|
| 318 |
-
if not "cat_embed_" + model_name in
|
| 319 |
get_category_embeddings(embeddings_metadata)
|
| 320 |
|
| 321 |
-
category_embeddings =
|
| 322 |
|
| 323 |
-
print("text_search = ",
|
| 324 |
if model_name:
|
| 325 |
-
input_embedding = get_sentence_transformer_embeddings(
|
| 326 |
else:
|
| 327 |
-
input_embedding = get_sentence_transformer_embeddings(
|
| 328 |
for index in range(len(categories)):
|
| 329 |
category = categories[index]
|
| 330 |
|
|
@@ -332,7 +332,7 @@ def get_sorted_cosine_similarity(embeddings_metadata):
|
|
| 332 |
category_embedding = category_embeddings[category]
|
| 333 |
else:
|
| 334 |
category_embedding = get_sentence_transformer_embeddings(category, model_name=model_name)
|
| 335 |
-
|
| 336 |
|
| 337 |
cosine_sim[index] = cosine_similarity(input_embedding, category_embedding)
|
| 338 |
|
|
@@ -406,7 +406,6 @@ if __name__ == "__main__":
|
|
| 406 |
|
| 407 |
# Find closest word to an input word
|
| 408 |
if st.session_state.text_search:
|
| 409 |
-
print(f"Debug Text Search = {st.session_state.text_search}")
|
| 410 |
# Glove embeddings
|
| 411 |
print("Glove Embedding")
|
| 412 |
embeddings_metadata = {
|
|
@@ -417,7 +416,7 @@ if __name__ == "__main__":
|
|
| 417 |
}
|
| 418 |
with st.spinner("Obtaining Cosine similarity for Glove..."):
|
| 419 |
sorted_cosine_sim_glove = get_sorted_cosine_similarity(
|
| 420 |
-
embeddings_metadata
|
| 421 |
)
|
| 422 |
|
| 423 |
# Sentence transformer embeddings
|
|
@@ -425,7 +424,7 @@ if __name__ == "__main__":
|
|
| 425 |
embeddings_metadata = {"embedding_model": "transformers", "model_name": ""}
|
| 426 |
with st.spinner("Obtaining Cosine similarity for 384d sentence transformer..."):
|
| 427 |
sorted_cosine_sim_transformer = get_sorted_cosine_similarity(
|
| 428 |
-
embeddings_metadata
|
| 429 |
)
|
| 430 |
|
| 431 |
# Results and Plot Pie Chart for Glove
|
|
|
|
| 285 |
|
| 286 |
|
| 287 |
# Task III: Sort the cosine similarity
|
| 288 |
+
def get_sorted_cosine_similarity(st_session_state, embeddings_metadata):
|
| 289 |
"""
|
| 290 |
Get sorted cosine similarity between input sentence and categories
|
| 291 |
Steps:
|
|
|
|
| 296 |
5. Return sorted cosine similarity
|
| 297 |
(50 pts)
|
| 298 |
"""
|
| 299 |
+
categories = st_session_state.categories.split(" ")
|
| 300 |
cosine_sim = {}
|
| 301 |
if embeddings_metadata["embedding_model"] == "glove":
|
| 302 |
word_index_dict = embeddings_metadata["word_index_dict"]
|
| 303 |
embeddings = embeddings_metadata["embeddings"]
|
| 304 |
model_type = embeddings_metadata["model_type"]
|
| 305 |
+
print(f'Debug: {st_session_state.text_search}')
|
| 306 |
+
input_embedding = averaged_glove_embeddings_gdrive(st_session_state.text_search,
|
| 307 |
word_index_dict,
|
| 308 |
embeddings, model_type)
|
| 309 |
|
|
|
|
| 315 |
|
| 316 |
else:
|
| 317 |
model_name = embeddings_metadata["model_name"]
|
| 318 |
+
if not "cat_embed_" + model_name in st_session_state:
|
| 319 |
get_category_embeddings(embeddings_metadata)
|
| 320 |
|
| 321 |
+
category_embeddings = st_session_state["cat_embed_" + model_name]
|
| 322 |
|
| 323 |
+
print("text_search = ", st_session_state.text_search)
|
| 324 |
if model_name:
|
| 325 |
+
input_embedding = get_sentence_transformer_embeddings(st_session_state.text_search, model_name=model_name)
|
| 326 |
else:
|
| 327 |
+
input_embedding = get_sentence_transformer_embeddings(st_session_state.text_search)
|
| 328 |
for index in range(len(categories)):
|
| 329 |
category = categories[index]
|
| 330 |
|
|
|
|
| 332 |
category_embedding = category_embeddings[category]
|
| 333 |
else:
|
| 334 |
category_embedding = get_sentence_transformer_embeddings(category, model_name=model_name)
|
| 335 |
+
st_session_state["cat_embed_" + model_name][category] = category_embedding
|
| 336 |
|
| 337 |
cosine_sim[index] = cosine_similarity(input_embedding, category_embedding)
|
| 338 |
|
|
|
|
| 406 |
|
| 407 |
# Find closest word to an input word
|
| 408 |
if st.session_state.text_search:
|
|
|
|
| 409 |
# Glove embeddings
|
| 410 |
print("Glove Embedding")
|
| 411 |
embeddings_metadata = {
|
|
|
|
| 416 |
}
|
| 417 |
with st.spinner("Obtaining Cosine similarity for Glove..."):
|
| 418 |
sorted_cosine_sim_glove = get_sorted_cosine_similarity(
|
| 419 |
+
st.session_state.text_search, embeddings_metadata
|
| 420 |
)
|
| 421 |
|
| 422 |
# Sentence transformer embeddings
|
|
|
|
| 424 |
embeddings_metadata = {"embedding_model": "transformers", "model_name": ""}
|
| 425 |
with st.spinner("Obtaining Cosine similarity for 384d sentence transformer..."):
|
| 426 |
sorted_cosine_sim_transformer = get_sorted_cosine_similarity(
|
| 427 |
+
st.session_state.text_search, embeddings_metadata
|
| 428 |
)
|
| 429 |
|
| 430 |
# Results and Plot Pie Chart for Glove
|