drewgenai commited on
Commit
e8df9ea
·
1 Parent(s): 9ad40ad

split tools breakout llms

Browse files
Files changed (1) hide show
  1. app.py +226 -71
app.py CHANGED
@@ -37,12 +37,14 @@ INITIAL_EMBEDDINGS_NAME = "initial_embeddings"
37
  USER_EMBEDDINGS_NAME = "user_embeddings"
38
 
39
  #XLSX_MODEL_ID = "Snowflake/snowflake-arctic-embed-m"
40
- #XLSX_MODEL_ID = "text-embedding-3-small"
41
  XLSX_MODEL_ID = "pritamdeka/S-PubMedBert-MS-MARCO"
42
  #PDF_MODEL_ID = "Snowflake/snowflake-arctic-embed-m"
43
- #PDF_MODEL_ID = "text-embedding-3-small"
44
  PDF_MODEL_ID = "pritamdeka/S-PubMedBert-MS-MARCO"
45
 
 
 
 
 
46
 
47
  # Make sure upload directory exists
48
  os.makedirs(UPLOAD_PATH, exist_ok=True)
@@ -299,15 +301,17 @@ def search_excel_data(query: str, top_k: int = 3) -> str:
299
 
300
  # If we have a user collection, also search that
301
  try:
302
- # Use the same model that was used to create the collection
303
- user_vectorstore = QdrantVectorStore(
 
 
 
 
 
304
  client=qdrant_client,
305
  collection_name=USER_EMBEDDINGS_NAME,
306
- embedding=get_embedding_model(PDF_MODEL_ID) # Use PDF_MODEL_ID here
307
- )
308
-
309
- # Create a retrieval chain for user documents
310
- user_retriever = user_vectorstore.as_retriever(search_kwargs={"k": top_k})
311
 
