Spaces:
Sleeping
Sleeping
| from langchain_community.vectorstores import FAISS, Qdrant | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.chains import ConversationalRetrievalChain, RetrievalQA | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain_community.llms import HuggingFaceEndpoint | |
| from langchain.prompts import PromptTemplate | |
| from langchain.retrievers import ContextualCompressionRetriever | |
| from langchain.retrievers.document_compressors import FlashrankRerank | |
| from langchain_community.embeddings.fastembed import FastEmbedEmbeddings | |
| from langchain_community.document_loaders import UnstructuredMarkdownLoader | |
| from llama_parse import LlamaParse | |
| from langchain_groq import ChatGroq | |
| from dotenv import load_dotenv | |
| import os | |
| import streamlit as st | |
| import PyPDF2 | |
| import tempfile | |
| import markdown | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Environment variables for API keys | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
| LLAMA_CLOUD_API_KEY = os.getenv("LLAMA_CLOUD_API_KEY") | |
| HF_API_TOKEN = os.getenv("HF_AUTH_TOKEN") | |
| st.set_page_config(page_title="💬 QA Chatbot") | |
| # | |
| def read_pdf(uploaded_file): | |
| pdf_reader = PyPDF2.PdfReader(uploaded_file) | |
| text = "" | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() | |
| return text | |
| def split_chunks(docs): | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=2048, | |
| chunk_overlap=128, | |
| ) | |
| return text_splitter.split_text(docs) | |
| def create_db(splits): | |
| #embeddings_model = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2') | |
| embeddings_model = FastEmbedEmbeddings(model_name="BAAI/bge-base-en-v1.5") | |
| #vectordb = FAISS.from_documents(splits, embeddings_model) | |
| #st.write(vectordb) | |
| vectordb = Qdrant.from_documents( | |
| docs, | |
| embeddings_model, | |
| location=":memory:", | |
| #path="./db", | |
| collection_name="document_embeddings", | |
| ) | |
| return vectordb | |
| def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db): | |
| #st.write(selected_model) | |
| #st.write(max_tokens) | |
| if selected_model == "Llama-3-70B": | |
| llm = ChatGroq( | |
| model_name=llm_model, | |
| temperature=temperature, | |
| #max_tokens=max_tokens, | |
| #top_k=top_k | |
| ) | |
| else: | |
| llm = HuggingFaceEndpoint( | |
| repo_id=llm_model, | |
| huggingfacehub_api_token=HF_API_TOKEN, | |
| temperature=temperature, | |
| max_new_tokens=max_tokens, | |
| top_k=top_k, | |
| ) | |
| retriever = vector_db.as_retriever(search_kwargs={"k": 3}) | |
| compressor = FlashrankRerank(model="ms-marco-MiniLM-L-12-v2") | |
| compression_retriever = ContextualCompressionRetriever( | |
| base_compressor=compressor, base_retriever=retriever | |
| ) | |
| qachain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=compression_retriever, | |
| return_source_documents=True, | |
| chain_type_kwargs={"prompt": prompt, "verbose": False}, | |
| ) | |
| return qachain | |
| def generate_llm_response(prompt_input): | |
| #llm_model = ChatGroq(model_name="llama3-70b-8192", temperature=temperature, ) | |
| qa_chain = initialize_llmchain(llm_model, temperature, max_tokens, top_k, st.session_state.vector_db) | |
| #st.write(qa_chain) | |
| st.write(prompt_input) | |
| response = qa_chain({"query": prompt_input, "context": "", "question": prompt_input}) | |
| #st.write(response) | |
| return response | |
| def clear_chat_history(): | |
| st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}] | |
| instruction = """The provided document is Meta First Quarter 2024 Results. | |
| This form provides detailed financial information about the company's performance for a specific quarter. | |
| It includes unaudited financial statements, management discussion and analysis, and other relevant disclosures required by the SEC. | |
| It contains many tables. | |
| Try to be precise while answering the questions""" | |
| parser = LlamaParse( | |
| api_key=LLAMA_CLOUD_API_KEY, | |
| result_type="markdown", | |
| parsing_instruction=instruction, | |
| max_timeout=5000, | |
| ) | |
| # Store LLM generated responses | |
| if "messages" not in st.session_state.keys(): | |
| st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}] | |
| if "vector_db" not in st.session_state: | |
| st.session_state.vector_db = None | |
| if "uploaded_file" not in st.session_state: | |
| st.session_state.uploaded_file = None | |
| prompt_template = """ | |
| Use the following pieces of information to answer the user's question. | |
| If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
| Context: {context} | |
| Question: {question} | |
| Answer the question and provide additional helpful information, | |
| based on the pieces of information, if applicable. Be succinct. | |
| Responses should be properly formatted to be easily read. | |
| """ | |
| prompt = PromptTemplate( | |
| template=prompt_template, input_variables=["context", "question"] | |
| ) | |
| # Main app layout | |
| st.title("QA Chatbot with Custom PDF") | |
| st.markdown("---") | |
| # Sidebar for model selection and parameters | |
| with st.sidebar: | |
| uploaded_file = st.sidebar.file_uploader("Upload a PDF", type="pdf") | |
| # Check if a new file has been uploaded or the file has been removed | |
| if uploaded_file != st.session_state.uploaded_file: | |
| st.session_state.uploaded_file = uploaded_file | |
| clear_chat_history() | |
| st.session_state.vector_db = None | |
| if uploaded_file is not None and st.session_state.vector_db is None: | |
| with st.spinner("Converting to Vectors..."): | |
| #text = read_pdf(uploaded_file) | |
| #chunks = split_chunks(docs=text) | |
| #st.write(len(chunks)) | |
| #st.session_state.vector_db = create_db(chunks) | |
| #file_data = uploaded_file.getvalue() | |
| #st.write(type(file_data)) | |
| #temp_dir = tempfile.mkdtemp() | |
| #documents = LlamaParse(result_type="markdown").load_data(uploaded_file.name) | |
| #st.write(documents[0].text[:1000]) | |
| temp_dir = tempfile.mkdtemp() | |
| file_path = os.path.join(temp_dir, uploaded_file.name) | |
| with open(file_path, "wb") as f: | |
| f.write(uploaded_file.getvalue()) | |
| documents = LlamaParse(result_type="markdown").load_data(file_path) | |
| document_path = os.path.join(temp_dir,"parsed_document.md") | |
| with open(document_path, "w", encoding="utf-8") as f: # Ensuring UTF-8 encoding | |
| f.write(str(documents)) | |
| loader = UnstructuredMarkdownLoader(document_path) | |
| loaded_documents = loader.load() | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=128) | |
| docs = text_splitter.split_documents(loaded_documents) | |
| #st.write(len(docs)) | |
| st.session_state.vector_db = create_db(loaded_documents) | |
| #st.sidebar.write("PDF processed and vector database created!") | |
| st.sidebar.markdown('<p style="color:green;">PDF processed and vector database created!</p>', unsafe_allow_html=True) | |
| st.sidebar.title("Model Settings") | |
| #selected_model = st.sidebar.selectbox('Choose a LLM model', ['Llama-3-8B', 'Mistral-7B'], key='selected_model') | |
| #llm_model = 'meta-llama/Meta-Llama-3-8B-Instruct' if selected_model == 'Llama-3-8B' else 'mistralai/Mistral-7B-Instruct-v0.2' | |
| selected_model = st.sidebar.selectbox('Choose a LLM model', ['Llama-3-70B', 'Llama-3-8B', 'Mistral-7B'], key='selected_model') | |
| llm_model = 'meta-llama/Meta-Llama-3-8B-Instruct' if selected_model == 'Llama-3-8B' else 'mistralai/Mistral-7B-Instruct-v0.2' if selected_model == 'Mistral-7B' else 'llama3-70b-8192' | |
| #st.write(selected_model) | |
| temperature = st.sidebar.slider('Temperature', 0.0, 1.0, 0.1) | |
| top_k = st.sidebar.slider('Top_k', 1, 10, 3) | |
| max_tokens = st.sidebar.slider('Max Tokens', 1, 512, 256) | |
| show_clear_button = len(st.session_state.messages) > 1 | |
| if st.session_state.vector_db is not None: | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.write(message["content"]) | |
| user_input = st.chat_input("Ask a question here") | |
| if user_input: | |
| st.session_state.messages.append({"role": "user", "content": user_input}) | |
| with st.chat_message("user"): | |
| st.write(user_input) | |
| # Generate a new response if last message is not from assistant | |
| if st.session_state.messages[-1]["role"] != "assistant": | |
| with st.chat_message("assistant"): | |
| with st.spinner("Thinking..."): | |
| response = generate_llm_response(user_input) | |
| #st.write(response) | |
| placeholder = st.empty() | |
| full_response = '' | |
| for item in response["result"]: | |
| full_response += item | |
| #placeholder.markdown(full_response) | |
| placeholder.markdown(full_response) | |
| message = {"role": "assistant", "content": full_response} | |
| #st.write("-------") | |
| #st.write(message) | |
| st.session_state.messages.append(message) | |
| show_clear_button = len(st.session_state.messages) > 1 | |
| else: | |
| st.write("Please upload a PDF file to initialize the database.") | |
| if show_clear_button and st.button('Clear Chat History'): | |
| clear_chat_history() | |
| st.rerun() # This line ensures the page reruns to reflect the changes | |