yuanjunchai commited on
Commit
9d0e21e
·
1 Parent(s): da456c0

add application files

Browse files
Files changed (1) hide show
  1. app.py +12 -13
app.py CHANGED
@@ -285,7 +285,7 @@ def averaged_glove_embeddings_gdrive(sentence, word_index_dict, embeddings, mode
285
 
286
 
287
  # Task III: Sort the cosine similarity
288
- def get_sorted_cosine_similarity(embeddings_metadata):
289
  """
290
  Get sorted cosine similarity between input sentence and categories
291
  Steps:
@@ -296,14 +296,14 @@ def get_sorted_cosine_similarity(embeddings_metadata):
296
  5. Return sorted cosine similarity
297
  (50 pts)
298
  """
299
- categories = st.session_state.categories.split(" ")
300
  cosine_sim = {}
301
  if embeddings_metadata["embedding_model"] == "glove":
302
  word_index_dict = embeddings_metadata["word_index_dict"]
303
  embeddings = embeddings_metadata["embeddings"]
304
  model_type = embeddings_metadata["model_type"]
305
-
306
- input_embedding = averaged_glove_embeddings_gdrive(st.session_state.text_search,
307
  word_index_dict,
308
  embeddings, model_type)
309
 
@@ -315,16 +315,16 @@ def get_sorted_cosine_similarity(embeddings_metadata):
315
 
316
  else:
317
  model_name = embeddings_metadata["model_name"]
318
- if not "cat_embed_" + model_name in st.session_state:
319
  get_category_embeddings(embeddings_metadata)
320
 
321
- category_embeddings = st.session_state["cat_embed_" + model_name]
322
 
323
- print("text_search = ", st.session_state.text_search)
324
  if model_name:
325
- input_embedding = get_sentence_transformer_embeddings(st.session_state.text_search, model_name=model_name)
326
  else:
327
- input_embedding = get_sentence_transformer_embeddings(st.session_state.text_search)
328
  for index in range(len(categories)):
329
  category = categories[index]
330
 
@@ -332,7 +332,7 @@ def get_sorted_cosine_similarity(embeddings_metadata):
332
  category_embedding = category_embeddings[category]
333
  else:
334
  category_embedding = get_sentence_transformer_embeddings(category, model_name=model_name)
335
- st.session_state["cat_embed_" + model_name][category] = category_embedding
336
 
337
  cosine_sim[index] = cosine_similarity(input_embedding, category_embedding)
338
 
@@ -406,7 +406,6 @@ if __name__ == "__main__":
406
 
407
  # Find closest word to an input word
408
  if st.session_state.text_search:
409
- print(f"Debug Text Search = {st.session_state.text_search}")
410
  # Glove embeddings
411
  print("Glove Embedding")
412
  embeddings_metadata = {
@@ -417,7 +416,7 @@ if __name__ == "__main__":
417
  }
418
  with st.spinner("Obtaining Cosine similarity for Glove..."):
419
  sorted_cosine_sim_glove = get_sorted_cosine_similarity(
420
- embeddings_metadata
421
  )
422
 
423
  # Sentence transformer embeddings
@@ -425,7 +424,7 @@ if __name__ == "__main__":
425
  embeddings_metadata = {"embedding_model": "transformers", "model_name": ""}
426
  with st.spinner("Obtaining Cosine similarity for 384d sentence transformer..."):
427
  sorted_cosine_sim_transformer = get_sorted_cosine_similarity(
428
- embeddings_metadata
429
  )
430
 
431
  # Results and Plot Pie Chart for Glove
 
285
 
286
 
287
  # Task III: Sort the cosine similarity
288
+ def get_sorted_cosine_similarity(st_session_state, embeddings_metadata):
289
  """
290
  Get sorted cosine similarity between input sentence and categories
291
  Steps:
 
296
  5. Return sorted cosine similarity
297
  (50 pts)
298
  """
299
+ categories = st_session_state.categories.split(" ")
300
  cosine_sim = {}
301
  if embeddings_metadata["embedding_model"] == "glove":
302
  word_index_dict = embeddings_metadata["word_index_dict"]
303
  embeddings = embeddings_metadata["embeddings"]
304
  model_type = embeddings_metadata["model_type"]
305
+ print(f'Debug: {st_session_state.text_search}')
306
+ input_embedding = averaged_glove_embeddings_gdrive(st_session_state.text_search,
307
  word_index_dict,
308
  embeddings, model_type)
309
 
 
315
 
316
  else:
317
  model_name = embeddings_metadata["model_name"]
318
+ if not "cat_embed_" + model_name in st_session_state:
319
  get_category_embeddings(embeddings_metadata)
320
 
321
+ category_embeddings = st_session_state["cat_embed_" + model_name]
322
 
323
+ print("text_search = ", st_session_state.text_search)
324
  if model_name:
325
+ input_embedding = get_sentence_transformer_embeddings(st_session_state.text_search, model_name=model_name)
326
  else:
327
+ input_embedding = get_sentence_transformer_embeddings(st_session_state.text_search)
328
  for index in range(len(categories)):
329
  category = categories[index]
330
 
 
332
  category_embedding = category_embeddings[category]
333
  else:
334
  category_embedding = get_sentence_transformer_embeddings(category, model_name=model_name)
335
+ st_session_state["cat_embed_" + model_name][category] = category_embedding
336
 
337
  cosine_sim[index] = cosine_similarity(input_embedding, category_embedding)
338
 
 
406
 
407
  # Find closest word to an input word
408
  if st.session_state.text_search:
 
409
  # Glove embeddings
410
  print("Glove Embedding")
411
  embeddings_metadata = {
 
416
  }
417
  with st.spinner("Obtaining Cosine similarity for Glove..."):
418
  sorted_cosine_sim_glove = get_sorted_cosine_similarity(
419
+ st.session_state.text_search, embeddings_metadata
420
  )
421
 
422
  # Sentence transformer embeddings
 
424
  embeddings_metadata = {"embedding_model": "transformers", "model_name": ""}
425
  with st.spinner("Obtaining Cosine similarity for 384d sentence transformer..."):
426
  sorted_cosine_sim_transformer = get_sorted_cosine_similarity(
427
+ st.session_state.text_search, embeddings_metadata
428
  )
429
 
430
  # Results and Plot Pie Chart for Glove