aniruddhdoki commited on
Commit
10f1f60
·
1 Parent(s): b67de0f

refreshed ui, added support for response streaming, and added new database.

Browse files
.streamlit/secrets.toml CHANGED
@@ -3,4 +3,5 @@ LANGCHAIN_ENDPOINT="https://api.smith.langchain.com"
3
  LANGCHAIN_API_KEY="ls__3382b1f40a7f4eefa6959cb2b03dd687"
4
  LANGCHAIN_PROJECT="ConsultAI v1"
5
  OPENAI_API_KEY = "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"
6
- APIFY_CLIENT_KEY = "apify_api_GhFIqZgUf2BGqO46OdBcQOyk2rekQt0ns3Wv"
 
 
3
  LANGCHAIN_API_KEY="ls__3382b1f40a7f4eefa6959cb2b03dd687"
4
  LANGCHAIN_PROJECT="ConsultAI v1"
5
  OPENAI_API_KEY = "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"
6
+ APIFY_CLIENT_KEY = "apify_api_GhFIqZgUf2BGqO46OdBcQOyk2rekQt0ns3Wv"
7
+ DEV_DATASET = "https://huggingface.co/datasets/aniruddhdoki/Atticus_DEV"
app.py CHANGED
@@ -1,142 +1,93 @@
1
- import os
2
- import openai
3
  import streamlit as st
4
- from streamlit_chat import message
5
- from utils.db import Database
6
 
7
- st.title("ATTICUS")
8
- st.subheader("I know everything")
 
 
 
 
 
9
 
10
  if "messages" not in st.session_state:
11
- st.session_state["messages"] = []
12
 
13
  if "history" not in st.session_state:
14
- st.session_state["history"] = []
15
-
16
- if "input_disabled" not in st.session_state:
17
- st.session_state["input_disabled"] = False
18
-
19
-
20
- openai.api_key = st.secrets.get('OPENAI_API_KEY')
21
- if not openai.api_key:
22
- openai_api_key = st.text_input(
23
- "Enter your OpenAI API key here:",
24
- disabled=st.session_state["input_disabled"],
25
- type="password"
26
- )
27
- st.write("Please enter your OpenAI API key above")
28
- st.stop()
29
-
30
- @st.cache_resource
31
- def initialize_db():
32
- return Database()
33
- db = initialize_db()
34
-
35
- def ingest(uploaded_files):
36
- db.add_files(uploaded_files)
37
-
38
- def disable_input_cb():
39
- st.session_state["input_disabled"] = True
40
-
41
- def enable_input_cb():
42
- st.session_state["input_disabled"] = False
43
-
44
- with st.sidebar:
45
- st.write('Upload knowledge here. Can digest .pdf files. May take a while to process')
46
- uploaded_files = st.file_uploader(
47
- "Upload Files",
48
- accept_multiple_files=True,
49
- type=['pdf'],
50
- disabled=st.session_state["input_disabled"]
51
- )
52
-
53
- if uploaded_files:
54
- if st.button("Ingest PDFs", on_click=disable_input_cb):
55
- with st.spinner("Analyzing... DON'T INTERACT WITH THE PAGE UNTIL IT'S DONE!"):
56
- disable_input_cb()
57
- ingest(uploaded_files)
58
- st.success("Done!")
59
- st.balloons()
60
- enable_input_cb()
61
- st.experimental_rerun()
62
-
63
- st.subheader('current files in db')
64
- for file in db.view_db().get('files', {}).get('documents', []):
65
- st.write(file)
66
-
67
- def get_text():
68
- input_text = st.text_input(
69
- "You: ",
70
- help="Ask your questions here!",
71
- key="input",
72
- disabled=st.session_state["input_disabled"]
73
- )
74
- return input_text
75
-
76
- user_input = get_text()
77
-
78
- def generate_response(query):
79
- search = db.similarity_search(query, 5)
80
- sources = """"""
81
- print(search)
82
- for metadata in search.get('metadatas', [])[0]:
83
- if metadata.get('page_num', False):
84
- sources += f"\nHarvard CBE Consultants"
85
- else:
86
- sources += f"\n{metadata['link']}"
87
-
88
- context = '\n'.join([' '.join(document) for document in search['documents']])
89
- print(context)
90
- template = f"""
91
- You are a smart, witty, hipster, know-it-all young generative AI consultant out of Harvard University. You will answer a question given the following context:
92
-
93
- Here's the context: {context}
94
-
95
- Here's the question: {query}
96
-
97
- Now generate an answer to the question using the context provided.
98
- """
99
- st.session_state["history"].append({
100
- 'role': 'user',
101
- 'content': template.format(context=context, query=query),
102
- })
103
- st.session_state["messages"].append("You: " + query)
104
-
105
- response = openai.ChatCompletion.create(
106
- model='gpt-3.5-turbo-0613',
107
- messages=st.session_state["history"]
108
- )
109
- formatted_response = f"""
110
- Assistant: {response.get('choices', [])[0].get('message', {}).get('content', 'Response generated but error displaying. Please try again')}
111
-
112
- Sources: {sources}
113
  """
