Spaces:
Build error
Build error
Syed Junaid Iqbal commited on
Commit ·
030d46c
1
Parent(s): a7ce0dd
Upload 5 files
Browse files- app.py +61 -0
- bm25 +0 -0
- retriever.py +38 -0
- streaming.py +11 -0
- utils.py +39 -0
app.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import streamlit as st
|
| 3 |
+
from streaming import StreamHandler
|
| 4 |
+
import utils
|
| 5 |
+
from langchain.callbacks.manager import CallbackManager
|
| 6 |
+
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
| 7 |
+
from retriever import retriever
|
| 8 |
+
from langchain.chains import RetrievalQA
|
| 9 |
+
from langchain.llms import LlamaCpp
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
|
| 12 |
+
class CustomDataChatbot:
|
| 13 |
+
def __init__(self):
|
| 14 |
+
# Initialize session state variables, including messages
|
| 15 |
+
st.session_state.messages = []
|
| 16 |
+
|
| 17 |
+
@st.spinner('Analyzing documents..')
|
| 18 |
+
def setup_qa_chain(self):
|
| 19 |
+
# Setup memory for contextual conversation
|
| 20 |
+
# memory = ConversationBufferMemory(
|
| 21 |
+
# memory_key='chat_history',
|
| 22 |
+
# return_messages=True
|
| 23 |
+
# )
|
| 24 |
+
|
| 25 |
+
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
| 26 |
+
# Setup LLM and QA chain
|
| 27 |
+
llm = LlamaCpp(model_path="./models/openhermes-2.5-neural-chat-7b-v3-1-7b.Q5_K_M.gguf",
|
| 28 |
+
temperature=0.34,
|
| 29 |
+
max_tokens=4000,
|
| 30 |
+
n_ctx=4096,
|
| 31 |
+
top_p=1,
|
| 32 |
+
callback_manager=callback_manager,
|
| 33 |
+
verbose=True)
|
| 34 |
+
|
| 35 |
+
# qa_chain = ConversationalRetrievalChain.from_llm(llm, retriever=retriever(), memory=memory, verbose=True)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
return RetrievalQA.from_chain_type( llm, retriever= retriever())
|
| 39 |
+
|
| 40 |
+
@utils.enable_chat_history
|
| 41 |
+
def main(self):
|
| 42 |
+
load_dotenv()
|
| 43 |
+
st.set_page_config(page_title="ChatPDF", page_icon="📄")
|
| 44 |
+
st.header('Chat with your documents')
|
| 45 |
+
st.write('Has access to custom documents and can respond to user queries by referring to the content within those documents')
|
| 46 |
+
st.write('[](https://github.com/shashankdeshpande/langchain-chatbot/blob/master/pages/4_%F0%9F%93%84_chat_with_your_documents.py)')
|
| 47 |
+
|
| 48 |
+
user_query = st.chat_input(placeholder="Ask me anything!")
|
| 49 |
+
|
| 50 |
+
if user_query:
|
| 51 |
+
qa_chain = self.setup_qa_chain()
|
| 52 |
+
utils.display_msg(user_query, 'user')
|
| 53 |
+
|
| 54 |
+
with st.chat_message("assistant"):
|
| 55 |
+
st_cb = StreamHandler(st.empty())
|
| 56 |
+
response = qa_chain.run(user_query, callbacks=[st_cb])
|
| 57 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
| 58 |
+
|
| 59 |
+
if __name__ == "__main__":
|
| 60 |
+
obj = CustomDataChatbot()
|
| 61 |
+
obj.main()
|
bm25
ADDED
|
Binary file (184 kB). View file
|
|
|
retriever.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
from langchain.retrievers import EnsembleRetriever
|
| 3 |
+
from langchain.vectorstores import FAISS
|
| 4 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
| 5 |
+
|
| 6 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
| 7 |
+
from transformers import AutoModel
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def retriever():
|
| 12 |
+
|
| 13 |
+
# Embeddings
|
| 14 |
+
# Defign our Embedding Model
|
| 15 |
+
|
| 16 |
+
model_name = "jinaai/jina-embeddings-v2-base-en"
|
| 17 |
+
model_kwargs = {'device': 'cpu'}
|
| 18 |
+
encode_kwargs = {'normalize_embeddings': False, }
|
| 19 |
+
|
| 20 |
+
model = AutoModel.from_pretrained( model_name, trust_remote_code=True)
|
| 21 |
+
|
| 22 |
+
embeddings = HuggingFaceEmbeddings( model_name=model_name,
|
| 23 |
+
model_kwargs=model_kwargs,
|
| 24 |
+
encode_kwargs=encode_kwargs)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
#to read bm25 object
|
| 28 |
+
with open('./bm25', 'rb') as file:
|
| 29 |
+
bm25_retriever = pickle.load(file)
|
| 30 |
+
|
| 31 |
+
bm25_retriever.k = 2
|
| 32 |
+
|
| 33 |
+
# Load FAISS
|
| 34 |
+
faiss_vectorstore = FAISS.load_local("./Vector_DB/", embeddings)
|
| 35 |
+
faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": 1})
|
| 36 |
+
|
| 37 |
+
# initialize the ensemble retriever
|
| 38 |
+
return EnsembleRetriever( retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5] )
|
streaming.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain.callbacks.base import BaseCallbackHandler
|
| 2 |
+
|
| 3 |
+
class StreamHandler(BaseCallbackHandler):
|
| 4 |
+
|
| 5 |
+
def __init__(self, container, initial_text=""):
|
| 6 |
+
self.container = container
|
| 7 |
+
self.text = initial_text
|
| 8 |
+
|
| 9 |
+
def on_llm_new_token(self, token: str, **kwargs):
|
| 10 |
+
self.text += token
|
| 11 |
+
self.container.markdown(self.text)
|
utils.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import streamlit as st
|
| 4 |
+
|
| 5 |
+
#decorator
|
| 6 |
+
def enable_chat_history(func):
|
| 7 |
+
if os.environ.get("OPENAI_API_KEY"):
|
| 8 |
+
|
| 9 |
+
# to clear chat history after swtching chatbot
|
| 10 |
+
current_page = func.__qualname__
|
| 11 |
+
if "current_page" not in st.session_state:
|
| 12 |
+
st.session_state["current_page"] = current_page
|
| 13 |
+
if st.session_state["current_page"] != current_page:
|
| 14 |
+
try:
|
| 15 |
+
st.cache_resource.clear()
|
| 16 |
+
del st.session_state["current_page"]
|
| 17 |
+
del st.session_state["messages"]
|
| 18 |
+
except:
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
# to show chat history on ui
|
| 22 |
+
if "messages" not in st.session_state:
|
| 23 |
+
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
| 24 |
+
for msg in st.session_state["messages"]:
|
| 25 |
+
st.chat_message(msg["role"]).write(msg["content"])
|
| 26 |
+
|
| 27 |
+
def execute(*args, **kwargs):
|
| 28 |
+
func(*args, **kwargs)
|
| 29 |
+
return execute
|
| 30 |
+
|
| 31 |
+
def display_msg(msg, author):
|
| 32 |
+
"""Method to display message on the UI
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
msg (str): message to display
|
| 36 |
+
author (str): author of the message -user/assistant
|
| 37 |
+
"""
|
| 38 |
+
st.session_state.messages.append({"role": author, "content": msg})
|
| 39 |
+
st.chat_message(author).write(msg)
|