achapman commited on
Commit
b306f48
·
1 Parent(s): edc17fa
Files changed (3) hide show
  1. app.py +57 -37
  2. chainlit.md +1 -1
  3. requirements.txt +60 -282
app.py CHANGED
@@ -11,6 +11,9 @@ from langchain_core.prompts import PromptTemplate
11
  from langchain.schema.output_parser import StrOutputParser
12
  from langchain.schema.runnable import RunnablePassthrough
13
  from langchain.schema.runnable.config import RunnableConfig
 
 
 
14
 
15
  # GLOBAL SCOPE - ENTIRE APPLICATION HAS ACCESS TO VALUES SET IN THIS SCOPE #
16
  # ---- ENV VARIABLES ---- #
@@ -43,32 +46,63 @@ documents = document_loader.load()
43
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=30)
44
  split_documents = text_splitter.split_documents(documents)
45
 
46
- """print("initializing embeddings")
47
  hf_embeddings = HuggingFaceEndpointEmbeddings(
48
  model=HF_EMBED_ENDPOINT,
49
  task="feature-extraction",
50
  huggingfacehub_api_token=HF_TOKEN,
51
  )
52
 
53
- if os.path.exists("./data/vectorstore"):
54
- vectorstore = FAISS.load_local(
55
- "./data/vectorstore",
56
- hf_embeddings,
57
- allow_dangerous_deserialization=True # this is necessary to load the vectorstore from disk as it's stored as a `.pkl` file.
58
- )
59
- hf_retriever = vectorstore.as_retriever()
60
- print("Loaded Vectorstore")
61
- else:
 
 
 
 
62
  print("Indexing Files")
63
- #os.makedirs("./data/vectorstore", exist_ok=True)
64
- for i in range(0, len(split_documents), 32):
65
- if i == 0:
66
- vectorstore = FAISS.from_documents(split_documents[i:i+32], hf_embeddings)
67
- continue
68
- vectorstore.add_documents(split_documents[i:i+32])
69
- #vectorstore.save_local("./data/vectorstore")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- hf_retriever = vectorstore.as_retriever()"""
72
 
73
  # -- AUGMENTED -- #
74
  """
@@ -83,10 +117,7 @@ You are a helpful assistant. You answer user questions based on provided context
83
  User Query:
84
  {query}
85
 
86
- <|start_header_id|>assistant<|end_header_id|>
87
- """
88
-
89
- """Context:
90
  {context}<|eot_id|>
91
 
92
  <|start_header_id|>assistant<|end_header_id|>
@@ -129,20 +160,13 @@ async def start_chat():
129
 
130
  The user session is a dictionary that is unique to each user session, and is stored in the memory of the server.
131
  """
132
- print("entering on_chat_start")
133
- """lcel_rag_chain = (
134
  {"context": itemgetter("query") | hf_retriever, "query": itemgetter("query")}
135
  | rag_prompt | hf_llm
136
- )"""
137
- lcel_rag_chain = {"query": itemgetter("query")} | rag_prompt | hf_llm
138
 
139
- try:
140
- # Attempt to set up session normally
141
- cl.user_session.set("lcel_rag_chain", lcel_rag_chain)
142
- except KeyError:
143
- print("Reinitializing session due to disconnection.")
144
- cl.user_session.clear()
145
- cl.user_session.set("lcel_rag_chain", lcel_rag_chain)
146
 
147
  @cl.on_message
148
  async def main(message: cl.Message):
@@ -153,10 +177,6 @@ async def main(message: cl.Message):
153
 
154
  The LCEL RAG chain is stored in the user session, and is unique to each user session - this is why we can access it here.
