Spaces:
Sleeping
Sleeping
Sandaruth
commited on
Commit
·
2ffda8f
1
Parent(s):
870ee5f
update model
Browse files- Retrieval.py +34 -0
- app.py +19 -15
- model.py +1 -11
Retrieval.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from model import llm, vectorstore, splitter, embedding, QA_PROMPT
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Chain for Web
|
| 6 |
+
from langchain.chains import RetrievalQA
|
| 7 |
+
|
| 8 |
+
bsic_chain = RetrievalQA.from_chain_type(
|
| 9 |
+
llm=llm,
|
| 10 |
+
chain_type="stuff",
|
| 11 |
+
retriever = vectorstore.as_retriever(search_kwargs={"k": 4}),
|
| 12 |
+
return_source_documents= True,
|
| 13 |
+
input_key="question",
|
| 14 |
+
chain_type_kwargs={"prompt": QA_PROMPT},
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
from langchain.retrievers.multi_query import MultiQueryRetriever
|
| 20 |
+
# from kk import MultiQueryRetriever
|
| 21 |
+
|
| 22 |
+
retriever_from_llm = MultiQueryRetriever.from_llm(
|
| 23 |
+
retriever=vectorstore.as_retriever(search_kwargs={"k": 3}),
|
| 24 |
+
llm=llm,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
multiQuery_chain = RetrievalQA.from_chain_type(
|
| 28 |
+
llm=llm,
|
| 29 |
+
chain_type="stuff",
|
| 30 |
+
retriever = retriever_from_llm,
|
| 31 |
+
return_source_documents= True,
|
| 32 |
+
input_key="question",
|
| 33 |
+
chain_type_kwargs={"prompt": QA_PROMPT},
|
| 34 |
+
)
|
app.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
-
from
|
| 3 |
import time
|
| 4 |
|
| 5 |
from htmlTemplates import css, bot_template, user_template, source_template
|
| 6 |
|
| 7 |
-
st.set_page_config(page_title="Chat with ATrad",page_icon=":currency_exchange:")
|
| 8 |
st.write(css, unsafe_allow_html=True)
|
| 9 |
|
| 10 |
def main():
|
|
@@ -17,8 +17,11 @@ def main():
|
|
| 17 |
4. Source documents will be displayed in the sidebar.
|
| 18 |
""")
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
# Button to connect to Google link ------------------------------------------------
|
| 21 |
-
|
| 22 |
st.sidebar.markdown('<a href="https://drive.google.com/drive/folders/13v6LsaYH9wEwvqVtlLG1U4OiUHgZ7hY4?usp=sharing" target="_blank" style="display: inline-block;'
|
| 23 |
'background-color: #475063; color: white; padding: 10px 20px; text-align: center;border: 1px solid white;'
|
| 24 |
'text-decoration: none; cursor: pointer; border-radius: 5px;">Sources</a>',
|
|
@@ -27,8 +30,8 @@ def main():
|
|
| 27 |
st.title("ATrad Chat App")
|
| 28 |
|
| 29 |
# Chat area -----------------------------------------------------------------------
|
| 30 |
-
|
| 31 |
user_input = st.text_input("", key="user_input",placeholder="Type your question here...")
|
|
|
|
| 32 |
# JavaScript code to submit the form on Enter key press
|
| 33 |
js_submit = f"""
|
| 34 |
document.addEventListener("keydown", function(event) {{
|
|
@@ -38,31 +41,32 @@ def main():
|
|
| 38 |
}});
|
| 39 |
"""
|
| 40 |
st.markdown(f'<script>{js_submit}</script>', unsafe_allow_html=True)
|
|
|
|
| 41 |
if st.button("Send"):
|
| 42 |
if user_input:
|
| 43 |
-
|
| 44 |
with st.spinner('Waiting for response...'):
|
| 45 |
-
|
| 46 |
# Add bot response here (you can replace this with your bot logic)
|
| 47 |
-
response, metadata, source_documents = generate_bot_response(user_input)
|
| 48 |
-
st.write(user_template.replace(
|
| 49 |
-
"{{MSG}}",
|
| 50 |
-
st.write(bot_template.replace(
|
| 51 |
-
"{{MSG}}", response ), unsafe_allow_html=True)
|
| 52 |
|
| 53 |
# Source documents
|
| 54 |
-
print("metadata", metadata)
|
| 55 |
st.sidebar.title("Source Documents")
|
| 56 |
for i, doc in enumerate(source_documents, 1):
|
| 57 |
-
tit=metadata[i-1]["source"].split("\\")[-1]
|
| 58 |
with st.sidebar.expander(f"{tit}"):
|
| 59 |
st.write(doc) # Assuming the Document object can be directly written to display its content
|
| 60 |
|
| 61 |
-
def generate_bot_response(user_input):
|
| 62 |
# Simple bot logic (replace with your actual bot logic)
|
| 63 |
start_time = time.time()
|
| 64 |
print(f"User Input: {user_input}")
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
response = res['result']
|
| 67 |
metadata = [i.metadata for i in res.get("source_documents", [])]
|
| 68 |
end_time = time.time()
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
from Retrieval import bsic_chain, multiQuery_chain
|
| 3 |
import time
|
| 4 |
|
| 5 |
from htmlTemplates import css, bot_template, user_template, source_template
|
| 6 |
|
| 7 |
+
st.set_page_config(page_title="Chat with ATrad", page_icon=":currency_exchange:")
|
| 8 |
st.write(css, unsafe_allow_html=True)
|
| 9 |
|
| 10 |
def main():
|
|
|
|
| 17 |
4. Source documents will be displayed in the sidebar.
|
| 18 |
""")
|
| 19 |
|
| 20 |
+
# Dropdown to select model --------------------------------------------------------
|
| 21 |
+
model_selection = st.sidebar.selectbox("Select Model", ["Basic", "MultiQuery"])
|
| 22 |
+
print(model_selection)
|
| 23 |
+
|
| 24 |
# Button to connect to Google link ------------------------------------------------
|
|
|
|
| 25 |
st.sidebar.markdown('<a href="https://drive.google.com/drive/folders/13v6LsaYH9wEwvqVtlLG1U4OiUHgZ7hY4?usp=sharing" target="_blank" style="display: inline-block;'
|
| 26 |
'background-color: #475063; color: white; padding: 10px 20px; text-align: center;border: 1px solid white;'
|
| 27 |
'text-decoration: none; cursor: pointer; border-radius: 5px;">Sources</a>',
|
|
|
|
| 30 |
st.title("ATrad Chat App")
|
| 31 |
|
| 32 |
# Chat area -----------------------------------------------------------------------
|
|
|
|
| 33 |
user_input = st.text_input("", key="user_input",placeholder="Type your question here...")
|
| 34 |
+
|
| 35 |
# JavaScript code to submit the form on Enter key press
|
| 36 |
js_submit = f"""
|
| 37 |
document.addEventListener("keydown", function(event) {{
|
|
|
|
| 41 |
}});
|
| 42 |
"""
|
| 43 |
st.markdown(f'<script>{js_submit}</script>', unsafe_allow_html=True)
|
| 44 |
+
|
| 45 |
if st.button("Send"):
|
| 46 |
if user_input:
|
|
|
|
| 47 |
with st.spinner('Waiting for response...'):
|
|
|
|
| 48 |
# Add bot response here (you can replace this with your bot logic)
|
| 49 |
+
response, metadata, source_documents = generate_bot_response(user_input, model_selection)
|
| 50 |
+
st.write(user_template.replace("{{MSG}}", user_input), unsafe_allow_html=True)
|
| 51 |
+
st.write(bot_template.replace("{{MSG}}", response), unsafe_allow_html=True)
|
|
|
|
|
|
|
| 52 |
|
| 53 |
# Source documents
|
|
|
|
| 54 |
st.sidebar.title("Source Documents")
|
| 55 |
for i, doc in enumerate(source_documents, 1):
|
| 56 |
+
tit = metadata[i-1]["source"].split("\\")[-1]
|
| 57 |
with st.sidebar.expander(f"{tit}"):
|
| 58 |
st.write(doc) # Assuming the Document object can be directly written to display its content
|
| 59 |
|
| 60 |
+
def generate_bot_response(user_input, model):
|
| 61 |
# Simple bot logic (replace with your actual bot logic)
|
| 62 |
start_time = time.time()
|
| 63 |
print(f"User Input: {user_input}")
|
| 64 |
+
|
| 65 |
+
if model == "Basic":
|
| 66 |
+
res = bsic_chain(user_input)
|
| 67 |
+
elif model == "MultiQuery":
|
| 68 |
+
res = multiQuery_chain(user_input)
|
| 69 |
+
|
| 70 |
response = res['result']
|
| 71 |
metadata = [i.metadata for i in res.get("source_documents", [])]
|
| 72 |
end_time = time.time()
|
model.py
CHANGED
|
@@ -68,15 +68,5 @@ from langchain.prompts import PromptTemplate
|
|
| 68 |
QA_PROMPT = PromptTemplate(input_variables=["context", "question"],template=qa_template_V2,)
|
| 69 |
|
| 70 |
|
| 71 |
-
|
| 72 |
-
from langchain.chains import RetrievalQA
|
| 73 |
-
|
| 74 |
-
Web_qa = RetrievalQA.from_chain_type(
|
| 75 |
-
llm=llm,
|
| 76 |
-
chain_type="stuff",
|
| 77 |
-
retriever = vectorstore.as_retriever(search_kwargs={"k": 4}),
|
| 78 |
-
return_source_documents= True,
|
| 79 |
-
input_key="question",
|
| 80 |
-
chain_type_kwargs={"prompt": QA_PROMPT},
|
| 81 |
-
)
|
| 82 |
|
|
|
|
| 68 |
QA_PROMPT = PromptTemplate(input_variables=["context", "question"],template=qa_template_V2,)
|
| 69 |
|
| 70 |
|
| 71 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|