File size: 4,598 Bytes
244c875
b4deae8
 
59c9bb2
 
a3bb3ff
 
 
 
 
 
 
 
59c9bb2
a3bb3ff
59c9bb2
 
 
a3bb3ff
59c9bb2
 
 
 
 
244c875
a3bb3ff
 
 
244c875
a3bb3ff
b4deae8
823d1b2
244c875
b4deae8
 
244c875
b4deae8
 
 
 
 
 
244c875
b4deae8
 
 
 
 
 
a3bb3ff
 
 
b4deae8
59c9bb2
a3bb3ff
b4deae8
 
 
 
 
 
 
 
 
 
 
a3bb3ff
 
b4deae8
 
a3bb3ff
 
 
b4deae8
 
a3bb3ff
b4deae8
 
a3bb3ff
b4deae8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3bb3ff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import streamlit as st
import os
import time
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import Chroma
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationBufferMemory
from langchain.chains import RetrievalQA
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.manager import CallbackManager

# Model and Embedding Configuration
model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}

embeddings = HuggingFaceEmbeddings(
    model_name=model_name,
    model_kwargs=model_kwargs,
    encode_kwargs=encode_kwargs
)

# Directory setup
os.makedirs('files', exist_ok=True)
os.makedirs('jj', exist_ok=True)

# Streamlit session state setup
if 'template' not in st.session_state:
    st.session_state.template = """You are a knowledgeable chatbot, here to help with questions of the user. Your tone should be professional and informative.Try to give answer in tabular and shortcut.

    Context: {context}
    History: {history}

    User: {question}
    Chatbot:"""
if 'prompt' not in st.session_state:
    st.session_state.prompt = PromptTemplate(
        input_variables=["history", "context", "question"],
        template=st.session_state.template,
    )
if 'memory' not in st.session_state:
    st.session_state.memory = ConversationBufferMemory(
        memory_key="history",
        return_messages=True,
        input_key="question")
if 'vectorstore' not in st.session_state:
    # Proper embedding configuration, avoids meta tensor errors
    st.session_state.vectorstore = Chroma(persist_directory='jj', embedding_function=embeddings)

if 'llm' not in st.session_state:
    st.session_state.llm = HuggingFaceEndpoint(repo_id="mistralai/Mistral-7B-Instruct-v0.2", Temperature=0.9)

if 'chat_history' not in st.session_state:
    st.session_state.chat_history = []

st.title("PDF Chatbot")

uploaded_file = st.file_uploader("Upload your PDF", type='pdf')
for message in st.session_state.chat_history:
    with st.chat_message(message["role"]):
        st.markdown(message["message"])

if uploaded_file is not None:
    file_path = os.path.join("files", uploaded_file.name + ".pdf")
    if not os.path.isfile(file_path):
        with st.status("Analyzing your document..."):
            bytes_data = uploaded_file.read()
            with open(file_path, "wb") as f:
                f.write(bytes_data)
            loader = PyPDFLoader(file_path)
            data = loader.load()

            text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=0, length_function=len)
            all_splits = text_splitter.split_documents(data)

            st.session_state.vectorstore = Chroma.from_documents(documents=all_splits, embedding=embeddings)
            st.session_state.vectorstore.persist()

    st.session_state.retriever = st.session_state.vectorstore.as_retriever()
    if 'qa_chain' not in st.session_state:
        st.session_state.qa_chain = RetrievalQA.from_chain_type(
            llm=st.session_state.llm,
            chain_type='stuff',
            retriever=st.session_state.retriever,
            verbose=True,
            chain_type_kwargs={
                "verbose": True,
                "prompt": st.session_state.prompt,
                "memory": st.session_state.memory,
            }
        )

    if user_input := st.chat_input("You:", key="user_input"):
        user_message = {"role": "user", "message": user_input}
        st.session_state.chat_history.append(user_message)
        with st.chat_message("user"):
            st.markdown(user_input)
        with st.chat_message("assistant"):
            with st.spinner("Assistant is typing..."):
                response = st.session_state.qa_chain(user_input)
            message_placeholder = st.empty()
            full_response = ""
            for chunk in response['result'].split():
                full_response += chunk + " "
                time.sleep(0.05)
                message_placeholder.markdown(full_response + "▌")
            message_placeholder.markdown(full_response)

        chatbot_message = {"role": "assistant", "message": response['result']}
        st.session_state.chat_history.append(chatbot_message)
else:
    st.write("Please upload a PDF... file.")