312
  user_retrieval_chain = (
313
  {"context": itemgetter("question") | user_retriever | format_docs,
@@ -323,72 +327,190 @@ def search_excel_data(query: str, top_k: int = 3) -> str:
323
  return f"From Excel files:\n{result}\n\nFrom your uploaded PDF:\n{user_result}"
324
  except Exception as e:
325
  print(f"Error searching user vector store: {str(e)}")
326
- # If no user collection exists yet, just return Excel results
327
  return result
328
 
329
  @tool
330
- def identify_heal_instruments(protocol_text: str = "") -> str:
331
- """Identify instruments (CRF questionaires) used in the protocol for each NIH HEAL CDE core domain.
332
 
333
  Args:
334
- protocol_text: Optional text from the protocol to analyze
335
 
336
  Returns:
337
- String containing identified instruments for each domain
338
  """
339
- # Check if user collection exists
340
  try:
341
- # Check if files exist in the upload directory
342
- uploaded_files = [f for f in os.listdir(UPLOAD_PATH) if f.endswith('.pdf')]
343
-
344
- if not uploaded_files:
345
- return "No protocol document has been uploaded yet."
346
 
347
- user_vectorstore = QdrantVectorStore(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  client=qdrant_client,
349
  collection_name=USER_EMBEDDINGS_NAME,
350
- embedding=get_embedding_model(PDF_MODEL_ID) # Use PDF_MODEL_ID here
 
 
 
 
 
 
 
 
351
  )
352
- user_retriever = user_vectorstore.as_retriever(search_kwargs={"k": 10})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  except Exception as e:
354
  print(f"Error accessing user vector store: {str(e)}")
355
- return "No protocol document has been uploaded yet or there was an error accessing it."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
  # For each domain, search for relevant instruments
358
  domain_instruments = {}
359
 
360
  for domain in NIH_HEAL_CORE_DOMAINS:
361
- # Search for instruments related to this domain in the protocol
362
- query = f"What instrument or measure is used for {domain} in the protocol?"
363
-
364
- # Retrieve relevant chunks from the protocol
365
- try:
366
- docs = user_retriever.invoke(query)
367
- protocol_context = format_docs(docs)
368
-
369
- # Search for instruments in the Excel data that match this domain
370
- excel_query = f"What are standard instruments or measures for {domain}?"
371
- excel_instruments = initialembeddings_retrieval_chain.invoke({"question": excel_query})
372
-
373
- # Use the model to identify the most likely instrument for this domain
374
- prompt = f"""
375
- Based on the protocol information and known instruments, identify which instrument is being used for the domain: {domain}
376
-
377
- Protocol information:
378
- {protocol_context}
379
-
380
- Known instruments for this domain:
381
- {excel_instruments}
382
-
383
- Respond with only the name of the identified instrument. If you cannot identify a specific instrument, respond with "Not identified".
384
- """
385
-
386
- instrument = chat_model.invoke([HumanMessage(content=prompt)]).content
387
- domain_instruments[domain] = instrument.strip()
388
- print(f"Identified instrument for {domain}: {instrument.strip()}")
389
- except Exception as e:
390
- print(f"Error identifying instrument for {domain}: {str(e)}")
391
- domain_instruments[domain] = "Error during identification"
392
 
393
  # Format the results as a markdown table
394
  result = "# NIH HEAL CDE Core Domains and Identified Instruments\n\n"
@@ -400,23 +522,61 @@ def identify_heal_instruments(protocol_text: str = "") -> str:
400
 
401
  return result
402
 
403
- tools = [search_excel_data, identify_heal_instruments]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
 
405
  # LangGraph components
406
- model = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
407
- final_model = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
408
 
409
- # System message for the model
410
  system_message = """You are a helpful assistant specializing in NIH HEAL CDE protocols.
411
 
412
  You have access to:
413
  1. Excel data through the search_excel_data tool
414
- 2. A tool to identify instruments (CRF questionaires) in NIH HEAL protocols (identify_heal_instruments)
 
 
 
 
415
 
416
  WHEN TO USE TOOLS:
417
- - When users ask about instruments, measures, assessments, questionnaires, or scales in a protocol, use the identify_heal_instruments tool.
 
 
 
418
  - When users ask about data or information in the Excel files, use the search_excel_data tool.
419
- - For general questions about NIH HEAL CDE domains, use the search_excel_data tool.
420
 
421
  Be specific in your tool queries to get the most relevant information.
422
  Always use the appropriate tool before responding to questions about the protocol or Excel data.
@@ -496,17 +656,18 @@ async def on_chat_start():
496
  await processing_msg.send()
497
 
498
  # Process the uploaded files
499
- user_vectorstore = await process_uploaded_files(files)
 
500
 
501
  if user_vectorstore:
502
  analysis_msg = cl.Message(content="Analyzing your protocol to identify instruments (CRF questionaires) for NIH HEAL CDE core domains...")
503
  await analysis_msg.send()
504
 
505
- # Use the identify_heal_instruments tool to analyze the protocol
506
  config = {"configurable": {"thread_id": cl.context.session.id}}
507
 
508
  # Create a message to trigger the analysis
509
- analysis_request = HumanMessage(content="Please analyze the uploaded protocol and identify instruments (CRF questionaires) for each NIH HEAL CDE core domain.")
510
 
511
  final_answer = cl.Message(content="")
512
 
@@ -541,12 +702,6 @@ async def on_message(msg: cl.Message):
541
  # For all messages, use the graph to handle the logic
542
  final_answer = cl.Message(content="")
543
 
544
- # Check if files exist for instrument-related queries
545
- if (any(keyword in msg.content.lower() for keyword in ["instrument", "measure", "assessment", "questionnaire", "scale", "protocol"]) and
546
- not any(f for f in os.listdir(UPLOAD_PATH) if f.endswith('.pdf'))):
547
- await cl.Message(content="No protocol document has been detected. Please upload a protocol document first.").send()
548
- return
549
-
550
  # Let the graph handle all message processing
551
  for msg_response, metadata in graph.stream(
552
  {"messages": [HumanMessage(content=msg.content)]},
 
37
  USER_EMBEDDINGS_NAME = "user_embeddings"
38
 
39
  #XLSX_MODEL_ID = "Snowflake/snowflake-arctic-embed-m"
 
40
  XLSX_MODEL_ID = "pritamdeka/S-PubMedBert-MS-MARCO"
41
  #PDF_MODEL_ID = "Snowflake/snowflake-arctic-embed-m"
 
42
  PDF_MODEL_ID = "pritamdeka/S-PubMedBert-MS-MARCO"
43
 
44
+ INSTRUMENT_SEARCH_LLM = "gpt-4o" # LLM for searching instruments
45
+ INSTRUMENT_ANALYSIS_LLM = "gpt-4o" # LLM for analyzing all domains
46
+
47
+
48
 
49
  # Make sure upload directory exists
50
  os.makedirs(UPLOAD_PATH, exist_ok=True)
 
301
 
302
  # If we have a user collection, also search that
303
  try:
304
+ # Check if user collection exists
305
+ if USER_EMBEDDINGS_NAME not in [c.name for c in qdrant_client.get_collections().collections]:
306
+ # If no user collection exists yet, just return Excel results
307
+ return result
308
+
309
+ # Create a retrieval chain for user documents
310
+ user_retriever = QdrantVectorStore(
311
  client=qdrant_client,
312
  collection_name=USER_EMBEDDINGS_NAME,
313
+ embedding=get_embedding_model(PDF_MODEL_ID) # Still need embedding model for retrieval
314
+ ).as_retriever(search_kwargs={"k": top_k})
 
 
 
315
 
316
  user_retrieval_chain = (
317
  {"context": itemgetter("question") | user_retriever | format_docs,
 
327
  return f"From Excel files:\n{result}\n\nFrom your uploaded PDF:\n{user_result}"
328
  except Exception as e:
329
  print(f"Error searching user vector store: {str(e)}")
330
+ # If error occurs, just return Excel results
331
  return result
332
 
333
  @tool
334
+ def load_and_embed_protocol_pdf(file_path: str = None) -> str:
335
+ """Load and embed a protocol PDF file into the vector store.
336
 
337
  Args:
338
+ file_path: Optional path to the PDF file. If None, will use files in the upload directory.
339
 
340
  Returns:
341
+ String indicating success or failure of the embedding process
342
  """
 
343
  try:
344
+ # If no specific file path is provided, use all PDFs in the upload directory
345
+ if not file_path:
346
+ uploaded_files = [f for f in os.listdir(UPLOAD_PATH) if f.endswith('.pdf')]
347
+ if not uploaded_files:
348
+ return "No protocol documents found in the upload directory."
349
 
350
+ # Create file objects for processing
351
+ files = []
352
+ for filename in uploaded_files:
353
+ file_path = os.path.join(UPLOAD_PATH, filename)
354
+ # Create a simple object with the necessary attributes
355
+ class FileObj:
356
+ def __init__(self, path, name, size):
357
+ self.path = path
358
+ self.name = name
359
+ self.size = size
360
+
361
+ file_size = os.path.getsize(file_path)
362
+ files.append(FileObj(file_path, filename, file_size))
363
+ else:
364
+ # Create a file object for the specific file
365
+ if not os.path.exists(file_path):
366
+ return f"File not found: {file_path}"
367
+
368
+ filename = os.path.basename(file_path)
369
+ file_size = os.path.getsize(file_path)
370
+
371
+ class FileObj:
372
+ def __init__(self, path, name, size):
373
+ self.path = path
374
+ self.name = name
375
+ self.size = size
376
+
377
+ files = [FileObj(file_path, filename, file_size)]
378
+
379
+ # Process the files asynchronously
380
+ import asyncio
381
+ documents_with_metadata = asyncio.run(load_and_chunk_pdf_files(files))
382
+ user_vectorstore = asyncio.run(embed_pdf_chunks_in_qdrant(documents_with_metadata, PDF_MODEL_ID))
383
+
384
+ if user_vectorstore:
385
+ return f"Successfully embedded {len(documents_with_metadata)} chunks from {len(files)} protocol document(s)."
386
+ else:
387
+ return "Failed to embed protocol document(s)."
388
+ except Exception as e:
389
+ return f"Error embedding protocol document: {str(e)}"
390
+
391
+ @tool
392
+ def search_protocol(query: str, top_k: int = 5) -> str:
393
+ """Search the protocol for information related to the query.
394
+
395
+ Args:
396
+ query: The search query
397
+ top_k: Number of results to return (default: 5)
398
+
399
+ Returns:
400
+ String containing the search results with their content and source files
401
+ """
402
+ try:
403
+ # Check if user collection exists
404
+ if USER_EMBEDDINGS_NAME not in [c.name for c in qdrant_client.get_collections().collections]:
405
+ return "No protocol document has been embedded yet. Please upload and embed a protocol first."
406
+
407
+ # Create a retrieval chain for user documents
408
+ user_retriever = QdrantVectorStore(
409
  client=qdrant_client,
410
  collection_name=USER_EMBEDDINGS_NAME,
411
+ embedding=get_embedding_model(PDF_MODEL_ID) # Still need embedding model for retrieval
412
+ ).as_retriever(search_kwargs={"k": top_k})
413
+
414
+ user_retrieval_chain = (
415
+ {"context": itemgetter("question") | user_retriever | format_docs,
416
+ "question": itemgetter("question")}
417
+ | rag_prompt
418
+ | chat_model
419
+ | StrOutputParser()
420
  )
421
+
422
+ result = user_retrieval_chain.invoke({"question": query})
423
+ return result
424
+ except Exception as e:
425
+ return f"Error searching protocol: {str(e)}"
426
+
427
+ @tool
428
+ def search_protocol_for_instruments(domain: str) -> dict:
429
+ """Search the protocol for instruments related to a specific NIH HEAL CDE core domain.
430
+
431
+ Args:
432
+ domain: The NIH HEAL CDE core domain to search for
433
+
434
+ Returns:
435
+ Dictionary containing the domain, identified instrument, and supporting context
436
+ """
437
+ # Check if user collection exists
438
+ try:
439
+ # Check if collection exists
440
+ if USER_EMBEDDINGS_NAME not in [c.name for c in qdrant_client.get_collections().collections]:
441
+ return {"domain": domain, "instrument": "No protocol document embedded", "context": ""}
442
+
443
+ # Create retriever for user documents
444
+ user_retriever = QdrantVectorStore(
445
+ client=qdrant_client,
446
+ collection_name=USER_EMBEDDINGS_NAME,
447
+ embedding=get_embedding_model(PDF_MODEL_ID) # Still need embedding model for retrieval
448
+ ).as_retriever(search_kwargs={"k": 10})
449
  except Exception as e:
450
  print(f"Error accessing user vector store: {str(e)}")
451
+ return {"domain": domain, "instrument": "Error accessing protocol", "context": str(e)}
452
+
453
+ # Create the chat model with the specified model from constants
454
+ domain_chat_model = ChatOpenAI(model_name=INSTRUMENT_SEARCH_LLM, temperature=0)
455
+
456
+ # Search for instruments related to this domain in the protocol
457
+ query = f"What instrument or measure is used for {domain} in the protocol?"
458
+
459
+ try:
460
+ # Retrieve relevant chunks from the protocol
461
+ docs = user_retriever.invoke(query)
462
+ protocol_context = format_docs(docs)
463
+
464
+ # Search for instruments in the Excel data that match this domain
465
+ excel_query = f"What are standard instruments or measures for {domain}?"
466
+ excel_instruments = initialembeddings_retrieval_chain.invoke({"question": excel_query})
467
+
468
+ # Use the model to identify the most likely instrument for this domain
469
+ prompt = f"""
470
+ Based on the protocol information and known instruments, identify which instrument is being used for the domain: {domain}
471
+
472
+ Protocol information:
473
+ {protocol_context}
474
+
475
+ Known instruments for this domain:
476
+ {excel_instruments}
477
+
478
+ Respond with only the name of the identified instrument. If you cannot identify a specific instrument, respond with "Not identified".
479
+ """
480
+
481
+ instrument = domain_chat_model.invoke([HumanMessage(content=prompt)]).content
482
+
483
+ # Return the results as a dictionary
484
+ return {
485
+ "domain": domain,
486
+ "instrument": instrument.strip(),
487
+ "context": protocol_context,
488
+ "known_instruments": excel_instruments
489
+ }
490
+ except Exception as e:
491
+ print(f"Error identifying instrument for {domain}: {str(e)}")
492
+ return {"domain": domain, "instrument": "Error during identification", "context": str(e)}
493
+
494
+ @tool
495
+ def analyze_all_heal_domains() -> str:
496
+ """Analyze all NIH HEAL CDE core domains and identify instruments used in the protocol.
497
+
498
+ Returns:
499
+ Markdown formatted table of domains and identified instruments
500
+ """
501
+ # Check if protocol document exists
502
+ uploaded_files = [f for f in os.listdir(UPLOAD_PATH) if f.endswith('.pdf')]
503
+ if not uploaded_files:
504
+ return "No protocol document has been uploaded yet."
505
 
506
  # For each domain, search for relevant instruments
507
  domain_instruments = {}
508
 
509
  for domain in NIH_HEAL_CORE_DOMAINS:
510
+ # Use the search_protocol_for_instruments tool to get results for each domain
511
+ result = search_protocol_for_instruments(domain)
512
+ domain_instruments[domain] = result["instrument"]
513
+ print(f"Identified instrument for {domain}: {result['instrument']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
 
515
  # Format the results as a markdown table
516
  result = "# NIH HEAL CDE Core Domains and Identified Instruments\n\n"
 
522
 
523
  return result
524
 
525
+ @tool
526
+ def format_instrument_analysis(analysis_results: list, title: str = "NIH HEAL CDE Core Domains Analysis") -> str:
527
+ """Format instrument analysis results into a markdown table.
528
+
529
+ Args:
530
+ analysis_results: List of dictionaries with domain and instrument information
531
+ title: Title for the markdown output
532
+
533
+ Returns:
534
+ Markdown formatted table of domains and identified instruments
535
+ """
536
+ # Format the results as a markdown table
537
+ result = f"# {title}\n\n"
538
+ result += "| Domain | Protocol Instrument |\n"
539
+ result += "|--------|--------------------|\n"
540
+
541
+ for item in analysis_results:
542
+ domain = item.get("domain", "Unknown")
543
+ instrument = item.get("instrument", "Not identified")
544
+ result += f"| {domain} | {instrument} |\n"
545
+
546
+ return result
547
+
548
+ # Update the tools list
549
+ tools = [
550
+ search_excel_data,
551
+ load_and_embed_protocol_pdf,
552
+ search_protocol,
553
+ search_protocol_for_instruments,
554
+ analyze_all_heal_domains,
555
+ format_instrument_analysis
556
+ ]
557
 
558
  # LangGraph components
559
+ model = ChatOpenAI(model_name=INSTRUMENT_ANALYSIS_LLM, temperature=0)
560
+ final_model = ChatOpenAI(model_name=INSTRUMENT_ANALYSIS_LLM, temperature=0)
561
 
562
+ # Update the system message
563
  system_message = """You are a helpful assistant specializing in NIH HEAL CDE protocols.
564
 
565
  You have access to:
566
  1. Excel data through the search_excel_data tool
567
+ 2. A tool to load and embed protocol PDFs (load_and_embed_protocol_pdf)
568
+ 3. A tool to search protocol documents for general information (search_protocol)
569
+ 4. A tool to search for instruments in protocols for specific domains (search_protocol_for_instruments)
570
+ 5. A tool to analyze all NIH HEAL domains at once (analyze_all_heal_domains)
571
+ 6. A tool to format analysis results into a markdown table (format_instrument_analysis)
572
 
573
  WHEN TO USE TOOLS:
574
+ - When users upload a protocol PDF, use the load_and_embed_protocol_pdf tool.
575
+ - When users ask general questions about the protocol, use the search_protocol tool.
576
+ - When users ask about a specific instrument for a domain, use the search_protocol_for_instruments tool.
577
+ - When users want a complete analysis of all domains, use the analyze_all_heal_domains tool.
578
  - When users ask about data or information in the Excel files, use the search_excel_data tool.
579
+ - When you have multiple analysis results to present, use format_instrument_analysis to create a nice table.
580
 
581
  Be specific in your tool queries to get the most relevant information.
582
  Always use the appropriate tool before responding to questions about the protocol or Excel data.
 
656
  await processing_msg.send()
657
 
658
  # Process the uploaded files
659
+ documents_with_metadata = await load_and_chunk_pdf_files(files)
660
+ user_vectorstore = await embed_pdf_chunks_in_qdrant(documents_with_metadata)
661
 
662
  if user_vectorstore:
663
  analysis_msg = cl.Message(content="Analyzing your protocol to identify instruments (CRF questionaires) for NIH HEAL CDE core domains...")
664
  await analysis_msg.send()
665
 
666
+ # Use the analyze_all_heal_domains tool to analyze the protocol
667
  config = {"configurable": {"thread_id": cl.context.session.id}}
668
 
669
  # Create a message to trigger the analysis
670
+ analysis_request = HumanMessage(content="Please analyze the uploaded protocol and identify instruments for each NIH HEAL CDE core domain.")
671
 
672
  final_answer = cl.Message(content="")
673
 
 
702
  # For all messages, use the graph to handle the logic
703
  final_answer = cl.Message(content="")
704
 
 
 
 
 
 
 
705
  # Let the graph handle all message processing
706
  for msg_response, metadata in graph.stream(
707
  {"messages": [HumanMessage(content=msg.content)]},