114
- return formatted_response
115
-
116
- if user_input:
117
- res = generate_response(user_input)
118
- st.session_state["history"].append({
119
- 'role': 'assistant',
120
- 'content': res
121
- })
122
- st.session_state["messages"].append("Assistant: " + res)
123
-
124
- print(st.session_state["messages"])
125
-
126
- for msg in st.session_state["messages"]:
127
- st.write(msg)
128
-
129
- with st.expander("debug db"):
130
- if st.button(
131
- "empty db. this is very dangerous. pls don't.",
132
- on_click=lambda: disable_input_cb,
133
- disabled=st.session_state["input_disabled"],
134
- ):
135
- db.empty()
136
- st.write("emptied db")
137
- enable_input_cb()
138
- st.experimental_rerun()
139
-
140
- st.write(db.view_db())
141
-
142
- db._save_db()
 
 
 
1
  import streamlit as st
2
+ import openai
3
+ from supabase import client, create_client
4
 
5
+ openai.api_key="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"
6
+ supabase = create_client(
7
+ "https://djfytlaeuxwbcztcwrzt.supabase.co",
8
+ "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImRqZnl0bGFldXh3YmN6dGN3cnp0Iiwicm9sZSI6ImFub24iLCJpYXQiOjE3MDA1MTM4NTEsImV4cCI6MjAxNjA4OTg1MX0.G0ni6ZeuU1NxnoLhXyY99agZz6TkufxFc5-LAat2sk4",
9
+ )
10
+
11
+ st.title("Atticus")
12
 
13
  if "messages" not in st.session_state:
14
+ st.session_state.messages = []
15
 
16
  if "history" not in st.session_state:
