ROHAN181 commited on
Commit
6cedaef
·
1 Parent(s): f19fa77

steamlitconvert

Browse files
Files changed (1) hide show
  1. app.py +216 -142
app.py CHANGED
@@ -1,156 +1,230 @@
1
  import streamlit as st
2
- from dotenv import load_dotenv
3
- from PyPDF2 import PdfReader
4
- from langchain.text_splitter import CharacterTextSplitter
5
- from langchain.embeddings import OpenAIEmbeddings, HuggingFaceInstructEmbeddings
6
- from langchain.vectorstores import FAISS
7
- from langchain.chat_models import ChatOpenAI
8
- from langchain.memory import ConversationBufferMemory
9
  from langchain.chains import ConversationalRetrievalChain
10
-
 
 
 
11
  from langchain.llms import HuggingFaceHub
12
 
13
-
14
-
15
-
16
-
17
- css = '''
18
- <style>
19
- .chat-message {
20
- padding: 1.5rem; border-radius: 0.5rem; margin-bottom: 1rem; display: flex
21
- }
22
- .chat-message.user {
23
- background-color: #2b313e
24
- }
25
- .chat-message.bot {
26
- background-color: #475063
27
- }
28
- .chat-message .avatar {
29
- width: 20%;
30
- }
31
- .chat-message .avatar img {
32
- max-width: 78px;
33
- max-height: 78px;
34
- border-radius: 50%;
35
- object-fit: cover;
36
- }
37
- .chat-message .message {
38
- width: 80%;
39
- padding: 0 1.5rem;
40
- color: #fff;
41
- }
42
- '''
43
-
44
- bot_template = '''
45
- <div class="chat-message bot">
46
- <div class="avatar">
47
- <img src="https://i.ibb.co/cN0nmSj/Screenshot-2023-05-28-at-02-37-21.png" style="max-height: 78px; max-width: 78px; border-radius: 50%; object-fit: cover;">
48
- </div>
49
- <div class="message">{{MSG}}</div>
50
- </div>
51
- '''
52
-
53
- user_template = '''
54
- <div class="chat-message user">
55
- <div class="avatar">
56
- <img src="https://i.ibb.co/rdZC7LZ/Photo-logo-1.png">
57
- </div>
58
- <div class="message">{{MSG}}</div>
59
- </div>
60
- '''
61
-
62
- def get_pdf_text(pdf_docs):
63
- text = ""
64
- for pdf in pdf_docs:
65
- pdf_reader = PdfReader(pdf)
66
- for page in pdf_reader.pages:
67
- text += page.extract_text()
68
- return text
69
-
70
-
71
-
72
-
73
-
74
- def get_text_chunks(text):
75
- text_splitter = CharacterTextSplitter(
76
- separator="\n",
77
- chunk_size=1000,
78
- chunk_overlap=200,
79
- length_function=len
80
  )
81
- chunks = text_splitter.split_text(text)
82
- return chunks
83
-
84
-
85
- def get_vectorstore(text_chunks):
86
- #embeddings = OpenAIEmbeddings()
87
- embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-xl")
88
- vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings)
89
- return vectorstore
90
-
91
-
92
- def get_conversation_chain(vectorstore):
93
- #llm = ChatOpenAI()
94
- llm = HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature":0.5, "max_length":512})
95
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  memory = ConversationBufferMemory(
97
- memory_key='chat_history', return_messages=True)
98
- conversation_chain = ConversationalRetrievalChain.from_llm(
99
- llm=llm,
100
- retriever=vectorstore.as_retriever(),
101
- memory=memory
 
 
 
 
 
 
 
 
 
 
 
102
  )
103
- return conversation_chain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
 
106
- def handle_userinput(user_question):
107
- response = st.session_state.conversation({'question': user_question})
108
- st.session_state.chat_history = response['chat_history']
109
 
110
- for i, message in enumerate(st.session_state.chat_history):
111
- if i % 2 == 0:
112
- st.write(user_template.replace(
113
- "{{MSG}}", message.content), unsafe_allow_html=True)
114
- else:
115
- st.write(bot_template.replace(
116
- "{{MSG}}", message.content), unsafe_allow_html=True)
117
 
118
 
119
  def main():
120
- load_dotenv()
121
- st.set_page_config(page_title="Chat with multiple PDFs",
122
- page_icon=":books:")
123
- st.write(css, unsafe_allow_html=True)
124
-
125
- if "conversation" not in st.session_state:
126
- st.session_state.conversation = None
127
- if "chat_history" not in st.session_state:
128
- st.session_state.chat_history = None
129
-
130
- st.header("Chat with multiple PDFs :books:")
131
- user_question = st.text_input("Ask a question about your documents:")
132
- if user_question:
133
- handle_userinput(user_question)
134
-
135
- with st.sidebar:
136
- st.subheader("Your documents")
137
- pdf_docs = st.file_uploader(
138
- "Upload your PDFs here and click on 'Process'", accept_multiple_files=True)
139
- if st.button("Process"):
140
- with st.spinner("Processing"):
141
- # get pdf text
142
- raw_text = get_pdf_text(pdf_docs)
143
-
144
- # get the text chunks
145
- text_chunks = get_text_chunks(raw_text)
146
-
147
- # create vector store
148
- vectorstore = get_vectorstore(text_chunks)
149
-
150
- # create conversation chain
151
- st.session_state.conversation = get_conversation_chain(
152
- vectorstore)
153
-
154
-
155
- if __name__ == '__main__':
 
