drewgenai commited on
Commit
614a873
·
1 Parent(s): 56f76ad

reorder remove unessary functions

Browse files
Files changed (1) hide show
  1. app.py +206 -95
app.py CHANGED
@@ -30,25 +30,38 @@ from langchain_core.output_parsers import StrOutputParser
30
  # Load environment variables
31
  load_dotenv()
32
 
33
- # Constants
 
34
  UPLOAD_PATH = "./uploads"
35
  INITIAL_EMBEDDINGS_DIR = "./initial_embeddings"
36
  INITIAL_EMBEDDINGS_NAME = "initial_embeddings"
37
  USER_EMBEDDINGS_NAME = "user_embeddings"
 
38
 
39
- #XLSX_MODEL_ID = "Snowflake/snowflake-arctic-embed-m"
40
- XLSX_MODEL_ID = "pritamdeka/S-PubMedBert-MS-MARCO"
41
- #PDF_MODEL_ID = "Snowflake/snowflake-arctic-embed-m"
42
- PDF_MODEL_ID = "pritamdeka/S-PubMedBert-MS-MARCO"
43
-
44
  INSTRUMENT_SEARCH_LLM = "gpt-4o" # LLM for searching instruments
45
  INSTRUMENT_ANALYSIS_LLM = "gpt-4o" # LLM for analyzing all domains
46
 
47
- # Add this after the other global variables
48
- pdf_embedding_model = None
49
- xlsx_embedding_model = None
 
 
 
 
 
 
 
 
 
 
50
 
51
- # Add this utility function before the initialize_embedding_models function
 
 
 
52
  def get_embedding_model(model_id):
53
  """Creates and returns the appropriate embedding model based on the model ID."""
54
  if "text-embedding" in model_id:
@@ -59,41 +72,39 @@ def get_embedding_model(model_id):
59
  # HuggingFace embeddings
60
  return HuggingFaceEmbeddings(model_name=model_id)
61
 
62
- # Initialize embedding models
63
  def initialize_embedding_models():
64
- """Initialize the embedding models once at startup"""
65
- global pdf_embedding_model, xlsx_embedding_model
66
- pdf_embedding_model = get_embedding_model(PDF_MODEL_ID)
67
- xlsx_embedding_model = get_embedding_model(XLSX_MODEL_ID)
68
- print(f"Initialized embedding models: {PDF_MODEL_ID} and {XLSX_MODEL_ID}")
 
 
69
 
70
- # Call this function after loading environment variables
71
  initialize_embedding_models()
72
 
73
- # Make sure upload directory exists
74
- os.makedirs(UPLOAD_PATH, exist_ok=True)
75
-
76
- # NIH HEAL CDE core domains
77
- NIH_HEAL_CORE_DOMAINS = [
78
- "Anxiety",
79
- "Depression",
80
- "Global satisfaction with treatment",
81
- "Pain catastrophizing",
82
- "Pain interference",
83
- "Pain intensity",
84
- "Physical functioning",
85
- "Quality of Life (QoL)",
86
- "Sleep",
87
- "Substance Use Screener"
88
- ]
89
 
 
90
  # Initialize Qdrant (in-memory)
91
  qdrant_client = QdrantClient(":memory:")
92
 
93
- # Create a semantic splitter for PDF documents
 
94
  semantic_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=50)
95
 
96
- # Utility functions
 
 
 
 
97
  def load_and_chunk_excel_files():
98
  """Loads all .xlsx files from the initial embeddings directory and splits them into chunks."""
99
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=50)
@@ -115,6 +126,7 @@ def load_and_chunk_excel_files():
115
  for chunk in chunks:
116
  chunk.metadata = chunk.metadata or {}
117
  chunk.metadata["filename"] = file
 
118
 
119
  all_chunks.extend(chunks)
120
  file_count += 1
@@ -128,31 +140,40 @@ def load_and_chunk_excel_files():
128
 