155
  """
156
- print("entering on_message")
157
- msg = cl.Message(content="Processing your request... this may take a moment.")
158
- await msg.send()
159
-
160
  lcel_rag_chain = cl.user_session.get("lcel_rag_chain")
161
 
162
  msg = cl.Message(content="")
 
11
  from langchain.schema.output_parser import StrOutputParser
12
  from langchain.schema.runnable import RunnablePassthrough
13
  from langchain.schema.runnable.config import RunnableConfig
14
+ from tqdm.asyncio import tqdm_asyncio
15
+ import asyncio
16
+ from tqdm.asyncio import tqdm
17
 
18
  # GLOBAL SCOPE - ENTIRE APPLICATION HAS ACCESS TO VALUES SET IN THIS SCOPE #
19
  # ---- ENV VARIABLES ---- #
 
46
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=30)
47
  split_documents = text_splitter.split_documents(documents)
48
 
 
49
  hf_embeddings = HuggingFaceEndpointEmbeddings(
50
  model=HF_EMBED_ENDPOINT,
51
  task="feature-extraction",
52
  huggingfacehub_api_token=HF_TOKEN,
53
  )
54
 
55
+ async def add_documents_async(vectorstore, documents):
56
+ await vectorstore.aadd_documents(documents)
57
+
58
+ async def process_batch(vectorstore, batch, is_first_batch, pbar):
59
+ if is_first_batch:
60
+ result = await FAISS.afrom_documents(batch, hf_embeddings)
61
+ else:
62
+ await add_documents_async(vectorstore, batch)
63
+ result = vectorstore
64
+ pbar.update(len(batch))
65
+ return result
66
+
67
+ async def main():
68
  print("Indexing Files")
69
+
70
+ vectorstore = None
71
+ batch_size = 32
72
+
73
+ batches = [split_documents[i:i+batch_size] for i in range(0, len(split_documents), batch_size)]
74
+
75
+ async def process_all_batches():
76
+ nonlocal vectorstore
77
+ tasks = []
78
+ pbars = []
79
+
80
+ for i, batch in enumerate(batches):
81
+ pbar = tqdm(total=len(batch), desc=f"Batch {i+1}/{len(batches)}", position=i)
82
+ pbars.append(pbar)
83
+
84
+ if i == 0:
85
+ vectorstore = await process_batch(None, batch, True, pbar)
86
+ else:
87
+ tasks.append(process_batch(vectorstore, batch, False, pbar))
88
+
89
+ if tasks:
90
+ await asyncio.gather(*tasks)
91
+
92
+ for pbar in pbars:
93
+ pbar.close()
94
+
95
+ await process_all_batches()
96
+
97
+ hf_retriever = vectorstore.as_retriever()
98
+ print("\nIndexing complete. Vectorstore is ready for use.")
99
+ return hf_retriever
100
+
101
+ async def run():
102
+ retriever = await main()
103
+ return retriever
104
 
105
+ hf_retriever = asyncio.run(run())
106
 
107
  # -- AUGMENTED -- #
108
  """
 
117
  User Query:
118
  {query}
119
 
120
+ Context:
 
 
 
121
  {context}<|eot_id|>
122
 
123
  <|start_header_id|>assistant<|end_header_id|>
 
160
 
161
  The user session is a dictionary that is unique to each user session, and is stored in the memory of the server.
162
  """
163
+
164
+ lcel_rag_chain = (
165
  {"context": itemgetter("query") | hf_retriever, "query": itemgetter("query")}
166
  | rag_prompt | hf_llm
167
+ )
 
168
 
169
+ cl.user_session.set("lcel_rag_chain", lcel_rag_chain)
 
 
 
 
 
 
170
 
171
  @cl.on_message
172
  async def main(message: cl.Message):
 
177
 
178
  The LCEL RAG chain is stored in the user session, and is unique to each user session - this is why we can access it here.
179
  """
 
 
 
 
180
  lcel_rag_chain = cl.user_session.get("lcel_rag_chain")
181
 
182
  msg = cl.Message(content="")
chainlit.md CHANGED
@@ -1 +1 @@
1
- # FILL OUT YOUR CHAINLIT MD HERE WITH A DESCRIPTION OF YOUR APPLICATION
 
1
+ This bot to help users explore the writings of Paul Graham.
requirements.txt CHANGED
@@ -1,354 +1,132 @@
1
- #
2
- # This file is autogenerated by pip-compile with Python 3.9
3
- # by the following command:
4
- #
5
- # pip-compile requirement.in
6
- #
7
  aiofiles==23.2.1
8
- # via chainlit
9
  aiohappyeyeballs==2.4.3
10
- # via aiohttp
11
- aiohttp==3.10.9
12
- # via
13
- # langchain
14
- # langchain-community
15
  aiosignal==1.3.1
