praneeth dodedu commited on
Commit
b647893
·
1 Parent(s): 74563e8
Files changed (1) hide show
  1. app.py +256 -152
app.py CHANGED
@@ -1,167 +1,271 @@
1
  #!/usr/bin/env python3
2
- import os
3
- import glob
4
- from typing import List
5
  from dotenv import load_dotenv
6
- from multiprocessing import Pool
7
- from tqdm import tqdm
8
-
9
- from langchain.document_loaders import (
10
- CSVLoader,
11
- EverNoteLoader,
12
- PDFMinerLoader,
13
- TextLoader,
14
- UnstructuredEmailLoader,
15
- UnstructuredEPubLoader,
16
- UnstructuredHTMLLoader,
17
- UnstructuredMarkdownLoader,
18
- UnstructuredODTLoader,
19
- UnstructuredPowerPointLoader,
20
- UnstructuredWordDocumentLoader,
21
- )
22
-
23
- from langchain.text_splitter import RecursiveCharacterTextSplitter
24
- from langchain.vectorstores import Chroma
25
  from langchain.embeddings import HuggingFaceEmbeddings
26
- from langchain.docstore.document import Document
27
- from constants import CHROMA_SETTINGS
28
-
 
 
 
 
 
29
 
30
  load_dotenv()
31
 
32
-
33
- # Load environment variables
34
  persist_directory = os.environ.get('PERSIST_DIRECTORY')
35
- source_directory = os.environ.get('SOURCE_DIRECTORY', 'source_documents')
36
- embeddings_model_name = os.environ.get('EMBEDDINGS_MODEL_NAME')
37
- chunk_size = 500
38
- chunk_overlap = 50
39
-
40
-
41
- # Custom document loaders
42
- class MyElmLoader(UnstructuredEmailLoader):
43
- """Wrapper to fallback to text/plain when default does not work"""
44
-
45
- def load(self) -> List[Document]:
46
- """Wrapper adding fallback for elm without html"""
47
- try:
48
- try:
49
- doc = UnstructuredEmailLoader.load(self)
50
- except ValueError as e:
51
- if 'text/html content not found in email' in str(e):
52
- # Try plain text
53
- self.unstructured_kwargs["content_source"]="text/plain"
54
- doc = UnstructuredEmailLoader.load(self)
55
- else:
56
- raise
57
- except Exception as e:
58
- # Add file_path to exception message
59
- raise type(e)(f"{self.file_path}: {e}") from e
60
-
61
- return doc
62
-
63
-
64
- # Map file extensions to document loaders and their arguments
65
- LOADER_MAPPING = {
66
- ".csv": (CSVLoader, {}),
67
- # ".docx": (Docx2txtLoader, {}),
68
- ".doc": (UnstructuredWordDocumentLoader, {}),
69
- ".docx": (UnstructuredWordDocumentLoader, {}),
70
- ".enex": (EverNoteLoader, {}),
71
- ".eml": (MyElmLoader, {}),
72
- ".epub": (UnstructuredEPubLoader, {}),
73
- ".html": (UnstructuredHTMLLoader, {}),
74
- ".md": (UnstructuredMarkdownLoader, {}),
75
- ".odt": (UnstructuredODTLoader, {}),
76
- ".pdf": (PDFMinerLoader, {}),
77
- ".ppt": (UnstructuredPowerPointLoader, {}),
78
- ".pptx": (UnstructuredPowerPointLoader, {}),
79
- ".txt": (TextLoader, {"encoding": "utf8"}),
80
- # Add more mappings for other file extensions and loaders as needed
81
- }
82
-
83
-
84
- def load_single_document(file_path: str) -> Document:
85
- ext = "." + file_path.rsplit(".", 1)[-1]
86
- if ext in LOADER_MAPPING:
87
- loader_class, loader_args = LOADER_MAPPING[ext]
88
- loader = loader_class(file_path, **loader_args)
89
- return loader.load()[0]
90
-
91
- raise ValueError(f"Unsupported file extension '{ext}'")
92
-
93
-
94
- def load_documents(source_dir: str, ignored_files: List[str] = []) -> List[Document]:
95
- """
96
- Loads all documents from the source documents directory, ignoring specified files
97
- """
98
- all_files = []
99
- for ext in LOADER_MAPPING:
100
- all_files.extend(
101
- glob.glob(os.path.join(source_dir, f"**/*{ext}"), recursive=True)
102
- )
103
- filtered_files = [file_path for file_path in all_files if file_path not in ignored_files]
104
-
105
- with Pool(processes=os.cpu_count()) as pool:
106
- results = []
107
- with tqdm(total=len(filtered_files), desc='Loading new documents', ncols=80) as pbar:
108
- for i, doc in enumerate(pool.imap_unordered(load_single_document, filtered_files)):
109
- results.append(doc)
110
- pbar.update()
111
-
112
- return results
113
-
114
- def process_documents(ignored_files: List[str] = []) -> List[Document]:
115
- """
116
- Load documents and split in chunks
117
- """
118
- print(f"Loading documents from {source_directory}")
119
- documents = load_documents(source_directory, ignored_files)
120
- if not documents:
121
- print("No new documents to load")
122
- exit(0)
123
- print(f"Loaded {len(documents)} new documents from {source_directory}")
124
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
125
- texts = text_splitter.split_documents(documents)
126
- print(f"Split into {len(texts)} chunks of text (max. {chunk_size} tokens each)")
127
- return texts
128
-
129
- def does_vectorstore_exist(persist_directory: str) -> bool:
130
- """
131
- Checks if vectorstore exists
132
- """
133
- if os.path.exists(os.path.join(persist_directory, 'index')):
134
- if os.path.exists(os.path.join(persist_directory, 'chroma-collections.parquet')) and os.path.exists(os.path.join(persist_directory, 'chroma-embeddings.parquet')):
135
- list_index_files = glob.glob(os.path.join(persist_directory, 'index/*.bin'))
136
- list_index_files += glob.glob(os.path.join(persist_directory, 'index/*.pkl'))
137
- # At least 3 documents are needed in a working vectorstore
138
- if len(list_index_files) > 3:
139
- return True
140
- return False
141
 
