Update agent.py
Browse files
agent.py
CHANGED
|
@@ -23,7 +23,6 @@ logging.basicConfig(level=logging.INFO)
|
|
| 23 |
UMLS_API_KEY = os.getenv("UMLS_API_KEY")
|
| 24 |
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
| 25 |
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
|
| 26 |
-
|
| 27 |
if not all([UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY]):
|
| 28 |
logger.error("Missing one or more required API keys: UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY")
|
| 29 |
raise RuntimeError("Missing required API keys")
|
|
@@ -43,8 +42,8 @@ class ClinicalPrompts:
|
|
| 43 |
def wrap_message(msg: Any) -> AIMessage:
|
| 44 |
"""
|
| 45 |
Ensures the given message is an AIMessage.
|
| 46 |
-
If it is a dict,
|
| 47 |
-
Otherwise,
|
| 48 |
"""
|
| 49 |
if isinstance(msg, AIMessage):
|
| 50 |
return msg
|
|
@@ -358,6 +357,7 @@ def tool_node(state: AgentState) -> Dict[str, Any]:
|
|
| 358 |
new_state = {"messages": []}
|
| 359 |
return propagate_state(new_state, state)
|
| 360 |
last = wrap_message(messages_list[-1])
|
|
|
|
| 361 |
tool_calls = last.__dict__.get("tool_calls")
|
| 362 |
if not (isinstance(last, AIMessage) and tool_calls):
|
| 363 |
logger.warning("tool_node invoked without pending tool_calls")
|
|
@@ -463,6 +463,7 @@ def should_continue(state: AgentState) -> str:
|
|
| 463 |
state["done"] = True
|
| 464 |
return "end_conversation_turn"
|
| 465 |
state["done"] = False
|
|
|
|
| 466 |
return "start"
|
| 467 |
|
| 468 |
def after_tools_router(state: AgentState) -> str:
|
|
@@ -479,15 +480,17 @@ class ClinicalAgent:
|
|
| 479 |
wf.add_node("tools", tool_node)
|
| 480 |
wf.add_node("reflection", reflection_node)
|
| 481 |
wf.set_entry_point("start")
|
|
|
|
| 482 |
wf.add_conditional_edges("start", should_continue, {
|
| 483 |
"continue_tools": "tools",
|
|
|
|
| 484 |
"end_conversation_turn": END
|
| 485 |
})
|
| 486 |
wf.add_conditional_edges("tools", after_tools_router, {
|
| 487 |
"reflection": "reflection",
|
| 488 |
"end_conversation_turn": END
|
| 489 |
})
|
| 490 |
-
# Removed edge from reflection back to start.
|
| 491 |
self.graph_app = wf.compile()
|
| 492 |
logger.info("ClinicalAgent ready")
|
| 493 |
|
|
|
|
| 23 |
UMLS_API_KEY = os.getenv("UMLS_API_KEY")
|
| 24 |
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
| 25 |
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
|
|
|
|
| 26 |
if not all([UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY]):
|
| 27 |
logger.error("Missing one or more required API keys: UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY")
|
| 28 |
raise RuntimeError("Missing required API keys")
|
|
|
|
| 42 |
def wrap_message(msg: Any) -> AIMessage:
|
| 43 |
"""
|
| 44 |
Ensures the given message is an AIMessage.
|
| 45 |
+
If it is a dict, extracts the 'content' field (or serializes the dict).
|
| 46 |
+
Otherwise, converts the message to a string.
|
| 47 |
"""
|
| 48 |
if isinstance(msg, AIMessage):
|
| 49 |
return msg
|
|
|
|
| 357 |
new_state = {"messages": []}
|
| 358 |
return propagate_state(new_state, state)
|
| 359 |
last = wrap_message(messages_list[-1])
|
| 360 |
+
# Safely retrieve pending tool_calls from the message's __dict__
|
| 361 |
tool_calls = last.__dict__.get("tool_calls")
|
| 362 |
if not (isinstance(last, AIMessage) and tool_calls):
|
| 363 |
logger.warning("tool_node invoked without pending tool_calls")
|
|
|
|
| 463 |
state["done"] = True
|
| 464 |
return "end_conversation_turn"
|
| 465 |
state["done"] = False
|
| 466 |
+
# Return "start" to loop back.
|
| 467 |
return "start"
|
| 468 |
|
| 469 |
def after_tools_router(state: AgentState) -> str:
|
|
|
|
| 480 |
wf.add_node("tools", tool_node)
|
| 481 |
wf.add_node("reflection", reflection_node)
|
| 482 |
wf.set_entry_point("start")
|
| 483 |
+
# Note: Added a "start" branch in the conditional edges.
|
| 484 |
wf.add_conditional_edges("start", should_continue, {
|
| 485 |
"continue_tools": "tools",
|
| 486 |
+
"start": "start",
|
| 487 |
"end_conversation_turn": END
|
| 488 |
})
|
| 489 |
wf.add_conditional_edges("tools", after_tools_router, {
|
| 490 |
"reflection": "reflection",
|
| 491 |
"end_conversation_turn": END
|
| 492 |
})
|
| 493 |
+
# Removed edge from reflection back to start to break the cycle.
|
| 494 |
self.graph_app = wf.compile()
|
| 495 |
logger.info("ClinicalAgent ready")
|
| 496 |
|