16
- # via aiohttp
17
  annotated-types==0.7.0
18
- # via pydantic
19
  anyio==3.7.1
20
- # via
21
- # asyncer
22
- # httpx
23
- # starlette
24
- # watchfiles
25
  async-timeout==4.0.3
26
- # via
27
- # aiohttp
28
- # langchain
29
  asyncer==0.0.2
30
- # via chainlit
31
  attrs==24.2.0
32
- # via aiohttp
33
  bidict==0.23.1
34
- # via python-socketio
35
  certifi==2024.8.30
36
- # via
37
- # httpcore
38
- # httpx
39
- # requests
40
- chainlit==1.1.302
41
- # via -r requirement.in
42
  charset-normalizer==3.3.2
43
- # via requests
44
- chevron==0.14.0
45
- # via literalai
46
  click==8.1.7
47
- # via
48
- # chainlit
49
- # uvicorn
50
  dataclasses-json==0.5.14
51
- # via
52
- # chainlit
53
- # langchain-community
54
- deprecated==1.2.14
55
- # via
56
- # opentelemetry-api
57
- # opentelemetry-exporter-otlp-proto-grpc
58
- # opentelemetry-exporter-otlp-proto-http
59
- # opentelemetry-semantic-conventions
60
  exceptiongroup==1.2.2
61
- # via anyio
62
  faiss-cpu==1.8.0.post1
63
- # via -r requirement.in
64
- fastapi==0.110.3
65
- # via chainlit
66
  filelock==3.16.1
67
- # via
68
- # huggingface-hub
69
- # torch
70
- # transformers
71
  filetype==1.2.0
72
- # via chainlit
73
  frozenlist==1.4.1
74
- # via
75
- # aiohttp
76
- # aiosignal
77
  fsspec==2024.9.0
78
- # via
79
- # huggingface-hub
80
- # torch
81
  googleapis-common-protos==1.65.0
82
- # via
83
- # opentelemetry-exporter-otlp-proto-grpc
84
- # opentelemetry-exporter-otlp-proto-http
85
  greenlet==3.1.1
86
- # via sqlalchemy
87
  grpcio==1.66.2
88
- # via opentelemetry-exporter-otlp-proto-grpc
89
  h11==0.14.0
90
- # via
91
- # httpcore
92
- # uvicorn
93
- # wsproto
94
- httpcore==1.0.6
95
- # via httpx
96
- httpx==0.27.2
97
- # via
98
- # chainlit
99
- # langsmith
100
- # literalai
101
  huggingface-hub==0.25.1
102
- # via
103
- # langchain-huggingface
104
- # sentence-transformers
105
- # tokenizers
106
- # transformers
107
  idna==3.10
108
- # via
109
- # anyio
110
- # httpx
111
- # requests
112
- # yarl
113
- importlib-metadata==8.4.0
114
- # via opentelemetry-api
115
- jinja2==3.1.4
116
- # via torch
117
  joblib==1.4.2
118
- # via scikit-learn
119
  jsonpatch==1.33
120
- # via langchain-core
121
  jsonpointer==3.0.0
122
- # via jsonpatch
123
- langchain==0.2.5
124
- # via
125
- # -r requirement.in
126
- # langchain-community
127
- langchain-community==0.2.5
128
- # via -r requirement.in
129
- langchain-core==0.2.9
130
- # via
131
- # -r requirement.in
132
- # langchain
133
- # langchain-community
134
- # langchain-huggingface
135
- # langchain-text-splitters
136
- langchain-huggingface==0.0.3
137
- # via -r requirement.in
138
- langchain-text-splitters==0.2.1
139
- # via
140
- # -r requirement.in
141
- # langchain
142
- langsmith==0.1.132
143
- # via
144
- # langchain
145
- # langchain-community
146
- # langchain-core
147
- lazify==0.4.0
148
- # via chainlit
149
- literalai==0.0.604
150
- # via chainlit
151
- markupsafe==3.0.0
152
- # via jinja2
153
  marshmallow==3.22.0
154
- # via dataclasses-json
155
  mpmath==1.3.0
156
- # via sympy
157
  multidict==6.1.0