129
  def embed_chunks_in_qdrant(chunks):
130
  """Embeds document chunks and stores them in Qdrant."""
131
- global xlsx_embedding_model
132
 
133
  if not chunks:
134
  print("No Excel files found to process or all files were empty.")
135
  return None
136
 
137
- print(f"Using embedding model: {XLSX_MODEL_ID}")
 
 
 
 
 
138
  print("Creating vector store...")
139
- vector_store = QdrantVectorStore.from_documents(
140
- documents=chunks,
141
- embedding=xlsx_embedding_model,
142
- location=":memory:",
143
- collection_name=INITIAL_EMBEDDINGS_NAME
144
- )
145
- print(f"Successfully loaded all .xlsx files into Qdrant collection '{INITIAL_EMBEDDINGS_NAME}'.")
146
- return vector_store
 
 
 
 
 
 
147
 
148
  def process_initial_embeddings():
149
  """Loads all .xlsx files, extracts text, embeds, and stores in Qdrant."""
150
  chunks = load_and_chunk_excel_files()
151
  return embed_chunks_in_qdrant(chunks)
152
 
153
- def format_docs(docs):
154
- return "\n\n".join(doc.page_content for doc in docs)
155
-
156
  async def load_and_chunk_pdf_files(files):
157
  """Load PDF files and split them into chunks with metadata."""
158
  print(f"Loading {len(files)} uploaded PDF files")
@@ -176,7 +197,13 @@ async def load_and_chunk_pdf_files(files):
176
  source_name = file.name
177
  chunks = semantic_splitter.split_text(doc.page_content)
178
  for chunk in chunks:
179
- doc_chunk = Document(page_content=chunk, metadata={"source": source_name})
 
 
 
 
 
 
180
  documents_with_metadata.append(doc_chunk)
181
 
182
  print(f"Successfully processed {file.name}, extracted {len(documents_with_metadata)} chunks")
@@ -185,17 +212,9 @@ async def load_and_chunk_pdf_files(files):
185
 
186
  return documents_with_metadata
187
 
188
- # Add this utility function to get vector dimensions
189
- def get_embedding_dimensions(model_id):
190
- """Gets the dimensions of embeddings from a specific model."""
191
- model = get_embedding_model(model_id)
192
- sample_text = "Sample text to determine embedding dimension"
193
- sample_embedding = model.embed_query(sample_text)
194
- return len(sample_embedding)
195
-
196
- async def embed_pdf_chunks_in_qdrant(documents_with_metadata, model_name=PDF_MODEL_ID):
197
  """Create a vector store and embed PDF chunks into Qdrant."""
198
- global pdf_embedding_model
199
 
200
  if not documents_with_metadata:
201
  print("No documents to embed")
