multimodal_previsit / agent_langgraph.py
frabbani
context added ..
d45903a
"""
MedGemma Agent using LangGraph
==============================
A proper graph-based workflow using LangGraph's StateGraph.
Workflow Graph:
START -> discover -> skin_analysis -> plan -> execute -> reflect
^ |
|___(gaps found)_____|
|
v
synthesize -> END
Reuses all working logic from agent_v2.py:
- Dynamic tool registry with category filtering
- Condensed tool descriptions (token reduction)
- Semantic parameter normalization
- LLM-based error recovery
- Skin image analysis with Derm Foundation
"""
import os
import json
import re
from typing import Dict, List, Optional, Set, Any, Literal, AsyncGenerator
from typing_extensions import TypedDict, Annotated
# LangGraph imports
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
# Reuse ALL working logic from agent_v2
from agent_v2 import (
call_llm, stream_llm, filter_thinking,
get_patient_manifest, plan_tools, execute_and_extract,
reflect_on_facts, synthesize_answer,
get_relevant_categories, get_filtered_tools_description,
format_conversation_history,
WorkflowPhase, AgentState,
TOOLS
)
from tools import execute_tool, get_tools_description
LLAMA_SERVER_URL = os.environ.get('LLAMA_SERVER_URL', 'http://localhost:8080')
# =============================================================================
# State Definition
# =============================================================================
def _append_events(existing: List[Dict], new: List[Dict]) -> List[Dict]:
"""Reducer: append new events to existing list."""
return existing + new
def _append_facts(existing: List[Dict], new: List[Dict]) -> List[Dict]:
"""Reducer: append new facts to existing list."""
return existing + new
def _merge_tools(existing: Set[str], new: Set[str]) -> Set[str]:
"""Reducer: merge executed tools sets."""
return existing | new
class GraphState(TypedDict):
"""State that flows through the LangGraph."""
# Input (set once)
patient_id: str
question: str
skin_image_data: Optional[str]
conversation_history: List[Dict] # Prior conversation turns
# Discovery
manifest: Dict
# Planning
planned_tools: List[Dict]
# Execution
collected_facts: Annotated[List[Dict], _append_facts]
executed_tools: Annotated[Set[str], _merge_tools]
chart_data: Optional[Dict]
# Skin analysis
skin_analysis_result: Optional[Dict]
skin_llm_prompt: Optional[str]
# Reflection
reflection_gaps: List[str]
should_continue: bool
# Control
iteration: int
max_iterations: int
# Output - events collected by each node for streaming to UI
events: Annotated[List[Dict], _append_events]
final_answer: str
# =============================================================================
# Node Functions
# =============================================================================
async def discover_node(state: GraphState) -> Dict:
"""DISCOVER: Get patient data manifest."""
print(f"[LANGGRAPH] discover_node")
events = []
events.append({"type": "status", "message": "Discovering available data..."})
manifest = get_patient_manifest(state["patient_id"])
# Build summary
patient_info = manifest.get("patient_info", {})
available = manifest.get("available_data", {})
summary_parts = []
if patient_info:
summary_parts.append(f"Patient: {patient_info.get('name', 'Unknown')}, {patient_info.get('age', '?')}y, {patient_info.get('gender', '')}")
data_counts = [f"{info['count']} {table}" for table, info in available.items()]
if data_counts:
summary_parts.append(f"Available: {', '.join(data_counts)}")
events.append({"type": "discovery", "manifest": manifest, "summary": " | ".join(summary_parts)})
return {
"manifest": manifest,
"events": events
}
async def skin_analysis_node(state: GraphState) -> Dict:
"""SKIN ANALYSIS: Analyze uploaded skin image (if present)."""
if not state.get("skin_image_data"):
return {"events": []}
print(f"[LANGGRAPH] skin_analysis_node - image: {len(state['skin_image_data'])} chars")
events = []
events.append({"type": "status", "message": "Analyzing skin image with Derm Foundation + SCIN Classifier..."})
try:
from tools import analyze_skin_image
skin_result_str = analyze_skin_image(
state["patient_id"],
state["skin_image_data"],
symptoms=state["question"]
)
skin_result = json.loads(skin_result_str)
events.append({
"type": "skin_analysis",
"data": skin_result,
"image_data": state["skin_image_data"]
})
# Extract facts from skin analysis
facts = []
if skin_result.get("status") == "success":
facts_parts = []
conditions = skin_result.get("conditions", [])
if conditions:
if isinstance(conditions[0], dict):
cond_strs = [f"{c.get('name', 'Unknown')} ({c.get('confidence', 0)}% - {c.get('likelihood', 'possible')})" for c in conditions[:3]]
else:
cond_strs = [str(c) for c in conditions[:3]]
facts_parts.append(f"Possible conditions: {', '.join(cond_strs)}")
symptoms_img = skin_result.get("symptoms_from_image", [])
if symptoms_img:
if isinstance(symptoms_img[0], dict):
symp_strs = [s.get('name', '') for s in symptoms_img[:3] if s.get('name')]
else:
symp_strs = [str(s) for s in symptoms_img[:3]]
if symp_strs:
facts_parts.append(f"Detected symptoms: {', '.join(symp_strs)}")
facts_str = ". ".join(facts_parts) if facts_parts else f"Skin image analyzed with {skin_result.get('model', 'Derm Foundation')}."
facts_str += "\n" + skin_result.get("disclaimer", "")
facts.append({"tool": "analyze_skin_image", "facts": facts_str})
else:
facts.append({"tool": "analyze_skin_image", "facts": f"Skin analysis error: {skin_result.get('error', 'Unknown')}"})
return {
"skin_analysis_result": skin_result,
"skin_llm_prompt": skin_result.get("llm_synthesis_prompt", ""),
"collected_facts": facts,
"executed_tools": {"analyze_skin_image"},
"events": events
}
except Exception as e:
print(f"[LANGGRAPH] Skin analysis error: {e}")
events.append({"type": "error", "message": f"Skin analysis error: {str(e)}"})
return {"events": events}
async def plan_node(state: GraphState) -> Dict:
"""PLAN: Use LLM to identify which tools are needed."""
iteration = state.get("iteration", 0) + 1
print(f"[LANGGRAPH] plan_node - iteration {iteration}")
events = []
events.append({"type": "status", "message": f"Planning approach (iteration {iteration})..."})
# Build an AgentState to reuse plan_tools from agent_v2
agent_state = AgentState(
patient_id=state["patient_id"],
question=state["question"],
conversation_history=state.get("conversation_history", []),
manifest=state["manifest"],
skin_image_data=state.get("skin_image_data"),
executed_tools=state.get("executed_tools", set()),
reflection_gaps=state.get("reflection_gaps", []),
iteration=iteration,
max_iterations=state.get("max_iterations", 3)
)
# If we have gaps from reflection, add context
if state.get("reflection_gaps"):
gap_context = f"\n\nPrevious iteration found these gaps: {', '.join(state['reflection_gaps'])}"
agent_state.question = state["question"] + gap_context
# Call the working plan_tools from agent_v2
planned = await plan_tools(agent_state)
# Remove already executed tools
executed = state.get("executed_tools", set())
planned = [t for t in planned if t.get("tool") not in executed]
# Remove skin analysis if already done
if state.get("skin_image_data"):
planned = [t for t in planned if t.get("tool") != "analyze_skin_image"]
relevant_categories = get_relevant_categories(state["question"], state["manifest"])
events.append({
"type": "plan",
"tools": planned,
"iteration": iteration,
"tool_filtering": {
"categories_used": sorted(relevant_categories),
"total_tools": len(TOOLS)
}
})
if not planned:
print(f"[LANGGRAPH] No new tools to execute")
return {
"planned_tools": planned,
"iteration": iteration,
"reflection_gaps": [], # Clear after use
"events": events
}
async def execute_node(state: GraphState) -> Dict:
"""EXECUTE: Run planned tools and extract facts."""
print(f"[LANGGRAPH] execute_node - {len(state['planned_tools'])} tools")
events = []
new_facts = []
new_executed = set()
chart_data = None
# Build AgentState for execute_and_extract
agent_state = AgentState(
patient_id=state["patient_id"],
question=state["question"],
manifest=state["manifest"],
executed_tools=state.get("executed_tools", set()),
iteration=state.get("iteration", 1),
max_iterations=state.get("max_iterations", 3)
)
for tool_call in state["planned_tools"]:
tool_name = tool_call.get("tool", "unknown")
tool_args = tool_call.get("args", {})
reason = tool_call.get("reason", "")
if tool_name in state.get("executed_tools", set()):
continue
events.append({"type": "status", "message": f"Retrieving {tool_name}..."})
events.append({"type": "tool_call", "tool": tool_name, "args": tool_args, "reason": reason})
try:
fact_result = await execute_and_extract(agent_state, tool_call)
new_facts.append(fact_result)
new_executed.add(tool_name)
events.append({
"type": "tool_result",
"tool": tool_name,
"facts": fact_result.get("facts", ""),
"raw_preview": fact_result.get("raw_data", "")[:200]
})
# Check for chart data
if agent_state.chart_data:
chart_data = agent_state.chart_data
events.append({"type": "chart_data", "data": chart_data})
agent_state.chart_data = None
except Exception as e:
print(f"[LANGGRAPH] Tool error {tool_name}: {e}")
events.append({"type": "tool_error", "tool": tool_name, "error": str(e)})
new_executed.add(tool_name)
return {
"collected_facts": new_facts,
"executed_tools": new_executed,
"chart_data": chart_data,
"events": events
}
async def reflect_node(state: GraphState) -> Dict:
"""REFLECT: Evaluate if collected data is sufficient."""
iteration = state.get("iteration", 1)
max_iter = state.get("max_iterations", 3)
print(f"[LANGGRAPH] reflect_node - iteration {iteration}/{max_iter}")
events = []
events.append({"type": "status", "message": "Reflecting on gathered information..."})
# Build AgentState for reflect_on_facts
agent_state = AgentState(
patient_id=state["patient_id"],
question=state["question"],
collected_facts=state.get("collected_facts", []),
executed_tools=state.get("executed_tools", set()),
iteration=iteration,
max_iterations=max_iter
)
reflection = await reflect_on_facts(agent_state)
has_enough = reflection.get("has_enough_info", True)
confidence = reflection.get("confidence", 0.8)
gaps = reflection.get("gaps", [])
events.append({
"type": "reflection",
"has_enough_info": has_enough,
"confidence": confidence,
"gaps": gaps,
"reasoning": reflection.get("reasoning", ""),
"iteration": iteration
})
should_continue = not has_enough and iteration < max_iter
if has_enough:
print(f"[LANGGRAPH] Reflection: Have enough info (confidence: {confidence})")
else:
print(f"[LANGGRAPH] Reflection: Need more info. Gaps: {gaps}")
return {
"reflection_gaps": gaps if should_continue else [],
"should_continue": should_continue,
"events": events
}
async def synthesize_node(state: GraphState) -> Dict:
"""SYNTHESIZE: Mark ready for synthesis (actual streaming happens in runner)."""
print(f"[LANGGRAPH] synthesize_node - marking ready")
# Don't buffer tokens here - the runner will handle streaming directly
return {
"final_answer": "__READY_FOR_SYNTHESIS__",
"events": [{"type": "status", "message": "Generating answer..."}]
}
# =============================================================================
# Conditional Edge Functions
# =============================================================================
def should_execute(state: GraphState) -> Literal["execute", "synthesize"]:
"""After plan: execute tools or skip to synthesis."""
if state.get("planned_tools"):
return "execute"
return "synthesize"
def should_loop_or_finish(state: GraphState) -> Literal["plan", "synthesize"]:
"""After reflect: loop back to plan or proceed to synthesis."""
if state.get("should_continue", False):
return "plan"
return "synthesize"
# =============================================================================
# Build the Graph
# =============================================================================
def create_med_agent_graph():
"""
Create and compile the MedGemma agent graph.
Graph:
START -> discover -> skin_analysis -> plan -> [execute or synthesize]
^ |
| v
| reflect -> [plan or synthesize]
|___________|
"""
graph = StateGraph(GraphState)
# Add nodes
graph.add_node("discover", discover_node)
graph.add_node("skin_analysis", skin_analysis_node)
graph.add_node("plan", plan_node)
graph.add_node("execute", execute_node)
graph.add_node("reflect", reflect_node)
graph.add_node("synthesize", synthesize_node)
# Edges: START -> discover -> skin_analysis -> plan
graph.add_edge(START, "discover")
graph.add_edge("discover", "skin_analysis")
graph.add_edge("skin_analysis", "plan")
# Conditional: plan -> execute (if tools) or synthesize (if none)
graph.add_conditional_edges(
"plan",
should_execute,
{"execute": "execute", "synthesize": "synthesize"}
)
# execute -> reflect
graph.add_edge("execute", "reflect")
# Conditional: reflect -> plan (loop back) or synthesize (done)
graph.add_conditional_edges(
"reflect",
should_loop_or_finish,
{"plan": "plan", "synthesize": "synthesize"}
)
# synthesize -> END
graph.add_edge("synthesize", END)
# Compile with checkpointer
checkpointer = MemorySaver()
compiled = graph.compile(checkpointer=checkpointer)
print("[LANGGRAPH] Graph compiled successfully")
return compiled
# Global graph instance (lazy init)
_GRAPH = None
def get_graph():
"""Get or create the compiled graph."""
global _GRAPH
if _GRAPH is None:
_GRAPH = create_med_agent_graph()
return _GRAPH
# =============================================================================
# Main Runner - Streams events to the UI
# =============================================================================
async def run_agent_langgraph(
patient_id: str,
question: str,
skin_image_data: Optional[str] = None,
conversation_history: Optional[List[Dict]] = None,
thread_id: str = None
) -> AsyncGenerator[Dict, None]:
"""
Run the LangGraph agent and yield SSE events for the UI.
The graph handles discover→skin→plan→execute→reflect loop.
Synthesis is streamed directly (not buffered) for real-time typing effect.
"""
graph = get_graph()
# Initial state
initial_state = {
"patient_id": patient_id,
"question": question,
"skin_image_data": skin_image_data,
"conversation_history": conversation_history or [],
"manifest": {},
"planned_tools": [],
"collected_facts": [],
"executed_tools": set(),
"chart_data": None,
"skin_analysis_result": None,
"skin_llm_prompt": None,
"reflection_gaps": [],
"should_continue": True,
"iteration": 0,
"max_iterations": 3,
"events": [],
"final_answer": ""
}
config = {"configurable": {"thread_id": thread_id or f"thread_{patient_id}_{id(question)}"}}
final_state = None
try:
# Stream through graph - yields events from each node in real-time
async for step in graph.astream(initial_state, config, stream_mode="updates"):
for node_name, state_update in step.items():
node_events = state_update.get("events", [])
for event in node_events:
yield event
# Capture state for synthesis
if node_name == "synthesize":
final_state = state_update
# Now get the full final state from the graph for synthesis
# We need the accumulated state (collected_facts, manifest, etc.)
graph_state = graph.get_state(config)
full_state = graph_state.values if graph_state else initial_state
# Stream synthesis tokens directly (not buffered!)
yield {"type": "answer_start", "content": ""}
agent_state = AgentState(
patient_id=full_state.get("patient_id", patient_id),
question=full_state.get("question", question),
conversation_history=full_state.get("conversation_history", []),
manifest=full_state.get("manifest", {}),
collected_facts=full_state.get("collected_facts", []),
skin_image_data=full_state.get("skin_image_data"),
skin_analysis_result=full_state.get("skin_analysis_result"),
skin_llm_prompt=full_state.get("skin_llm_prompt"),
iteration=full_state.get("iteration", 1)
)
async for token in synthesize_answer(agent_state):
yield {"type": "token", "content": token}
yield {"type": "answer_end", "content": ""}
yield {
"type": "workflow_complete",
"iterations": full_state.get("iteration", 1),
"tools_executed": list(full_state.get("executed_tools", set())),
}
except Exception as e:
import traceback
traceback.print_exc()
yield {"type": "error", "message": f"Agent error: {str(e)}"}
# =============================================================================
# Simple Interface for Testing
# =============================================================================
async def run_agent_langgraph_simple(patient_id: str, question: str) -> str:
"""Simple interface - returns just the final answer."""
answer = ""
async for event in run_agent_langgraph(patient_id, question):
if event["type"] == "token":
answer += event["content"]
elif event["type"] == "error":
answer = f"Error: {event['message']}"
return answer
# =============================================================================
# Visualization
# =============================================================================
def get_graph_visualization():
"""Get Mermaid diagram of the workflow graph."""
graph = get_graph()
try:
return graph.get_graph().draw_mermaid()
except Exception:
return """
graph TD
START --> discover
discover --> skin_analysis
skin_analysis --> plan
plan -->|has tools| execute
plan -->|no tools| synthesize
execute --> reflect
reflect -->|gaps found| plan
reflect -->|enough info| synthesize
synthesize --> END
"""
if __name__ == "__main__":
print("MedGemma LangGraph Agent")
print("=" * 50)
print("\nWorkflow Graph (Mermaid):\n")
print(get_graph_visualization())