156
  main()
 
1
  import streamlit as st
2
+ import os
3
+ from langchain.document_loaders import PyPDFLoader
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain.vectorstores import Chroma
 
 
 
6
  from langchain.chains import ConversationalRetrievalChain
7
+ from langchain.embeddings import HuggingFaceEmbeddings
8
+ from langchain.llms import HuggingFacePipeline
9
+ from langchain.chains import ConversationChain
10
+ from langchain.memory import ConversationBufferMemory
11
  from langchain.llms import HuggingFaceHub
12
 
13
+ from transformers import AutoTokenizer
14
+ import transformers
15
+ import torch
16
+ import tqdm
17
+ import accelerate
18
+
19
+ default_persist_directory = './chroma_HF/'
20
+
21
+ llm_name1 = "mistralai/Mistral-7B-Instruct-v0.2"
22
+ llm_name2 = "mistralai/Mistral-7B-Instruct-v0.1"
23
+ llm_name3 = "meta-llama/Llama-2-7b-chat-hf"
24
+ llm_name4 = "microsoft/phi-2"
25
+ llm_name5 = "mosaicml/mpt-7b-instruct"
26
+ llm_name6 = "tiiuae/falcon-7b-instruct"
27
+ llm_name7 = "google/flan-t5-xxl"
28
+ list_llm = [llm_name1, llm_name2, llm_name3, llm_name4, llm_name5, llm_name6, llm_name7]
29
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
30
+
31
+
32
+
33
+ Load PDF document and create doc splits
34
+ def load_doc(list_file_path, chunk_size, chunk_overlap):
35
+ # Processing for one document only
36
+ # loader = PyPDFLoader(file_path)
37
+ # pages = loader.load()
38
+ loaders = [PyPDFLoader(x) for x in list_file_path]
39
+ pages = []
40
+ for loader in loaders:
41
+ pages.extend(loader.load())
42
+ # text_splitter = RecursiveCharacterTextSplitter(chunk_size = 600, chunk_overlap = 50)
43
+ text_splitter = RecursiveCharacterTextSplitter(
44
+ chunk_size = chunk_size,
45
+ chunk_overlap = chunk_overlap)
46
+ doc_splits = text_splitter.split_documents(pages)
47
+ return doc_splits
48
+
49
+
50
+ # Create vector database
51
+ def create_db(splits):
52
+ embedding = HuggingFaceEmbeddings()
53
+ vectordb = Chroma.from_documents(
54
+ documents=splits,
55
+ embedding=embedding,
56
+ persist_directory=default_persist_directory
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  )
58
+ return vectordb
59
+
60
+
61
+ # Load vector database
62
+ def load_db():
63
+ embedding = HuggingFaceEmbeddings()
64
+ vectordb = Chroma(
65
+ persist_directory=default_persist_directory,
66
+ embedding_function=embedding)
67
+ return vectordb
68
+
69
+
70
+ # Initialize langchain LLM chain
71
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
72
+ progress(0.1, desc="Initializing HF tokenizer...")
73
+ # HuggingFacePipeline uses local model
74
+ # Warning: it will download model locally...
75
+ # tokenizer=AutoTokenizer.from_pretrained(llm_model)
76
+ # progress(0.5, desc="Initializing HF pipeline...")
77
+ # pipeline=transformers.pipeline(
78
+ # "text-generation",
79
+ # model=llm_model,
80
+ # tokenizer=tokenizer,
81
+ # torch_dtype=torch.bfloat16,
82
+ # trust_remote_code=True,
83
+ # device_map="auto",
84
+ # # max_length=1024,
85
+ # max_new_tokens=max_tokens,
86
+ # do_sample=True,
87
+ # top_k=top_k,
88
+ # num_return_sequences=1,
89
+ # eos_token_id=tokenizer.eos_token_id
90
+ # )
91
+ # llm = HuggingFacePipeline(pipeline=pipeline, model_kwargs={'temperature': temperature})
92
+
93
+ # HuggingFaceHub uses HF inference endpoints
94
+ progress(0.5, desc="Initializing HF Hub...")
95
+ llm = HuggingFaceHub(
96
+ repo_id=llm_model,
97
+ model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k,\
98
+ "trust_remote_code": True, "torch_dtype": "auto"}
99
+ )
100
+
101
+ progress(0.75, desc="Defining buffer memory...")
102
  memory = ConversationBufferMemory(
103
+ memory_key="chat_history",
104
+ output_key='answer',
105
+ return_messages=True
106
+ )
107
+ # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
108
+ retriever=vector_db.as_retriever()
109
+ progress(0.8, desc="Defining retrieval chain...")
110
+ qa_chain = ConversationalRetrievalChain.from_llm(
111
+ llm,
112
+ retriever=retriever,
113
+ chain_type="stuff",
114
+ memory=memory,
115
+ # combine_docs_chain_kwargs={"prompt": your_prompt})
116
+ return_source_documents=True,
117
+ # return_generated_question=True,
118
+ # verbose=True,
119
  )
