praneeth dodedu commited on
Commit
8c51a26
·
1 Parent(s): 6a9410c
Files changed (2) hide show
  1. app.py +152 -251
  2. privategpt.py +1 -6
app.py CHANGED
@@ -1,266 +1,167 @@
1
  #!/usr/bin/env python3
 
 
 
2
  from dotenv import load_dotenv
3
- from langchain.chains import RetrievalQA
4
- from langchain.embeddings import HuggingFaceEmbeddings
5
- from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from langchain.vectorstores import Chroma
7
- from langchain.llms import GPT4All, LlamaCpp
8
- import os
9
- import argparse
10
- from pathlib import Path
11
- import base64
12
- import gradio as gr
13
 
14
- load_dotenv()
15
 
16
- embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
17
- persist_directory = os.environ.get('PERSIST_DIRECTORY')
18
 
19
- model_type = os.environ.get('MODEL_TYPE')
20
- model_path = os.environ.get('MODEL_PATH')
21
- model_n_ctx = os.environ.get('MODEL_N_CTX')
22
 
23
- from constants import CHROMA_SETTINGS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def main():
26
- # Parse the command line arguments
27
- args = parse_arguments()
28
  embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
29
- db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
30
- retriever = db.as_retriever()
31
- # activate/deactivate the streaming StdOut callback for LLMs
32
- callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()]
33
- # Prepare the LLM
34
- '''match model_type:
35
- case "LlamaCpp":
36
- llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False)
37
- case "GPT4All":
38
- llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
39
- case _default:
40
- print(f"Model {model_type} not supported!")
41
- exit;'''
42
- if model_type == "LlamaCpp":
43
- llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False)
44
- elif model_type == "GPT4All":
45
- llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
46
- else:
47
- print(f"Model {model_type} not supported!")
48
- exit;
49
- qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents= not args.hide_source)
50
- # Interactive questions and answers
51
- while True:
52
- query = input("\nEnter a query: ")
53
- if query == "exit":
54
- break
55
-
56
- # Get the answer from the chain
57
- res = qa(query)
58
- answer, docs = res['result'], [] if args.hide_source else res['source_documents']
59
-
60
- # Print the result
61
- print("\n\n> Question:")
62
- print(query)
63
- print("\n> Answer:")
64
- print(answer)
65
-
66
- # Print the relevant sources used for the answer
67
- for document in docs:
68
- print("\n> " + document.metadata["source"] + ":")
69
- print(document.page_content)
70
 
71
- def parse_arguments():
72
- parser = argparse.ArgumentParser(description='privateGPT: Ask questions to your documents without an internet connection, '
73
- 'using the power of LLMs.')
74
- parser.add_argument("--hide-source", "-S", action='store_true',
75
- help='Use this flag to disable printing of source documents used for answers.')
76
-
77
- parser.add_argument("--mute-stream", "-M",
78
- action='store_true',
79
- help='Use this flag to disable the streaming StdOut callback for LLMs.')
80
-
81
- return parser.parse_args()
82
-
83
-
84
- def apply_html(text, color):
85
- if "<table>" in text and "</table>" in text:
86
- # If the text contains table tags, modify the table structure for Gradio
87
- table_start = text.index("<table>")
88
- table_end = text.index("</table>") + len("</table>")
89
- table_content = text[table_start:table_end]
90
-
91
- # Modify the table structure for Gradio
92
- modified_table = table_content.replace("<table>", "<table style='border-collapse: collapse;'>")
93
- modified_table = modified_table.replace("<th>", "<th style='border: 1px solid #ddd; padding: 8px; background-color: #f2f2f2;'>")
94
- modified_table = modified_table.replace("<td>", "<td style='border: 1px solid #ddd; padding: 8px;'>")
95
-
96
- # Replace the modified table back into the original text
97
- modified_text = text[:table_start] + modified_table + text[table_end:]
98
- return modified_text
99
- else:
100
- # Return the plain text as is
101
- return text
102
-
103
- def add_text(history, text):
104
- # Apply selected rules
105
-
106
- if history is not None:
107
- # If all rules pass, add message to chat history with bot's response set to None
108
- history.append([apply_html(text, "blue"), None])
109
-
110
- return history, text
111
-
112
- def bot(query, history, fileListHistory, k=5):
113
- # Parse the command line arguments
114
- args = parse_arguments()
115
- print("QUERY : " + query)
116
- embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
117
- db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
118
- retriever = db.as_retriever()
119
- # activate/deactivate the streaming StdOut callback for LLMs
120
- callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()]
121
- # Prepare the LLM
122
- '''match model_type:
123
- case "LlamaCpp":
124
- llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False)
125
- case "GPT4All":
126
- llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
127
- case _default:
128
- print(f"Model {model_type} not supported!")
129
- exit;'''
130
- if model_type == "LlamaCpp":
131
- llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False)
132
- elif model_type == "GPT4All":
133
- llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
134
  else:
135
- print(f"Model {model_type} not supported!")
136
- exit;
137
- qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents= not args.hide_source)
 
 
 
 
138
 
139
- # Get the answer from the chain
140
- res = qa(query)
141
- answer, docs = res['result'], [] if args.hide_source else res['source_documents']
142
-
143
- # Print the result
144
- print("\n\n> Question:")
145
- print(query)
146
- print("\n> Answer:")
147
- print(answer)
148
-
149
- # Print the relevant sources used for the answer
150
- for document in docs:
151
- print("\n> " + document.metadata["source"] + ":")
152
- print(document.page_content)
153
-
154
- # If the call was not successful after 3 attempts, set the response to a timeout message
155
- if answer is None:
156
- print("Unfortunately, the connection to ChatGPT timed out. Please try after some time.")
157
- if history is not None and len(history) > 0:
158
- # Update the chat history with the bot's response
159
- history[-1][1] = apply_html(answer.text.strip(), "black")
160
- else:
161
- # Print the generated response
162
- print("\nGPT RESPONSE:\n")
163
- # print(answer['choices'][0]['message']['content'].strip())
164
-
165
- if history is not None and len(history) > 0:
166
- # Update the chat history with the bot's response
167
- history[-1][1] = apply_html(answer.strip(), "black")
168
- return history, fileListHistory
169
-
170
-
171
-
172
- # Open the image and convert it to base64
173
- with open(Path("rybot_small.png"), "rb") as img_file:
174
- img_str = base64.b64encode(img_file.read()).decode()
175
-
176
- html_code = f'''
177
- <!DOCTYPE html>
178
- <html>
179
- <head>
180
- <style>
181
- .center {{
182
- display: flex;
183
- justify-content: center;
184
- align-items: center;
185
- margin-top: -40px; /* adjust this value as per your requirement */
186
- margin-bottom: 5px;
187
- }}
188
- .large-text {{
189
- font-size: 40px;
190
- font-family: Arial, Helvetica, sans-serif;
191
- font-weight: 900 !important;
192
- margin-left: 5px;
193
- color: #5b5b5b !important;
194
- }}
195
- .image-container {{
196
- display: inline-block;
197
- vertical-align: middle;
198
- height: 50px; /* Twice the font-size */
199
- margin-bottom: 5px;
200
- }}
201
- </style>
202
- </head>
203
- <body>
204
- <div class="center">
205
- <img src="data:image/jpg;base64,{img_str}" alt="RyBOT image" class="image-container" />
206
- <strong class="large-text">RyBOT</strong>
207
- </div>
208
- <br>
209
- <div class="center">
210
- <h3> [ "I'm smart but the humans have me running on a hamster wheel. Please forgive the slow responses." ] </h3>
211
- </div>
212
- </body>
213
- </html>
214
- '''
215
-
216
-
217
- css = """
218
- .feedback textarea {background-color: #e9f0f7}
219
- .gradio-container {background-color: #eeeeee}
220
- """
221
-
222
- def clear_textbox():
223
- print("Calling CLEAR")
224
- return None
225
 
