Darshika94 commited on
Commit
d19b243
Β·
verified Β·
1 Parent(s): 8b79a4c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +236 -0
app.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain_core.messages import AIMessage, HumanMessage
3
+ from langchain_core.prompts import MessagesPlaceholder
4
+
5
+ from langchain_ollama import ChatOllama
6
+ from langchain_openai import ChatOpenAI
7
+
8
+ from langchain_core.output_parsers import StrOutputParser
9
+ from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
10
+
11
+ import torch
12
+ from langchain_huggingface import ChatHuggingFace
13
+ from langchain_huggingface import HuggingFaceEndpoint
14
+
15
+ import faiss
16
+ import tempfile
17
+ import os
18
+ import time
19
+ from langchain_community.vectorstores import FAISS
20
+ from langchain_huggingface import HuggingFaceEmbeddings
21
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
22
+ from langchain.chains import create_history_aware_retriever, create_retrieval_chain
23
+ from langchain.chains.combine_documents import create_stuff_documents_chain
24
+ from langchain_community.document_loaders import PyPDFLoader
25
+
26
+ from dotenv import load_dotenv
27
+
28
+ load_dotenv()
29
+
30
+ # Streamlit Settings
31
+ st.set_page_config(page_title="Chat with documents πŸ“š", page_icon="πŸ“š")
32
+ st.title("Chat with documents πŸ“š")
33
+ # Subtitle
34
+ st.subheader("Ask questions and get answers from your documents πŸ’¬") #newline-d
35
+
36
+ #new in progress
37
+
38
+
39
+
40
+ #
41
+
42
+ model_class = "hf_hub" # @param ["hf_hub", "openai", "ollama"]
43
+
44
+ ## Model Providers
45
+ def model_hf_hub(model="meta-llama/Meta-Llama-3-8B-Instruct", temperature=0.1):
46
+ llm = HuggingFaceEndpoint(
47
+ repo_id=model,
48
+ temperature=temperature,
49
+ max_new_tokens=512,
50
+ return_full_text=False,
51
+ #model_kwargs={
52
+ # "max_length": 64,
53
+ # #"stop": ["<|eot_id|>"],
54
+ #}
55
+ )
56
+ return llm
57
+
58
+ def model_openai(model="gpt-4o-mini", temperature=0.1):
59
+ llm = ChatOpenAI(
60
+ model=model,
61
+ temperature=temperature
62
+ # other parameters...
63
+ )
64
+ return llm
65
+
66
+ def model_ollama(model="phi3", temperature=0.1):
67
+ llm = ChatOllama(
68
+ model=model,
69
+ temperature=temperature,
70
+ )
71
+ return llm
72
+
73
+
74
+ ## Indexing and Retrieval
75
+
76
+ def config_retriever(uploads):
77
+ # Load
78
+ docs = []
79
+ temp_dir = tempfile.TemporaryDirectory()
80
+ for file in uploads:
81
+ temp_filepath = os.path.join(temp_dir.name, file.name)
82
+ with open(temp_filepath, "wb") as f:
83
+ f.write(file.getvalue())
84
+ loader = PyPDFLoader(temp_filepath)
85
+ docs.extend(loader.load())
86
+
87
+ # Split
88
+ text_splitter = RecursiveCharacterTextSplitter(
89
+ chunk_size=1000,
90
+ chunk_overlap=200
91
+ )
92
+ splits = text_splitter.split_documents(docs)
93
+
94
+ # Embeddings
95
+ embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-m3")
96
+
97
+ # Store
98
+ vectorstore = FAISS.from_documents(splits, embeddings)
99
+
100
+ vectorstore.save_local('vectorstore/db_faiss')
101
+
102
+ # Retrieve
103
+ retriever = vectorstore.as_retriever(
104
+ search_type='mmr',
105
+ search_kwargs={'k':3, 'fetch_k':4}
106
+ )
107
+
108
+ return retriever
109
+
110
+
111
+ def config_rag_chain(model_class, retriever):
112
+
113
+ ### Loading the LLM
114
+ if model_class == "hf_hub":
115
+ llm = model_hf_hub()
116
+ elif model_class == "openai":
117
+ llm = model_openai()
118
+ elif model_class == "ollama":
119
+ llm = model_ollama()
120
+
121
+ # Prompt definition
122
+ if model_class.startswith("hf"):
123
+ token_s, token_e = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>", "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
124
+ else:
125
+ token_s, token_e = "", ""
126
+
127
+ # Contextualization prompt
128
+ context_q_system_prompt = "Given the following chat history and the follow-up question which might reference context in the chat history, formulate a standalone question which can be understood without the chat history. Do NOT answer the question, just reformulate it if needed and otherwise return it as is."
129
+
130
+ context_q_system_prompt = token_s + context_q_system_prompt
131
+ context_q_user_prompt = "Question: {input}" + token_e
132
+ context_q_prompt = ChatPromptTemplate.from_messages(
133
+ [
134
+ ("system", context_q_system_prompt),
135
+ MessagesPlaceholder("chat_history"),
136
+ ("human", context_q_user_prompt),
137
+ ]
138
+ )
139
+
140
+ # Chain for contextualization
141
+ history_aware_retriever = create_history_aware_retriever(
142
+ llm=llm, retriever=retriever, prompt=context_q_prompt
143
+ )
144
+
145
+ # Q&A Prompt
146
+ qa_prompt_template = """You are a helpful virtual assistant answering general questions.
147
+ Use the following bits of retrieved context to answer the question.
148
+ If you don't know the answer, just say you don't know. Keep your answer concise.
149
+ Answer in English. \n\n
150
+ Question: {input} \n
151
+ Context: {context}"""
152
+
153
+ qa_prompt = PromptTemplate.from_template(token_s + qa_prompt_template + token_e)
154
+
155
+ # Configure LLM and Chain for Q&A
156
+
157
+ qa_chain = create_stuff_documents_chain(llm, qa_prompt)
158
+
159
+ rag_chain = create_retrieval_chain(
160
+ history_aware_retriever,
161
+ qa_chain,
162
+ )
163
+
164
+ return rag_chain
165
+
166
+
167
+ ## Creates side panel in the interface
168
+ uploads = st.sidebar.file_uploader(
169
+ label="Upload files", type=["pdf"],
170
+ accept_multiple_files=True
171
+ )
172
+ if not uploads:
173
+ st.info("Please send some file to continue!")
174
+ st.stop()
175
+
176
+
177
+ if "chat_history" not in st.session_state:
178
+ st.session_state.chat_history = [
179
+ AIMessage(content="Hi, I'm your virtual assistant! How can I help you?"),
180
+ ]
181
+
182
+ if "docs_list" not in st.session_state:
183
+ st.session_state.docs_list = None
184
+
185
+ if "retriever" not in st.session_state:
186
+ st.session_state.retriever = None
187
+
188
+ for message in st.session_state.chat_history:
189
+ if isinstance(message, AIMessage):
190
+ with st.chat_message("AI"):
191
+ st.write(message.content)
192
+ elif isinstance(message, HumanMessage):
193
+ with st.chat_message("Human"):
194
+ st.write(message.content)
195
+
196
+ # we use time to measure how long it took for generation
197
+ start = time.time()
198
+ user_query = st.chat_input("Enter your message here...")
199
+
200
+ if user_query is not None and user_query != "" and uploads is not None:
201
+
202
+ st.session_state.chat_history.append(HumanMessage(content=user_query))
203
+
204
+ with st.chat_message("Human"):
205
+ st.markdown(user_query)
206
+
207
+ with st.chat_message("AI"):
208
+
209
+ if st.session_state.docs_list != uploads:
210
+ print(uploads)
211
+ st.session_state.docs_list = uploads
212
+ st.session_state.retriever = config_retriever(uploads)
213
+
214
+ rag_chain = config_rag_chain(model_class, st.session_state.retriever)
215
+
216
+ result = rag_chain.invoke({"input": user_query, "chat_history": st.session_state.chat_history})
217
+
218
+ resp = result['answer']
219
+ st.write(resp)
220
+
221
+ # show the source
222
+ sources = result['context']
223
+ for idx, doc in enumerate(sources):
224
+ source = doc.metadata['source']
225
+ file = os.path.basename(source)
226
+ page = doc.metadata.get('page', 'Page not specified')
227
+
228
+ ref = f":link: Source {idx}: *{file} - p. {page}*"
229
+ print(ref)
230
+ with st.popover(ref):
231
+ st.caption(doc.page_content)
232
+
233
+ st.session_state.chat_history.append(AIMessage(content=resp))
234
+
235
+ end = time.time()
236
+ print("Time: ", end - start)