Spaces:
Sleeping
Sleeping
File size: 6,631 Bytes
71e61b7 3842010 71e61b7 b5d7000 71e61b7 b5d7000 3842010 1dbbef2 3842010 71e61b7 b5d7000 71e61b7 3842010 b5d7000 71e61b7 b5d7000 71e61b7 b5d7000 71e61b7 b5d7000 71e61b7 3842010 71e61b7 b5d7000 71e61b7 b5d7000 71e61b7 b5d7000 71e61b7 b5d7000 71e61b7 b5d7000 71e61b7 b5d7000 71e61b7 b5d7000 71e61b7 b5d7000 71e61b7 fc5c6b1 3842010 71e61b7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | import streamlit as st
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFaceEndpoint
from langchain.memory import ConversationBufferMemory
from pathlib import Path
import chromadb
from unidecode import unidecode
import re
import os
# List of available LLMs
list_llm = [
"mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1",
"google/gemma-7b-it", "google/gemma-2b-it", "HuggingFaceH4/zephyr-7b-beta", "HuggingFaceH4/zephyr-7b-gemma-v0.1",
"meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2", "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct",
"tiiuae/falcon-7b-instruct", "google/flan-t5-xxl"
]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
# Load and split PDF document
def load_doc(file_paths, chunk_size, chunk_overlap):
loaders = [PyPDFLoader(fp) for fp in file_paths]
pages = [page for loader in loaders for page in loader.load()]
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
return text_splitter.split_documents(pages)
# Create vector database
def create_db(docs, collection_name):
embedding = HuggingFaceEmbeddings()
client = chromadb.EphemeralClient()
return Chroma.from_documents(documents=docs, embedding=embedding, client=client, collection_name=collection_name)
# Initialize LLM chain
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
if llm_model in ["mistralai/Mixtral-8x7B-Instruct-v0.1", "HuggingFaceH4/zephyr-7b-gemma-v0.1", "mosaicml/mpt-7b-instruct"]:
raise ValueError("LLM model is too large to be loaded automatically on free inference endpoint")
model_kwargs = {"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
llm = HuggingFaceEndpoint(repo_id=llm_model, **model_kwargs)
memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
retriever = vector_db.as_retriever()
return ConversationalRetrievalChain.from_llm(llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True, verbose=False)
# Generate collection name for vector database
def create_collection_name(filepath):
collection_name = Path(filepath).stem
collection_name = unidecode(collection_name)
collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
collection_name = collection_name[:50]
if len(collection_name) < 3:
collection_name = collection_name + 'xyz'
if not collection_name[0].isalnum():
collection_name = 'A' + collection_name[1:]
if not collection_name[-1].isalnum():
collection_name = collection_name[:-1] + 'Z'
return collection_name
# Initialize database
def initialize_database(files, chunk_size, chunk_overlap):
file_paths = [file.name for file in files]
collection_name = create_collection_name(file_paths[0])
doc_splits = load_doc(file_paths, chunk_size, chunk_overlap)
vector_db = create_db(doc_splits, collection_name)
return vector_db, collection_name, "Complete!"
# Initialize LLM
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db):
llm_name = list_llm[llm_option]
qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db)
return qa_chain, "Complete!"
# Format chat history
def format_chat_history(message, chat_history):
return [f"User: {user_message}\nAssistant: {bot_message}" for user_message, bot_message in chat_history]
# Handle conversation
def conversation(qa_chain, message, history):
formatted_chat_history = format_chat_history(message, history)
response = qa_chain({"question": message, "chat_history": formatted_chat_history})
response_answer = response["answer"].split("Helpful Answer:")[-1] if "Helpful Answer:" in response["answer"] else response["answer"]
response_sources = response["source_documents"]
response_source1 = response_sources[0].page_content.strip()
response_source2 = response_sources[1].page_content.strip()
response_source3 = response_sources[2].page_content.strip()
response_source1_page = response_sources[0].metadata["page"] + 1
response_source2_page = response_sources[1].metadata["page"] + 1
response_source3_page = response_sources[2].metadata["page"] + 1
new_history = history + [(message, response_answer)]
return qa_chain, "", new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
# Streamlit app
def main():
st.title("PDF-based Chatbot")
st.write("Upload your PDF documents and interact with the chatbot to get insights from your PDFs.")
uploaded_files = st.file_uploader("Upload PDF files", type="pdf", accept_multiple_files=True)
if uploaded_files:
chunk_size = st.slider("Chunk Size", 100, 1000, 600)
chunk_overlap = st.slider("Chunk Overlap", 10, 200, 40)
vector_db, collection_name, db_status = initialize_database(uploaded_files, chunk_size, chunk_overlap)
st.write(f"Vector Database Initialized: {db_status}")
llm_option = st.selectbox("Select LLM Model", options=list_llm_simple)
llm_temperature = st.slider("Temperature", 0.01, 1.0, 0.7)
max_tokens = st.slider("Max Tokens", 224, 4096, 1024)
top_k = st.slider("Top-K Samples", 1, 10, 3)
qa_chain, llm_status = initialize_LLM(list_llm_simple.index(llm_option), llm_temperature, max_tokens, top_k, vector_db)
st.write(f"QA Chain Initialized: {llm_status}")
st.write("Chat with the bot:")
chat_history = []
user_message = st.text_input("Your Message:")
if st.button("Submit"):
if user_message:
qa_chain, _, chat_history, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page = conversation(qa_chain, user_message, chat_history)
st.write(f"**Bot's Response:** {chat_history[-1][1]}")
st.write(f"**Reference 1:** {doc_source1} (Page {source1_page})")
st.write(f"**Reference 2:** {doc_source2} (Page {source2_page})")
st.write(f"**Reference 3:** {doc_source3} (Page {source3_page})")
if __name__ == "__main__":
main()
|