158
- # via
159
- # aiohttp
160
- # yarl
161
  mypy-extensions==1.0.0
162
- # via typing-inspect
163
  nest-asyncio==1.6.0
164
- # via chainlit
165
  networkx==3.2.1
166
- # via torch
167
  numpy==1.26.4
168
- # via
169
- # chainlit
170
- # faiss-cpu
171
- # langchain
172
- # langchain-community
173
- # scikit-learn
174
- # scipy
175
- # transformers
 
 
 
 
 
176
  opentelemetry-api==1.27.0
177
- # via
178
- # opentelemetry-exporter-otlp-proto-grpc
179
- # opentelemetry-exporter-otlp-proto-http
180
- # opentelemetry-instrumentation
181
- # opentelemetry-sdk
182
- # opentelemetry-semantic-conventions
183
- # uptrace
184
  opentelemetry-exporter-otlp==1.27.0
185
- # via uptrace
186
  opentelemetry-exporter-otlp-proto-common==1.27.0
187
- # via
188
- # opentelemetry-exporter-otlp-proto-grpc
189
- # opentelemetry-exporter-otlp-proto-http
190
  opentelemetry-exporter-otlp-proto-grpc==1.27.0
191
- # via opentelemetry-exporter-otlp
192
  opentelemetry-exporter-otlp-proto-http==1.27.0
193
- # via opentelemetry-exporter-otlp
194
  opentelemetry-instrumentation==0.48b0
195
- # via uptrace
196
  opentelemetry-proto==1.27.0
197
- # via
198
- # opentelemetry-exporter-otlp-proto-common
199
- # opentelemetry-exporter-otlp-proto-grpc
200
- # opentelemetry-exporter-otlp-proto-http
201
  opentelemetry-sdk==1.27.0
202
- # via
203
- # opentelemetry-exporter-otlp-proto-grpc
204
- # opentelemetry-exporter-otlp-proto-http
205
- # uptrace
206
  opentelemetry-semantic-conventions==0.48b0
207
- # via opentelemetry-sdk
208
  orjson==3.10.7
209
- # via langsmith
210
  packaging==23.2
211
- # via
212
- # chainlit
213
- # faiss-cpu
214
- # huggingface-hub
215
- # langchain-core
216
- # literalai
217
- # marshmallow
218
- # transformers
219
  pillow==10.4.0
220
- # via sentence-transformers
221
  protobuf==4.25.5
222
- # via
223
- # googleapis-common-protos
224
- # opentelemetry-proto
225
  pydantic==2.9.2
226
- # via
227
- # chainlit
228
- # fastapi
229
- # langchain
230
- # langchain-core
231
- # langsmith
232
- # literalai
233
- pydantic-core==2.23.4
234
- # via pydantic
235
- pyjwt==2.9.0
236
- # via chainlit
237
  python-dotenv==1.0.1
238
- # via
239
- # -r requirement.in
240
- # chainlit
241
  python-engineio==4.9.1
242
- # via python-socketio
243
- python-multipart==0.0.9
244
- # via chainlit
245
  python-socketio==5.11.4
246
- # via chainlit
247
- pyyaml==6.0.2
248
- # via
249
- # huggingface-hub
250
- # langchain
251
- # langchain-community
252
- # langchain-core
253
- # transformers
254
  regex==2024.9.11
255
- # via transformers
256
  requests==2.32.3
257
- # via
258
- # huggingface-hub
259
- # langchain
260
- # langchain-community
261
- # langsmith
262
- # opentelemetry-exporter-otlp-proto-http
263
- # requests-toolbelt
264
- # transformers
265
- requests-toolbelt==1.0.0
266
- # via langsmith
267
  safetensors==0.4.5
268
- # via transformers
269
  scikit-learn==1.5.2
270
- # via sentence-transformers
271
  scipy==1.13.1
272
- # via
273
- # scikit-learn
274
- # sentence-transformers
275
  sentence-transformers==3.1.1
276
- # via langchain-huggingface
277
  simple-websocket==1.0.0
278
- # via python-engineio
279
  sniffio==1.3.1
