lmd_chatbot / app.py
jonathanjordan21's picture
Update app.py
57982ad
from InstructorEmbedding import INSTRUCTOR
import streamlit as st
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
# if 'model' not in st.session_state:
# st.session_state['model'] = INSTRUCTOR('hkunlp/instructor-large')
if 'query_bool' not in st.session_state:
st.session_state['query_bool'] = False
if 'df' not in st.session_state:
st.session_state['df'] = pd.DataFrame()
@st.cache_resource
def load_model():
return INSTRUCTOR('hkunlp/instructor-large')
model = load_model()
@st.cache_data(persist='disk')
def embed_data(data):
return model.encode(data)
def process_data(df, desc, message, query):
data = [
[
f'Represent the document for retrieval of {x[desc]} information : ',
x[message]
] for _,x in df.iterrows()
]
corpus_embeddings = embed_data(data)
query_embeddings = model.encode(query)
with st.spinner('Wait for it...'):
similarities = cosine_similarity(query_embeddings,corpus_embeddings)
retrieved_doc_id = np.argmax(similarities)
st.markdown(f"{data[retrieved_doc_id][-1]}",unsafe_allow_html=True)
# question = st.text_input("Question (Press Enter to query) :")
# query = [['Represent the question for retrieving supporting documents: ',question]]
# btn_q = st.button("Submit", key="submit_query")
# if btn_q :
# query = [['Represent the question for retrieving supporting documents: ',question]]
# query_embeddings = model.encode(query)
# with st.spinner('Wait for it...'):
# similarities = cosine_similarity(query_embeddings,corpus_embeddings)
# retrieved_doc_id = np.argmax(similarities)
# st.markdown(f"{data[retrieved_doc_id][-1]}",unsafe_allow_html=True)
opt = st.radio("Choose Data : ", ["intent.csv", "upload file CSV"], captions=["LMD CSV intent data", "Custom upload CSV data"])
if opt == "intent.csv":
df = pd.read_csv("intent.csv", delimiter=";")
st.dataframe(df)
question = st.text_input("Question (Press Enter to query) :")
btn_q = st.button("Submit", key="submit_query")
if btn_q:
query = [['Represent the question for retrieving supporting documents: ',question]]
process_data(df, desc='description', message='message', query=query)
else :
f = st.file_uploader("Upload CSV File with at least 2 columns", ['xlsx', 'csv'])
delim = st.text_input('CSV File Delimiter')
btn = st.button("Submit", key="submit_first")
if btn:
# df = pd.read_csv(f, delimiter=delim)
st.session_state.df = pd.read_csv(f, delimiter=delim)
if len(list(st.session_state.df.columns)) < 2 :
st.write("FAILED! At least 2 columns needed. Please check your dataset")
else :
st.session_state.query_bool = True
if st.session_state.query_bool:
st.dataframe(st.session_state.df)
with st.form("my_form"):
cols = list(st.session_state.df.columns)
desc = st.radio("Description Column (e.g. description)", cols)
message = st.radio("Response Template Column (e.g. message)", cols)
question = st.text_input("Question (Press Enter to query) :")
submitted = st.form_submit_button("submit")
if submitted:
query = [['Represent the question for retrieving supporting documents: ',question]]
process_data(st.session_state.df, desc, message, query=query)