142
  def main():
143
- # Create embeddings
 
144
  embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
- if does_vectorstore_exist(persist_directory):
147
- # Update and store locally vectorstore
148
- print(f"Appending to existing vectorstore at {persist_directory}")
149
- db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
150
- collection = db.get()
151
- texts = process_documents([metadata['source'] for metadata in collection['metadatas']])
152
- print(f"Creating embeddings. May take some minutes...")
153
- db.add_documents(texts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  else:
155
- # Create and store locally vectorstore
156
- print("Creating new vectorstore")
157
- texts = process_documents()
158
- print(f"Creating embeddings. May take some minutes...")
159
- db = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory, client_settings=CHROMA_SETTINGS)
160
- db.persist()
161
- db = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
- print(f"Ingestion complete! You can now run privateGPT.py to query your documents")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
 
166
- if __name__ == "__main__":
167
- main()
 
1
  #!/usr/bin/env python3
 
 
 
2
  from dotenv import load_dotenv
3
+ from langchain.chains import RetrievalQA
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from langchain.embeddings import HuggingFaceEmbeddings
5
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
6
+ from langchain.vectorstores import Chroma
7
+ from langchain.llms import GPT4All, LlamaCpp
8
+ import os
9
+ import argparse
10
+ from pathlib import Path
11
+ import base64
12
+ import gradio as gr
13
 
14
  load_dotenv()
15
 
16
+ embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
 
17
  persist_directory = os.environ.get('PERSIST_DIRECTORY')
18
+
19
+ model_type = os.environ.get('MODEL_TYPE')
20
+ model_path = os.environ.get('MODEL_PATH')
21
+ model_n_ctx = os.environ.get('MODEL_N_CTX')
22
+
23
+ from constants import CHROMA_SETTINGS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def main():
26
+ # Parse the command line arguments
27
+ args = parse_arguments()
28
  embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