@@ -210,7 +229,7 @@ async def embed_pdf_chunks_in_qdrant(documents_with_metadata, model_name=PDF_MOD
210
 
211
  # Create the collection with proper parameters
212
  # Get the embedding dimension from the model
213
- embedding_dimension = len(pdf_embedding_model.embed_query("Sample text"))
214
 
215
  qdrant_client.create_collection(
216
  collection_name=USER_EMBEDDINGS_NAME,
@@ -221,7 +240,7 @@ async def embed_pdf_chunks_in_qdrant(documents_with_metadata, model_name=PDF_MOD
221
  user_vectorstore = QdrantVectorStore(
222
  client=qdrant_client,
223
  collection_name=USER_EMBEDDINGS_NAME,
224
- embedding=pdf_embedding_model
225
  )
226
 
227
  # Add documents to the vector store
@@ -233,34 +252,31 @@ async def embed_pdf_chunks_in_qdrant(documents_with_metadata, model_name=PDF_MOD
233
  print(f"Error creating vector store: {str(e)}")
234
  return None
235
 
236
- async def process_uploaded_files(files, model_name=PDF_MODEL_ID):
237
  """Process uploaded PDF files and add them to a separate vector store collection"""
238
  documents_with_metadata = await load_and_chunk_pdf_files(files)
239
  return await embed_pdf_chunks_in_qdrant(documents_with_metadata, model_name)
240
 
241
- # Data processing and initialization
242
- vectorstore = process_initial_embeddings()
243
 
244
- # Create retrievers for each collection
245
- if vectorstore:
246
- # Retriever for initial Excel embeddings
247
- excel_retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
248
- print("Excel retriever created successfully.")
249
- else:
250
- print("Failed to create Excel retriever: No vector store available.")
 
 
 
 
 
 
 
 
 
251
 
252
- # The PDF retriever is created dynamically when files are uploaded
253
- # in the embed_pdf_chunks_in_qdrant function:
254
- #
255
- # user_vectorstore = QdrantVectorStore(
256
- # client=qdrant_client,
257
- # collection_name=USER_EMBEDDINGS_NAME,
258
- # embedding=pdf_model
259
- # )
260
- #
261
- # user_retriever = user_vectorstore.as_retriever(search_kwargs={"k": top_k})
262
-
263
- # RAG setup for Excel data
264
  RAG_TEMPLATE = """\
265
  You are a helpful and kind assistant. Use the context provided below to answer the question.
266
 
@@ -274,9 +290,33 @@ Context:
274
  """
275
 
276
  rag_prompt = ChatPromptTemplate.from_template(RAG_TEMPLATE)
277
-
278
  chat_model = ChatOpenAI()
279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  # Chain for retrieving from Excel embeddings
281
  initialembeddings_retrieval_chain = (
282
  {"context": itemgetter("question") | excel_retriever | format_docs,
@@ -286,11 +326,20 @@ initialembeddings_retrieval_chain = (
286
  | StrOutputParser()
287
  )
288
 
289
- # Tool definitions
 
 
 
 
 
 
 
 
 
290
  @tool
291
  def search_excel_data(query: str, top_k: int = 3) -> str:
292
  """Search both Excel data and user-uploaded PDF data for information related to the query."""
293
- global pdf_embedding_model
294
 
295
  # Use the existing initialembeddings_retrieval_chain
296
  result = initialembeddings_retrieval_chain.invoke({"question": query})
@@ -306,7 +355,7 @@ def search_excel_data(query: str, top_k: int = 3) -> str:
306
  user_retriever = QdrantVectorStore(
307
  client=qdrant_client,
308
  collection_name=USER_EMBEDDINGS_NAME,
309
- embedding=pdf_embedding_model # Use the global model
310
  ).as_retriever(search_kwargs={"k": top_k})
311
 
312
  user_retrieval_chain = (
@@ -375,7 +424,7 @@ def load_and_embed_protocol_pdf(file_path: str = None) -> str:
375
  # Process the files asynchronously
376
  import asyncio
377
  documents_with_metadata = asyncio.run(load_and_chunk_pdf_files(files))
378
- user_vectorstore = asyncio.run(embed_pdf_chunks_in_qdrant(documents_with_metadata, PDF_MODEL_ID))
379
 
380
  if user_vectorstore:
381
  return f"Successfully embedded {len(documents_with_metadata)} chunks from {len(files)} protocol document(s)."
@@ -387,7 +436,7 @@ def load_and_embed_protocol_pdf(file_path: str = None) -> str:
387
  @tool
388
  def search_protocol(query: str, top_k: int = 5) -> str:
389
  """Search the protocol for information related to the query."""
390
- global pdf_embedding_model
391
 
392
  try:
393
  # Check if user collection exists
@@ -398,7 +447,7 @@ def search_protocol(query: str, top_k: int = 5) -> str:
398
  user_retriever = QdrantVectorStore(
399
  client=qdrant_client,
400
  collection_name=USER_EMBEDDINGS_NAME,
401
- embedding=pdf_embedding_model # Use the global model
402
  ).as_retriever(search_kwargs={"k": top_k})
403
 
404
  user_retrieval_chain = (
@@ -417,7 +466,7 @@ def search_protocol(query: str, top_k: int = 5) -> str:
417
  @tool
418
  def search_protocol_for_instruments(domain: str) -> dict:
419
  """Search the protocol for instruments related to a specific NIH HEAL CDE core domain."""
420
- global pdf_embedding_model
421
 
422
  # Check if user collection exists
423
  try:
@@ -429,7 +478,7 @@ def search_protocol_for_instruments(domain: str) -> dict:
429
  user_retriever = QdrantVectorStore(
430
  client=qdrant_client,
431
  collection_name=USER_EMBEDDINGS_NAME,
432
- embedding=pdf_embedding_model # Use the global model
433
  ).as_retriever(search_kwargs={"k": 10})
434
  except Exception as e:
435
  print(f"Error accessing user vector store: {str(e)}")
@@ -476,6 +525,44 @@ def search_protocol_for_instruments(domain: str) -> dict:
476
  print(f"Error identifying instrument for {domain}: {str(e)}")
477
  return {"domain": domain, "instrument": "Error during identification", "context": str(e)}
478
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
  @tool
480
  def analyze_all_heal_domains() -> str:
481
  """Analyze all NIH HEAL CDE core domains and identify instruments used in the protocol.
@@ -507,6 +594,25 @@ def analyze_all_heal_domains() -> str:
507
 
508
  return result
509
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
  @tool
511
  def format_instrument_analysis(analysis_results: list, title: str = "NIH HEAL CDE Core Domains Analysis") -> str:
512
  """Format instrument analysis results into a markdown table.
@@ -530,21 +636,25 @@ def format_instrument_analysis(analysis_results: list, title: str = "NIH HEAL CD
530
 
531
  return result
532
 
533
- # Update the tools list
534
  tools = [
 
535
  search_excel_data,
536
  load_and_embed_protocol_pdf,
537
  search_protocol,
538
  search_protocol_for_instruments,
 
539
  analyze_all_heal_domains,
 
540
  format_instrument_analysis
541
  ]
542
 
 
543
  # LangGraph components
544
  model = ChatOpenAI(model_name=INSTRUMENT_ANALYSIS_LLM, temperature=0)
545
  final_model = ChatOpenAI(model_name=INSTRUMENT_ANALYSIS_LLM, temperature=0)
546
 
547
- # Update the system message
548
  system_message = """You are a helpful assistant specializing in NIH HEAL CDE protocols.
549
 
550
  You have access to:
@@ -621,7 +731,7 @@ builder.add_edge("final", END)
621
 
622
  graph = builder.compile()
623
 
624
- # Chainlit handlers
625
  @cl.on_chat_start
626
  async def on_chat_start():
627
  # Welcome message
@@ -702,3 +812,4 @@ async def on_message(msg: cl.Message):
702
 
703
  await final_answer.send()
704
 
 
 
30
  # Load environment variables
31
  load_dotenv()
32
 
33
+ # ==================== CONSTANTS ====================
34
+ # Paths and directories
35
  UPLOAD_PATH = "./uploads"
36
  INITIAL_EMBEDDINGS_DIR = "./initial_embeddings"
37
  INITIAL_EMBEDDINGS_NAME = "initial_embeddings"
38
  USER_EMBEDDINGS_NAME = "user_embeddings"
39
+ VECTOR_STORE_COLLECTION = "documents"
40
 
41
+ # Model IDs
42
+ EMBEDDING_MODEL_ID = "pritamdeka/S-PubMedBert-MS-MARCO"
43
+ #EMBEDDING_MODEL_ID = "Snowflake/snowflake-arctic-embed-m"
 
 
44
  INSTRUMENT_SEARCH_LLM = "gpt-4o" # LLM for searching instruments
45
  INSTRUMENT_ANALYSIS_LLM = "gpt-4o" # LLM for analyzing all domains
46
 
47
+ # NIH HEAL CDE core domains
48
+ NIH_HEAL_CORE_DOMAINS = [
49
+ "Anxiety",
50
+ "Depression",
51
+ "Global satisfaction with treatment",
52
+ "Pain catastrophizing",
53
+ "Pain interference",
54
+ "Pain intensity",
55
+ "Physical functioning",
56
+ "Quality of Life (QoL)",
57
+ "Sleep",
58
+ "Substance Use Screener"
59
+ ]
60
 
61
+ # Make sure upload directory exists
62
+ os.makedirs(UPLOAD_PATH, exist_ok=True)
63
+
64
+ # ==================== EMBEDDING MODEL SETUP ====================
65
  def get_embedding_model(model_id):
66
  """Creates and returns the appropriate embedding model based on the model ID."""
67
  if "text-embedding" in model_id:
 
72
  # HuggingFace embeddings
73
  return HuggingFaceEmbeddings(model_name=model_id)
74
 
 
75
  def initialize_embedding_models():
76
+ """Initialize a single embedding model for all document types"""
77
+ global embedding_model
78
+
79
+ # Initialize a single model for all document types
80
+ embedding_model = get_embedding_model(EMBEDDING_MODEL_ID)
81
+
82
+ print(f"Initialized embedding model: {EMBEDDING_MODEL_ID}")
83
 
84
+ # Initialize the embedding model
85
  initialize_embedding_models()
86
 
87
+ # Get embedding dimensions utility
88
+ def get_embedding_dimensions(model_id):
89
+ """Gets the dimensions of embeddings from a specific model."""
90
+ model = get_embedding_model(model_id)
91
+ sample_text = "Sample text to determine embedding dimension"
92
+ sample_embedding = model.embed_query(sample_text)
93
+ return len(sample_embedding)
 
 
 
 
 
 
 
 
 
94
 
95
+ # ==================== QDRANT SETUP ====================
96
  # Initialize Qdrant (in-memory)
97
  qdrant_client = QdrantClient(":memory:")
98
 
99
+ # ==================== DOCUMENT PROCESSING ====================
100
+ # Create a semantic splitter for documents
101
  semantic_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=50)
102
 
103
+ def format_docs(docs):
104
+ """Format a list of documents into a single string."""
105
+ return "\n\n".join(doc.page_content for doc in docs)
106
+
107
+ # ==================== EXCEL DOCUMENT PROCESSING ====================
108
  def load_and_chunk_excel_files():
109
  """Loads all .xlsx files from the initial embeddings directory and splits them into chunks."""
110
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=50)
 
126
  for chunk in chunks:
127
  chunk.metadata = chunk.metadata or {}
128
  chunk.metadata["filename"] = file
129
+ chunk.metadata["type"] = "excel" # Add document type
130
 
131
  all_chunks.extend(chunks)
132
  file_count += 1
 
140
 
141
  def embed_chunks_in_qdrant(chunks):
142
  """Embeds document chunks and stores them in Qdrant."""
143
+ global embedding_model
144
 
145
  if not chunks:
146
  print("No Excel files found to process or all files were empty.")
147
  return None
148
 
149
+ # Ensure we have a valid embedding model
150
+ if embedding_model is None:
151
+ print("ERROR: No embedding model available. Initializing now.")
152
+ initialize_embedding_models()
153
+
154
+ print(f"Using embedding model: {EMBEDDING_MODEL_ID}")
155
  print("Creating vector store...")
156
+
157
+ try:
158
+ vector_store = QdrantVectorStore.from_documents(
159
+ documents=chunks,
160
+ embedding=embedding_model,
161
+ location=":memory:",
162
+ collection_name=INITIAL_EMBEDDINGS_NAME
163
+ )
164
+ print(f"Successfully loaded all .xlsx files into Qdrant collection '{INITIAL_EMBEDDINGS_NAME}'.")
165
+ return vector_store
166
+ except Exception as e:
167
+ print(f"Error creating vector store: {str(e)}")
168
+ print(f"Embedding model status: {embedding_model is not None}")
169
+ return None
170
 
171
  def process_initial_embeddings():
172
  """Loads all .xlsx files, extracts text, embeds, and stores in Qdrant."""
173
  chunks = load_and_chunk_excel_files()
174
  return embed_chunks_in_qdrant(chunks)
175
 
176
+ # ==================== PDF DOCUMENT PROCESSING ====================
 
 
177
  async def load_and_chunk_pdf_files(files):
178
  """Load PDF files and split them into chunks with metadata."""
179
  print(f"Loading {len(files)} uploaded PDF files")
 
197
  source_name = file.name
198
  chunks = semantic_splitter.split_text(doc.page_content)
199
  for chunk in chunks:
200
+ doc_chunk = Document(
201
+ page_content=chunk,
202
+ metadata={
203
+ "source": source_name,
204
+ "type": "pdf" # Add document type
205
+ }
206
+ )
207
  documents_with_metadata.append(doc_chunk)
208
 
209
  print(f"Successfully processed {file.name}, extracted {len(documents_with_metadata)} chunks")
 
212
 
213
  return documents_with_metadata
214
 
215
+ async def embed_pdf_chunks_in_qdrant(documents_with_metadata, model_name=EMBEDDING_MODEL_ID):
 
 
 
 
 
 
 
 
216
  """Create a vector store and embed PDF chunks into Qdrant."""
217
+ global embedding_model
218
 
219
  if not documents_with_metadata:
220
  print("No documents to embed")
 
229
 
230
  # Create the collection with proper parameters
231
  # Get the embedding dimension from the model
232
+ embedding_dimension = len(embedding_model.embed_query("Sample text"))
233
 
234
  qdrant_client.create_collection(
235
  collection_name=USER_EMBEDDINGS_NAME,
 
240
  user_vectorstore = QdrantVectorStore(
241
  client=qdrant_client,
242
  collection_name=USER_EMBEDDINGS_NAME,
243
+ embedding=embedding_model
244
  )
245
 
246
  # Add documents to the vector store
 
252
  print(f"Error creating vector store: {str(e)}")
253
  return None
254
 
255
+ async def process_uploaded_files(files, model_name=EMBEDDING_MODEL_ID):
256
  """Process uploaded PDF files and add them to a separate vector store collection"""
257
  documents_with_metadata = await load_and_chunk_pdf_files(files)
258
  return await embed_pdf_chunks_in_qdrant(documents_with_metadata, model_name)
259
 
 
 
260
 
261
+ # ==================== RETRIEVAL FUNCTIONS ====================
262
+ def retrieve_documents(query, doc_type=None, k=5):
263
+ """Retrieve documents, optionally filtering by document type"""
264
+ vector_store = QdrantVectorStore(
265
+ client=qdrant_client,
266
+ collection_name=VECTOR_STORE_COLLECTION,
267
+ embedding=embedding_model
268
+ )
269
+
270
+ # Set up filter if doc_type is specified
271
+ search_kwargs = {"k": k}
272
+ if doc_type:
273
+ search_kwargs["filter"] = {"type": doc_type}
274
+
275
+ retriever = vector_store.as_retriever(search_kwargs=search_kwargs)
276
+ return retriever.invoke(query)
277
 
278
+ # ==================== RAG SETUP ====================
279
+ # RAG template for all retrievals
 
 
 
 
 
 
 
 
 
 
280
  RAG_TEMPLATE = """\
281
  You are a helpful and kind assistant. Use the context provided below to answer the question.
282
 
 
290
  """
291
 
292
  rag_prompt = ChatPromptTemplate.from_template(RAG_TEMPLATE)
 
293
  chat_model = ChatOpenAI()
294
 
295
+ # Create a RAG chain that can be filtered by document type
296
+ def create_rag_chain(doc_type=None):
297
+ """Create a RAG chain that can be filtered by document type"""
298
+ def retrieve_with_type(query):
299
+ docs = retrieve_documents(query, doc_type=doc_type)
300
+ return format_docs(docs)
301
+
302
+ chain = (
303
+ {"context": lambda x: retrieve_with_type(x["question"]),
304
+ "question": itemgetter("question")}
305
+ | rag_prompt
306
+ | chat_model
307
+ | StrOutputParser()
308
+ )
309
+
310
+ return chain
311
+
312
+ # Initialize the Excel retriever
313
+ vectorstore = process_initial_embeddings()
314
+ if vectorstore:
315
+ excel_retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
316
+ print("Excel retriever created successfully.")
317
+ else:
318
+ print("Failed to create Excel retriever: No vector store available.")
319
+
320
  # Chain for retrieving from Excel embeddings
321
  initialembeddings_retrieval_chain = (
322
  {"context": itemgetter("question") | excel_retriever | format_docs,
 
326
  | StrOutputParser()
327
  )
328
 
329
+ # ==================== TOOL DEFINITIONS ====================
330
+ @tool
331
+ def search_data(query: str, doc_type: str = None) -> str:
332
+ """Search all data or filter by document type (pdf/excel)"""
333
+ try:
334
+ chain = create_rag_chain(doc_type)
335
+ return chain.invoke({"question": query})
336
+ except Exception as e:
337
+ return f"Error searching data: {str(e)}"
338
+
339
  @tool
340
  def search_excel_data(query: str, top_k: int = 3) -> str:
341
  """Search both Excel data and user-uploaded PDF data for information related to the query."""
342
+ global embedding_model
343
 
344
  # Use the existing initialembeddings_retrieval_chain
345
  result = initialembeddings_retrieval_chain.invoke({"question": query})
 
355
  user_retriever = QdrantVectorStore(
356
  client=qdrant_client,
357
  collection_name=USER_EMBEDDINGS_NAME,
358
+ embedding=embedding_model
359
  ).as_retriever(search_kwargs={"k": top_k})
360
 
361
  user_retrieval_chain = (
 
424
  # Process the files asynchronously
425
  import asyncio
426
  documents_with_metadata = asyncio.run(load_and_chunk_pdf_files(files))
427
+ user_vectorstore = asyncio.run(embed_pdf_chunks_in_qdrant(documents_with_metadata, EMBEDDING_MODEL_ID))
428
 
429
  if user_vectorstore:
430
  return f"Successfully embedded {len(documents_with_metadata)} chunks from {len(files)} protocol document(s)."
 
436
  @tool
437
  def search_protocol(query: str, top_k: int = 5) -> str:
438
  """Search the protocol for information related to the query."""
439
+ global embedding_model
440
 
441
  try:
442
  # Check if user collection exists
 
447
  user_retriever = QdrantVectorStore(
448
  client=qdrant_client,
449
  collection_name=USER_EMBEDDINGS_NAME,
450
+ embedding=embedding_model
451
  ).as_retriever(search_kwargs={"k": top_k})
452
 
453
  user_retrieval_chain = (
 
466
  @tool
467
  def search_protocol_for_instruments(domain: str) -> dict:
468
  """Search the protocol for instruments related to a specific NIH HEAL CDE core domain."""
469
+ global embedding_model
470
 
471
  # Check if user collection exists
472
  try:
 
478
  user_retriever = QdrantVectorStore(
479
  client=qdrant_client,
480
  collection_name=USER_EMBEDDINGS_NAME,
481
+ embedding=embedding_model
482
  ).as_retriever(search_kwargs={"k": 10})
483
  except Exception as e:
484
  print(f"Error accessing user vector store: {str(e)}")
 
525
  print(f"Error identifying instrument for {domain}: {str(e)}")
526
  return {"domain": domain, "instrument": "Error during identification", "context": str(e)}
527
 
528
+ @tool
529
+ def analyze_domain(domain: str) -> dict:
530
+ """Analyze a specific NIH HEAL CDE core domain"""
531
+ # Query for this specific domain
532
+ query = f"What instrument or measure is used for {domain} in the protocol?"
533
+
534
+ # Get protocol context
535
+ protocol_docs = retrieve_documents(query, doc_type="pdf", k=5)
536
+ protocol_context = format_docs(protocol_docs)
537
+
538
+ # Get known instruments from Excel data
539
+ excel_query = f"What are standard instruments or measures for {domain}?"
540
+ excel_docs = retrieve_documents(excel_query, doc_type="excel", k=5)
541
+ excel_context = format_docs(excel_docs)
542
+
543
+ # Use the model to identify the instrument
544
+ prompt = f"""
545
+ Based on the protocol information and known instruments, identify which instrument is being used for the domain: {domain}
546
+
547
+ Protocol information:
548
+ {protocol_context}
549
+
550
+ Known instruments for this domain:
551
+ {excel_context}
552
+
553
+ Respond with only the name of the identified instrument. If you cannot identify a specific instrument, respond with "Not identified".
554
+ """
555
+
556
+ instrument = ChatOpenAI(model_name=INSTRUMENT_SEARCH_LLM, temperature=0).invoke(
557
+ [HumanMessage(content=prompt)]
558
+ ).content
559
+
560
+ return {
561
+ "domain": domain,
562
+ "instrument": instrument.strip(),
563
+ "context": protocol_context
564
+ }
565
+
566
  @tool
567
  def analyze_all_heal_domains() -> str:
568
  """Analyze all NIH HEAL CDE core domains and identify instruments used in the protocol.
 
594
 
595
  return result
596
 
597
+ @tool
598
+ def analyze_all_domains() -> str:
599
+ """Analyze all NIH HEAL CDE core domains at once"""
600
+ results = []
601
+
602
+ for domain in NIH_HEAL_CORE_DOMAINS:
603
+ result = analyze_domain(domain)
604
+ results.append(result)
605
+
606
+ # Format as markdown table
607
+ markdown = "# NIH HEAL CDE Core Domains Analysis\n\n"
608
+ markdown += "| Domain | Protocol Instrument |\n"
609
+ markdown += "|--------|--------------------|\n"
610
+
611
+ for result in results:
612
+ markdown += f"| {result['domain']} | {result['instrument']} |\n"
613
+
614
+ return markdown
615
+
616
  @tool
617
  def format_instrument_analysis(analysis_results: list, title: str = "NIH HEAL CDE Core Domains Analysis") -> str:
618
  """Format instrument analysis results into a markdown table.
 
636
 
637
  return result
638
 
639
+ # Collect all tools
640
  tools = [
641
+ search_data,
642
  search_excel_data,
643
  load_and_embed_protocol_pdf,
644
  search_protocol,
645
  search_protocol_for_instruments,
646
+ analyze_domain,
647
  analyze_all_heal_domains,
648
+ analyze_all_domains,
649
  format_instrument_analysis
650
  ]
651
 
652
+ # ==================== LANGGRAPH SETUP ====================
653
  # LangGraph components
654
  model = ChatOpenAI(model_name=INSTRUMENT_ANALYSIS_LLM, temperature=0)
655
  final_model = ChatOpenAI(model_name=INSTRUMENT_ANALYSIS_LLM, temperature=0)
656
 
657
+ # System message
658
  system_message = """You are a helpful assistant specializing in NIH HEAL CDE protocols.
659
 
660
  You have access to:
 
731
 
732
  graph = builder.compile()
733
 
734
+ # ==================== CHAINLIT HANDLERS ====================
735
  @cl.on_chat_start
736
  async def on_chat_start():
737
  # Welcome message
 
812
 
813
  await final_answer.send()
814
 
815
+