Ali-Omrani commited on
Commit
9e7d9ab
·
1 Parent(s): 25d05a5

Added dropdown for models

Browse files
Files changed (1) hide show
  1. app.py +35 -19
app.py CHANGED
@@ -6,32 +6,44 @@ import pandas as pd
6
  from sentence_transformers import SentenceTransformer, util
7
 
8
  def encode_column(model, filename, col_name):
9
- df = pd.read_csv(filename)
10
- df["embedding"] = list(model.encode(df[col_name]))
11
- return df
12
 
13
  def item_level_ccr(data_encoded_df, questionnaire_encoded_df):
14
- q_embeddings = questionnaire_encoded_df.embedding
15
- d_embeddings = data_encoded_df.embedding
16
- similarities = util.pytorch_cos_sim(d_embeddings, q_embeddings)
17
- for i in range(1,len(questionnaire_encoded_df)+1):
18
- data_encoded_df["sim_item_{}".format(i)] = similarities[:, i-1]
19
- return data_encoded_df
20
 
21
  # encoding questionnaire
22
- def ccr_wrapper(data_file, data_col, q_file, q_col):
23
- model = SentenceTransformer('all-MiniLM-L6-v2')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  questionnaire_filename = q_file.name
26
  data_filename = data_file.name
27
-
28
 
29
  q_encoded_df = encode_column(model, questionnaire_filename, q_col)
30
-
31
  data_encoded_df = encode_column(model, data_filename, data_col)
32
-
33
  ccr_df = item_level_ccr(data_encoded_df, q_encoded_df)
34
- # ccr_df = ccr_df.drop(columns=["embeddings"])
35
 
36
  ccr_df.to_csv("ccr_results.csv")
37
  return "ccr_results.csv"
@@ -72,6 +84,12 @@ with gr.Blocks() as demo:
72
 
73
  gr.Markdown("""<br><h4>Or process a whole file!</h4>""")
74
 
 
 
 
 
 
 
75
  with gr.Row():
76
  with gr.Column():
77
  user_data = gr.File(label="Participant Data File")
@@ -83,7 +101,5 @@ with gr.Blocks() as demo:
83
  submit = gr.Button("Get CCR!")
84
 
85
  outputs=gr.File()
86
- submit.click(ccr_wrapper, inputs=[user_data, text_col,questionnaire_data,q_col], outputs=[outputs])
87
- demo.launch()
88
-
89
-
 
6
  from sentence_transformers import SentenceTransformer, util
7
 
8
  def encode_column(model, filename, col_name):
9
+ df = pd.read_csv(filename)
10
+ df["embedding"] = list(model.encode(df[col_name]))
11
+ return df
12
 
13
  def item_level_ccr(data_encoded_df, questionnaire_encoded_df):
14
+ q_embeddings = questionnaire_encoded_df.embedding
15
+ d_embeddings = data_encoded_df.embedding
16
+ similarities = util.pytorch_cos_sim(d_embeddings, q_embeddings)
17
+ for i in range(1,len(questionnaire_encoded_df)+1):
18
+ data_encoded_df["sim_item_{}".format(i)] = similarities[:, i-1]
19
+ return data_encoded_df
20
 
21
  # encoding questionnaire
22
+ def ccr_wrapper(data_file, data_col, q_file, q_col, model='all-MiniLM-L6-v2'):
23
+ """
24
+ Returns a Dataframe that is the content of data_file with one additional column for CCR value per question
25
+
26
+ Parameters:
27
+ data_file (str): path to the file containing user text
28
+ data_col (str): column that includes user text
29
+ q_file (str): path to the file containing questionnaires
30
+ q_col (str): column that includes questions
31
+ model (str): name of the SBERT model to use for CCR see https://www.sbert.net/docs/pretrained_models.html for full list
32
+
33
+ """
34
+ try:
35
+ model = SentenceTransformer(model)
36
+ except:
37
+ print("model name was not included, using all-MiniLM-L6-v2")
38
+ model = SentenceTransformer('all-MiniLM-L6-v2')
39
 
40
  questionnaire_filename = q_file.name
41
  data_filename = data_file.name
 
42
 
43
  q_encoded_df = encode_column(model, questionnaire_filename, q_col)
 
44
  data_encoded_df = encode_column(model, data_filename, data_col)
 
45
  ccr_df = item_level_ccr(data_encoded_df, q_encoded_df)
46
+
47
 
48
  ccr_df.to_csv("ccr_results.csv")
49
  return "ccr_results.csv"
 
84
 
85
  gr.Markdown("""<br><h4>Or process a whole file!</h4>""")
86
 
87
+ with gr.Row():
88
+ model_name = gr.Dropdown(label="Choose the Model",
89
+ choices=["all-mpnet-base-v2","multi-qa-mpnet-base-dot-v1", "distiluse-base-multilingual-cased-v2",
90
+ "distiluse-base-multilingual-cased-v1", "paraphrase-MiniLM-L3-v2", "paraphrase-multilingual-MiniLM-L12-v2",
91
+ "paraphrase-albert-small-v2", "paraphrase-multilingual-mpnet-base-v2", "multi-qa-MiniLM-L6-cos-v1",
92
+ "all-MiniLM-L6-v2", "multi-qa-distilbert-cos-v1", "all-MiniLM-L12-v2", "all-distilroberta-v1"])
93
  with gr.Row():
94
  with gr.Column():
95
  user_data = gr.File(label="Participant Data File")
 
101
  submit = gr.Button("Get CCR!")
102
 
103
  outputs=gr.File()
104
+ submit.click(ccr_wrapper, inputs=[user_data, text_col,questionnaire_data,q_col, model_name], outputs=[outputs])
105
+ demo.launch()