280
- # via
281
- # anyio
282
- # httpx
283
- sqlalchemy==2.0.35
284
- # via
285
- # langchain
286
- # langchain-community
287
- starlette==0.37.2
288
- # via
289
- # chainlit
290
- # fastapi
291
  sympy==1.13.3
292
- # via torch
293
  syncer==2.0.3
294
- # via chainlit
295
  tenacity==8.5.0
296
- # via
297
- # langchain
298
- # langchain-community
299
- # langchain-core
300
  threadpoolctl==3.5.0
301
- # via scikit-learn
302
  tokenizers==0.20.0
303
- # via
304
- # langchain-huggingface
305
- # transformers
306
- tomli==2.0.2
307
- # via chainlit
308
- torch==2.2.2
309
- # via sentence-transformers
310
  tqdm==4.66.5
311
- # via
312
- # huggingface-hub
313
- # sentence-transformers
314
- # transformers
315
- transformers==4.45.2
316
- # via
317
- # langchain-huggingface
318
- # sentence-transformers
319
- typing-extensions==4.12.2
320
- # via
321
- # fastapi
322
- # huggingface-hub
323
- # multidict
324
- # opentelemetry-sdk
325
- # pydantic
326
- # pydantic-core
327
- # sqlalchemy
328
- # starlette
329
- # torch
330
- # typing-inspect
331
- # uvicorn
332
  typing-inspect==0.9.0
333
- # via dataclasses-json
334
- uptrace==1.27.0
335
- # via chainlit
336
  urllib3==2.2.3
337
- # via requests
338
- uvicorn==0.25.0
339
- # via chainlit
340
  watchfiles==0.20.0
341
- # via chainlit
342
  wrapt==1.16.0
343
- # via
344
- # deprecated
345
- # opentelemetry-instrumentation
346
  wsproto==1.2.0
347
- # via simple-websocket
348
  yarl==1.13.1
349
- # via aiohttp
350
- zipp==3.20.2
351
- # via importlib-metadata
352
-
353
- # The following packages are considered to be unsafe in a requirements file:
354
- # setuptools
 
 
 
 
 
 
 
1
  aiofiles==23.2.1
 
2
  aiohappyeyeballs==2.4.3
3
+ aiohttp==3.10.8
 
 
 
 
4
  aiosignal==1.3.1
 
5
  annotated-types==0.7.0
 
6
  anyio==3.7.1
 
 
 
 
 
7
  async-timeout==4.0.3
 
 
 
8
  asyncer==0.0.2
 
9
  attrs==24.2.0
 
10
  bidict==0.23.1
 
11
  certifi==2024.8.30
12
+ chainlit==0.7.700
 
 
 
 
 
13
  charset-normalizer==3.3.2
 
 
 
14
  click==8.1.7
 
 
 
15
  dataclasses-json==0.5.14
16
+ Deprecated==1.2.14
17
+ distro==1.9.0
 
 
 
 
 
 
 
18
  exceptiongroup==1.2.2
 
19
  faiss-cpu==1.8.0.post1
20
+ fastapi==0.100.1
21
+ fastapi-socketio==0.0.10
 
22
  filelock==3.16.1
 
 
 
 
23
  filetype==1.2.0
 
24
  frozenlist==1.4.1
 
 
 
25
  fsspec==2024.9.0
 
 
 
26
  googleapis-common-protos==1.65.0
 
 
 
27
  greenlet==3.1.1
 
28
  grpcio==1.66.2
29
+ grpcio-tools==1.62.3
30
  h11==0.14.0
31
+ h2==4.1.0
32
+ hpack==4.0.0
33
+ httpcore==0.17.3
34
+ httpx==0.24.1
 
 
 
 
 
 
 
35
  huggingface-hub==0.25.1
36
+ hyperframe==6.0.1
 
 
 
 
37
  idna==3.10
38
+ importlib_metadata==8.4.0
39
+ Jinja2==3.1.4
40
+ jiter==0.5.0
 
 
 
 
 
 
41
  joblib==1.4.2
 
42
  jsonpatch==1.33
 
43
  jsonpointer==3.0.0
