drewgenai commited on
Commit
f5b7c50
·
1 Parent(s): 621b412

added async

Browse files
Files changed (1) hide show
  1. app.py +26 -23
app.py CHANGED
@@ -320,7 +320,7 @@ async def process_uploaded_protocol(files, session_qdrant_client, model_name=EMB
320
  return await embed_protocol_in_qdrant(documents_with_metadata, session_qdrant_client, model_name)
321
 
322
  # ==================== RETRIEVAL FUNCTIONS ====================
323
- def retrieve_from_core(query, k=5):
324
  """Retrieve documents from core reference database"""
325
  global core_retriever
326
 
@@ -331,11 +331,11 @@ def retrieve_from_core(query, k=5):
331
  # Override k if needed
332
  if k != 10: # Assuming default k=10 was used when creating the retriever
333
  retriever = core_vectorstore.as_retriever(search_kwargs={"k": k})
334
- return retriever.invoke(query)
335
 
336
- return core_retriever.invoke(query)
337
 
338
- def retrieve_from_protocol(query, k=5):
339
  """Retrieve documents from protocol database"""
340
  # Get the session-specific client
341
  session_qdrant_client = cl.user_session.get("session_qdrant_client")
@@ -361,20 +361,20 @@ def retrieve_from_protocol(query, k=5):
361
 
362
  # Create and use retriever
363
  protocol_retriever = protocol_vectorstore.as_retriever(search_kwargs={"k": k})
364
- return protocol_retriever.invoke(query)
365
 
366
  # ==================== TOOL DEFINITIONS ====================
367
  @tool
368
- def search_all_data(query: str, doc_type: str = None) -> str:
369
  """Search all data or filter by document type (protocol/core_reference)"""
370
  try:
371
- chain = create_rag_chain(doc_type)
372
- return chain.invoke({"question": query})
373
  except Exception as e:
374
  return f"Error searching data: {str(e)}"
375
 
376
  @tool
377
- def analyze_protocol_domains(export_csv: bool = True) -> str:
378
  """Analyze all NIH HEAL CDE core domains and identify instruments used in the protocol.
379
 
380
  Args:
@@ -394,9 +394,12 @@ def analyze_protocol_domains(export_csv: bool = True) -> str:
394
  # For each domain, search for relevant instruments
395
  domain_analysis_results = []
396
 
397
- for domain in NIH_HEAL_CORE_DOMAINS:
398
- # Search for instruments related to this domain in the protocol
399
- result = _search_protocol_for_instruments(domain)
 
 
 
400
  print(f"Identified instrument for {domain}: {result['instrument']}")
401
 
402
  # Add the result to our list of analysis results
@@ -453,7 +456,7 @@ def analyze_protocol_domains(export_csv: bool = True) -> str:
453
  return result
454
 
455
  # Helper functions (not exposed as tools)
456
- def _search_protocol_for_instruments(domain: str) -> dict:
457
  """Search the protocol for instruments related to a specific NIH HEAL CDE core domain."""
458
  global embedding_model
459
 
@@ -486,12 +489,12 @@ def _search_protocol_for_instruments(domain: str) -> dict:
486
 
487
  try:
488
  # Retrieve relevant chunks from the protocol
489
- docs = user_retriever.invoke(query)
490
  protocol_context = format_docs(docs)
491
 
492
  # Search for instruments in the core reference data that match this domain
493
  core_reference_query = f"What are standard instruments or measures for {domain}?"
494
- core_reference_instruments = core_reference_retrieval_chain.invoke({"question": core_reference_query})
495
 
496
  # Use the model to identify the most likely instrument for this domain
497
  prompt = f"""
@@ -506,12 +509,12 @@ def _search_protocol_for_instruments(domain: str) -> dict:
506
  Respond with only the name of the identified instrument. If you cannot identify a specific instrument, respond with "Not identified".
507
  """
508
 
509
- instrument = domain_chat_model.invoke([HumanMessage(content=prompt)]).content
510
 
511
  # Return the results as a dictionary
512
  return {
513
  "domain": domain,
514
- "instrument": instrument.strip(),
515
  "context": protocol_context,
516
  "known_instruments": core_reference_instruments
517
  }
@@ -519,7 +522,7 @@ def _search_protocol_for_instruments(domain: str) -> dict:
519
  print(f"Error identifying instrument for {domain}: {str(e)}")
520
  return {"domain": domain, "instrument": "Error during identification", "context": str(e)}
521
 
522
- def create_rag_chain(doc_type=None):
523
  """Create a RAG chain based on the document type."""
524
  # Get the session-specific Qdrant client
525
  session_qdrant_client = cl.user_session.get("session_qdrant_client")
@@ -580,7 +583,7 @@ def create_rag_chain(doc_type=None):
580
  else:
581
  retriever = retrievers[0]
582
 
583
- # Create and return the RAG chain
584
  return (
585
  {"context": itemgetter("question") | retriever | format_docs,
586
  "question": itemgetter("question")}
@@ -629,12 +632,12 @@ def should_continue(state: MessagesState) -> Literal["tools", END]:
629
  # Otherwise, we end the graph (reply to the user)
630
  return END
631
 
632
- def call_model(state: MessagesState):
633
  messages = state["messages"]
634
  # Add the system message at the beginning of the messages list
635
  if messages and not any(isinstance(msg, SystemMessage) for msg in messages):
636
  messages = [SystemMessage(content=system_message)] + messages
637
- response = model.invoke(messages)
638
  # We return a list, because this will get added to the existing list
639
  return {"messages": [response]}
640
 
@@ -708,8 +711,8 @@ async def on_message(msg: cl.Message):
708
  # For all messages, use the graph to handle the logic
709
  final_answer = cl.Message(content="")
710
 
711
- # Let the graph handle all message processing
712
- for msg_response, metadata in graph.stream(
713
  {"messages": [HumanMessage(content=msg.content)]},
714
  stream_mode="messages",
715
  config=config
 
320
  return await embed_protocol_in_qdrant(documents_with_metadata, session_qdrant_client, model_name)
321
 
322
  # ==================== RETRIEVAL FUNCTIONS ====================
323
+ async def retrieve_from_core(query, k=5):
324
  """Retrieve documents from core reference database"""
325
  global core_retriever
326
 
 
331
  # Override k if needed
332
  if k != 10: # Assuming default k=10 was used when creating the retriever
333
  retriever = core_vectorstore.as_retriever(search_kwargs={"k": k})
334
+ return await retriever.ainvoke(query)
335
 
336
+ return await core_retriever.ainvoke(query)
337
 
338
+ async def retrieve_from_protocol(query, k=5):
339
  """Retrieve documents from protocol database"""
340
  # Get the session-specific client
341
  session_qdrant_client = cl.user_session.get("session_qdrant_client")
 
361
 
362
  # Create and use retriever
363
  protocol_retriever = protocol_vectorstore.as_retriever(search_kwargs={"k": k})
364
+ return await protocol_retriever.ainvoke(query)
365
 
366
  # ==================== TOOL DEFINITIONS ====================
367
  @tool
368
+ async def search_all_data(query: str, doc_type: str = None) -> str:
369
  """Search all data or filter by document type (protocol/core_reference)"""
370
  try:
371
+ chain = await create_rag_chain(doc_type)
372
+ return await chain.ainvoke({"question": query})
373
  except Exception as e:
374
  return f"Error searching data: {str(e)}"
375
 
376
  @tool
377
+ async def analyze_protocol_domains(export_csv: bool = True) -> str:
378
  """Analyze all NIH HEAL CDE core domains and identify instruments used in the protocol.
379
 
380
  Args:
 
394
  # For each domain, search for relevant instruments
395
  domain_analysis_results = []
396
 
397
+ # Use asyncio.gather to run all domain searches in parallel
398
+ import asyncio
399
+ tasks = [_search_protocol_for_instruments(domain) for domain in NIH_HEAL_CORE_DOMAINS]
400
+ results = await asyncio.gather(*tasks)
401
+
402
+ for domain, result in zip(NIH_HEAL_CORE_DOMAINS, results):
403
  print(f"Identified instrument for {domain}: {result['instrument']}")
404
 
405
  # Add the result to our list of analysis results
 
456
  return result
457
 
458
  # Helper functions (not exposed as tools)
459
+ async def _search_protocol_for_instruments(domain: str) -> dict:
460
  """Search the protocol for instruments related to a specific NIH HEAL CDE core domain."""
461
  global embedding_model
462
 
 
489
 
490
  try:
491
  # Retrieve relevant chunks from the protocol
492
+ docs = await user_retriever.ainvoke(query)
493
  protocol_context = format_docs(docs)
494
 
495
  # Search for instruments in the core reference data that match this domain
496
  core_reference_query = f"What are standard instruments or measures for {domain}?"
497
+ core_reference_instruments = await core_reference_retrieval_chain.ainvoke({"question": core_reference_query})
498
 
499
  # Use the model to identify the most likely instrument for this domain
500
  prompt = f"""
 
509
  Respond with only the name of the identified instrument. If you cannot identify a specific instrument, respond with "Not identified".
510
  """
511
 
512
+ instrument = await domain_chat_model.ainvoke([HumanMessage(content=prompt)])
513
 
514
  # Return the results as a dictionary
515
  return {
516
  "domain": domain,
517
+ "instrument": instrument.content.strip(),
518
  "context": protocol_context,
519
  "known_instruments": core_reference_instruments
520
  }
 
522
  print(f"Error identifying instrument for {domain}: {str(e)}")
523
  return {"domain": domain, "instrument": "Error during identification", "context": str(e)}
524
 
525
+ async def create_rag_chain(doc_type=None):
526
  """Create a RAG chain based on the document type."""
527
  # Get the session-specific Qdrant client
528
  session_qdrant_client = cl.user_session.get("session_qdrant_client")
 
583
  else:
584
  retriever = retrievers[0]
585
 
586
+ # Create and return the RAG chain with async support
587
  return (
588
  {"context": itemgetter("question") | retriever | format_docs,
589
  "question": itemgetter("question")}
 
632
  # Otherwise, we end the graph (reply to the user)
633
  return END
634
 
635
+ async def call_model(state: MessagesState):
636
  messages = state["messages"]
637
  # Add the system message at the beginning of the messages list
638
  if messages and not any(isinstance(msg, SystemMessage) for msg in messages):
639
  messages = [SystemMessage(content=system_message)] + messages
640
+ response = await model.ainvoke(messages)
641
  # We return a list, because this will get added to the existing list
642
  return {"messages": [response]}
643
 
 
711
  # For all messages, use the graph to handle the logic
712
  final_answer = cl.Message(content="")
713
 
714
+ # Use astream instead of stream since we're using async functions
715
+ async for msg_response, metadata in graph.astream(
716
  {"messages": [HumanMessage(content=msg.content)]},
717
  stream_mode="messages",
718
  config=config