17
+ st.session_state.history = []
18
+
19
+ for message in st.session_state.history:
20
+ with st.chat_message(message["role"]):
21
+ st.markdown(message["content"])
22
+
23
+ prompt = st.chat_input("I literally know everything about financial AI. Ask me anything.")
24
+
25
+ # def example_prompt_cb(text):
26
+ # global prompt
27
+ # prompt = text
28
+
29
+ # if not prompt:
30
+ # st.button(
31
+ # "What are top financial companies thinking about when it comes to Artificial Intelligence and large language models?",
32
+ # on_click=example_prompt_cb,
33
+ # args=["What are top financial companies thinking about when it comes to Artificial Intelligence and large language models?"]
34
+ # )
35
+
36
+
37
+ if prompt:
38
+ # embed question
39
+ embedding = openai.embeddings.create(
40
+ model='text-embedding-ada-002',
41
+ input=prompt,
42
+ encoding_format='float'
43
+ ).data[0].embedding
44
+
45
+ # perform similarity search
46
+ data = supabase.rpc("match_documents", {
47
+ "match_count": 10,
48
+ "query_embedding": embedding
49
+ }).execute().data
50
+
51
+ # construct prompt
52
+ context = '\n'.join([doc.get('content', '') for doc in data])
53
+ sources = '\n'.join(set([doc.get('metadata', {}).get('source', '') for doc in data]))
54
+ query = """
55
+ You are a smart, witty, hipster, know-it-all young finance and generative AI consultant out of Harvard University. Be snarky, know-it-all, hipster, and above all, be cool. Give as much information as possible and be as helpful as possible. Be detailed. Give as much detail as possible. Produce your answer in MARKDOWN format. Only answer the question asked. Only answer about finance.
56
+
57
+ LLMS = Large Language Models
58
+
59
+ Answer the following question given the context below. Do not use any information outside of this context. If you don't have the answer in the given context, say you do not know. DO NOT PRODUCE URLS.
60
+
61
+ Here is the question: {prompt}
62
+
63
+ Here is the context: {context}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  """
65
+ query = query.format(prompt=prompt, context=context)
66
+
67
+ # generate response
68
+ st.session_state.history.append({"role": "user", "content": prompt})
69
+ st.session_state.messages.append({"role": "user", "content": prompt})
70
+ with st.chat_message("user"):
71
+ st.markdown(prompt)
72
+
73
+ with st.chat_message("assistant"):
74
+ message_placeholder = st.empty()
75
+ full_response = ""
76
+ for response in openai.chat.completions.create(
77
+ model='gpt-3.5-turbo-16k-0613',
78
+ messages=[
79
+ {"role": m["role"], "content": m["content"]}
80
+ for m in st.session_state.messages
81
+ ],
82
+ stream=True,
83
+ ):
84
+ full_response += (response.choices[0].delta.content or "")
85
+ message_placeholder.markdown(full_response + "▌")
86
+ final_message = full_response
87
+ sources = sources.split('\n')
88
+ sources = '\n'.join([f"{num + 1}. {source}" for num, source in enumerate(sources)])
89
+ print(sources)
90
+ full_response += f"\n\nSources: \n{sources}"
91
+ message_placeholder.markdown(full_response)
92
+ st.session_state.messages.append({"role": "assistant", "content": final_message})
93
+ st.session_state.history.append({"role": "assistant", "content": full_response})
deprecated/app.py DELETED
@@ -1,255 +0,0 @@
1
- import os
2
- import PyPDF2
3
- import chromadb
4
- import openai
5
- import uuid
6
- import streamlit as st
7
- from streamlit_chat import message
8
- from bs4 import BeautifulSoup
9
- from apify_client import ApifyClient
10
- from pprint import pprint
11
- from utils.split import split
12
- from utils.db import initialize_db
13
-
14
- ## NOTE: STREAMLIT RUNS THE ENTIRE SCRIPT FROM TOP TO BOTTOM ON EVERY USER INTERACTION
15
-
16
- ## streamlit ui
17
- st.title("Atticus")
18
- st.subheader("STILL IN DEVELOPMENT. DO NOT USE 'UPLOAD FILES' FEATURE IN SIDEBAR YET. IF NO DATA CAN BE SEEN IN DATABASE CONTACT ME.")
19
- openai_api_key = st.text_input("Enter your OpenAI API key here:")
20
- openai.api_key = openai_api_key
21
- if not openai_api_key:
22
- st.write("Please enter your OpenAI API key above")
23
- st.stop()
24
-
25
- # create vectorstore globally (persists data across runs)
26
- db = initialize_db()
27
-
28
- # scrape links
29
- @st.cache_data
30
- def scraper(links):
31
- data = {
32
- 'id': [],
33
- 'embeddings': [],
34
- 'metadatas': [],
35
- 'documents': [],
36
- 'links': []
37
- }
38
-
39
- client = ApifyClient(os.environ.get('APIFY_CLIENT_KEY'))
40
-
41
- run_input = {
42
- "startUrls": [{ "url": link } for link in links],
43
- "crawlerType": "playwright:firefox",
44
- "includeUrlGlobs": [],
45
- "excludeUrlGlobs": [],
46
- "maxCrawlDepth": 0,
47
- "maxCrawlPages": 9999999,
48
- "initialConcurrency": 0,
49
- "maxConcurrency": 200,
50
- "initialCookies": [],
51
- "proxyConfiguration": { "useApifyProxy": True },
52
- "requestTimeoutSecs": 60,
53
- "dynamicContentWaitSecs": 10,
54
- "maxScrollHeightPixels": 5000,
55
- "removeElementsCssSelector": """nav, footer, script, style, noscript, svg,
56
- [role=\"alert\"],
57
- [role=\"banner\"],
58
- [role=\"dialog\"],
59
- [role=\"alertdialog\"],
60
- [role=\"region\"][aria-label*=\"skip\" i],
61
- [aria-modal=\"true\"]""",
62
- "removeCookieWarnings": True,
63
- "clickElementsCssSelector": "[aria-expanded=\"false\"]",
64
- "htmlTransformer": "readableText",
65
- "readableTextCharThreshold": 100,
66
- "aggressivePrune": False,
67
- "debugMode": False,
68
- "debugLog": False,
69
- "saveHtml": False,
70
- "saveMarkdown": False,
71
- "saveFiles": False,
72
- "saveScreenshots": False,
73
- "maxResults": 9999999,
74
- }
75
-
76
- try:
77
- run = client.actor("apify/website-content-crawler").call(run_input=run_input)
78
- for item in client.dataset(run["defaultDatasetId"]).iterate_items():
79
- text = item['text']
80
- text = list(split(text))
81
- for t in text:
82
- metadata = {'link': item['url']}
83
- data['id'].append(str(uuid.uuid4()))
84
- data['documents'].append(t)
85
- data['metadatas'].append(metadata)
86
- # generate embedding
87
- embedding = openai.Embedding.create(
88
- model='text-embedding-ada-002',
89
- input=text,
90
- encoding_format='float'
91
- ).get('data', [])[0].get('embedding', [])
92
- data['embeddings'].append(embedding)
93
- collection = db.get_or_create_collection('documents')
94
- collection.add(
95
- ids=data['id'],
96
- embeddings=data['embeddings'],
97
- documents=data['documents'],
98
- metadatas=data['metadatas'], ## TODO: ADD SUPPORT FOR METADATA
99
- )
100
- except Exception as e:
101
- print(e)
102
-
103
- if "disabled" not in st.session_state:
104
- st.session_state["disabled"] = False
105
-
106
- # process input pdfs
107
- @st.cache_data
108
- def ingest_pdfs(files):
109
- """
110
- process input pdfs and add to a browser cached chromadb
111
- """
112
-
113
- status = st.empty()
114
-
115
- data = {
116
- 'id': [],
117
- 'embeddings': [],
118
- 'metadatas': [],
119
- 'documents': [],
120
- 'links': []
121
- }
122
-
123
- status.write('Processing PDFs...')
124
- print('cooking...')
125
- for file in files:
126
- if '.pdf' not in file.name:
127
- status.write(f'Skipping {file.name} (not a PDF)')
128
- continue
129
- status.write(f'Processing {file.name}...')
130
- reader = PyPDF2.PdfReader(file)
131
- for num in range(len(reader.pages)):
132
- # extract text
133
- page = reader.pages[num]
134
- text = page.extract_text()
135
- metadata = {'page_number': num}
136
- data['id'].append(str(uuid.uuid4()))
137
- data['documents'].append(text)
138
- data['metadatas'].append(metadata)
139
-
140
- # generate embedding
141
- embedding = openai.Embedding.create(
142
- model='text-embedding-ada-002',
143
- input=text,
144
- encoding_format='float'
145
- ).get('data', [])[0].get('embedding', [])
146
- data['embeddings'].append(embedding)
147
-
148
- # extract links
149
- if page.get('/Annots'):
150
- annotations = page['/Annots']
151
- for annotation in annotations:
152
- content = annotation.get_object()
153
- if content['/Subtype'] == '/Link':
154
- uri = content['/A']['/URI']
155
- data['links'].append(uri)
156
-
157
- if len(data['links']) > 0:
158
- status.write('Scraping all source links (may take 10+ minutes)...')
159
- scraper(data['links'])
160
- else:
161
- status.write('No links found in PDFs')
162
-
163
- # add data to db
164
- status.write('Adding data to database...')
165
- collection = db.get_or_create_collection('documents')
166
- collection.add(
167
- ids=data['id'],
168
- embeddings=data['embeddings'],
169
- documents=data['documents'],
170
- metadatas=data['metadatas'], ## TODO: ADD SUPPORT FOR METADATA
171
- )
172
- status.write('Done!')
173
- print('done')
174
-
175
- def processing_callback():
176
- st.session_state["disabled"] = True
177
-
178
- # sidebar for file input
179
- with st.sidebar:
180
- st.write("Upload Files (may take a while to process)")
181
- uploaded_files = st.file_uploader(
182
- "Choose 1 or more .PDF files",
183
- accept_multiple_files=True,
184
- disabled=st.session_state["disabled"]
185
- )
186
- if uploaded_files:
187
- if st.button("Ingest PDFs"):
188
- with st.spinner("Analyzing... DON'T INTERACT WITH THE PAGE UNTIL IT'S DONE!"):
189
- ingest_pdfs(uploaded_files)
190
-
191
- if "messages" not in st.session_state:
192
- st.session_state["messages"] = []
193
-
194
- if "history" not in st.session_state:
195
- st.session_state["history"] = []
196
-
197
- def get_text():
198
- input_text = st.text_input(
199
- "You: ",
200
- help="Ask your questions here!",
201
- key="input",
202
- disabled=st.session_state["disabled"]
203
- )
204
- return input_text
205
-
206
- user_input = get_text()
207
-
208
- def generate_response(query):
209
- # chromadb query
210
- collection = db.get_or_create_collection('documents')
211
- query_embedding = openai.Embedding.create(
212
- model='text-embedding-ada-002',
213
- input=query,
214
- encoding_format='float'
215
- ).get('data', [])[0].get('embedding', [])
216
- context = collection.query(query_embedding, n_results=5)
217
- context = '\n'.join([' '.join(document) for document in context['documents']])
218
-
219
- # generate response
220
- template = f"""
221
- You are a smart, witty, hipster, know-it-all young generative AI consultant out of Harvard University. You will answer a question given the following context:
222
-
223
- Here's the context: {context}
224
-
225
- Here's the question: {query}
226
-
227
- Now generate an answer to the question using the context provided.
228
- """
229
- st.session_state["history"].append({
230
- 'role': 'user',
231
- 'content': template.format(context=context, query=query),
232
- })
233
- st.session_state["messages"].append("You: " + query)
234
-
235
- response = openai.ChatCompletion.create(
236
- model='gpt-4',
237
- messages=st.session_state["history"]
238
- )
239
- return response.get('choices', [])[0].get('message', {}).get('content', '')
240
-
241
- if user_input:
242
- res = generate_response(user_input)
243
- st.session_state["history"].append({
244
- 'role': 'assistant',
245
- 'content': res
246
- })
247
- st.session_state["messages"].append("Assistant: " + res)
248
-
249
- print(st.session_state["messages"])
250
-
251
- for msg in st.session_state["messages"]:
252
- st.write(msg)
253
-
254
- with st.expander("View Database"):
255
- st.write(db.get_or_create_collection('documents').peek(db.get_or_create_collection('documents').count()))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/__pycache__/db.cpython-310.pyc DELETED
Binary file (13.4 kB)
 