44
+ langchain==0.3.0
45
+ langchain-community==0.3.0
46
+ langchain-core==0.3.1
47
+ langchain-huggingface==0.1.0
48
+ langchain-openai==0.2.0
49
+ langchain-qdrant==0.1.4
50
+ langchain-text-splitters==0.3.0
51
+ langsmith==0.1.121
52
+ Lazify==0.4.0
53
+ MarkupSafe==2.1.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  marshmallow==3.22.0
 
55
  mpmath==1.3.0
 
56
  multidict==6.1.0
 
 
 
57
  mypy-extensions==1.0.0
 
58
  nest-asyncio==1.6.0
 
59
  networkx==3.2.1
 
60
  numpy==1.26.4
61
+ nvidia-cublas-cu12==12.1.3.1
62
+ nvidia-cuda-cupti-cu12==12.1.105
63
+ nvidia-cuda-nvrtc-cu12==12.1.105
64
+ nvidia-cuda-runtime-cu12==12.1.105
65
+ nvidia-cudnn-cu12==9.1.0.70
66
+ nvidia-cufft-cu12==11.0.2.54
67
+ nvidia-curand-cu12==10.3.2.106
68
+ nvidia-cusolver-cu12==11.4.5.107
69
+ nvidia-cusparse-cu12==12.1.0.106
70
+ nvidia-nccl-cu12==2.20.5
71
+ nvidia-nvjitlink-cu12==12.6.77
72
+ nvidia-nvtx-cu12==12.1.105
73
+ openai==1.51.0
74
  opentelemetry-api==1.27.0
 
 
 
 
 
 
 
75
  opentelemetry-exporter-otlp==1.27.0
 
76
  opentelemetry-exporter-otlp-proto-common==1.27.0
 
 
 
77
  opentelemetry-exporter-otlp-proto-grpc==1.27.0
 
78
  opentelemetry-exporter-otlp-proto-http==1.27.0
 
79
  opentelemetry-instrumentation==0.48b0
 
80
  opentelemetry-proto==1.27.0
 
 
 
 
81
  opentelemetry-sdk==1.27.0
 
 
 
 
82
  opentelemetry-semantic-conventions==0.48b0
 
83
  orjson==3.10.7
 
84
  packaging==23.2
 
 
 
 
 
 
 
 
85
  pillow==10.4.0
86
+ portalocker==2.10.1
87
  protobuf==4.25.5
 
 
 
88
  pydantic==2.9.2
89
+ pydantic-settings==2.5.2
90
+ pydantic_core==2.23.4
91
+ PyJWT==2.9.0
92
+ PyMuPDF==1.24.10
93
+ PyMuPDFb==1.24.10
 
 
 
 
 
 
94
  python-dotenv==1.0.1
 
 
 
95
  python-engineio==4.9.1
96
+ python-graphql-client==0.4.3
97
+ python-multipart==0.0.6
 
98
  python-socketio==5.11.4
99
+ PyYAML==6.0.2
100
+ qdrant-client==1.11.2
 
 
 
 
 
 
101
  regex==2024.9.11
 
102
  requests==2.32.3
 
 
 
 
 
 
 
 
 
 
103
  safetensors==0.4.5
 
104
  scikit-learn==1.5.2
 
105
  scipy==1.13.1
 
 
 
106
  sentence-transformers==3.1.1
 
107
  simple-websocket==1.0.0
 
108
  sniffio==1.3.1
109
+ SQLAlchemy==2.0.35
110
+ starlette==0.27.0
 
 
 
 
 
 
 
 
 
111
  sympy==1.13.3
 
112
  syncer==2.0.3
 
113
  tenacity==8.5.0
 
 
 
 
114
  threadpoolctl==3.5.0
115
+ tiktoken==0.7.0
116
  tokenizers==0.20.0
117
+ tomli==2.0.1
118
+ torch==2.4.1
 
 
 
 
 
119
  tqdm==4.66.5
120
+ transformers==4.45.1
121
+ triton==3.0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  typing-inspect==0.9.0
123
+ typing_extensions==4.12.2
124
+ uptrace==1.26.0
 
125
  urllib3==2.2.3
126
+ uvicorn==0.23.2
 
 
127
  watchfiles==0.20.0
128
+ websockets==13.1
129
  wrapt==1.16.0
 
 
 
130
  wsproto==1.2.0
 
131
  yarl==1.13.1
132
+ zipp==3.20.2