29
+ db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
30
+ retriever = db.as_retriever()
31
+ # activate/deactivate the streaming StdOut callback for LLMs
32
+ callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()]
33
+ # Prepare the LLM
34
+ '''match model_type:
35
+ case "LlamaCpp":
36
+ llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False)
37
+ case "GPT4All":
38
+ llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
39
+ case _default:
40
+ print(f"Model {model_type} not supported!")
41
+ exit;'''
42
+ if model_type == "LlamaCpp":
43
+ llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False)
44
+ elif model_type == "GPT4All":
45
+ llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
46
+ else:
47
+ print(f"Model {model_type} not supported!")
48
+ exit;
49
+ qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents= not args.hide_source)
50
+ # Interactive questions and answers
51
+ while True:
52
+ query = input("\nEnter a query: ")
53
+ if query == "exit":
54
+ break
55
+
56
+ # Get the answer from the chain
57
+ res = qa(query)
58
+ answer, docs = res['result'], [] if args.hide_source else res['source_documents']
59
+
60
+ # Print the result
61
+ print("\n\n> Question:")
62
+ print(query)
63
+ print("\n> Answer:")
64
+ print(answer)
65
+
66
+ # Print the relevant sources used for the answer
67
+ for document in docs:
68
+ print("\n> " + document.metadata["source"] + ":")
69
+ print(document.page_content)
70
+
71
+ def parse_arguments():
72
+ parser = argparse.ArgumentParser(description='privateGPT: Ask questions to your documents without an internet connection, '
73
+ 'using the power of LLMs.')
74
+ parser.add_argument("--hide-source", "-S", action='store_true',
75
+ help='Use this flag to disable printing of source documents used for answers.')
76
 
77
+ parser.add_argument("--mute-stream", "-M",
78
+ action='store_true',
79
+ help='Use this flag to disable the streaming StdOut callback for LLMs.')
80
+
81
+ return parser.parse_args()
82
+
83
+
84
+ def apply_html(text, color):
85
+ if "<table>" in text and "</table>" in text:
86
+ # If the text contains table tags, modify the table structure for Gradio
87
+ table_start = text.index("<table>")
88
+ table_end = text.index("</table>") + len("</table>")
89
+ table_content = text[table_start:table_end]
90
+
91
+ # Modify the table structure for Gradio
92
+ modified_table = table_content.replace("<table>", "<table style='border-collapse: collapse;'>")
93
+ modified_table = modified_table.replace("<th>", "<th style='border: 1px solid #ddd; padding: 8px; background-color: #f2f2f2;'>")
94
+ modified_table = modified_table.replace("<td>", "<td style='border: 1px solid #ddd; padding: 8px;'>")
95
+
96
+ # Replace the modified table back into the original text
97
+ modified_text = text[:table_start] + modified_table + text[table_end:]
98
+ return modified_text
99
  else:
100
+ # Return the plain text as is
101
+ return text
102
+
103
+ def add_text(history, text):
104
+ # Apply selected rules
105
+
106
+ if history is not None:
107
+ # If all rules pass, add message to chat history with bot's response set to None
108
+ history.append([apply_html(text, "blue"), None])
109
+
110
+ return history, text
111
+
112
+ def bot(query, history, fileListHistory, k=5):
113
+ # Parse the command line arguments
114
+ args = parse_arguments()
115
+ print("QUERY : " + query)
116
+ embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
117
+ db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
118
+ retriever = db.as_retriever()
119
+ # activate/deactivate the streaming StdOut callback for LLMs
120
+ callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()]
121
+ # Prepare the LLM
122
+ '''match model_type:
123
+ case "LlamaCpp":
124
+ llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False)
125
+ case "GPT4All":
126
+ llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
127
+ case _default:
128
+ print(f"Model {model_type} not supported!")
129
+ exit;'''
130
+ if model_type == "LlamaCpp":
131
+ llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False)
132
+ elif model_type == "GPT4All":
133
+ llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
134
+ else:
135
+ print(f"Model {model_type} not supported!")
136
+ exit;
137
+ qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents= not args.hide_source)
138
+
139
+ # Get the answer from the chain
140
+ res = qa(query)
141
+ answer, docs = res['result'], [] if args.hide_source else res['source_documents']
142
+
143
+ # Print the result
144
+ print("\n\n> Question:")
145
+ print(query)
146
+ print("\n> Answer:")
147
+ print(answer)
148
+
149
+ # Print the relevant sources used for the answer
150
+ for document in docs:
151
+ print("\n> " + document.metadata["source"] + ":")
152
+ print(document.page_content)
153
+
154
+ # If the call was not successful after 3 attempts, set the response to a timeout message
155
+ if answer is None:
156
+ print("Unfortunately, the connection to ChatGPT timed out. Please try after some time.")
157
+ if history is not None and len(history) > 0:
158
+ # Update the chat history with the bot's response
159
+ history[-1][1] = apply_html(answer.text.strip(), "black")
160
+ else:
161
+ # Print the generated response
162
+ print("\nGPT RESPONSE:\n")
163
+ # print(answer['choices'][0]['message']['content'].strip())
164
+
165
+ if history is not None and len(history) > 0:
166
+ # Update the chat history with the bot's response
167
+ history[-1][1] = apply_html(answer.strip(), "black")
168
+ return history, fileListHistory
169
+
170
+
171
+
172
+ # Open the image and convert it to base64
173
+ with open(Path("rybot_small.png"), "rb") as img_file:
174
+ img_str = base64.b64encode(img_file.read()).decode()
175
+
176
+ html_code = f'''
177
+ <!DOCTYPE html>
178
+ <html>
179
+ <head>
180
+ <style>
181
+ .center {{
182
+ display: flex;
183
+ justify-content: center;
184
+ align-items: center;
185
+ margin-top: -40px; /* adjust this value as per your requirement */
186
+ margin-bottom: 5px;
187
+ }}
188
+ .large-text {{
189
+ font-size: 40px;
190
+ font-family: Arial, Helvetica, sans-serif;
191
+ font-weight: 900 !important;
192
+ margin-left: 5px;
193
+ color: #5b5b5b !important;
194
+ }}
195
+ .image-container {{
196
+ display: inline-block;
197
+ vertical-align: middle;
198
+ height: 50px; /* Twice the font-size */
199
+ margin-bottom: 5px;
200
+ }}
201
+ </style>
202
+ </head>
203
+ <body>
204
+ <div class="center">
205
+ <img src="data:image/jpg;base64,{img_str}" alt="RyBOT image" class="image-container" />
206
+ <strong class="large-text">RyBOT</strong>
207
+ </div>
208
+ <br>
209
+ <div class="center">
210
+ <h3> [ "I'm smart but the humans have me running on a hamster wheel. Please forgive the slow responses." ] </h3>
211
+ </div>
212
+ </body>
213
+ </html>
214
+ '''
215
+
216
+
217
+ css = """
218
+ .feedback textarea {background-color: #e9f0f7}
219
+ .gradio-container {background-color: #eeeeee}
220
+ """
221
+
222
+ def clear_textbox():
223
+ print("Calling CLEAR")
224
+ return None
225
+
226
+ with gr.Blocks(theme=gr.themes.Soft(), css=css, title="RyBOT") as demo:
227
+
228
+ gr.HTML(html_code)
229
+ chatbot = gr.Chatbot([], elem_id="chatbot", label="Chat", color_map=["blue","grey"]).style(height=450)
230
+ fileListBot = gr.Chatbot([], elem_id="fileListBot", label="References", color_map=["blue","grey"]).style(height=150)
231
+
232
+ txt = gr.Textbox(
233
+ label="Type your query here:",
234
+ placeholder="What would you like to find today?"
235
+ ).style(container=True)
236
+
237
+ txt.submit(
238
+ add_text,
239
+ [chatbot, txt],
240
+ [chatbot, txt]
241
+ ).then(
242
+ bot,
243
+ [txt, chatbot, fileListBot],
244
+ [chatbot, fileListBot]
245
+ ).then(
246
+ clear_textbox,
247
+ inputs=None,
248
+ outputs=[txt]
249
+ )
250
 
251
+ btn = gr.Button(value="Send")
252
+ btn.click(
253
+ add_text,
254
+ [chatbot, txt],
255
+ [chatbot, txt],
256
+ ).then(
257
+ bot,
258
+ [txt, chatbot, fileListBot],
259
+ [chatbot, fileListBot]
260
+ ).then(
261
+ clear_textbox,
262
+ inputs=None,
263
+ outputs=[txt]
264
+ )
265
+
266
+ gr.close_all()
267
+ demo.launch(server_port=7861)
268
 
269
 
270
+ #if __name__ == "__main__":
271
+ # main()