drewgenai commited on
Commit
4f17e1e
·
1 Parent(s): bb7b31c

update function/tool names

Browse files
Files changed (1) hide show
  1. app.py +73 -79
app.py CHANGED
@@ -104,8 +104,8 @@ def format_docs(docs):
104
  """Format a list of documents into a single string."""
105
  return "\n\n".join(doc.page_content for doc in docs)
106
 
107
- # ==================== EXCEL DOCUMENT PROCESSING ====================
108
- def load_and_chunk_excel_files():
109
  """Loads all .xlsx files from the initial embeddings directory and splits them into chunks."""
110
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=50)
111
  all_chunks = []
@@ -138,8 +138,8 @@ def load_and_chunk_excel_files():
138
  print(f"Processed {file_count} Excel files with a total of {len(all_chunks)} chunks.")
139
  return all_chunks
140
 
141
- def embed_chunks_in_qdrant(chunks):
142
- """Embeds document chunks and stores them in Qdrant."""
143
  global embedding_model
144
 
145
  if not chunks:
@@ -168,14 +168,14 @@ def embed_chunks_in_qdrant(chunks):
168
  print(f"Embedding model status: {embedding_model is not None}")
169
  return None
170
 
171
- def process_initial_embeddings():
172
  """Loads all .xlsx files, extracts text, embeds, and stores in Qdrant."""
173
- chunks = load_and_chunk_excel_files()
174
- return embed_chunks_in_qdrant(chunks)
175
 
176
- # ==================== PDF DOCUMENT PROCESSING ====================
177
- async def load_and_chunk_pdf_files(files):
178
- """Load PDF files and split them into chunks with metadata."""
179
  print(f"Loading {len(files)} uploaded PDF files")
180
  documents_with_metadata = []
181
 
@@ -212,8 +212,8 @@ async def load_and_chunk_pdf_files(files):
212
 
213
  return documents_with_metadata
214
 
215
- async def embed_pdf_chunks_in_qdrant(documents_with_metadata, model_name=EMBEDDING_MODEL_ID):
216
- """Create a vector store and embed PDF chunks into Qdrant."""
217
  global embedding_model
218
 
219
  if not documents_with_metadata:
@@ -252,11 +252,10 @@ async def embed_pdf_chunks_in_qdrant(documents_with_metadata, model_name=EMBEDDI
252
  print(f"Error creating vector store: {str(e)}")
253
  return None
254
 
255
- async def process_uploaded_files(files, model_name=EMBEDDING_MODEL_ID):
256
- """Process uploaded PDF files and add them to a separate vector store collection"""
257
- documents_with_metadata = await load_and_chunk_pdf_files(files)
258
- return await embed_pdf_chunks_in_qdrant(documents_with_metadata, model_name)
259
-
260
 
261
  # ==================== RETRIEVAL FUNCTIONS ====================
262
  def retrieve_documents(query, doc_type=None, k=5):
@@ -309,17 +308,17 @@ def create_rag_chain(doc_type=None):
309
 
310
  return chain
311
 
312
- # Initialize the Excel retriever
313
- vectorstore = process_initial_embeddings()
314
  if vectorstore:
315
- excel_retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
316
- print("Excel retriever created successfully.")
317
  else:
318
- print("Failed to create Excel retriever: No vector store available.")
319
 
320
- # Chain for retrieving from Excel embeddings
321
- initialembeddings_retrieval_chain = (
322
- {"context": itemgetter("question") | excel_retriever | format_docs,
323
  "question": itemgetter("question")}
324
  | rag_prompt
325
  | chat_model
@@ -328,8 +327,8 @@ initialembeddings_retrieval_chain = (
328
 
329
  # ==================== TOOL DEFINITIONS ====================
330
  @tool
331
- def search_data(query: str, doc_type: str = None) -> str:
332
- """Search all data or filter by document type (pdf/excel)"""
333
  try:
334
  chain = create_rag_chain(doc_type)
335
  return chain.invoke({"question": query})
@@ -337,18 +336,18 @@ def search_data(query: str, doc_type: str = None) -> str:
337
  return f"Error searching data: {str(e)}"
338
 
339
  @tool
340
- def search_excel_data(query: str, top_k: int = 3) -> str:
341
- """Search both Excel data and user-uploaded PDF data for information related to the query."""
342
  global embedding_model
343
 
344
- # Use the existing initialembeddings_retrieval_chain
345
- result = initialembeddings_retrieval_chain.invoke({"question": query})
346
 
347
  # If we have a user collection, also search that
348
  try:
349
  # Check if user collection exists
350
  if USER_EMBEDDINGS_NAME not in [c.name for c in qdrant_client.get_collections().collections]:
351
- # If no user collection exists yet, just return Excel results
352
  return result
353
 
354
  # Create a retrieval chain for user documents
@@ -369,14 +368,15 @@ def search_excel_data(query: str, top_k: int = 3) -> str:
369
  user_result = user_retrieval_chain.invoke({"question": query})
370
 
371
  # Combine results
372
- return f"From Excel files:\n{result}\n\nFrom your uploaded PDF:\n{user_result}"
373
  except Exception as e:
374
  print(f"Error searching user vector store: {str(e)}")
375
- # If error occurs, just return Excel results
376
  return result
377
 
378
  @tool
379
- def load_and_embed_protocol_pdf(file_path: str = None) -> str:
 
380
  """Load and embed a protocol PDF file into the vector store.
381
 
382
  Args:
@@ -423,8 +423,8 @@ def load_and_embed_protocol_pdf(file_path: str = None) -> str:
423
 
424
  # Process the files asynchronously
425
  import asyncio
426
- documents_with_metadata = asyncio.run(load_and_chunk_pdf_files(files))
427
- user_vectorstore = asyncio.run(embed_pdf_chunks_in_qdrant(documents_with_metadata, EMBEDDING_MODEL_ID))
428
 
429
  if user_vectorstore:
430
  return f"Successfully embedded {len(documents_with_metadata)} chunks from {len(files)} protocol document(s)."
@@ -433,7 +433,6 @@ def load_and_embed_protocol_pdf(file_path: str = None) -> str:
433
  except Exception as e:
434
  return f"Error embedding protocol document: {str(e)}"
435
 
436
-
437
  @tool
438
  def search_protocol_for_instruments(domain: str) -> dict:
439
  """Search the protocol for instruments related to a specific NIH HEAL CDE core domain."""
@@ -466,9 +465,9 @@ def search_protocol_for_instruments(domain: str) -> dict:
466
  docs = user_retriever.invoke(query)
467
  protocol_context = format_docs(docs)
468
 
469
- # Search for instruments in the Excel data that match this domain
470
- excel_query = f"What are standard instruments or measures for {domain}?"
471
- excel_instruments = initialembeddings_retrieval_chain.invoke({"question": excel_query})
472
 
473
  # Use the model to identify the most likely instrument for this domain
474
  prompt = f"""
@@ -478,7 +477,7 @@ def search_protocol_for_instruments(domain: str) -> dict:
478
  {protocol_context}
479
 
480
  Known instruments for this domain:
481
- {excel_instruments}
482
 
483
  Respond with only the name of the identified instrument. If you cannot identify a specific instrument, respond with "Not identified".
484
  """
@@ -490,21 +489,15 @@ def search_protocol_for_instruments(domain: str) -> dict:
490
  "domain": domain,
491
  "instrument": instrument.strip(),
492
  "context": protocol_context,
493
- "known_instruments": excel_instruments
494
  }
495
  except Exception as e:
496
  print(f"Error identifying instrument for {domain}: {str(e)}")
497
  return {"domain": domain, "instrument": "Error during identification", "context": str(e)}
498
 
499
-
500
-
501
  @tool
502
- def analyze_all_heal_domains() -> str:
503
- """Analyze all NIH HEAL CDE core domains and identify instruments used in the protocol.
504
-
505
- Returns:
506
- Markdown formatted table of domains and identified instruments
507
- """
508
  # Check if protocol document exists
509
  uploaded_files = [f for f in os.listdir(UPLOAD_PATH) if f.endswith('.pdf')]
510
  if not uploaded_files:
@@ -529,18 +522,19 @@ def analyze_all_heal_domains() -> str:
529
 
530
  return result
531
 
532
-
533
  @tool
534
- def format_instrument_analysis(analysis_results: list, title: str = "NIH HEAL CDE Core Domains Analysis") -> str:
535
- """Format instrument analysis results into a markdown table.
536
-
537
- Args:
538
  analysis_results: List of dictionaries with domain and instrument information
539
  title: Title for the markdown output
540
 
541
  Returns:
542
  Markdown formatted table of domains and identified instruments
 
 
543
  """
 
544
  # Format the results as a markdown table
545
  result = f"# {title}\n\n"
546
  result += "| Domain | Protocol Instrument |\n"
@@ -555,12 +549,12 @@ def format_instrument_analysis(analysis_results: list, title: str = "NIH HEAL CD
555
 
556
  # Collect all tools
557
  tools = [
558
- search_data,
559
- search_excel_data,
560
- load_and_embed_protocol_pdf,
561
  search_protocol_for_instruments,
562
- analyze_all_heal_domains,
563
- format_instrument_analysis
564
  ]
565
 
566
  # ==================== LANGGRAPH SETUP ====================
@@ -572,23 +566,23 @@ final_model = ChatOpenAI(model_name=INSTRUMENT_ANALYSIS_LLM, temperature=0)
572
  system_message = """You are a helpful assistant specializing in NIH HEAL CDE protocols.
573
 
574
  You have access to:
575
- 1. Excel data through the search_excel_data tool
576
- 2. A tool to load and embed protocol PDFs (load_and_embed_protocol_pdf)
577
- 3. A tool to search protocol documents for general information (search_protocol)
578
- 4. A tool to search for instruments in protocols for specific domains (search_protocol_for_instruments)
579
- 5. A tool to analyze all NIH HEAL domains at once (analyze_all_heal_domains)
580
- 6. A tool to format analysis results into a markdown table (format_instrument_analysis)
581
 
582
  WHEN TO USE TOOLS:
583
- - When users upload a protocol PDF, use the load_and_embed_protocol_pdf tool.
584
- - When users ask general questions about the protocol, use the search_protocol tool.
585
  - When users ask about a specific instrument for a domain, use the search_protocol_for_instruments tool.
586
- - When users want a complete analysis of all domains, use the analyze_all_heal_domains tool.
587
- - When users ask about data or information in the Excel files, use the search_excel_data tool.
588
- - When you have multiple analysis results to present, use format_instrument_analysis to create a nice table.
589
 
590
  Be specific in your tool queries to get the most relevant information.
591
- Always use the appropriate tool before responding to questions about the protocol or Excel data.
592
  """
593
 
594
  # Bind tools and configure models
@@ -654,7 +648,7 @@ async def on_chat_start():
654
 
655
  # Wait for file upload
656
  files = await cl.AskFileMessage(
657
- content="Please upload a NIH HEAL protocol PDF file to analyze alongside the Excel data.",
658
  accept=["application/pdf"],
659
  max_size_mb=20,
660
  timeout=180,
@@ -665,14 +659,14 @@ async def on_chat_start():
665
  await processing_msg.send()
666
 
667
  # Process the uploaded files
668
- documents_with_metadata = await load_and_chunk_pdf_files(files)
669
- user_vectorstore = await embed_pdf_chunks_in_qdrant(documents_with_metadata)
670
 
671
  if user_vectorstore:
672
  analysis_msg = cl.Message(content="Analyzing your protocol to identify instruments (CRF questionaires) for NIH HEAL CDE core domains...")
673
  await analysis_msg.send()
674
 
675
- # Use the analyze_all_heal_domains tool to analyze the protocol
676
  config = {"configurable": {"thread_id": cl.context.session.id}}
677
 
678
  # Create a message to trigger the analysis
@@ -694,11 +688,11 @@ async def on_chat_start():
694
 
695
  await final_answer.send()
696
 
697
- await cl.Message(content="You can now ask additional questions about the protocol or the Excel data.").send()
698
  else:
699
  await cl.Message(content="There was an issue processing your PDF. Please try uploading again.").send()
700
  else:
701
- await cl.Message(content="No file was uploaded. You can still ask questions about the Excel data.").send()
702
 
703
  @cl.on_message
704
  async def on_message(msg: cl.Message):
 
104
  """Format a list of documents into a single string."""
105
  return "\n\n".join(doc.page_content for doc in docs)
106
 
107
+ # ==================== CORE EMBEDDINGS PROCESSING ====================
108
+ def load_and_chunk_core_reference_files():
109
  """Loads all .xlsx files from the initial embeddings directory and splits them into chunks."""
110
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=50)
111
  all_chunks = []
 
138
  print(f"Processed {file_count} Excel files with a total of {len(all_chunks)} chunks.")
139
  return all_chunks
140
 
141
+ def embed_core_reference_in_qdrant(chunks):
142
+ """Embeds core reference chunks and stores them in Qdrant."""
143
  global embedding_model
144
 
145
  if not chunks:
 
168
  print(f"Embedding model status: {embedding_model is not None}")
169
  return None
170
 
171
+ def initialize_core_reference_embeddings():
172
  """Loads all .xlsx files, extracts text, embeds, and stores in Qdrant."""
173
+ chunks = load_and_chunk_core_reference_files()
174
+ return embed_core_reference_in_qdrant(chunks)
175
 
176
+ # ==================== PROTOCOL DOCUMENT PROCESSING ====================
177
+ async def load_and_chunk_protocol_files(files):
178
+ """Load protocol PDF files and split them into chunks with metadata."""
179
  print(f"Loading {len(files)} uploaded PDF files")
180
  documents_with_metadata = []
181
 
 
212
 
213
  return documents_with_metadata
214
 
215
+ async def embed_protocol_in_qdrant(documents_with_metadata, model_name=EMBEDDING_MODEL_ID):
216
+ """Create a vector store and embed protocol chunks into Qdrant."""
217
  global embedding_model
218
 
219
  if not documents_with_metadata:
 
252
  print(f"Error creating vector store: {str(e)}")
253
  return None
254
 
255
+ async def process_uploaded_protocol(files, model_name=EMBEDDING_MODEL_ID):
256
+ """Process uploaded protocol PDF files and add them to a separate vector store collection"""
257
+ documents_with_metadata = await load_and_chunk_protocol_files(files)
258
+ return await embed_protocol_in_qdrant(documents_with_metadata, model_name)
 
259
 
260
  # ==================== RETRIEVAL FUNCTIONS ====================
261
  def retrieve_documents(query, doc_type=None, k=5):
 
308
 
309
  return chain
310
 
311
+ # Initialize the core reference retriever
312
+ vectorstore = initialize_core_reference_embeddings()
313
  if vectorstore:
314
+ core_reference_retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
315
+ print("Core reference retriever created successfully.")
316
  else:
317
+ print("Failed to create core reference retriever: No vector store available.")
318
 
319
+ # Chain for retrieving from core reference embeddings
320
+ core_reference_retrieval_chain = (
321
+ {"context": itemgetter("question") | core_reference_retriever | format_docs,
322
  "question": itemgetter("question")}
323
  | rag_prompt
324
  | chat_model
 
327
 
328
  # ==================== TOOL DEFINITIONS ====================
329
  @tool
330
+ def search_all_data(query: str, doc_type: str = None) -> str:
331
+ """Search all data or filter by document type (protocol/core_reference)"""
332
  try:
333
  chain = create_rag_chain(doc_type)
334
  return chain.invoke({"question": query})
 
336
  return f"Error searching data: {str(e)}"
337
 
338
  @tool
339
+ def search_core_reference(query: str, top_k: int = 3) -> str:
340
+ """Search core reference data and protocol data for information related to the query."""
341
  global embedding_model
342
 
343
+ # Use the existing core_reference_retrieval_chain
344
+ result = core_reference_retrieval_chain.invoke({"question": query})
345
 
346
  # If we have a user collection, also search that
347
  try:
348
  # Check if user collection exists
349
  if USER_EMBEDDINGS_NAME not in [c.name for c in qdrant_client.get_collections().collections]:
350
+ # If no user collection exists yet, just return core reference results
351
  return result
352
 
353
  # Create a retrieval chain for user documents
 
368
  user_result = user_retrieval_chain.invoke({"question": query})
369
 
370
  # Combine results
371
+ return f"From core reference files:\n{result}\n\nFrom your uploaded PDF:\n{user_result}"
372
  except Exception as e:
373
  print(f"Error searching user vector store: {str(e)}")
374
+ # If error occurs, just return core reference results
375
  return result
376
 
377
  @tool
378
+ def load_and_embed_protocol(file_path: str = None) -> str:
379
+ """Load and embed a protocol PDF file into the vector store."""
380
  """Load and embed a protocol PDF file into the vector store.
381
 
382
  Args:
 
423
 
424
  # Process the files asynchronously
425
  import asyncio
426
+ documents_with_metadata = asyncio.run(load_and_chunk_protocol_files(files))
427
+ user_vectorstore = asyncio.run(embed_protocol_in_qdrant(documents_with_metadata, EMBEDDING_MODEL_ID))
428
 
429
  if user_vectorstore:
430
  return f"Successfully embedded {len(documents_with_metadata)} chunks from {len(files)} protocol document(s)."
 
433
  except Exception as e:
434
  return f"Error embedding protocol document: {str(e)}"
435
 
 
436
  @tool
437
  def search_protocol_for_instruments(domain: str) -> dict:
438
  """Search the protocol for instruments related to a specific NIH HEAL CDE core domain."""
 
465
  docs = user_retriever.invoke(query)
466
  protocol_context = format_docs(docs)
467
 
468
+ # Search for instruments in the core reference data that match this domain
469
+ core_reference_query = f"What are standard instruments or measures for {domain}?"
470
+ core_reference_instruments = core_reference_retrieval_chain.invoke({"question": core_reference_query})
471
 
472
  # Use the model to identify the most likely instrument for this domain
473
  prompt = f"""
 
477
  {protocol_context}
478
 
479
  Known instruments for this domain:
480
+ {core_reference_instruments}
481
 
482
  Respond with only the name of the identified instrument. If you cannot identify a specific instrument, respond with "Not identified".
483
  """
 
489
  "domain": domain,
490
  "instrument": instrument.strip(),
491
  "context": protocol_context,
492
+ "known_instruments": core_reference_instruments
493
  }
494
  except Exception as e:
495
  print(f"Error identifying instrument for {domain}: {str(e)}")
496
  return {"domain": domain, "instrument": "Error during identification", "context": str(e)}
497
 
 
 
498
  @tool
499
+ def analyze_protocol_domains() -> str:
500
+ """Analyze all NIH HEAL CDE core domains and identify instruments used in the protocol."""
 
 
 
 
501
  # Check if protocol document exists
502
  uploaded_files = [f for f in os.listdir(UPLOAD_PATH) if f.endswith('.pdf')]
503
  if not uploaded_files:
 
522
 
523
  return result
524
 
 
525
  @tool
526
+ def format_domain_analysis(analysis_results: list, title: str = "NIH HEAL CDE Core Domains Analysis") -> str:
527
+ """Format domain analysis results into a markdown table.
528
+ Args:
 
529
  analysis_results: List of dictionaries with domain and instrument information
530
  title: Title for the markdown output
531
 
532
  Returns:
533
  Markdown formatted table of domains and identified instruments
534
+
535
+
536
  """
537
+
538
  # Format the results as a markdown table
539
  result = f"# {title}\n\n"
540
  result += "| Domain | Protocol Instrument |\n"
 
549
 
550
  # Collect all tools
551
  tools = [
552
+ search_all_data,
553
+ search_core_reference,
554
+ load_and_embed_protocol,
555
  search_protocol_for_instruments,
556
+ analyze_protocol_domains,
557
+ format_domain_analysis
558
  ]
559
 
560
  # ==================== LANGGRAPH SETUP ====================
 
566
  system_message = """You are a helpful assistant specializing in NIH HEAL CDE protocols.
567
 
568
  You have access to:
569
+ 1. Core reference data through the search_core_reference tool
570
+ 2. A tool to load and embed protocol PDFs (load_and_embed_protocol)
571
+ 3. A tool to search for instruments in protocols for specific domains (search_protocol_for_instruments)
572
+ 4. A tool to analyze all NIH HEAL domains at once (analyze_protocol_domains)
573
+ 5. A tool to format analysis results into a markdown table (format_domain_analysis)
574
+ 6. A tool to search all available data (search_all_data)
575
 
576
  WHEN TO USE TOOLS:
577
+ - When users upload a protocol PDF, use the load_and_embed_protocol tool.
578
+ - When users ask general questions about the protocol, use the search_all_data tool.
579
  - When users ask about a specific instrument for a domain, use the search_protocol_for_instruments tool.
580
+ - When users want a complete analysis of all domains, use the analyze_protocol_domains tool.
581
+ - When users ask about data or information in the core reference files, use the search_core_reference tool.
582
+ - When you have multiple analysis results to present, use format_domain_analysis to create a nice table.
583
 
584
  Be specific in your tool queries to get the most relevant information.
585
+ Always use the appropriate tool before responding to questions about the protocol or core reference data.
586
  """
587
 
588
  # Bind tools and configure models
 
648
 
649
  # Wait for file upload
650
  files = await cl.AskFileMessage(
651
+ content="Please upload a NIH HEAL protocol PDF file to analyze alongside the core reference data.",
652
  accept=["application/pdf"],
653
  max_size_mb=20,
654
  timeout=180,
 
659
  await processing_msg.send()
660
 
661
  # Process the uploaded files
662
+ documents_with_metadata = await load_and_chunk_protocol_files(files)
663
+ user_vectorstore = await embed_protocol_in_qdrant(documents_with_metadata)
664
 
665
  if user_vectorstore:
666
  analysis_msg = cl.Message(content="Analyzing your protocol to identify instruments (CRF questionaires) for NIH HEAL CDE core domains...")
667
  await analysis_msg.send()
668
 
669
+ # Use the analyze_protocol_domains tool to analyze the protocol
670
  config = {"configurable": {"thread_id": cl.context.session.id}}
671
 
672
  # Create a message to trigger the analysis
 
688
 
689
  await final_answer.send()
690
 
691
+ await cl.Message(content="You can now ask additional questions about the protocol or the core reference data.").send()
692
  else:
693
  await cl.Message(content="There was an issue processing your PDF. Please try uploading again.").send()
694
  else:
695
+ await cl.Message(content="No file was uploaded. You can still ask questions about the core reference data.").send()
696
 
697
  @cl.on_message
698
  async def on_message(msg: cl.Message):