teganmosi commited on
Commit
b1eeebb
·
1 Parent(s): ad934fd

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +424 -0
utils.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """utils.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1Nh7BlDmV5_ZCWOQO0GxMOn597Ztc_k71
8
+ """
9
+
10
+ pip install deeplake openai streamlit python-dotenv
11
+
12
+ pip install langchain==0.0.208 deeplake openai tiktoken
13
+
14
+ import logging
15
+ import os
16
+ import re
17
+ import shutil
18
+ import sys
19
+ from typing import List
20
+
21
+ import deeplake
22
+ import openai
23
+ import streamlit as st
24
+ from dotenv import load_dotenv
25
+ from langchain.callbacks import OpenAICallbackHandler, get_openai_callback
26
+ from langchain.chains import ConversationalRetrievalChain
27
+ from langchain.chat_models import ChatOpenAI
28
+ from langchain.document_loaders import (
29
+ CSVLoader,
30
+ DirectoryLoader,
31
+ GitLoader,
32
+ NotebookLoader,
33
+ OnlinePDFLoader,
34
+ PythonLoader,
35
+ TextLoader,
36
+ UnstructuredFileLoader,
37
+ UnstructuredHTMLLoader,
38
+ UnstructuredPDFLoader,
39
+ UnstructuredWordDocumentLoader,
40
+ WebBaseLoader,
41
+ )
42
+ from langchain.embeddings.openai import OpenAIEmbeddings
43
+ from langchain.schema import Document
44
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
45
+ from langchain.vectorstores import DeepLake, VectorStore
46
+ from streamlit.uploaded_file_manager import UploadedFile
47
+
48
+ from constants import (
49
+ APP_NAME,
50
+ CHUNK_SIZE,
51
+ DATA_PATH,
52
+ FETCH_K,
53
+ MAX_TOKENS,
54
+ MODEL,
55
+ PAGE_ICON,
56
+ REPO_URL,
57
+ TEMPERATURE,
58
+ K,
59
+ )
60
+
61
+ # loads environment variables
62
+ load_dotenv()
63
+
64
+ logger = logging.getLogger(APP_NAME)
65
+
66
+
67
+ def configure_logger(debug: int = 0) -> None:
68
+ # boilerplate code to enable logging in the streamlit app console
69
+ log_level = logging.DEBUG if debug == 1 else logging.INFO
70
+ logger.setLevel(log_level)
71
+
72
+ stream_handler = logging.StreamHandler(stream=sys.stdout)
73
+ stream_handler.setLevel(log_level)
74
+
75
+ formatter = logging.Formatter("%(message)s")
76
+
77
+ stream_handler.setFormatter(formatter)
78
+
79
+ logger.addHandler(stream_handler)
80
+ logger.propagate = False
81
+
82
+
83
+ configure_logger(0)
84
+
85
+
86
+ def authenticate(
87
+ openai_api_key: str, activeloop_token: str, activeloop_org_name: str
88
+ ) -> None:
89
+ # Validate all credentials are set and correct
90
+ # Check for env variables to enable local dev and deployments with shared credentials
91
+ openai_api_key = (
92
+ openai_api_key
93
+ or os.environ.get("OPENAI_API_KEY")
94
+ or st.secrets.get("OPENAI_API_KEY")
95
+ )
96
+ activeloop_token = (
97
+ activeloop_token
98
+ or os.environ.get("ACTIVELOOP_TOKEN")
99
+ or st.secrets.get("ACTIVELOOP_TOKEN")
100
+ )
101
+ activeloop_org_name = (
102
+ activeloop_org_name
103
+ or os.environ.get("ACTIVELOOP_ORG_NAME")
104
+ or st.secrets.get("ACTIVELOOP_ORG_NAME")
105
+ )
106
+ if not (openai_api_key and activeloop_token and activeloop_org_name):
107
+ st.session_state["auth_ok"] = False
108
+ st.error("Credentials neither set nor stored", icon=PAGE_ICON)
109
+ return
110
+ try:
111
+ # Try to access openai and deeplake
112
+ with st.spinner("Authenticating..."):
113
+ openai.api_key = openai_api_key
114
+ openai.Model.list()
115
+ deeplake.exists(
116
+ f"hub://{activeloop_org_name}/DataChad-Authentication-Check",
117
+ token=activeloop_token,
118
+ )
119
+ except Exception as e:
120
+ logger.error(f"Authentication failed with {e}")
121
+ st.session_state["auth_ok"] = False
122
+ st.error("Authentication failed", icon=PAGE_ICON)
123
+ return
124
+ # store credentials in the session state
125
+ st.session_state["auth_ok"] = True
126
+ st.session_state["openai_api_key"] = openai_api_key
127
+ st.session_state["activeloop_token"] = activeloop_token
128
+ st.session_state["activeloop_org_name"] = activeloop_org_name
129
+ logger.info("Authentication successful!")
130
+
131
+
132
+ def advanced_options_form() -> None:
133
+ # Input Form that takes advanced options and rebuilds chain with them
134
+ advanced_options = st.checkbox(
135
+ "Advanced Options", help="Caution! This may break things!"
136
+ )
137
+ if advanced_options:
138
+ with st.form("advanced_options"):
139
+ temperature = st.slider(
140
+ "temperature",
141
+ min_value=0.0,
142
+ max_value=1.0,
143
+ value=TEMPERATURE,
144
+ help="Controls the randomness of the language model output",
145
+ )
146
+ col1, col2 = st.columns(2)
147
+ fetch_k = col1.number_input(
148
+ "k_fetch",
149
+ min_value=1,
150
+ max_value=1000,
151
+ value=FETCH_K,
152
+ help="The number of documents to pull from the vector database",
153
+ )
154
+ k = col2.number_input(
155
+ "k",
156
+ min_value=1,
157
+ max_value=100,
158
+ value=K,
159
+ help="The number of most similar documents to build the context from",
160
+ )
161
+ chunk_size = col1.number_input(
162
+ "chunk_size",
163
+ min_value=1,
164
+ max_value=100000,
165
+ value=CHUNK_SIZE,
166
+ help=(
167
+ "The size at which the text is divided into smaller chunks "
168
+ "before being embedded.\n\nChanging this parameter makes re-embedding "
169
+ "and re-uploading the data to the database necessary "
170
+ ),
171
+ )
172
+ max_tokens = col2.number_input(
173
+ "max_tokens",
174
+ min_value=1,
175
+ max_value=4069,
176
+ value=MAX_TOKENS,
177
+ help="Limits the documents returned from database based on number of tokens",
178
+ )
179
+ applied = st.form_submit_button("Apply")
180
+ if applied:
181
+ st.session_state["k"] = k
182
+ st.session_state["fetch_k"] = fetch_k
183
+ st.session_state["chunk_size"] = chunk_size
184
+ st.session_state["temperature"] = temperature
185
+ st.session_state["max_tokens"] = max_tokens
186
+ update_chain()
187
+
188
+
189
+ def save_uploaded_file(uploaded_file: UploadedFile) -> str:
190
+ # streamlit uploaded files need to be stored locally
191
+ # before embedded and uploaded to the hub
192
+ if not os.path.exists(DATA_PATH):
193
+ os.makedirs(DATA_PATH)
194
+ file_path = str(DATA_PATH / uploaded_file.name)
195
+ uploaded_file.seek(0)
196
+ file_bytes = uploaded_file.read()
197
+ file = open(file_path, "wb")
198
+ file.write(file_bytes)
199
+ file.close()
200
+ logger.info(f"Saved: {file_path}")
201
+ return file_path
202
+
203
+
204
+ def delete_uploaded_file(uploaded_file: UploadedFile) -> None:
205
+ # cleanup locally stored files
206
+ file_path = str(DATA_PATH / uploaded_file.name)
207
+ if os.path.exists(file_path):
208
+ os.remove(file_path)
209
+ logger.info(f"Removed: {file_path}")
210
+
211
+
212
+ def handle_load_error(e: str = None) -> None:
213
+ e = e or f"No Loader found for your data source. Consider contributing: {REPO_URL}!"
214
+ error_msg = f"Failed to load {st.session_state['data_source']} with Error:\n{e}"
215
+ st.error(error_msg, icon=PAGE_ICON)
216
+ logger.info(error_msg)
217
+ st.stop()
218
+
219
+
220
+ def load_git(data_source: str, chunk_size: int = CHUNK_SIZE) -> List[Document]:
221
+ # We need to try both common main branches
222
+ # Thank you GitHub for the "master" to "main" switch
223
+ repo_name = data_source.split("/")[-1].split(".")[0]
224
+ repo_path = str(DATA_PATH / repo_name)
225
+ text_splitter = RecursiveCharacterTextSplitter(
226
+ chunk_size=chunk_size, chunk_overlap=0
227
+ )
228
+ branches = ["main", "master"]
229
+ for branch in branches:
230
+ if os.path.exists(repo_path):
231
+ data_source = None
232
+ try:
233
+ docs = GitLoader(repo_path, data_source, branch).load_and_split(
234
+ text_splitter
235
+ )
236
+ break
237
+ except Exception as e:
238
+ logger.info(f"Error loading git: {e}")
239
+ if os.path.exists(repo_path):
240
+ # cleanup repo afterwards
241
+ shutil.rmtree(repo_path)
242
+ try:
243
+ return docs
244
+ except Exception as e:
245
+ handle_load_error()
246
+
247
+
248
+ def load_any_data_source(
249
+ data_source: str, chunk_size: int = CHUNK_SIZE
250
+ ) -> List[Document]:
251
+ # Ugly thing that decides how to load data
252
+ # It ain't much, but it's honest work
253
+ is_text = data_source.endswith(".txt")
254
+ is_web = data_source.startswith("http")
255
+ is_pdf = data_source.endswith(".pdf")
256
+ is_csv = data_source.endswith(".csv")
257
+ is_html = data_source.endswith(".html")
258
+ is_git = data_source.endswith(".git")
259
+ is_notebook = data_source.endswith(".ipynb")
260
+ is_doc = data_source.endswith(".doc")
261
+ is_py = data_source.endswith(".py")
262
+ is_dir = os.path.isdir(data_source)
263
+ is_file = os.path.isfile(data_source)
264
+
265
+ loader = None
266
+ if is_dir:
267
+ loader = DirectoryLoader(data_source, recursive=True, silent_errors=True)
268
+ elif is_git:
269
+ return load_git(data_source, chunk_size)
270
+ elif is_web:
271
+ if is_pdf:
272
+ loader = OnlinePDFLoader(data_source)
273
+ else:
274
+ loader = WebBaseLoader(data_source)
275
+ elif is_file:
276
+ if is_text:
277
+ loader = TextLoader(data_source)
278
+ elif is_notebook:
279
+ loader = NotebookLoader(data_source)
280
+ elif is_pdf:
281
+ loader = UnstructuredPDFLoader(data_source)
282
+ elif is_html:
283
+ loader = UnstructuredHTMLLoader(data_source)
284
+ elif is_doc:
285
+ loader = UnstructuredWordDocumentLoader(data_source)
286
+ elif is_csv:
287
+ loader = CSVLoader(data_source, encoding="utf-8")
288
+ elif is_py:
289
+ loader = PythonLoader(data_source)
290
+ else:
291
+ loader = UnstructuredFileLoader(data_source)
292
+ try:
293
+ # Chunk size is a major trade-off parameter to control result accuracy over computation
294
+ text_splitter = RecursiveCharacterTextSplitter(
295
+ chunk_size=chunk_size, chunk_overlap=0
296
+ )
297
+ docs = loader.load_and_split(text_splitter)
298
+ logger.info(f"Loaded: {len(docs)} document chunks")
299
+ return docs
300
+ except Exception as e:
301
+ handle_load_error(e if loader else None)
302
+
303
+
304
+ def clean_data_source_string(data_source_string: str) -> str:
305
+ # replace all non-word characters with dashes
306
+ # to get a string that can be used to create a new dataset
307
+ dashed_string = re.sub(r"\W+", "-", data_source_string)
308
+ cleaned_string = re.sub(r"--+", "- ", dashed_string).strip("-")
309
+ return cleaned_string
310
+
311
+
312
+ def setup_vector_store(data_source: str, chunk_size: int = CHUNK_SIZE) -> VectorStore:
313
+ # either load existing vector store or upload a new one to the hub
314
+ embeddings = OpenAIEmbeddings(
315
+ disallowed_special=(), openai_api_key=st.session_state["openai_api_key"]
316
+ )
317
+ data_source_name = clean_data_source_string(data_source)
318
+ dataset_path = f"hub://{st.session_state['activeloop_org_name']}/{data_source_name}-{chunk_size}"
319
+ if deeplake.exists(dataset_path, token=st.session_state["activeloop_token"]):
320
+ with st.spinner("Loading vector store..."):
321
+ logger.info(f"Dataset '{dataset_path}' exists -> loading")
322
+ vector_store = DeepLake(
323
+ dataset_path=dataset_path,
324
+ read_only=True,
325
+ embedding_function=embeddings,
326
+ token=st.session_state["activeloop_token"],
327
+ )
328
+ else:
329
+ with st.spinner("Reading, embedding and uploading data to hub..."):
330
+ logger.info(f"Dataset '{dataset_path}' does not exist -> uploading")
331
+ docs = load_any_data_source(data_source, chunk_size)
332
+ vector_store = DeepLake.from_documents(
333
+ docs,
334
+ embeddings,
335
+ dataset_path=dataset_path,
336
+ token=st.session_state["activeloop_token"],
337
+ )
338
+ return vector_store
339
+
340
+
341
+ def build_chain(
342
+ data_source: str,
343
+ k: int = K,
344
+ fetch_k: int = FETCH_K,
345
+ chunk_size: int = CHUNK_SIZE,
346
+ temperature: float = TEMPERATURE,
347
+ max_tokens: int = MAX_TOKENS,
348
+ ) -> ConversationalRetrievalChain:
349
+ # create the langchain that will be called to generate responses
350
+ vector_store = setup_vector_store(data_source, chunk_size)
351
+ retriever = vector_store.as_retriever()
352
+ # Search params "fetch_k" and "k" define how many documents are pulled from the hub
353
+ # and selected after the document matching to build the context
354
+ # that is fed to the model together with your prompt
355
+ search_kwargs = {
356
+ "maximal_marginal_relevance": True,
357
+ "distance_metric": "cos",
358
+ "fetch_k": fetch_k,
359
+ "k": k,
360
+ }
361
+ retriever.search_kwargs.update(search_kwargs)
362
+ model = ChatOpenAI(
363
+ model_name=MODEL,
364
+ temperature=temperature,
365
+ openai_api_key=st.session_state["openai_api_key"],
366
+ )
367
+ chain = ConversationalRetrievalChain.from_llm(
368
+ model,
369
+ retriever=retriever,
370
+ chain_type="stuff",
371
+ verbose=True,
372
+ # we limit the maximum number of used tokens
373
+ # to prevent running into the model's token limit of 4096
374
+ max_tokens_limit=max_tokens,
375
+ )
376
+ logger.info(f"Data source '{data_source}' is ready to go!")
377
+ return chain
378
+
379
+
380
+ def update_chain() -> None:
381
+ # Build chain with parameters from session state and store it back
382
+ # Also delete chat history to not confuse the bot with old context
383
+ try:
384
+ st.session_state["chain"] = build_chain(
385
+ data_source=st.session_state["data_source"],
386
+ k=st.session_state["k"],
387
+ fetch_k=st.session_state["fetch_k"],
388
+ chunk_size=st.session_state["chunk_size"],
389
+ temperature=st.session_state["temperature"],
390
+ max_tokens=st.session_state["max_tokens"],
391
+ )
392
+ st.session_state["chat_history"] = []
393
+ except Exception as e:
394
+ msg = f"Failed to build chain for data source {st.session_state['data_source']} with error: {e}"
395
+ logger.error(msg)
396
+ st.error(msg, icon=PAGE_ICON)
397
+
398
+
399
+ def update_usage(cb: OpenAICallbackHandler) -> None:
400
+ # Accumulate API call usage via callbacks
401
+ logger.info(f"Usage: {cb}")
402
+ callback_properties = [
403
+ "total_tokens",
404
+ "prompt_tokens",
405
+ "completion_tokens",
406
+ "total_cost",
407
+ ]
408
+ for prop in callback_properties:
409
+ value = getattr(cb, prop, 0)
410
+ st.session_state["usage"].setdefault(prop, 0)
411
+ st.session_state["usage"][prop] += value
412
+
413
+
414
+ def generate_response(prompt: str) -> str:
415
+ # call the chain to generate responses and add them to the chat history
416
+ with st.spinner("Generating response"), get_openai_callback() as cb:
417
+ response = st.session_state["chain"](
418
+ {"question": prompt, "chat_history": st.session_state["chat_history"]}
419
+ )
420
+ update_usage(cb)
421
+ logger.info(f"Response: '{response}'")
422
+ st.session_state["chat_history"].append((prompt, response["answer"]))
423
+ return response["answer"]
424
+