|
|
|
|
|
from langchain_community.vectorstores import FAISS, Qdrant |
|
|
from langchain_community.document_loaders import PyPDFLoader |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain.chains import ConversationalRetrievalChain, RetrievalQA |
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
|
from langchain.memory import ConversationBufferMemory |
|
|
from langchain_community.llms import HuggingFaceEndpoint |
|
|
from langchain.prompts import PromptTemplate |
|
|
from langchain.retrievers import ContextualCompressionRetriever |
|
|
from langchain.retrievers.document_compressors import FlashrankRerank |
|
|
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings |
|
|
from langchain_community.document_loaders import UnstructuredMarkdownLoader |
|
|
from llama_parse import LlamaParse |
|
|
from langchain_groq import ChatGroq |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
import os |
|
|
import streamlit as st |
|
|
import PyPDF2 |
|
|
import tempfile |
|
|
import markdown |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
GROQ_API_KEY = os.getenv("GROQ_API_KEY") |
|
|
LLAMA_CLOUD_API_KEY = os.getenv("LLAMA_CLOUD_API_KEY") |
|
|
HF_API_TOKEN = os.getenv("HF_AUTH_TOKEN") |
|
|
|
|
|
st.set_page_config(page_title="💬 QA Chatbot") |
|
|
|
|
|
|
|
|
def read_pdf(uploaded_file): |
|
|
pdf_reader = PyPDF2.PdfReader(uploaded_file) |
|
|
text = "" |
|
|
for page in pdf_reader.pages: |
|
|
text += page.extract_text() |
|
|
return text |
|
|
|
|
|
def split_chunks(docs): |
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
|
chunk_size=2048, |
|
|
chunk_overlap=128, |
|
|
) |
|
|
return text_splitter.split_text(docs) |
|
|
|
|
|
def create_db(splits): |
|
|
|
|
|
embeddings_model = FastEmbedEmbeddings(model_name="BAAI/bge-base-en-v1.5") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vectordb = Qdrant.from_documents( |
|
|
docs, |
|
|
embeddings_model, |
|
|
location=":memory:", |
|
|
|
|
|
collection_name="document_embeddings", |
|
|
) |
|
|
|
|
|
return vectordb |
|
|
|
|
|
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db): |
|
|
|
|
|
|
|
|
if selected_model == "Llama-3-70B": |
|
|
llm = ChatGroq( |
|
|
model_name=llm_model, |
|
|
temperature=temperature, |
|
|
|
|
|
|
|
|
) |
|
|
else: |
|
|
llm = HuggingFaceEndpoint( |
|
|
repo_id=llm_model, |
|
|
huggingfacehub_api_token=HF_API_TOKEN, |
|
|
temperature=temperature, |
|
|
max_new_tokens=max_tokens, |
|
|
top_k=top_k, |
|
|
) |
|
|
|
|
|
retriever = vector_db.as_retriever(search_kwargs={"k": 3}) |
|
|
|
|
|
compressor = FlashrankRerank(model="ms-marco-MiniLM-L-12-v2") |
|
|
|
|
|
compression_retriever = ContextualCompressionRetriever( |
|
|
base_compressor=compressor, base_retriever=retriever |
|
|
) |
|
|
|
|
|
qachain = RetrievalQA.from_chain_type( |
|
|
llm=llm, |
|
|
chain_type="stuff", |
|
|
retriever=compression_retriever, |
|
|
return_source_documents=True, |
|
|
chain_type_kwargs={"prompt": prompt, "verbose": False}, |
|
|
) |
|
|
|
|
|
return qachain |
|
|
|
|
|
def generate_llm_response(prompt_input): |
|
|
|
|
|
|
|
|
|
|
|
qa_chain = initialize_llmchain(llm_model, temperature, max_tokens, top_k, st.session_state.vector_db) |
|
|
|
|
|
|
|
|
st.write(prompt_input) |
|
|
|
|
|
response = qa_chain({"query": prompt_input, "context": "", "question": prompt_input}) |
|
|
|
|
|
|
|
|
return response |
|
|
|
|
|
def clear_chat_history(): |
|
|
st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}] |
|
|
|
|
|
|
|
|
instruction = """The provided document is Meta First Quarter 2024 Results. |
|
|
This form provides detailed financial information about the company's performance for a specific quarter. |
|
|
It includes unaudited financial statements, management discussion and analysis, and other relevant disclosures required by the SEC. |
|
|
It contains many tables. |
|
|
Try to be precise while answering the questions""" |
|
|
|
|
|
parser = LlamaParse( |
|
|
api_key=LLAMA_CLOUD_API_KEY, |
|
|
result_type="markdown", |
|
|
parsing_instruction=instruction, |
|
|
max_timeout=5000, |
|
|
) |
|
|
|
|
|
|
|
|
if "messages" not in st.session_state.keys(): |
|
|
st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}] |
|
|
|
|
|
if "vector_db" not in st.session_state: |
|
|
st.session_state.vector_db = None |
|
|
|
|
|
if "uploaded_file" not in st.session_state: |
|
|
st.session_state.uploaded_file = None |
|
|
|
|
|
prompt_template = """ |
|
|
Use the following pieces of information to answer the user's question. |
|
|
If you don't know the answer, just say that you don't know, don't try to make up an answer. |
|
|
|
|
|
Context: {context} |
|
|
Question: {question} |
|
|
|
|
|
Answer the question and provide additional helpful information, |
|
|
based on the pieces of information, if applicable. Be succinct. |
|
|
|
|
|
Responses should be properly formatted to be easily read. |
|
|
""" |
|
|
|
|
|
prompt = PromptTemplate( |
|
|
template=prompt_template, input_variables=["context", "question"] |
|
|
) |
|
|
|
|
|
|
|
|
st.title("QA Chatbot with Custom PDF") |
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
|
|
|
uploaded_file = st.sidebar.file_uploader("Upload a PDF", type="pdf") |
|
|
|
|
|
|
|
|
if uploaded_file != st.session_state.uploaded_file: |
|
|
st.session_state.uploaded_file = uploaded_file |
|
|
clear_chat_history() |
|
|
st.session_state.vector_db = None |
|
|
|
|
|
if uploaded_file is not None and st.session_state.vector_db is None: |
|
|
with st.spinner("Converting to Vectors..."): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
|
file_path = os.path.join(temp_dir, uploaded_file.name) |
|
|
with open(file_path, "wb") as f: |
|
|
f.write(uploaded_file.getvalue()) |
|
|
|
|
|
documents = LlamaParse(result_type="markdown").load_data(file_path) |
|
|
|
|
|
document_path = os.path.join(temp_dir,"parsed_document.md") |
|
|
with open(document_path, "w", encoding="utf-8") as f: |
|
|
f.write(str(documents)) |
|
|
|
|
|
loader = UnstructuredMarkdownLoader(document_path) |
|
|
loaded_documents = loader.load() |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=128) |
|
|
docs = text_splitter.split_documents(loaded_documents) |
|
|
|
|
|
|
|
|
|
|
|
st.session_state.vector_db = create_db(loaded_documents) |
|
|
|
|
|
|
|
|
st.sidebar.markdown('<p style="color:green;">PDF processed and vector database created!</p>', unsafe_allow_html=True) |
|
|
|
|
|
st.sidebar.title("Model Settings") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
selected_model = st.sidebar.selectbox('Choose a LLM model', ['Llama-3-70B', 'Llama-3-8B', 'Mistral-7B'], key='selected_model') |
|
|
|
|
|
llm_model = 'meta-llama/Meta-Llama-3-8B-Instruct' if selected_model == 'Llama-3-8B' else 'mistralai/Mistral-7B-Instruct-v0.2' if selected_model == 'Mistral-7B' else 'llama3-70b-8192' |
|
|
|
|
|
|
|
|
temperature = st.sidebar.slider('Temperature', 0.0, 1.0, 0.1) |
|
|
top_k = st.sidebar.slider('Top_k', 1, 10, 3) |
|
|
max_tokens = st.sidebar.slider('Max Tokens', 1, 512, 256) |
|
|
|
|
|
show_clear_button = len(st.session_state.messages) > 1 |
|
|
|
|
|
if st.session_state.vector_db is not None: |
|
|
for message in st.session_state.messages: |
|
|
with st.chat_message(message["role"]): |
|
|
st.write(message["content"]) |
|
|
|
|
|
user_input = st.chat_input("Ask a question here") |
|
|
if user_input: |
|
|
st.session_state.messages.append({"role": "user", "content": user_input}) |
|
|
with st.chat_message("user"): |
|
|
st.write(user_input) |
|
|
|
|
|
|
|
|
if st.session_state.messages[-1]["role"] != "assistant": |
|
|
with st.chat_message("assistant"): |
|
|
with st.spinner("Thinking..."): |
|
|
response = generate_llm_response(user_input) |
|
|
|
|
|
placeholder = st.empty() |
|
|
full_response = '' |
|
|
for item in response["result"]: |
|
|
full_response += item |
|
|
|
|
|
placeholder.markdown(full_response) |
|
|
message = {"role": "assistant", "content": full_response} |
|
|
|
|
|
|
|
|
st.session_state.messages.append(message) |
|
|
|
|
|
show_clear_button = len(st.session_state.messages) > 1 |
|
|
|
|
|
else: |
|
|
st.write("Please upload a PDF file to initialize the database.") |
|
|
|
|
|
if show_clear_button and st.button('Clear Chat History'): |
|
|
clear_chat_history() |
|
|
st.rerun() |
|
|
|
|
|
|
|
|
|