Spaces:
Build error
Build error
| import pickle | |
| import os | |
| import gradio as gr | |
| import gradio as gr | |
| import pandas as pd | |
| from sentence_transformers import SentenceTransformer, util | |
| def encode_column(model, filename, col_name): | |
| df = pd.read_csv(filename) | |
| df["embedding"] = list(model.encode(df[col_name])) | |
| return df | |
| def item_level_ccr(data_encoded_df, questionnaire_encoded_df): | |
| q_embeddings = questionnaire_encoded_df.embedding | |
| d_embeddings = data_encoded_df.embedding | |
| similarities = util.pytorch_cos_sim(d_embeddings, q_embeddings) | |
| for i in range(1,len(questionnaire_encoded_df)+1): | |
| data_encoded_df["sim_item_{}".format(i)] = similarities[:, i-1] | |
| return data_encoded_df | |
| # encoding questionnaire | |
| def ccr_wrapper(data_file, data_col, q_file, q_col, model='all-MiniLM-L6-v2'): | |
| """ | |
| Returns a Dataframe that is the content of data_file with one additional column for CCR value per question | |
| Parameters: | |
| data_file (str): path to the file containing user text | |
| data_col (str): column that includes user text | |
| q_file (str): path to the file containing questionnaires | |
| q_col (str): column that includes questions | |
| model (str): name of the SBERT model to use for CCR see https://www.sbert.net/docs/pretrained_models.html for full list | |
| """ | |
| try: | |
| model = SentenceTransformer(model) | |
| except: | |
| print("model name was not included, using all-MiniLM-L6-v2") | |
| model = SentenceTransformer('all-MiniLM-L6-v2') | |
| questionnaire_filename = q_file.name | |
| data_filename = data_file.name | |
| q_encoded_df = encode_column(model, questionnaire_filename, q_col) | |
| data_encoded_df = encode_column(model, data_filename, data_col) | |
| ccr_df = item_level_ccr(data_encoded_df, q_encoded_df) | |
| ccr_df.to_csv("ccr_results.csv") | |
| return "ccr_results.csv" | |
| def read_dataframe(data_file, data_col, q_file, q_col): | |
| # df = pd.read_csv(data_file.name) | |
| return data_file.name | |
| def single_text_ccr(text, question): | |
| model = SentenceTransformer('all-MiniLM-L6-v2') | |
| text_embedding = model.encode(text) | |
| question_embedding = model.encode(question) | |
| return round(util.pytorch_cos_sim(text_embedding, question_embedding).item(),3) | |
| with gr.Blocks() as demo: | |
| # gr.Markdown('This is the first page for CCR, info goes here!') | |
| gr.Markdown("""<h1><center>Contextual Construct Representations</center></h1> | |
| <h3><center>Ali Omrani and Mohammad Atari</center></h3>""") | |
| gr.Markdown("""<br><h4>Play around with your items!</h4>""") | |
| with gr.Row(): | |
| user_txt = gr.Textbox(label="Input Text", placeholder="Enter your desired text here ...") | |
| question = gr.Textbox(label="Question", placeholder="Enter the question text here ...") | |
| submit2 = gr.Button("Get CCR for this Text!") | |
| submit2.click(single_text_ccr, inputs=[user_txt, question], outputs=gr.Textbox(label="CCR Value")) | |
| gr.Markdown("""<br><h4>Or process a whole file!</h4>""") | |
| with gr.Row(): | |
| model_name = gr.Dropdown(label="Choose the Model", | |
| choices=["all-mpnet-base-v2","multi-qa-mpnet-base-dot-v1", "distiluse-base-multilingual-cased-v2", | |
| "distiluse-base-multilingual-cased-v1", "paraphrase-MiniLM-L3-v2", "paraphrase-multilingual-MiniLM-L12-v2", | |
| "paraphrase-albert-small-v2", "paraphrase-multilingual-mpnet-base-v2", "multi-qa-MiniLM-L6-cos-v1", | |
| "all-MiniLM-L6-v2", "multi-qa-distilbert-cos-v1", "all-MiniLM-L12-v2", "all-distilroberta-v1"]) | |
| with gr.Row(): | |
| with gr.Column(): | |
| user_data = gr.File(label="Participant Data File") | |
| text_col = gr.Textbox(label="Text Column", placeholder="text column ... ") | |
| with gr.Column(): | |
| questionnaire_data = gr.File(label="Questionnaire File") | |
| q_col = gr.Textbox(label="Question Column", placeholder="questionnaire column ... ") | |
| submit = gr.Button("Get CCR!") | |
| outputs=gr.File() | |
| submit.click(ccr_wrapper, inputs=[user_data, text_col,questionnaire_data,q_col, model_name], outputs=[outputs]) | |
| demo.launch() |