added async
Browse files
app.py
CHANGED
|
@@ -320,7 +320,7 @@ async def process_uploaded_protocol(files, session_qdrant_client, model_name=EMB
|
|
| 320 |
return await embed_protocol_in_qdrant(documents_with_metadata, session_qdrant_client, model_name)
|
| 321 |
|
| 322 |
# ==================== RETRIEVAL FUNCTIONS ====================
|
| 323 |
-
def retrieve_from_core(query, k=5):
|
| 324 |
"""Retrieve documents from core reference database"""
|
| 325 |
global core_retriever
|
| 326 |
|
|
@@ -331,11 +331,11 @@ def retrieve_from_core(query, k=5):
|
|
| 331 |
# Override k if needed
|
| 332 |
if k != 10: # Assuming default k=10 was used when creating the retriever
|
| 333 |
retriever = core_vectorstore.as_retriever(search_kwargs={"k": k})
|
| 334 |
-
return retriever.
|
| 335 |
|
| 336 |
-
return core_retriever.
|
| 337 |
|
| 338 |
-
def retrieve_from_protocol(query, k=5):
|
| 339 |
"""Retrieve documents from protocol database"""
|
| 340 |
# Get the session-specific client
|
| 341 |
session_qdrant_client = cl.user_session.get("session_qdrant_client")
|
|
@@ -361,20 +361,20 @@ def retrieve_from_protocol(query, k=5):
|
|
| 361 |
|
| 362 |
# Create and use retriever
|
| 363 |
protocol_retriever = protocol_vectorstore.as_retriever(search_kwargs={"k": k})
|
| 364 |
-
return protocol_retriever.
|
| 365 |
|
| 366 |
# ==================== TOOL DEFINITIONS ====================
|
| 367 |
@tool
|
| 368 |
-
def search_all_data(query: str, doc_type: str = None) -> str:
|
| 369 |
"""Search all data or filter by document type (protocol/core_reference)"""
|
| 370 |
try:
|
| 371 |
-
chain = create_rag_chain(doc_type)
|
| 372 |
-
return chain.
|
| 373 |
except Exception as e:
|
| 374 |
return f"Error searching data: {str(e)}"
|
| 375 |
|
| 376 |
@tool
|
| 377 |
-
def analyze_protocol_domains(export_csv: bool = True) -> str:
|
| 378 |
"""Analyze all NIH HEAL CDE core domains and identify instruments used in the protocol.
|
| 379 |
|
| 380 |
Args:
|
|
@@ -394,9 +394,12 @@ def analyze_protocol_domains(export_csv: bool = True) -> str:
|
|
| 394 |
# For each domain, search for relevant instruments
|
| 395 |
domain_analysis_results = []
|
| 396 |
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
|
|
|
|
|
|
|
|
|
| 400 |
print(f"Identified instrument for {domain}: {result['instrument']}")
|
| 401 |
|
| 402 |
# Add the result to our list of analysis results
|
|
@@ -453,7 +456,7 @@ def analyze_protocol_domains(export_csv: bool = True) -> str:
|
|
| 453 |
return result
|
| 454 |
|
| 455 |
# Helper functions (not exposed as tools)
|
| 456 |
-
def _search_protocol_for_instruments(domain: str) -> dict:
|
| 457 |
"""Search the protocol for instruments related to a specific NIH HEAL CDE core domain."""
|
| 458 |
global embedding_model
|
| 459 |
|
|
@@ -486,12 +489,12 @@ def _search_protocol_for_instruments(domain: str) -> dict:
|
|
| 486 |
|
| 487 |
try:
|
| 488 |
# Retrieve relevant chunks from the protocol
|
| 489 |
-
docs = user_retriever.
|
| 490 |
protocol_context = format_docs(docs)
|
| 491 |
|
| 492 |
# Search for instruments in the core reference data that match this domain
|
| 493 |
core_reference_query = f"What are standard instruments or measures for {domain}?"
|
| 494 |
-
core_reference_instruments = core_reference_retrieval_chain.
|
| 495 |
|
| 496 |
# Use the model to identify the most likely instrument for this domain
|
| 497 |
prompt = f"""
|
|
@@ -506,12 +509,12 @@ def _search_protocol_for_instruments(domain: str) -> dict:
|
|
| 506 |
Respond with only the name of the identified instrument. If you cannot identify a specific instrument, respond with "Not identified".
|
| 507 |
"""
|
| 508 |
|
| 509 |
-
instrument = domain_chat_model.
|
| 510 |
|
| 511 |
# Return the results as a dictionary
|
| 512 |
return {
|
| 513 |
"domain": domain,
|
| 514 |
-
"instrument": instrument.strip(),
|
| 515 |
"context": protocol_context,
|
| 516 |
"known_instruments": core_reference_instruments
|
| 517 |
}
|
|
@@ -519,7 +522,7 @@ def _search_protocol_for_instruments(domain: str) -> dict:
|
|
| 519 |
print(f"Error identifying instrument for {domain}: {str(e)}")
|
| 520 |
return {"domain": domain, "instrument": "Error during identification", "context": str(e)}
|
| 521 |
|
| 522 |
-
def create_rag_chain(doc_type=None):
|
| 523 |
"""Create a RAG chain based on the document type."""
|
| 524 |
# Get the session-specific Qdrant client
|
| 525 |
session_qdrant_client = cl.user_session.get("session_qdrant_client")
|
|
@@ -580,7 +583,7 @@ def create_rag_chain(doc_type=None):
|
|
| 580 |
else:
|
| 581 |
retriever = retrievers[0]
|
| 582 |
|
| 583 |
-
# Create and return the RAG chain
|
| 584 |
return (
|
| 585 |
{"context": itemgetter("question") | retriever | format_docs,
|
| 586 |
"question": itemgetter("question")}
|
|
@@ -629,12 +632,12 @@ def should_continue(state: MessagesState) -> Literal["tools", END]:
|
|
| 629 |
# Otherwise, we end the graph (reply to the user)
|
| 630 |
return END
|
| 631 |
|
| 632 |
-
def call_model(state: MessagesState):
|
| 633 |
messages = state["messages"]
|
| 634 |
# Add the system message at the beginning of the messages list
|
| 635 |
if messages and not any(isinstance(msg, SystemMessage) for msg in messages):
|
| 636 |
messages = [SystemMessage(content=system_message)] + messages
|
| 637 |
-
response = model.
|
| 638 |
# We return a list, because this will get added to the existing list
|
| 639 |
return {"messages": [response]}
|
| 640 |
|
|
@@ -708,8 +711,8 @@ async def on_message(msg: cl.Message):
|
|
| 708 |
# For all messages, use the graph to handle the logic
|
| 709 |
final_answer = cl.Message(content="")
|
| 710 |
|
| 711 |
-
#
|
| 712 |
-
for msg_response, metadata in graph.
|
| 713 |
{"messages": [HumanMessage(content=msg.content)]},
|
| 714 |
stream_mode="messages",
|
| 715 |
config=config
|
|
|
|
| 320 |
return await embed_protocol_in_qdrant(documents_with_metadata, session_qdrant_client, model_name)
|
| 321 |
|
| 322 |
# ==================== RETRIEVAL FUNCTIONS ====================
|
| 323 |
+
async def retrieve_from_core(query, k=5):
|
| 324 |
"""Retrieve documents from core reference database"""
|
| 325 |
global core_retriever
|
| 326 |
|
|
|
|
| 331 |
# Override k if needed
|
| 332 |
if k != 10: # Assuming default k=10 was used when creating the retriever
|
| 333 |
retriever = core_vectorstore.as_retriever(search_kwargs={"k": k})
|
| 334 |
+
return await retriever.ainvoke(query)
|
| 335 |
|
| 336 |
+
return await core_retriever.ainvoke(query)
|
| 337 |
|
| 338 |
+
async def retrieve_from_protocol(query, k=5):
|
| 339 |
"""Retrieve documents from protocol database"""
|
| 340 |
# Get the session-specific client
|
| 341 |
session_qdrant_client = cl.user_session.get("session_qdrant_client")
|
|
|
|
| 361 |
|
| 362 |
# Create and use retriever
|
| 363 |
protocol_retriever = protocol_vectorstore.as_retriever(search_kwargs={"k": k})
|
| 364 |
+
return await protocol_retriever.ainvoke(query)
|
| 365 |
|
| 366 |
# ==================== TOOL DEFINITIONS ====================
|
| 367 |
@tool
|
| 368 |
+
async def search_all_data(query: str, doc_type: str = None) -> str:
|
| 369 |
"""Search all data or filter by document type (protocol/core_reference)"""
|
| 370 |
try:
|
| 371 |
+
chain = await create_rag_chain(doc_type)
|
| 372 |
+
return await chain.ainvoke({"question": query})
|
| 373 |
except Exception as e:
|
| 374 |
return f"Error searching data: {str(e)}"
|
| 375 |
|
| 376 |
@tool
|
| 377 |
+
async def analyze_protocol_domains(export_csv: bool = True) -> str:
|
| 378 |
"""Analyze all NIH HEAL CDE core domains and identify instruments used in the protocol.
|
| 379 |
|
| 380 |
Args:
|
|
|
|
| 394 |
# For each domain, search for relevant instruments
|
| 395 |
domain_analysis_results = []
|
| 396 |
|
| 397 |
+
# Use asyncio.gather to run all domain searches in parallel
|
| 398 |
+
import asyncio
|
| 399 |
+
tasks = [_search_protocol_for_instruments(domain) for domain in NIH_HEAL_CORE_DOMAINS]
|
| 400 |
+
results = await asyncio.gather(*tasks)
|
| 401 |
+
|
| 402 |
+
for domain, result in zip(NIH_HEAL_CORE_DOMAINS, results):
|
| 403 |
print(f"Identified instrument for {domain}: {result['instrument']}")
|
| 404 |
|
| 405 |
# Add the result to our list of analysis results
|
|
|
|
| 456 |
return result
|
| 457 |
|
| 458 |
# Helper functions (not exposed as tools)
|
| 459 |
+
async def _search_protocol_for_instruments(domain: str) -> dict:
|
| 460 |
"""Search the protocol for instruments related to a specific NIH HEAL CDE core domain."""
|
| 461 |
global embedding_model
|
| 462 |
|
|
|
|
| 489 |
|
| 490 |
try:
|
| 491 |
# Retrieve relevant chunks from the protocol
|
| 492 |
+
docs = await user_retriever.ainvoke(query)
|
| 493 |
protocol_context = format_docs(docs)
|
| 494 |
|
| 495 |
# Search for instruments in the core reference data that match this domain
|
| 496 |
core_reference_query = f"What are standard instruments or measures for {domain}?"
|
| 497 |
+
core_reference_instruments = await core_reference_retrieval_chain.ainvoke({"question": core_reference_query})
|
| 498 |
|
| 499 |
# Use the model to identify the most likely instrument for this domain
|
| 500 |
prompt = f"""
|
|
|
|
| 509 |
Respond with only the name of the identified instrument. If you cannot identify a specific instrument, respond with "Not identified".
|
| 510 |
"""
|
| 511 |
|
| 512 |
+
instrument = await domain_chat_model.ainvoke([HumanMessage(content=prompt)])
|
| 513 |
|
| 514 |
# Return the results as a dictionary
|
| 515 |
return {
|
| 516 |
"domain": domain,
|
| 517 |
+
"instrument": instrument.content.strip(),
|
| 518 |
"context": protocol_context,
|
| 519 |
"known_instruments": core_reference_instruments
|
| 520 |
}
|
|
|
|
| 522 |
print(f"Error identifying instrument for {domain}: {str(e)}")
|
| 523 |
return {"domain": domain, "instrument": "Error during identification", "context": str(e)}
|
| 524 |
|
| 525 |
+
async def create_rag_chain(doc_type=None):
|
| 526 |
"""Create a RAG chain based on the document type."""
|
| 527 |
# Get the session-specific Qdrant client
|
| 528 |
session_qdrant_client = cl.user_session.get("session_qdrant_client")
|
|
|
|
| 583 |
else:
|
| 584 |
retriever = retrievers[0]
|
| 585 |
|
| 586 |
+
# Create and return the RAG chain with async support
|
| 587 |
return (
|
| 588 |
{"context": itemgetter("question") | retriever | format_docs,
|
| 589 |
"question": itemgetter("question")}
|
|
|
|
| 632 |
# Otherwise, we end the graph (reply to the user)
|
| 633 |
return END
|
| 634 |
|
| 635 |
+
async def call_model(state: MessagesState):
|
| 636 |
messages = state["messages"]
|
| 637 |
# Add the system message at the beginning of the messages list
|
| 638 |
if messages and not any(isinstance(msg, SystemMessage) for msg in messages):
|
| 639 |
messages = [SystemMessage(content=system_message)] + messages
|
| 640 |
+
response = await model.ainvoke(messages)
|
| 641 |
# We return a list, because this will get added to the existing list
|
| 642 |
return {"messages": [response]}
|
| 643 |
|
|
|
|
| 711 |
# For all messages, use the graph to handle the logic
|
| 712 |
final_answer = cl.Message(content="")
|
| 713 |
|
| 714 |
+
# Use astream instead of stream since we're using async functions
|
| 715 |
+
async for msg_response, metadata in graph.astream(
|
| 716 |
{"messages": [HumanMessage(content=msg.content)]},
|
| 717 |
stream_mode="messages",
|
| 718 |
config=config
|