120
+ progress(0.9, desc="Done!")
121
+ return qa_chain
122
+
123
+
124
+ # Initialize database
125
+ def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
126
+ # Create list of documents (when valid)
127
+ #file_path = file_obj.name
128
+ list_file_path = [x.name for x in list_file_obj if x is not None]
129
+ # print('list_file_path', list_file_path)
130
+ progress(0.25, desc="Loading document...")
131
+ # Load document and create splits
132
+ doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
133
+ # Create or load Vector database
134
+ progress(0.5, desc="Generating vector database...")
135
+ # global vector_db
136
+ vector_db = create_db(doc_splits)
137
+ progress(0.9, desc="Done!")
138
+ return vector_db, "Complete!"
139
+
140
+
141
+ def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
142
+ # print("llm_option",llm_option)
143
+ llm_name = list_llm[llm_option]
144
+ # print("llm_name",llm_name)
145
+ qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
146
+ return qa_chain, "Complete!"
147
+
148
+
149
+ def format_chat_history(message, chat_history):
150
+ formatted_chat_history = []
151
+ for user_message, bot_message in chat_history:
152
+ formatted_chat_history.append(f"User: {user_message}")
153
+ formatted_chat_history.append(f"Assistant: {bot_message}")
154
+ return formatted_chat_history
155
+
156
+
157
+ def conversation(qa_chain, message, history):
158
+ formatted_chat_history = format_chat_history(message, history)
159
+ #print("formatted_chat_history",formatted_chat_history)
160
+
161
+ # Generate response using QA chain
162
+ response = qa_chain({"question": message, "chat_history": formatted_chat_history})
163
+ response_answer = response["answer"]
164
+ response_sources = response["source_documents"]
165
+ response_source1 = response_sources[0].page_content.strip()
166
+ response_source2 = response_sources[1].page_content.strip()
167
+ # Langchain sources are zero-based
168
+ response_source1_page = response_sources[0].metadata["page"] + 1
169
+ response_source2_page = response_sources[1].metadata["page"] + 1
170
+ # print ('chat response: ', response_answer)
171
+ # print('DB source', response_sources)
172
+
173
+ # Append user message and response to chat history
174
+ new_history = history + [(message, response_answer)]
175
+ # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
176
+ return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page
177
+
178
+
179
+ def upload_file(file_obj):
180
+ list_file_path = []
181
+ for idx, file in enumerate(file_obj):
182
+ file_path = file_obj.name
183
+ list_file_path.append(file_path)
184
+ # print(file_path)
185
+ # initialize_database(file_path, progress)
186
+ return list_file_path
187
 
188
 
 
 
 
189
 
 
 
 
 
 
 
 
190
 
191
 
192
  def main():
193
+ st.title("PDF-based chatbot (powered by LangChain and open-source LLMs)")
194
+ st.markdown("""
195
+ ## Ask any questions about your PDF documents, along with follow-ups
196
+ **Note:** This AI assistant performs retrieval-augmented generation from your PDF documents.
197
+ When generating answers, it takes past questions into account (via conversational memory),
198
+ and includes document references for clarity purposes.
199
+ \n**Warning:** This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate an output.
200
+ """)
201
+
202
+ # Step 1 - Document pre-processing
203
+ st.header("Step 1 - Document pre-processing")
204
+ uploaded_files = st.file_uploader("Upload your PDF documents (single or multiple)", type="pdf", accept_multiple_files=True)
205
+ db_btn = st.radio("Vector database type", ["ChromaDB"])
206
+
207
+ st.slider("Chunk size", 100, 1000, 600, 20, key="chunk_size")
208
+ st.slider("Chunk overlap", 10, 200, 40, 10, key="chunk_overlap")
209
+
210
+ if st.button("Generating vector database..."):
211
+ # Call your initialization function here using uploaded_files, chunk_size, chunk_overlap
212
+
213
+ # Step 2 - QA chain initialization
214
+ st.header("Step 2 - QA chain initialization")
215
+ llm_option = st.radio("LLM models", list_llm_simple)
216
+ st.slider("Temperature", 0.0, 1.0, 0.7, 0.1, key="llm_temperature")
217
+ st.slider("Max Tokens", 224, 4096, 1024, 32, key="max_tokens")
218
+ st.slider("Top-k samples", 1, 10, 3, 1, key="top_k")
219
+
220
+ if st.button("Initializing question-answering chain..."):
221
+ # Call your initialization function here using llm_option, llm_temperature, max_tokens, top_k, vector_db
222
+
223
+ # Step 3 - Conversation with chatbot
224
+ st.header("Step 3 - Conversation with chatbot")
225
+ msg = st.text_input("Type message", key="message")
226
+ if st.button("Submit"):
227
+ # Call your conversation function here using qa_chain, msg, chatbot, and update UI accordingly
228
+
229
+ if __name__ == "__main__":
230
  main()