utils/db.py DELETED
@@ -1,466 +0,0 @@
1
- import streamlit as st
2
- import huggingface_hub as hf_hub
3
- import chromadb
4
- import os
5
- import uuid
6
- import PyPDF2
7
- import openai
8
- from apify_client import ApifyClient
9
- from chromadb.utils import embedding_functions
10
-
11
- os.environ["HF_TOKEN"] = "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"
12
- os.environ["OPENAI_API_KEY"] = "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"
13
- os.environ["APIFY_CLIENT_KEY"] = "apify_api_GhFIqZgUf2BGqO46OdBcQOyk2rekQt0ns3Wv"
14
-
15
- class Database(object):
16
- """
17
- Singleton ChromaDB database with persistence on HuggingFace Dataset repo
18
- """
19
- dataset_repo_url = "https://huggingface.co/datasets/aniruddhdoki/ConsultAI"
20
- db_path = "./database"
21
- db = None
22
- repo = None
23
- embedding_function = None
24
- openai.api_key = st.secrets.get("OPENAI_API_KEY")
25
-
26
- @classmethod
27
- def __new__(cls, *args, **kwargs):
28
- if not hasattr(cls, 'instance'):
29
- print('no database instance created. generating new one...')
30
- cls.instance = super(Database, cls).__new__(cls)
31
- print('database instance already exists. returning existing one...')
32
- return cls.instance
33
-
34
- def __init__(
35
- self,
36
- dataset_repo_url="https://huggingface.co/datasets/aniruddhdoki/ConsultAI",
37
- db_path="../database/.chroma", #"./database/.chroma",
38
- embedding_function="default"
39
- ):
40
- print('initializing instance')
41
- self.dataset_repo_url = dataset_repo_url
42
- self.db_path = db_path
43
- self._init_db()
44
- match embedding_function:
45
- case "default":
46
- self.embedding_function = embedding_functions.DefaultEmbeddingFunction()
47
- case "openai":
48
- self.embedding_function = embedding_functions.OpenAIEmbeddingFunction(
49
- api_key=os.environ["OPENAI_API_KEY"],
50
- model="text-embedding-ada-002",
51
- )
52
- case _:
53
- raise ValueError(f"Invalid embedding function {embedding_function}. Valid options are 'default' and 'openai'")
54
-
55
- def _init_db(self):
56
- """
57
- ### Get data from HuggingFace Dataset repo and initialize
58
- #### Params
59
- - None
60
- #### Returns
61
- - None
62
- """
63
- print('initializing chroma db...')
64
- if self.db is not None:
65
- return
66
- print('pulling saved db from repo')
67
- self.repo = hf_hub.Repository(
68
- local_dir=self.db_path,
69
- clone_from=self.dataset_repo_url,
70
- use_auth_token=os.environ["HF_TOKEN"]
71
- )
72
- print('creating new chroma client...')
73
- self.db = chromadb.PersistentClient(self.db_path)
74
- print('chroma client created!')
75
-
76
- def _save_db(self, commit_message="Database Update"):
77
- """
78
- ### Save database to HuggingFace Dataset repo
79
- #### Params
80
- - commit_message: String, commit message for repo
81
- #### Returns
82
- - None
83
- """
84
- print('saving db...')
85
- self.repo.push_to_hub(commit_message=commit_message)
86
- print('db saved!')
87
-
88
- def add_files(
89
- self,
90
- files,
91
- scrape=True,
92
- chunk_size=1000,
93
- ):
94
- """
95
- ### Add files to database.
96
- #### Params
97
- - files: List, list of UploadedFile objects
98
- #### Returns
99
- - None
100
- """
101
- print(f'adding {len(files)} files to db...')
102
- for file in files:
103
- self._add_file(
104
- file,
105
- scrape=scrape,
106
- chunk_size=chunk_size,
107
- )
108
- print('done adding files to db!')
109
-
110
- def _add_file(
111
- self,
112
- file,
113
- scrape,
114
- chunk_size,
115
- ):
116
- """
117
- ### Adds single file to database. Saves DB on completion.
118
- #### Params
119
- - file: UploadedFile object
120
- #### Returns
121
- - None
122
- """
123
- print(f'adding {file.name} file to db...')
124
- # check if file exists
125
- exists = self._check_if_file_exists(file)
126
- if exists:
127
- print(f"File {file.name} already exists in database, skipping...")
128
- return
129
- print(f'{file.name} not found in db, continuing...')
130
- # ingest file
131
- self._ingest_file(
132
- file,
133
- scrape=scrape,
134
- chunk_size=chunk_size
135
- )
136
- print(f'file {file.name} added to db!')
137
- # add file metadata to files collection
138
- self._add_file_metadata(file)
139
- # save db
140
- self._save_db()
141
- print(f'db saved after adding {file.name}!')
142
-
143
- def _check_if_file_exists(self, file):
144
- """
145
- ### Checks if file exists in database
146
- #### Params
147
- - file: UploadedFile object
148
- #### Returns
149
- - exists: Bool, whether file exists in database or not
150
- """
151
- print(f'checking if {file.name} exists in db...')
152
- files_collection = self.db.get_or_create_collection(
153
- name="files",
154
- embedding_function=self.embedding_function,
155
- )
156
- query = files_collection.get(
157
- where={"name": file.name},
158
- ).get('ids', [])
159
- return False if len(query) == 0 else True
160
-
161
- def _add_file_metadata(self, file):
162
- """
163
- ### Add file metadata to database
164
- #### Params
165
- - file: UploadedFile object
166
- #### Returns
167
- - None
168
- """
169
- print(f'adding metadata for {file.name} to db...')
170
- if self._check_if_file_exists(file):
171
- return
172
- files_collection = self.db.get_or_create_collection(
173
- "files",
174
- embedding_function=self.embedding_function
175
- )
176
- id = str(uuid.uuid4())
177
- files_collection.add(ids=id, documents=file.name, metadatas={"name": file.name})
178
- print(f'metadata for {file.name} added to db!')
179
-
180
- def _ingest_file(self, file, scrape, chunk_size):
181
- """
182
- ### Ingests file into database
183
- ONLY SUPPORTS .PDF FILES
184
- #### Params
185
- - file: UploadedFile object
186
- - scrape: Bool, whether to scrape HREFs from PDF or not
187
- #### Returns
188
- - None
189
- """
190
- print(f'ingesting {file.name} into db...')
191
- # process text from pdf
192
- reader = PyPDF2.PdfReader(file)
193
- print('reading pages...')
194
- for num in range(len(reader.pages)):
195
- page = reader.pages[num]
196
- metadata={
197
- "page_num": num,
198
- }
199
- text = page.extract_text()
200
- print(f'text extracted from page {num}')
201
- text_chunks = self._chunk_text(text, chunk_size)
202
- print(f'chunked text from page {num}')
203
- # process each chunk
204
- for chunk in text_chunks:
205
- self._add_chunk_to_db(chunk, metadata)
206
- print(f'chunk added to db from page {num}')
207
-
208
- if not scrape:
209
- return
210
-
211
- # extract links from pdf
212
- links = self._extract_links(reader)
213
- print(f'extracted {len(links)} links from {file.name}')
214
- # process each link
215
- self._ingest_links(links, chunk_size)
216
-
217
- def _chunk_text(self, text, chunk_size):
218
- """
219
- ### Chunk text into smaller chunks
220
- #### Params
221
- - text: String, page content of chunk
222
- - chunk_size: Int, size of each chunk
223
- #### Returns
224
- - chunks: List, list of chunks
225
- """
226
- print('chunking text...')
227
- return list((text[0+i:chunk_size+i] for i in range(0, len(text), chunk_size)))
228
-
229
- def _add_chunk_to_db(self, text, metadata):
230
- """
231
- ### Add text chunk to vector database
232
- #### Params
233
- - text: String, page content of chunk
234
- - metadatas: Dict, metadata for chunk
235
- """
236
- print(f'adding chunk to db...')
237
- chunks_collection = self.db.get_or_create_collection(
238
- "chunks",
239
- embedding_function=self.embedding_function
240
- )
241
- id = str(uuid.uuid4())
242
- chunks_collection.add(
243
- ids=id,
244
- documents=text,
245
- metadatas=metadata
246
- )
247
-
248
- def _extract_links(self, reader):
249
- """
250
- ### Extracts links from PDF
251
- #### Params
252
- - reader: PdfReader object with file loaded
253
- #### Returns
254
- - links: List, list of links
255
- """
256
- print('extracting links from file...')
257
- links = []
258
- for num in range(len(reader.pages)):
259
- page = reader.pages[num]
260
- if page.get('/Annots'):
261
- annotations = page['/Annots']
262
- for annotation in annotations:
263
- content = annotation.get_object()
264
- if content['/Subtype'] == '/Link':
265
- uri = content['/A']['/URI']
266
- links.append(uri)
267
- # # check how many annotations there are
268
- # print(annotations)
269
- # for annotation in annotations:
270
- # print(annotation)
271
- # content = annotation.get_object()
272
- # print(content)
273
- # if content.get('/Subtype', '') == '/Link':
274
- # uri = content["/A"]["/URI"]
275
- # links.append(uri)
276
- return links
277
-
278
- def _ingest_links(self, links, chunk_size):
279
- """
280
- ### Ingests links into database
281
- #### Params
282
- - links: List, list of links to ingest
283
- #### Returns
284
- - None
285
- """
286
- print('ingesting links into db...')
287
- start_urls = []
288
- for link in links:
289
- if self._check_if_link_exists(link):
290
- print(f'already exists in db, skipping {link}')
291
- continue
292
- start_urls.append({"url": link})
293
- print('generated start urls for apify')
294
-
295
- # scrape links
296
- dataset = self._apify_scrape(start_urls)
297
- # process each link
298
- self._ingest_apify_dataset(dataset, chunk_size)
299
-
300
- def _check_if_link_exists(self, link):
301
- """
302
- ### Checks if link exists in database
303
- #### Params
304
- - link: String, link to check
305
- #### Returns
306
- - exists: Bool, whether link exists in database or not
307
- """
308
- print(f'checking if {link} exists in db...')
309
- links_collection = self.db.get_or_create_collection(
310
- "links",
311
- embedding_function=self.embedding_function
312
- )
313
- query = links_collection.get(
314
- where={"link": link},
315
- ).get("ids", [])
316
- return False if len(query) == 0 else True
317
-
318
- def _add_link_metadata(self, link):
319
- """
320
- ### Add link metadata to database
321
- #### Params
322
- - link: String, link to add
323
- #### Returns
324
- - None
325
- """
326
- print(f'adding metadata for {link} to db...')
327
- if self._check_if_link_exists(link):
328
- return
329
- links_collection = self.db.get_or_create_collection(
330
- "links",
331
- embedding_function=self.embedding_function
332
- )
333
- id = str(uuid.uuid4())
334
- links_collection.add(ids=id, documents=link, metadatas={"link": link})
335
- print(f'metadata for {link} added to db!')
336
-
337
- def _apify_scrape(self, start_urls):
338
- """
339
- ### Scrape links using Apify
340
- #### Params
341
- - start_urls: List, list of urls to scrape
342
- #### Returns
343
- - dataset: DatasetClient, dataset of scraped links
344
- """
345
- print('scraping start_urls using apify...')
346
- client = ApifyClient(os.environ.get("APIFY_CLIENT_KEY"))
347
- print('initialized apify client')
348
- run_input = {
349
- "startUrls": start_urls,
350
- "crawlerType": "playwright:firefox",
351
- "includeUrlGlobs": [],
352
- "excludeUrlGlobs": [],
353
- "maxCrawlDepth": 0,
354
- "maxCrawlPages": 9999999,
355
- "initialConcurrency": 0,
356
- "maxConcurrency": 200,
357
- "initialCookies": [],
358
- "proxyConfiguration": { "useApifyProxy": True },
359
- "requestTimeoutSecs": 60,
360
- "dynamicContentWaitSecs": 10,
361
- "maxScrollHeightPixels": 5000,
362
- "removeElementsCssSelector": """nav, footer, script, style, noscript, svg,
363
- [role=\"alert\"],
364
- [role=\"banner\"],
365
- [role=\"dialog\"],
366
- [role=\"alertdialog\"],
367
- [role=\"region\"][aria-label*=\"skip\" i],
368
- [aria-modal=\"true\"]""",
369
- "removeCookieWarnings": True,
370
- "clickElementsCssSelector": "[aria-expanded=\"false\"]",
371
- "htmlTransformer": "readableText",
372
- "readableTextCharThreshold": 100,
373
- "aggressivePrune": False,
374
- "debugMode": False,
375
- "debugLog": False,
376
- "saveHtml": False,
377
- "saveMarkdown": False,
378
- "saveFiles": False,
379
- "saveScreenshots": False,
380
- "maxResults": 9999999,
381
- }
382
- try:
383
- print('trying to run apify crawler...')
384
- run = client.actor('apify/website-content-crawler').call(run_input=run_input)
385
- print('apify crawler completed!')
386
- return client.dataset(run['defaultDatasetId'])
387
- except Exception as e:
388
- raise e # currently unhandled, may error if out of resources on apify
389
-
390
- def _ingest_apify_dataset(self, dataset, chunk_size):
391
- """
392
- ### Process documents from apify dataset
393
- #### Params
394
- - dataset: DatasetClient, dataset of scraped links
395
- #### Returns
396
- - None
397
- """
398
- print('ingesting apify dataset into db...')
399
- for item in dataset.iterate_items():
400
- text = item.get('text', '')
401
- metadata = {'link': item.get('url', '')}
402
- text_chunks = self._chunk_text(text, chunk_size)
403
- print(f'generated {len(text_chunks)} chunks')
404
- for chunk in text_chunks:
405
- self._add_chunk_to_db(chunk, metadata)
406
- print(f'chunk added to db from apify dataset')
407
- self._add_link_metadata(item.get('url', '')) # add link metadata
408
- print('apify dataset ingested into db!')
409
-
410
- def view_db(self):
411
- """
412
- ### Return data in db as python dict.
413
- #### Params
414
- - None
415
- #### Returns
416
- - data: Dict, data in db
417
- """
418
- print('viewing db...')
419
- data = {}
420
- for collection in self.db.list_collections():
421
- data[collection.name] = collection.get()
422
- return data
423
-
424
- def peek(self, n=5):
425
- """
426
- ### Peek the first n values of the database.
427
- #### Params
428
- - n: Int, number of values to peek
429
- #### Returns
430
- - data: Dict, first n data in db
431
- """
432
- print(f'peeking first {n} values of db...')
433
- data = {}
434
- for collection in self.db.list_collections():
435
- data[collection.name] = collection.peek(limit=n)
436
- return data
437
-
438
- def empty(self):
439
- """
440
- ### Empty database.
441
- #### Params
442
- - None
443
- #### Returns
444
- - None
445
- """
446
- print('emptying db...')
447
- for collection in self.db.list_collections():
448
- self.db.delete_collection(collection.name)
449
- self._save_db(commit_message="Database Emptied")
450
- print('db emptied!')
451
-
452
- def similarity_search(self, query, n):
453
- """
454
- ### Similarity search for top n results.
455
- #### Params
456
- - query: String, query to search for
457
- - n: Int, number of results to return
458
- #### Returns
459
- - results: List, list of results
460
- """
461
- print(f'querying db for {query}...')
462
- results = self.db.get_or_create_collection('chunks').query(
463
- query_texts=query,
464
- n_results=n
465
- )
466
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/db_tests.py DELETED
@@ -1,22 +0,0 @@
1
- from db import Database
2
- from io import BytesIO
3
- from pprint import pprint
4
-
5
- database = Database()
6
-
7
- files = []
8
- fp = "/Users/aniruddhdoki/Library/Mobile Documents/com~apple~CloudDocs/work/fidelity/fidtest.pdf"
9
- with open(fp, 'rb') as fh:
10
- file = BytesIO(fh.read())
11
- file.name = "fidtest.pdf"
12
- files.append(file)
13
- file.name = "same fidtest file but want to see if it will rpocess two.pdf"
14
- files.append(file)
15
- print(files)
16
-
17
- database.add_files(files)
18
- pprint(database.peek())
19
-
20
- query = database.similarity_search("gimme apify", 2)
21
- pprint(query)
22
-