Spaces:
Runtime error
Runtime error
| from InstructorEmbedding import INSTRUCTOR | |
| import streamlit as st | |
| import pandas as pd | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import numpy as np | |
| # if 'model' not in st.session_state: | |
| # st.session_state['model'] = INSTRUCTOR('hkunlp/instructor-large') | |
| if 'query_bool' not in st.session_state: | |
| st.session_state['query_bool'] = False | |
| if 'df' not in st.session_state: | |
| st.session_state['df'] = pd.DataFrame() | |
| def load_model(): | |
| return INSTRUCTOR('hkunlp/instructor-large') | |
| model = load_model() | |
| def embed_data(data): | |
| return model.encode(data) | |
| def process_data(df, desc, message, query): | |
| data = [ | |
| [ | |
| f'Represent the document for retrieval of {x[desc]} information : ', | |
| x[message] | |
| ] for _,x in df.iterrows() | |
| ] | |
| corpus_embeddings = embed_data(data) | |
| query_embeddings = model.encode(query) | |
| with st.spinner('Wait for it...'): | |
| similarities = cosine_similarity(query_embeddings,corpus_embeddings) | |
| retrieved_doc_id = np.argmax(similarities) | |
| st.markdown(f"{data[retrieved_doc_id][-1]}",unsafe_allow_html=True) | |
| # question = st.text_input("Question (Press Enter to query) :") | |
| # query = [['Represent the question for retrieving supporting documents: ',question]] | |
| # btn_q = st.button("Submit", key="submit_query") | |
| # if btn_q : | |
| # query = [['Represent the question for retrieving supporting documents: ',question]] | |
| # query_embeddings = model.encode(query) | |
| # with st.spinner('Wait for it...'): | |
| # similarities = cosine_similarity(query_embeddings,corpus_embeddings) | |
| # retrieved_doc_id = np.argmax(similarities) | |
| # st.markdown(f"{data[retrieved_doc_id][-1]}",unsafe_allow_html=True) | |
| opt = st.radio("Choose Data : ", ["intent.csv", "upload file CSV"], captions=["LMD CSV intent data", "Custom upload CSV data"]) | |
| if opt == "intent.csv": | |
| df = pd.read_csv("intent.csv", delimiter=";") | |
| st.dataframe(df) | |
| question = st.text_input("Question (Press Enter to query) :") | |
| btn_q = st.button("Submit", key="submit_query") | |
| if btn_q: | |
| query = [['Represent the question for retrieving supporting documents: ',question]] | |
| process_data(df, desc='description', message='message', query=query) | |
| else : | |
| f = st.file_uploader("Upload CSV File with at least 2 columns", ['xlsx', 'csv']) | |
| delim = st.text_input('CSV File Delimiter') | |
| btn = st.button("Submit", key="submit_first") | |
| if btn: | |
| # df = pd.read_csv(f, delimiter=delim) | |
| st.session_state.df = pd.read_csv(f, delimiter=delim) | |
| if len(list(st.session_state.df.columns)) < 2 : | |
| st.write("FAILED! At least 2 columns needed. Please check your dataset") | |
| else : | |
| st.session_state.query_bool = True | |
| if st.session_state.query_bool: | |
| st.dataframe(st.session_state.df) | |
| with st.form("my_form"): | |
| cols = list(st.session_state.df.columns) | |
| desc = st.radio("Description Column (e.g. description)", cols) | |
| message = st.radio("Response Template Column (e.g. message)", cols) | |
| question = st.text_input("Question (Press Enter to query) :") | |
| submitted = st.form_submit_button("submit") | |
| if submitted: | |
| query = [['Represent the question for retrieving supporting documents: ',question]] | |
| process_data(st.session_state.df, desc, message, query=query) | |