226
- with gr.Blocks(theme=gr.themes.Soft(), css=css, title="RyBOT") as demo:
227
-
228
- gr.HTML(html_code)
229
- chatbot = gr.Chatbot([], elem_id="chatbot", label="Chat", color_map=["blue","grey"]).style(height=450)
230
- fileListBot = gr.Chatbot([], elem_id="fileListBot", label="References", color_map=["blue","grey"]).style(height=150)
231
-
232
- txt = gr.Textbox(
233
- label="Type your query here:",
234
- placeholder="What would you like to find today?"
235
- ).style(container=True)
236
-
237
- txt.submit(
238
- add_text,
239
- [chatbot, txt],
240
- [chatbot, txt]
241
- ).then(
242
- bot,
243
- [txt, chatbot, fileListBot],
244
- [chatbot, fileListBot]
245
- ).then(
246
- clear_textbox,
247
- inputs=None,
248
- outputs=[txt]
249
- )
250
 
251
- btn = gr.Button(value="Send")
252
- btn.click(
253
- add_text,
254
- [chatbot, txt],
255
- [chatbot, txt],
256
- ).then(
257
- bot,
258
- [txt, chatbot, fileListBot],
259
- [chatbot, fileListBot]
260
- ).then(
261
- clear_textbox,
262
- inputs=None,
263
- outputs=[txt]
264
- )
265
-
266
- demo.launch()
 
1
  #!/usr/bin/env python3
2
+ import os
3
+ import glob
4
+ from typing import List
5
  from dotenv import load_dotenv
6
+ from multiprocessing import Pool
7
+ from tqdm import tqdm
8
+
9
+ from langchain.document_loaders import (
10
+ CSVLoader,
11
+ EverNoteLoader,
12
+ PDFMinerLoader,
13
+ TextLoader,
14
+ UnstructuredEmailLoader,
15
+ UnstructuredEPubLoader,
16
+ UnstructuredHTMLLoader,
17
+ UnstructuredMarkdownLoader,
18
+ UnstructuredODTLoader,
19
+ UnstructuredPowerPointLoader,
20
+ UnstructuredWordDocumentLoader,
21
+ )
22
+
23
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
24
  from langchain.vectorstores import Chroma
25
+ from langchain.embeddings import HuggingFaceEmbeddings
26
+ from langchain.docstore.document import Document
27
+ from constants import CHROMA_SETTINGS
 
 
 
28
 
 
29
 
30
+ load_dotenv()
 
31
 
 
 
 
32
 
