swe-faq-chatbot / app.py
zk1tty's picture
fork chatpdf
b75187f
import os
import tempfile
import streamlit as st
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import PyPDFLoader
from langchain.memory import ConversationBufferMemory
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import ConversationalRetrievalChain
from langchain.vectorstores import DocArrayInMemorySearch
from langchain.text_splitter import RecursiveCharacterTextSplitter
st.set_page_config(page_title="GA: Chat with Documents", page_icon="")
st.title(" GAI Hackathon : Chat with Documents")
@st.cache_resource(ttl="1h")
def configure_retriever(uploaded_files):
# Read documents
docs = []
temp_dir = tempfile.TemporaryDirectory()
for file in uploaded_files:
temp_filepath = os.path.join(temp_dir.name, file.name)
with open(temp_filepath, "wb") as f:
f.write(file.getvalue())
loader = PyPDFLoader(temp_filepath)
docs.extend(loader.load())
# Split documents
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200)
splits = text_splitter.split_documents(docs)
# Create embeddings and store in vectordb
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vectordb = DocArrayInMemorySearch.from_documents(splits, embeddings)
# Define retriever
retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={"k": 5, "fetch_k": 10})
return retriever
class StreamHandler(BaseCallbackHandler):
def __init__(self, container: st.delta_generator.DeltaGenerator, initial_text: str = ""):
self.container = container
self.text = initial_text
def on_llm_new_token(self, token: str, **kwargs) -> None:
self.text += token
self.container.markdown(self.text)
class PrintRetrievalHandler(BaseCallbackHandler):
def __init__(self, container):
self.container = container.expander("Context Retrieval")
def on_retriever_start(self, query: str, **kwargs):
self.container.write(f"**Question:** {query}")
def on_retriever_end(self, documents, **kwargs):
# self.container.write(documents)
for idx, doc in enumerate(documents):
source = os.path.basename(doc.metadata["source"])
self.container.write(f"**Document {idx} from {source}**")
self.container.markdown(doc.page_content)
#openai_api_key = st.sidebar.text_input("OpenAI API Key", type="password")
openai_api_key = os.getenv('OPENAI_KEY')
if not openai_api_key:
st.info("Please add your OpenAI API key to continue.")
st.stop()
uploaded_files = st.sidebar.file_uploader(
label="Upload PDF files", type=["pdf"], accept_multiple_files=True
)
if not uploaded_files:
st.info("Please upload PDF documents to continue.")
st.stop()
retriever = configure_retriever(uploaded_files)
# Setup memory for contextual conversation
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
# Setup LLM and QA chain
llm = ChatOpenAI(
model_name="gpt-3.5-turbo", openai_api_key=openai_api_key, temperature=0, streaming=True
)
qa_chain = ConversationalRetrievalChain.from_llm(
llm, retriever=retriever, memory=memory, verbose=True
)
if "messages" not in st.session_state or st.sidebar.button("Clear message history"):
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
for msg in st.session_state.messages:
st.chat_message(msg["role"]).write(msg["content"])
user_query = st.chat_input(placeholder="Ask me anything!")
if user_query:
st.session_state.messages.append({"role": "user", "content": user_query})
st.chat_message("user").write(user_query)
with st.chat_message("assistant"):
retrieval_handler = PrintRetrievalHandler(st.container())
stream_handler = StreamHandler(st.empty())
response = qa_chain.run(user_query, callbacks=[retrieval_handler, stream_handler])
st.session_state.messages.append({"role": "assistant", "content": response})