drewgenai commited on
Commit
0805914
·
1 Parent(s): d0861d5

change to session based secondary qdrant db

Browse files
Files changed (1) hide show
  1. app.py +165 -106
app.py CHANGED
@@ -36,7 +36,7 @@ UPLOAD_PATH = "./uploads"
36
  INITIAL_EMBEDDINGS_DIR = "./initial_embeddings"
37
  INITIAL_EMBEDDINGS_NAME = "initial_embeddings"
38
  USER_EMBEDDINGS_NAME = "user_embeddings"
39
- VECTOR_STORE_COLLECTION = "documents"
40
 
41
  # Model IDs
42
  EMBEDDING_MODEL_ID = "pritamdeka/S-PubMedBert-MS-MARCO"
@@ -61,15 +61,13 @@ NIH_HEAL_CORE_DOMAINS = [
61
  # Make sure upload directory exists
62
  os.makedirs(UPLOAD_PATH, exist_ok=True)
63
 
64
- # ==================== EMBEDDING MODEL SETUP ====================
65
  def get_embedding_model(model_id):
66
  """Creates and returns the appropriate embedding model based on the model ID."""
67
  if "text-embedding" in model_id:
68
- # OpenAI embeddings
69
  from langchain_openai import OpenAIEmbeddings
70
  return OpenAIEmbeddings(model=model_id)
71
  else:
72
- # HuggingFace embeddings
73
  return HuggingFaceEmbeddings(model_name=model_id)
74
 
75
  def initialize_embedding_models():
@@ -84,16 +82,13 @@ def initialize_embedding_models():
84
  # Initialize the embedding model
85
  initialize_embedding_models()
86
 
87
- # Get embedding dimensions utility
88
- def get_embedding_dimensions(model_id):
89
- """Gets the dimensions of embeddings from a specific model."""
90
- model = get_embedding_model(model_id)
91
- sample_text = "Sample text to determine embedding dimension"
92
- sample_embedding = model.embed_query(sample_text)
93
- return len(sample_embedding)
94
-
95
  # ==================== QDRANT SETUP ====================
96
- qdrant_client = QdrantClient(":memory:")
 
 
 
 
 
97
 
98
  # ==================== DOCUMENT PROCESSING ====================
99
  # Create a semantic splitter for documents
@@ -138,8 +133,8 @@ def load_and_chunk_core_reference_files():
138
  return all_chunks
139
 
140
  def embed_core_reference_in_qdrant(chunks):
141
- """Embeds core reference chunks and stores them in Qdrant."""
142
- global embedding_model
143
 
144
  if not chunks:
145
  print("No Excel files found to process or all files were empty.")
@@ -151,15 +146,32 @@ def embed_core_reference_in_qdrant(chunks):
151
  initialize_embedding_models()
152
 
153
  print(f"Using embedding model: {EMBEDDING_MODEL_ID}")
154
- print("Creating vector store...")
155
 
156
  try:
157
- vector_store = QdrantVectorStore.from_documents(
158
- documents=chunks,
159
- embedding=embedding_model,
160
- location=":memory:",
161
- collection_name=INITIAL_EMBEDDINGS_NAME
 
 
 
 
 
 
162
  )
 
 
 
 
 
 
 
 
 
 
 
163
  print(f"Successfully loaded all .xlsx files into Qdrant collection '{INITIAL_EMBEDDINGS_NAME}'.")
164
  return vector_store
165
  except Exception as e:
@@ -167,10 +179,18 @@ def embed_core_reference_in_qdrant(chunks):
167
  print(f"Embedding model status: {embedding_model is not None}")
168
  return None
169
 
 
 
 
170
  def initialize_core_reference_embeddings():
171
- """Loads all .xlsx files, extracts text, embeds, and stores in Qdrant."""
 
172
  chunks = load_and_chunk_core_reference_files()
173
- return embed_core_reference_in_qdrant(chunks)
 
 
 
 
174
 
175
  # ==================== PROTOCOL DOCUMENT PROCESSING ====================
176
  async def load_and_chunk_protocol_files(files):
@@ -182,10 +202,6 @@ async def load_and_chunk_protocol_files(files):
182
  print(f"Processing file: {file.name}, size: {file.size} bytes")
183
  file_path = os.path.join(UPLOAD_PATH, file.name)
184
 
185
- # Ensure the upload directory exists
186
- os.makedirs(UPLOAD_PATH, exist_ok=True)
187
-
188
- # Copy the file to the upload directory
189
  shutil.copyfile(file.path, file_path)
190
 
191
  try:
@@ -211,33 +227,29 @@ async def load_and_chunk_protocol_files(files):
211
 
212
  return documents_with_metadata
213
 
214
- async def embed_protocol_in_qdrant(documents_with_metadata, model_name=EMBEDDING_MODEL_ID):
215
- """Create a vector store and embed protocol chunks into Qdrant."""
216
  global embedding_model
217
-
218
- if not documents_with_metadata:
219
- print("No documents to embed")
220
- return None
221
-
222
  print(f"Using embedding model: {model_name}")
223
 
224
  try:
225
  # First, check if collection exists and delete it if it does
226
- if USER_EMBEDDINGS_NAME in [c.name for c in qdrant_client.get_collections().collections]:
227
- qdrant_client.delete_collection(USER_EMBEDDINGS_NAME)
228
 
229
  # Create the collection with proper parameters
230
  # Get the embedding dimension from the model
231
  embedding_dimension = len(embedding_model.embed_query("Sample text"))
232
 
233
- qdrant_client.create_collection(
234
  collection_name=USER_EMBEDDINGS_NAME,
235
  vectors_config=VectorParams(size=embedding_dimension, distance=Distance.COSINE)
236
  )
237
 
238
  # Create the vector store
239
  user_vectorstore = QdrantVectorStore(
240
- client=qdrant_client,
241
  collection_name=USER_EMBEDDINGS_NAME,
242
  embedding=embedding_model
243
  )
@@ -251,25 +263,40 @@ async def embed_protocol_in_qdrant(documents_with_metadata, model_name=EMBEDDING
251
  print(f"Error creating vector store: {str(e)}")
252
  return None
253
 
254
- async def process_uploaded_protocol(files, model_name=EMBEDDING_MODEL_ID):
255
- """Process uploaded protocol PDF files and add them to a separate vector store collection"""
256
  documents_with_metadata = await load_and_chunk_protocol_files(files)
257
- return await embed_protocol_in_qdrant(documents_with_metadata, model_name)
258
 
259
  # ==================== RETRIEVAL FUNCTIONS ====================
260
  def retrieve_documents(query, doc_type=None, k=5):
261
- """Retrieve documents, optionally filtering by document type"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  vector_store = QdrantVectorStore(
263
- client=qdrant_client,
264
- collection_name=VECTOR_STORE_COLLECTION,
265
  embedding=embedding_model
266
  )
267
 
268
- # Set up filter if doc_type is specified
269
  search_kwargs = {"k": k}
270
- if doc_type:
271
- search_kwargs["filter"] = {"type": doc_type}
272
-
273
  retriever = vector_store.as_retriever(search_kwargs=search_kwargs)
274
  return retriever.invoke(query)
275
 
@@ -288,11 +315,12 @@ Context:
288
  """
289
 
290
  rag_prompt = ChatPromptTemplate.from_template(RAG_TEMPLATE)
291
- chat_model = ChatOpenAI()
 
292
 
293
  # Create a RAG chain that can be filtered by document type
294
  def create_rag_chain(doc_type=None):
295
- """Create a RAG chain that can be filtered by document type"""
296
  def retrieve_with_type(query):
297
  docs = retrieve_documents(query, doc_type=doc_type)
298
  return format_docs(docs)
@@ -337,21 +365,42 @@ def search_all_data(query: str, doc_type: str = None) -> str:
337
  @tool
338
  def search_core_reference(query: str, top_k: int = 3) -> str:
339
  """Search core reference data and protocol data for information related to the query."""
340
- global embedding_model
341
 
342
- # Use the existing core_reference_retrieval_chain
343
- result = core_reference_retrieval_chain.invoke({"question": query})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
  # If we have a user collection, also search that
346
  try:
347
  # Check if user collection exists
348
- if USER_EMBEDDINGS_NAME not in [c.name for c in qdrant_client.get_collections().collections]:
349
  # If no user collection exists yet, just return core reference results
350
  return result
351
 
352
  # Create a retrieval chain for user documents
353
  user_retriever = QdrantVectorStore(
354
- client=qdrant_client,
355
  collection_name=USER_EMBEDDINGS_NAME,
356
  embedding=embedding_model
357
  ).as_retriever(search_kwargs={"k": top_k})
@@ -376,76 +425,75 @@ def search_core_reference(query: str, top_k: int = 3) -> str:
376
  @tool
377
  def load_and_embed_protocol(file_path: str = None) -> str:
378
  """Load and embed a protocol PDF file into the vector store."""
379
- """Load and embed a protocol PDF file into the vector store.
380
-
381
- Args:
382
- file_path: Optional path to the PDF file. If None, will use files in the upload directory.
 
 
 
 
 
 
383
 
384
- Returns:
385
- String indicating success or failure of the embedding process
386
- """
387
- try:
388
- # If no specific file path is provided, use all PDFs in the upload directory
389
- if not file_path:
390
- uploaded_files = [f for f in os.listdir(UPLOAD_PATH) if f.endswith('.pdf')]
391
- if not uploaded_files:
392
- return "No protocol documents found in the upload directory."
393
-
394
- # Create file objects for processing
395
- files = []
396
- for filename in uploaded_files:
397
- file_path = os.path.join(UPLOAD_PATH, filename)
398
- # Create a simple object with the necessary attributes
399
- class FileObj:
400
- def __init__(self, path, name, size):
401
- self.path = path
402
- self.name = name
403
- self.size = size
404
-
405
- file_size = os.path.getsize(file_path)
406
- files.append(FileObj(file_path, filename, file_size))
407
- else:
408
- # Create a file object for the specific file
409
- if not os.path.exists(file_path):
410
- return f"File not found: {file_path}"
411
-
412
- filename = os.path.basename(file_path)
413
- file_size = os.path.getsize(file_path)
414
-
415
  class FileObj:
416
  def __init__(self, path, name, size):
417
  self.path = path
418
  self.name = name
419
  self.size = size
420
 
421
- files = [FileObj(file_path, filename, file_size)]
 
 
 
 
 
422
 
423
- # Process the files asynchronously
424
- import asyncio
425
- documents_with_metadata = asyncio.run(load_and_chunk_protocol_files(files))
426
- user_vectorstore = asyncio.run(embed_protocol_in_qdrant(documents_with_metadata, EMBEDDING_MODEL_ID))
427
 
428
- if user_vectorstore:
429
- return f"Successfully embedded {len(documents_with_metadata)} chunks from {len(files)} protocol document(s)."
430
- else:
431
- return "Failed to embed protocol document(s)."
432
- except Exception as e:
433
- return f"Error embedding protocol document: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
434
 
435
  @tool
436
  def search_protocol_for_instruments(domain: str) -> dict:
437
  """Search the protocol for instruments related to a specific NIH HEAL CDE core domain."""
438
  global embedding_model
439
 
 
 
 
 
 
440
  # Check if user collection exists
441
  try:
442
  # Check if collection exists
443
- if USER_EMBEDDINGS_NAME not in [c.name for c in qdrant_client.get_collections().collections]:
444
  return {"domain": domain, "instrument": "No protocol document embedded", "context": ""}
445
 
446
  # Create retriever for user documents
447
  user_retriever = QdrantVectorStore(
448
- client=qdrant_client,
449
  collection_name=USER_EMBEDDINGS_NAME,
450
  embedding=embedding_model
451
  ).as_retriever(search_kwargs={"k": 10})
@@ -611,7 +659,8 @@ def call_final_model(state: MessagesState):
611
  last_ai_message = messages[-1]
612
  response = final_model.invoke(
613
  [
614
- SystemMessage("Rewrite this in the voice of a helpful and kind assistant"),
 
615
  HumanMessage(last_ai_message.content),
616
  ]
617
  )
@@ -640,6 +689,16 @@ graph = builder.compile()
640
  # ==================== CHAINLIT HANDLERS ====================
641
  @cl.on_chat_start
642
  async def on_chat_start():
 
 
 
 
 
 
 
 
 
 
643
  # Welcome message
644
  welcome_msg = cl.Message(content="Welcome! Please upload a NIH HEAL protocol PDF file to get started.")
645
  await welcome_msg.send()
@@ -656,9 +715,9 @@ async def on_chat_start():
656
  processing_msg = cl.Message(content="Processing your protocol PDF file...")
657
  await processing_msg.send()
658
 
659
- # Process the uploaded files
660
  documents_with_metadata = await load_and_chunk_protocol_files(files)
661
- user_vectorstore = await embed_protocol_in_qdrant(documents_with_metadata)
662
 
663
  if user_vectorstore:
664
  analysis_msg = cl.Message(content="Analyzing your protocol to identify instruments (CRF questionaires) for NIH HEAL CDE core domains...")
 
36
  INITIAL_EMBEDDINGS_DIR = "./initial_embeddings"
37
  INITIAL_EMBEDDINGS_NAME = "initial_embeddings"
38
  USER_EMBEDDINGS_NAME = "user_embeddings"
39
+ # VECTOR_STORE_COLLECTION = "documents"
40
 
41
  # Model IDs
42
  EMBEDDING_MODEL_ID = "pritamdeka/S-PubMedBert-MS-MARCO"
 
61
  # Make sure upload directory exists
62
  os.makedirs(UPLOAD_PATH, exist_ok=True)
63
 
64
+ # ==================== EMBEDDING MODEL SETUP to allow flexibility of model selection ====================
65
  def get_embedding_model(model_id):
66
  """Creates and returns the appropriate embedding model based on the model ID."""
67
  if "text-embedding" in model_id:
 
68
  from langchain_openai import OpenAIEmbeddings
69
  return OpenAIEmbeddings(model=model_id)
70
  else:
 
71
  return HuggingFaceEmbeddings(model_name=model_id)
72
 
73
  def initialize_embedding_models():
 
82
  # Initialize the embedding model
83
  initialize_embedding_models()
84
 
 
 
 
 
 
 
 
 
85
  # ==================== QDRANT SETUP ====================
86
+ # Create a global Qdrant client for the core embeddings (available to all sessions)
87
+ global_qdrant_client = QdrantClient(":memory:")
88
+
89
+ # Initialize a function to create session-specific Qdrant clients
90
+ def create_session_qdrant_client():
91
+ return QdrantClient(":memory:")
92
 
93
  # ==================== DOCUMENT PROCESSING ====================
94
  # Create a semantic splitter for documents
 
133
  return all_chunks
134
 
135
  def embed_core_reference_in_qdrant(chunks):
136
+ """Embeds core reference chunks and stores them in the global Qdrant instance."""
137
+ global embedding_model, global_qdrant_client
138
 
139
  if not chunks:
140
  print("No Excel files found to process or all files were empty.")
 
146
  initialize_embedding_models()
147
 
148
  print(f"Using embedding model: {EMBEDDING_MODEL_ID}")
149
+ print("Creating vector store for core reference data...")
150
 
151
  try:
152
+ # First, check if collection exists and delete it if it does
153
+ if INITIAL_EMBEDDINGS_NAME in [c.name for c in global_qdrant_client.get_collections().collections]:
154
+ global_qdrant_client.delete_collection(INITIAL_EMBEDDINGS_NAME)
155
+
156
+ # Create the collection with proper parameters
157
+ # Get the embedding dimension from the model
158
+ embedding_dimension = len(embedding_model.embed_query("Sample text"))
159
+
160
+ global_qdrant_client.create_collection(
161
+ collection_name=INITIAL_EMBEDDINGS_NAME,
162
+ vectors_config=VectorParams(size=embedding_dimension, distance=Distance.COSINE)
163
  )
164
+
165
+ # Create the vector store
166
+ vector_store = QdrantVectorStore(
167
+ client=global_qdrant_client,
168
+ collection_name=INITIAL_EMBEDDINGS_NAME,
169
+ embedding=embedding_model
170
+ )
171
+
172
+ # Add documents to the vector store
173
+ vector_store.add_documents(chunks)
174
+
175
  print(f"Successfully loaded all .xlsx files into Qdrant collection '{INITIAL_EMBEDDINGS_NAME}'.")
176
  return vector_store
177
  except Exception as e:
 
179
  print(f"Embedding model status: {embedding_model is not None}")
180
  return None
181
 
182
+ # Initialize core embeddings on application startup
183
+ core_vectorstore = None
184
+
185
  def initialize_core_reference_embeddings():
186
+ """Loads all .xlsx files, extracts text, embeds, and stores in global Qdrant."""
187
+ global core_vectorstore
188
  chunks = load_and_chunk_core_reference_files()
189
+ core_vectorstore = embed_core_reference_in_qdrant(chunks)
190
+ return core_vectorstore
191
+
192
+ # Call this function when the application starts
193
+ initialize_core_reference_embeddings()
194
 
195
  # ==================== PROTOCOL DOCUMENT PROCESSING ====================
196
  async def load_and_chunk_protocol_files(files):
 
202
  print(f"Processing file: {file.name}, size: {file.size} bytes")
203
  file_path = os.path.join(UPLOAD_PATH, file.name)
204
 
 
 
 
 
205
  shutil.copyfile(file.path, file_path)
206
 
207
  try:
 
227
 
228
  return documents_with_metadata
229
 
230
+ async def embed_protocol_in_qdrant(documents_with_metadata, session_qdrant_client, model_name=EMBEDDING_MODEL_ID):
231
+ """Create a vector store and embed protocol chunks into session-specific Qdrant."""
232
  global embedding_model
233
+
 
 
 
 
234
  print(f"Using embedding model: {model_name}")
235
 
236
  try:
237
  # First, check if collection exists and delete it if it does
238
+ if USER_EMBEDDINGS_NAME in [c.name for c in session_qdrant_client.get_collections().collections]:
239
+ session_qdrant_client.delete_collection(USER_EMBEDDINGS_NAME)
240
 
241
  # Create the collection with proper parameters
242
  # Get the embedding dimension from the model
243
  embedding_dimension = len(embedding_model.embed_query("Sample text"))
244
 
245
+ session_qdrant_client.create_collection(
246
  collection_name=USER_EMBEDDINGS_NAME,
247
  vectors_config=VectorParams(size=embedding_dimension, distance=Distance.COSINE)
248
  )
249
 
250
  # Create the vector store
251
  user_vectorstore = QdrantVectorStore(
252
+ client=session_qdrant_client,
253
  collection_name=USER_EMBEDDINGS_NAME,
254
  embedding=embedding_model
255
  )
 
263
  print(f"Error creating vector store: {str(e)}")
264
  return None
265
 
266
+ async def process_uploaded_protocol(files, session_qdrant_client, model_name=EMBEDDING_MODEL_ID):
267
+ """Process uploaded protocol PDF files and add them to a session-specific vector store collection"""
268
  documents_with_metadata = await load_and_chunk_protocol_files(files)
269
+ return await embed_protocol_in_qdrant(documents_with_metadata, session_qdrant_client, model_name)
270
 
271
  # ==================== RETRIEVAL FUNCTIONS ====================
272
  def retrieve_documents(query, doc_type=None, k=5):
273
+ """Retrieve documents from either core or session database based on doc_type"""
274
+ global embedding_model, global_qdrant_client
275
+
276
+ # Get the appropriate client and collection name
277
+ if doc_type == "protocol":
278
+ # Use session-specific client for protocol documents
279
+ client = cl.user_session.get("session_qdrant_client")
280
+ collection_name = USER_EMBEDDINGS_NAME
281
+ if not client:
282
+ print("No session client available")
283
+ return []
284
+ else:
285
+ # Use global client for core reference documents
286
+ client = global_qdrant_client
287
+ collection_name = INITIAL_EMBEDDINGS_NAME
288
+
289
+ # Create vector store with the appropriate client
290
  vector_store = QdrantVectorStore(
291
+ client=client,
292
+ collection_name=collection_name,
293
  embedding=embedding_model
294
  )
295
 
296
+ # Set up search parameters
297
  search_kwargs = {"k": k}
298
+
299
+ # Create and use retriever
 
300
  retriever = vector_store.as_retriever(search_kwargs=search_kwargs)
301
  return retriever.invoke(query)
302
 
 
315
  """
316
 
317
  rag_prompt = ChatPromptTemplate.from_template(RAG_TEMPLATE)
318
+ #chat_model = ChatOpenAI()
319
+ chat_model = ChatOpenAI(model_name="gpt-4o")
320
 
321
  # Create a RAG chain that can be filtered by document type
322
  def create_rag_chain(doc_type=None):
323
+ """Create a RAG chain that can be filtered by document type (protocol/core_reference)"""
324
  def retrieve_with_type(query):
325
  docs = retrieve_documents(query, doc_type=doc_type)
326
  return format_docs(docs)
 
365
  @tool
366
  def search_core_reference(query: str, top_k: int = 3) -> str:
367
  """Search core reference data and protocol data for information related to the query."""
368
+ global embedding_model, global_qdrant_client
369
 
370
+ # Create a retriever for the core embeddings
371
+ core_retriever = QdrantVectorStore(
372
+ client=global_qdrant_client,
373
+ collection_name=INITIAL_EMBEDDINGS_NAME,
374
+ embedding=embedding_model
375
+ ).as_retriever(search_kwargs={"k": top_k})
376
+
377
+ # Create a retrieval chain for core documents
378
+ core_retrieval_chain = (
379
+ {"context": itemgetter("question") | core_retriever | format_docs,
380
+ "question": itemgetter("question")}
381
+ | rag_prompt
382
+ | chat_model
383
+ | StrOutputParser()
384
+ )
385
+
386
+ # Get results from core reference
387
+ result = core_retrieval_chain.invoke({"question": query})
388
+
389
+ # Get the session-specific Qdrant client
390
+ session_qdrant_client = cl.user_session.get("session_qdrant_client")
391
+ if not session_qdrant_client:
392
+ return result # Return only core results if no session client
393
 
394
  # If we have a user collection, also search that
395
  try:
396
  # Check if user collection exists
397
+ if USER_EMBEDDINGS_NAME not in [c.name for c in session_qdrant_client.get_collections().collections]:
398
  # If no user collection exists yet, just return core reference results
399
  return result
400
 
401
  # Create a retrieval chain for user documents
402
  user_retriever = QdrantVectorStore(
403
+ client=session_qdrant_client,
404
  collection_name=USER_EMBEDDINGS_NAME,
405
  embedding=embedding_model
406
  ).as_retriever(search_kwargs={"k": top_k})
 
425
  @tool
426
  def load_and_embed_protocol(file_path: str = None) -> str:
427
  """Load and embed a protocol PDF file into the vector store."""
428
+ # Get the session-specific Qdrant client
429
+ session_qdrant_client = cl.user_session.get("session_qdrant_client")
430
+ if not session_qdrant_client:
431
+ return "No session-specific Qdrant client found. Please restart the chat."
432
+
433
+ # If no specific file path is provided, use all PDFs in the upload directory
434
+ if not file_path:
435
+ uploaded_files = [f for f in os.listdir(UPLOAD_PATH) if f.endswith('.pdf')]
436
+ if not uploaded_files:
437
+ return "No protocol documents found in the upload directory."
438
 
439
+ # Create file objects for processing
440
+ files = []
441
+ for filename in uploaded_files:
442
+ file_path = os.path.join(UPLOAD_PATH, filename)
443
+ # Create a simple object with the necessary attributes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
  class FileObj:
445
  def __init__(self, path, name, size):
446
  self.path = path
447
  self.name = name
448
  self.size = size
449
 
450
+ file_size = os.path.getsize(file_path)
451
+ files.append(FileObj(file_path, filename, file_size))
452
+ else:
453
+ # Create a file object for the specific file
454
+ if not os.path.exists(file_path):
455
+ return f"File not found: {file_path}"
456
 
457
+ filename = os.path.basename(file_path)
458
+ file_size = os.path.getsize(file_path)
 
 
459
 
460
+ class FileObj:
461
+ def __init__(self, path, name, size):
462
+ self.path = path
463
+ self.name = name
464
+ self.size = size
465
+
466
+ files = [FileObj(file_path, filename, file_size)]
467
+
468
+ # Process the files asynchronously
469
+ import asyncio
470
+ documents_with_metadata = asyncio.run(load_and_chunk_protocol_files(files))
471
+ user_vectorstore = asyncio.run(embed_protocol_in_qdrant(documents_with_metadata, session_qdrant_client, EMBEDDING_MODEL_ID))
472
+
473
+ if user_vectorstore:
474
+ return f"Successfully embedded {len(documents_with_metadata)} chunks from {len(files)} protocol document(s)."
475
+ else:
476
+ return "Failed to embed protocol document(s)."
477
 
478
  @tool
479
  def search_protocol_for_instruments(domain: str) -> dict:
480
  """Search the protocol for instruments related to a specific NIH HEAL CDE core domain."""
481
  global embedding_model
482
 
483
+ # Get the session-specific Qdrant client
484
+ session_qdrant_client = cl.user_session.get("session_qdrant_client")
485
+ if not session_qdrant_client:
486
+ return {"domain": domain, "instrument": "No session-specific Qdrant client found", "context": ""}
487
+
488
  # Check if user collection exists
489
  try:
490
  # Check if collection exists
491
+ if USER_EMBEDDINGS_NAME not in [c.name for c in session_qdrant_client.get_collections().collections]:
492
  return {"domain": domain, "instrument": "No protocol document embedded", "context": ""}
493
 
494
  # Create retriever for user documents
495
  user_retriever = QdrantVectorStore(
496
+ client=session_qdrant_client,
497
  collection_name=USER_EMBEDDINGS_NAME,
498
  embedding=embedding_model
499
  ).as_retriever(search_kwargs={"k": 10})
 
659
  last_ai_message = messages[-1]
660
  response = final_model.invoke(
661
  [
662
+ #SystemMessage("Rewrite this in the voice of a helpful and kind assistant"),
663
+ SystemMessage("do not alter just present the information"),
664
  HumanMessage(last_ai_message.content),
665
  ]
666
  )
 
689
  # ==================== CHAINLIT HANDLERS ====================
690
  @cl.on_chat_start
691
  async def on_chat_start():
692
+ # Create a session-specific Qdrant client
693
+ session_qdrant_client = create_session_qdrant_client()
694
+ cl.user_session.set("session_qdrant_client", session_qdrant_client)
695
+
696
+ # Create a retriever for the core embeddings
697
+ global core_vectorstore
698
+ if core_vectorstore:
699
+ core_retriever = core_vectorstore.as_retriever()
700
+ cl.user_session.set("core_retriever", core_retriever)
701
+
702
  # Welcome message
703
  welcome_msg = cl.Message(content="Welcome! Please upload a NIH HEAL protocol PDF file to get started.")
704
  await welcome_msg.send()
 
715
  processing_msg = cl.Message(content="Processing your protocol PDF file...")
716
  await processing_msg.send()
717
 
718
+ # Process the uploaded files with the session-specific client
719
  documents_with_metadata = await load_and_chunk_protocol_files(files)
720
+ user_vectorstore = await embed_protocol_in_qdrant(documents_with_metadata, session_qdrant_client)
721
 
722
  if user_vectorstore:
723
  analysis_msg = cl.Message(content="Analyzing your protocol to identify instruments (CRF questionaires) for NIH HEAL CDE core domains...")