yuanjunchai commited on
Commit
2596ac1
·
1 Parent(s): 9d0e21e

add application files

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -285,7 +285,7 @@ def averaged_glove_embeddings_gdrive(sentence, word_index_dict, embeddings, mode
285
 
286
 
287
  # Task III: Sort the cosine similarity
288
- def get_sorted_cosine_similarity(st_session_state, embeddings_metadata):
289
  """
290
  Get sorted cosine similarity between input sentence and categories
291
  Steps:
@@ -296,14 +296,14 @@ def get_sorted_cosine_similarity(st_session_state, embeddings_metadata):
296
  5. Return sorted cosine similarity
297
  (50 pts)
298
  """
299
- categories = st_session_state.categories.split(" ")
300
  cosine_sim = {}
301
  if embeddings_metadata["embedding_model"] == "glove":
302
  word_index_dict = embeddings_metadata["word_index_dict"]
303
  embeddings = embeddings_metadata["embeddings"]
304
  model_type = embeddings_metadata["model_type"]
305
- print(f'Debug: {st_session_state.text_search}')
306
- input_embedding = averaged_glove_embeddings_gdrive(st_session_state.text_search,
307
  word_index_dict,
308
  embeddings, model_type)
309
 
@@ -315,16 +315,16 @@ def get_sorted_cosine_similarity(st_session_state, embeddings_metadata):
315
 
316
  else:
317
  model_name = embeddings_metadata["model_name"]
318
- if not "cat_embed_" + model_name in st_session_state:
319
  get_category_embeddings(embeddings_metadata)
320
 
321
- category_embeddings = st_session_state["cat_embed_" + model_name]
322
 
323
- print("text_search = ", st_session_state.text_search)
324
  if model_name:
325
- input_embedding = get_sentence_transformer_embeddings(st_session_state.text_search, model_name=model_name)
326
  else:
327
- input_embedding = get_sentence_transformer_embeddings(st_session_state.text_search)
328
  for index in range(len(categories)):
329
  category = categories[index]
330
 
@@ -332,7 +332,7 @@ def get_sorted_cosine_similarity(st_session_state, embeddings_metadata):
332
  category_embedding = category_embeddings[category]
333
  else:
334
  category_embedding = get_sentence_transformer_embeddings(category, model_name=model_name)
335
- st_session_state["cat_embed_" + model_name][category] = category_embedding
336
 
337
  cosine_sim[index] = cosine_similarity(input_embedding, category_embedding)
338
 
 
285
 
286
 
287
  # Task III: Sort the cosine similarity
288
+ def get_sorted_cosine_similarity(st.session_state.text_search, embeddings_metadata):
289
  """
290
  Get sorted cosine similarity between input sentence and categories
291
  Steps:
 
296
  5. Return sorted cosine similarity
297
  (50 pts)
298
  """
299
+ categories = st.session_state.categories.split(" ")
300
  cosine_sim = {}
301
  if embeddings_metadata["embedding_model"] == "glove":
302
  word_index_dict = embeddings_metadata["word_index_dict"]
303
  embeddings = embeddings_metadata["embeddings"]
304
  model_type = embeddings_metadata["model_type"]
305
+ print(f'Debug: {st.session_state.text_search}')
306
+ input_embedding = averaged_glove_embeddings_gdrive(st.session_state.text_search,
307
  word_index_dict,
308
  embeddings, model_type)
309
 
 
315
 
316
  else:
317
  model_name = embeddings_metadata["model_name"]
318
+ if not "cat_embed_" + model_name in st.session_state:
319
  get_category_embeddings(embeddings_metadata)
320
 
321
+ category_embeddings = st.session_state["cat_embed_" + model_name]
322
 
323
+ print("text_search = ", st.session_state.text_search)
324
  if model_name:
325
+ input_embedding = get_sentence_transformer_embeddings(st.session_state.text_search, model_name=model_name)
326
  else:
327
+ input_embedding = get_sentence_transformer_embeddings(st.session_state.text_search)
328
  for index in range(len(categories)):
329
  category = categories[index]
330
 
 
332
  category_embedding = category_embeddings[category]
333
  else:
334
  category_embedding = get_sentence_transformer_embeddings(category, model_name=model_name)
335
+ st.session_state["cat_embed_" + model_name][category] = category_embedding
336
 
337
  cosine_sim[index] = cosine_similarity(input_embedding, category_embedding)
338