File size: 4,201 Bytes
50b3c03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87dd156
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import pickle
import os
import gradio as gr
import gradio as gr
import pandas as pd
from sentence_transformers import SentenceTransformer, util

def encode_column(model, filename, col_name):
    df = pd.read_csv(filename)
    df["embedding"] = list(model.encode(df[col_name]))
    return df

def item_level_ccr(data_encoded_df, questionnaire_encoded_df):
    q_embeddings = questionnaire_encoded_df.embedding
    d_embeddings = data_encoded_df.embedding
    similarities = util.pytorch_cos_sim(d_embeddings, q_embeddings)
    for i in range(1,len(questionnaire_encoded_df)+1):
        data_encoded_df["sim_item_{}".format(i)] = similarities[:, i-1]
    return data_encoded_df

# encoding questionnaire
def ccr_wrapper(data_file, data_col, q_file, q_col, model='all-MiniLM-L6-v2'):
    """
    Returns a Dataframe that is the content of data_file with one additional column for CCR value per question
    Parameters:
        data_file (str): path to the file containing user text
        data_col (str): column that includes user text
        q_file (str): path to the file containing questionnaires
        q_col (str): column that includes questions
        model (str): name of the SBERT model to use for CCR see https://www.sbert.net/docs/pretrained_models.html for full list
    """
    try:
        model = SentenceTransformer(model)
    except:
        print("model name was not included, using all-MiniLM-L6-v2")
        model = SentenceTransformer('all-MiniLM-L6-v2')

    questionnaire_filename = q_file.name
    data_filename = data_file.name

    q_encoded_df = encode_column(model, questionnaire_filename, q_col)
    data_encoded_df = encode_column(model, data_filename, data_col)
    ccr_df = item_level_ccr(data_encoded_df, q_encoded_df)


    ccr_df.to_csv("ccr_results.csv")
    return "ccr_results.csv"
    


def read_dataframe(data_file, data_col, q_file, q_col):

    # df = pd.read_csv(data_file.name)
    return data_file.name



def single_text_ccr(text, question):
    model = SentenceTransformer('all-MiniLM-L6-v2')
    text_embedding = model.encode(text)
    question_embedding = model.encode(question)
    return round(util.pytorch_cos_sim(text_embedding, question_embedding).item(),3)
    




with gr.Blocks() as demo:
    # gr.Markdown('This is the first page for CCR, info goes here!')
    gr.Markdown("""<h1><center>Contextual Construct Representations</center></h1>
    <h3><center>Ali Omrani and Mohammad Atari</center></h3>""")
    
    gr.Markdown("""<br><h4>Play around with your items!</h4>""")

    with gr.Row():
        user_txt = gr.Textbox(label="Input Text", placeholder="Enter your desired text here ...")
        question = gr.Textbox(label="Question", placeholder="Enter the question text here ...") 

    submit2 = gr.Button("Get CCR for this Text!")

    submit2.click(single_text_ccr, inputs=[user_txt, question], outputs=gr.Textbox(label="CCR Value"))
    
    gr.Markdown("""<br><h4>Or process a whole file!</h4>""")

    with  gr.Row():
        model_name = gr.Dropdown(label="Choose the Model", 
                                 choices=["all-mpnet-base-v2","multi-qa-mpnet-base-dot-v1", "distiluse-base-multilingual-cased-v2",
                                          "distiluse-base-multilingual-cased-v1", "paraphrase-MiniLM-L3-v2", "paraphrase-multilingual-MiniLM-L12-v2",
                                         "paraphrase-albert-small-v2", "paraphrase-multilingual-mpnet-base-v2", "multi-qa-MiniLM-L6-cos-v1",
                                         "all-MiniLM-L6-v2", "multi-qa-distilbert-cos-v1", "all-MiniLM-L12-v2", "all-distilroberta-v1"])
    with  gr.Row():
        with gr.Column():
            user_data = gr.File(label="Participant Data File")
            text_col = gr.Textbox(label="Text Column", placeholder="text column ... ")
        with gr.Column():
            questionnaire_data = gr.File(label="Questionnaire File")
            q_col = gr.Textbox(label="Question Column", placeholder="questionnaire column ... ") 

    submit = gr.Button("Get CCR!")
    
    outputs=gr.File()
    submit.click(ccr_wrapper, inputs=[user_data, text_col,questionnaire_data,q_col, model_name], outputs=[outputs])
demo.launch(share=True)