File size: 3,589 Bytes
7eea2bf
 
 
 
96769b9
7eea2bf
d85ff42
7eea2bf
 
d85ff42
db2cb32
 
d85ff42
 
 
7eea2bf
 
 
 
 
 
65973e7
 
93ab5aa
 
7eea2bf
26977ad
7eea2bf
 
88a1e29
7eea2bf
 
 
 
26977ad
01668e5
 
7af92a3
01668e5
 
 
 
 
 
 
7af92a3
 
 
 
 
 
 
 
 
 
7eea2bf
 
 
 
 
 
 
26977ad
176a001
 
 
 
57982ad
7eea2bf
 
 
 
59503a1
 
 
7eea2bf
d85ff42
 
7eea2bf
d85ff42
7eea2bf
 
26339c2
26977ad
db2cb32
26339c2
0caa784
db2cb32
d85ff42
7e74485
 
01668e5
db2cb32
01668e5
db2cb32
01668e5
26977ad
7eea2bf
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from InstructorEmbedding import INSTRUCTOR
import streamlit as st
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

# if 'model' not in st.session_state:
#     st.session_state['model'] = INSTRUCTOR('hkunlp/instructor-large')

if 'query_bool' not in st.session_state:
    st.session_state['query_bool'] = False

if 'df' not in st.session_state:
    st.session_state['df'] = pd.DataFrame()

@st.cache_resource
def load_model():
    return INSTRUCTOR('hkunlp/instructor-large')

model = load_model()

@st.cache_data(persist='disk')
def embed_data(data):
    return model.encode(data)
    

def process_data(df, desc, message, query):
    data = [
        [
            f'Represent the document for retrieval of {x[desc]} information : ',
            x[message]
        ] for _,x in df.iterrows()
    ]

    corpus_embeddings = embed_data(data)
    
    query_embeddings = model.encode(query)

    with st.spinner('Wait for it...'):
        similarities = cosine_similarity(query_embeddings,corpus_embeddings)
        retrieved_doc_id = np.argmax(similarities)
    st.markdown(f"{data[retrieved_doc_id][-1]}",unsafe_allow_html=True)
    
    # question = st.text_input("Question (Press Enter to query) :")
    # query  = [['Represent the question for retrieving supporting documents: ',question]]
    # btn_q = st.button("Submit", key="submit_query")

    # if btn_q :
    #     query  = [['Represent the question for retrieving supporting documents: ',question]]
    #     query_embeddings = model.encode(query)

    #     with st.spinner('Wait for it...'):
    #         similarities = cosine_similarity(query_embeddings,corpus_embeddings)
    #         retrieved_doc_id = np.argmax(similarities)
    #     st.markdown(f"{data[retrieved_doc_id][-1]}",unsafe_allow_html=True)
    

opt = st.radio("Choose Data : ", ["intent.csv", "upload file CSV"], captions=["LMD CSV intent data", "Custom upload CSV data"])


if opt == "intent.csv":
    df = pd.read_csv("intent.csv", delimiter=";")
    st.dataframe(df)
    question = st.text_input("Question (Press Enter to query) :")
    btn_q = st.button("Submit", key="submit_query")
    if btn_q:
        query  = [['Represent the question for retrieving supporting documents: ',question]]
        process_data(df, desc='description', message='message', query=query)

else :
    f = st.file_uploader("Upload CSV File with at least 2 columns", ['xlsx', 'csv'])
    delim = st.text_input('CSV File Delimiter')
    btn = st.button("Submit", key="submit_first")


    if btn:
        # df = pd.read_csv(f, delimiter=delim)
        st.session_state.df = pd.read_csv(f, delimiter=delim)
        
        if len(list(st.session_state.df.columns)) < 2 :
            st.write("FAILED! At least 2 columns needed. Please check your dataset")
        else :
            st.session_state.query_bool = True
            

    if st.session_state.query_bool:
        st.dataframe(st.session_state.df)
        with st.form("my_form"):
            cols = list(st.session_state.df.columns)
            desc = st.radio("Description Column (e.g. description)", cols)
            message = st.radio("Response Template Column (e.g. message)", cols)
            question = st.text_input("Question (Press Enter to query) :")
            submitted = st.form_submit_button("submit")
            
            if submitted:
                query  = [['Represent the question for retrieving supporting documents: ',question]]
                process_data(st.session_state.df, desc, message, query=query)