33
+ # Load environment variables
34
+ persist_directory = os.environ.get('PERSIST_DIRECTORY')
35
+ source_directory = os.environ.get('SOURCE_DIRECTORY', 'source_documents')
36
+ embeddings_model_name = os.environ.get('EMBEDDINGS_MODEL_NAME')
37
+ chunk_size = 500
38
+ chunk_overlap = 50
39
+
40
+
41
+ # Custom document loaders
42
+ class MyElmLoader(UnstructuredEmailLoader):
43
+ """Wrapper to fallback to text/plain when default does not work"""
44
+
45
+ def load(self) -> List[Document]:
46
+ """Wrapper adding fallback for elm without html"""
47
+ try:
48
+ try:
49
+ doc = UnstructuredEmailLoader.load(self)
50
+ except ValueError as e:
51
+ if 'text/html content not found in email' in str(e):
52
+ # Try plain text
53
+ self.unstructured_kwargs["content_source"]="text/plain"
54
+ doc = UnstructuredEmailLoader.load(self)
55
+ else:
56
+ raise
57
+ except Exception as e:
58
+ # Add file_path to exception message
59
+ raise type(e)(f"{self.file_path}: {e}") from e
60
+
61
+ return doc
62
+
63
+
64
+ # Map file extensions to document loaders and their arguments
65
+ LOADER_MAPPING = {
66
+ ".csv": (CSVLoader, {}),
67
+ # ".docx": (Docx2txtLoader, {}),
68
+ ".doc": (UnstructuredWordDocumentLoader, {}),
69
+ ".docx": (UnstructuredWordDocumentLoader, {}),
70
+ ".enex": (EverNoteLoader, {}),
71
+ ".eml": (MyElmLoader, {}),
72
+ ".epub": (UnstructuredEPubLoader, {}),
73
+ ".html": (UnstructuredHTMLLoader, {}),
74
+ ".md": (UnstructuredMarkdownLoader, {}),
75
+ ".odt": (UnstructuredODTLoader, {}),
76
+ ".pdf": (PDFMinerLoader, {}),
77
+ ".ppt": (UnstructuredPowerPointLoader, {}),
78
+ ".pptx": (UnstructuredPowerPointLoader, {}),
79
+ ".txt": (TextLoader, {"encoding": "utf8"}),
80
+ # Add more mappings for other file extensions and loaders as needed
81
+ }
82
+
83
+
84
+ def load_single_document(file_path: str) -> Document:
85
+ ext = "." + file_path.rsplit(".", 1)[-1]
86
+ if ext in LOADER_MAPPING:
87
+ loader_class, loader_args = LOADER_MAPPING[ext]
88
+ loader = loader_class(file_path, **loader_args)
89
+ return loader.load()[0]
90
+
91
+ raise ValueError(f"Unsupported file extension '{ext}'")
92
+
93
+
94
+ def load_documents(source_dir: str, ignored_files: List[str] = []) -> List[Document]:
95
+ """
96
+ Loads all documents from the source documents directory, ignoring specified files
97
+ """
98
+ all_files = []
99
+ for ext in LOADER_MAPPING:
100
+ all_files.extend(
101
+ glob.glob(os.path.join(source_dir, f"**/*{ext}"), recursive=True)
102
+ )
103
+ filtered_files = [file_path for file_path in all_files if file_path not in ignored_files]
104
+
105
+ with Pool(processes=os.cpu_count()) as pool:
106
+ results = []
107
+ with tqdm(total=len(filtered_files), desc='Loading new documents', ncols=80) as pbar:
108
+ for i, doc in enumerate(pool.imap_unordered(load_single_document, filtered_files)):
109
+ results.append(doc)
110
+ pbar.update()
111
+
112
+ return results
113
+
114
+ def process_documents(ignored_files: List[str] = []) -> List[Document]:
115
+ """
116
+ Load documents and split in chunks
117
+ """
118
+ print(f"Loading documents from {source_directory}")
119
+ documents = load_documents(source_directory, ignored_files)
120
+ if not documents:
121
+ print("No new documents to load")
122
+ exit(0)
123
+ print(f"Loaded {len(documents)} new documents from {source_directory}")
124
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
125
+ texts = text_splitter.split_documents(documents)
126
+ print(f"Split into {len(texts)} chunks of text (max. {chunk_size} tokens each)")
127
+ return texts
128
+
129
+ def does_vectorstore_exist(persist_directory: str) -> bool:
130
+ """
131
+ Checks if vectorstore exists
132
+ """
133
+ if os.path.exists(os.path.join(persist_directory, 'index')):
134
+ if os.path.exists(os.path.join(persist_directory, 'chroma-collections.parquet')) and os.path.exists(os.path.join(persist_directory, 'chroma-embeddings.parquet')):
135
+ list_index_files = glob.glob(os.path.join(persist_directory, 'index/*.bin'))
136
+ list_index_files += glob.glob(os.path.join(persist_directory, 'index/*.pkl'))
137
+ # At least 3 documents are needed in a working vectorstore
138
+ if len(list_index_files) > 3:
139
+ return True
140
+ return False
141
 
142
  def main():
143
+ # Create embeddings
 
144
  embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ if does_vectorstore_exist(persist_directory):
147
+ # Update and store locally vectorstore
148
+ print(f"Appending to existing vectorstore at {persist_directory}")
149
+ db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
150
+ collection = db.get()
151
+ texts = process_documents([metadata['source'] for metadata in collection['metadatas']])
152
+ print(f"Creating embeddings. May take some minutes...")
153
+ db.add_documents(texts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  else:
155
+ # Create and store locally vectorstore
156
+ print("Creating new vectorstore")
157
+ texts = process_documents()
158
+ print(f"Creating embeddings. May take some minutes...")
159
+ db = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory, client_settings=CHROMA_SETTINGS)
160
+ db.persist()
161
+ db = None
162
 
163
+ print(f"Ingestion complete! You can now run privateGPT.py to query your documents")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
+ if __name__ == "__main__":
167
+ main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
privategpt.py CHANGED
@@ -263,9 +263,4 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css, title="RyBOT") as demo:
263
  outputs=[txt]
264
  )
265
 
266
- gr.close_all()
267
- demo.launch(server_port=7861)
268
-
269
-
270
- #if __name__ == "__main__":
271
- # main()
 
263
  outputs=[txt]
264
  )
265
 
266
+ demo.launch()