""" MedGemma Agent using LangGraph ============================== A proper graph-based workflow using LangGraph's StateGraph. Workflow Graph: START -> discover -> skin_analysis -> plan -> execute -> reflect ^ | |___(gaps found)_____| | v synthesize -> END Reuses all working logic from agent_v2.py: - Dynamic tool registry with category filtering - Condensed tool descriptions (token reduction) - Semantic parameter normalization - LLM-based error recovery - Skin image analysis with Derm Foundation """ import os import json import re from typing import Dict, List, Optional, Set, Any, Literal, AsyncGenerator from typing_extensions import TypedDict, Annotated # LangGraph imports from langgraph.graph import StateGraph, START, END from langgraph.checkpoint.memory import MemorySaver # Reuse ALL working logic from agent_v2 from agent_v2 import ( call_llm, stream_llm, filter_thinking, get_patient_manifest, plan_tools, execute_and_extract, reflect_on_facts, synthesize_answer, get_relevant_categories, get_filtered_tools_description, format_conversation_history, WorkflowPhase, AgentState, TOOLS ) from tools import execute_tool, get_tools_description LLAMA_SERVER_URL = os.environ.get('LLAMA_SERVER_URL', 'http://localhost:8080') # ============================================================================= # State Definition # ============================================================================= def _append_events(existing: List[Dict], new: List[Dict]) -> List[Dict]: """Reducer: append new events to existing list.""" return existing + new def _append_facts(existing: List[Dict], new: List[Dict]) -> List[Dict]: """Reducer: append new facts to existing list.""" return existing + new def _merge_tools(existing: Set[str], new: Set[str]) -> Set[str]: """Reducer: merge executed tools sets.""" return existing | new class GraphState(TypedDict): """State that flows through the LangGraph.""" # Input (set once) patient_id: str question: str skin_image_data: Optional[str] conversation_history: List[Dict] # Prior conversation turns # Discovery manifest: Dict # Planning planned_tools: List[Dict] # Execution collected_facts: Annotated[List[Dict], _append_facts] executed_tools: Annotated[Set[str], _merge_tools] chart_data: Optional[Dict] # Skin analysis skin_analysis_result: Optional[Dict] skin_llm_prompt: Optional[str] # Reflection reflection_gaps: List[str] should_continue: bool # Control iteration: int max_iterations: int # Output - events collected by each node for streaming to UI events: Annotated[List[Dict], _append_events] final_answer: str # ============================================================================= # Node Functions # ============================================================================= async def discover_node(state: GraphState) -> Dict: """DISCOVER: Get patient data manifest.""" print(f"[LANGGRAPH] discover_node") events = [] events.append({"type": "status", "message": "Discovering available data..."}) manifest = get_patient_manifest(state["patient_id"]) # Build summary patient_info = manifest.get("patient_info", {}) available = manifest.get("available_data", {}) summary_parts = [] if patient_info: summary_parts.append(f"Patient: {patient_info.get('name', 'Unknown')}, {patient_info.get('age', '?')}y, {patient_info.get('gender', '')}") data_counts = [f"{info['count']} {table}" for table, info in available.items()] if data_counts: summary_parts.append(f"Available: {', '.join(data_counts)}") events.append({"type": "discovery", "manifest": manifest, "summary": " | ".join(summary_parts)}) return { "manifest": manifest, "events": events } async def skin_analysis_node(state: GraphState) -> Dict: """SKIN ANALYSIS: Analyze uploaded skin image (if present).""" if not state.get("skin_image_data"): return {"events": []} print(f"[LANGGRAPH] skin_analysis_node - image: {len(state['skin_image_data'])} chars") events = [] events.append({"type": "status", "message": "Analyzing skin image with Derm Foundation + SCIN Classifier..."}) try: from tools import analyze_skin_image skin_result_str = analyze_skin_image( state["patient_id"], state["skin_image_data"], symptoms=state["question"] ) skin_result = json.loads(skin_result_str) events.append({ "type": "skin_analysis", "data": skin_result, "image_data": state["skin_image_data"] }) # Extract facts from skin analysis facts = [] if skin_result.get("status") == "success": facts_parts = [] conditions = skin_result.get("conditions", []) if conditions: if isinstance(conditions[0], dict): cond_strs = [f"{c.get('name', 'Unknown')} ({c.get('confidence', 0)}% - {c.get('likelihood', 'possible')})" for c in conditions[:3]] else: cond_strs = [str(c) for c in conditions[:3]] facts_parts.append(f"Possible conditions: {', '.join(cond_strs)}") symptoms_img = skin_result.get("symptoms_from_image", []) if symptoms_img: if isinstance(symptoms_img[0], dict): symp_strs = [s.get('name', '') for s in symptoms_img[:3] if s.get('name')] else: symp_strs = [str(s) for s in symptoms_img[:3]] if symp_strs: facts_parts.append(f"Detected symptoms: {', '.join(symp_strs)}") facts_str = ". ".join(facts_parts) if facts_parts else f"Skin image analyzed with {skin_result.get('model', 'Derm Foundation')}." facts_str += "\n" + skin_result.get("disclaimer", "") facts.append({"tool": "analyze_skin_image", "facts": facts_str}) else: facts.append({"tool": "analyze_skin_image", "facts": f"Skin analysis error: {skin_result.get('error', 'Unknown')}"}) return { "skin_analysis_result": skin_result, "skin_llm_prompt": skin_result.get("llm_synthesis_prompt", ""), "collected_facts": facts, "executed_tools": {"analyze_skin_image"}, "events": events } except Exception as e: print(f"[LANGGRAPH] Skin analysis error: {e}") events.append({"type": "error", "message": f"Skin analysis error: {str(e)}"}) return {"events": events} async def plan_node(state: GraphState) -> Dict: """PLAN: Use LLM to identify which tools are needed.""" iteration = state.get("iteration", 0) + 1 print(f"[LANGGRAPH] plan_node - iteration {iteration}") events = [] events.append({"type": "status", "message": f"Planning approach (iteration {iteration})..."}) # Build an AgentState to reuse plan_tools from agent_v2 agent_state = AgentState( patient_id=state["patient_id"], question=state["question"], conversation_history=state.get("conversation_history", []), manifest=state["manifest"], skin_image_data=state.get("skin_image_data"), executed_tools=state.get("executed_tools", set()), reflection_gaps=state.get("reflection_gaps", []), iteration=iteration, max_iterations=state.get("max_iterations", 3) ) # If we have gaps from reflection, add context if state.get("reflection_gaps"): gap_context = f"\n\nPrevious iteration found these gaps: {', '.join(state['reflection_gaps'])}" agent_state.question = state["question"] + gap_context # Call the working plan_tools from agent_v2 planned = await plan_tools(agent_state) # Remove already executed tools executed = state.get("executed_tools", set()) planned = [t for t in planned if t.get("tool") not in executed] # Remove skin analysis if already done if state.get("skin_image_data"): planned = [t for t in planned if t.get("tool") != "analyze_skin_image"] relevant_categories = get_relevant_categories(state["question"], state["manifest"]) events.append({ "type": "plan", "tools": planned, "iteration": iteration, "tool_filtering": { "categories_used": sorted(relevant_categories), "total_tools": len(TOOLS) } }) if not planned: print(f"[LANGGRAPH] No new tools to execute") return { "planned_tools": planned, "iteration": iteration, "reflection_gaps": [], # Clear after use "events": events } async def execute_node(state: GraphState) -> Dict: """EXECUTE: Run planned tools and extract facts.""" print(f"[LANGGRAPH] execute_node - {len(state['planned_tools'])} tools") events = [] new_facts = [] new_executed = set() chart_data = None # Build AgentState for execute_and_extract agent_state = AgentState( patient_id=state["patient_id"], question=state["question"], manifest=state["manifest"], executed_tools=state.get("executed_tools", set()), iteration=state.get("iteration", 1), max_iterations=state.get("max_iterations", 3) ) for tool_call in state["planned_tools"]: tool_name = tool_call.get("tool", "unknown") tool_args = tool_call.get("args", {}) reason = tool_call.get("reason", "") if tool_name in state.get("executed_tools", set()): continue events.append({"type": "status", "message": f"Retrieving {tool_name}..."}) events.append({"type": "tool_call", "tool": tool_name, "args": tool_args, "reason": reason}) try: fact_result = await execute_and_extract(agent_state, tool_call) new_facts.append(fact_result) new_executed.add(tool_name) events.append({ "type": "tool_result", "tool": tool_name, "facts": fact_result.get("facts", ""), "raw_preview": fact_result.get("raw_data", "")[:200] }) # Check for chart data if agent_state.chart_data: chart_data = agent_state.chart_data events.append({"type": "chart_data", "data": chart_data}) agent_state.chart_data = None except Exception as e: print(f"[LANGGRAPH] Tool error {tool_name}: {e}") events.append({"type": "tool_error", "tool": tool_name, "error": str(e)}) new_executed.add(tool_name) return { "collected_facts": new_facts, "executed_tools": new_executed, "chart_data": chart_data, "events": events } async def reflect_node(state: GraphState) -> Dict: """REFLECT: Evaluate if collected data is sufficient.""" iteration = state.get("iteration", 1) max_iter = state.get("max_iterations", 3) print(f"[LANGGRAPH] reflect_node - iteration {iteration}/{max_iter}") events = [] events.append({"type": "status", "message": "Reflecting on gathered information..."}) # Build AgentState for reflect_on_facts agent_state = AgentState( patient_id=state["patient_id"], question=state["question"], collected_facts=state.get("collected_facts", []), executed_tools=state.get("executed_tools", set()), iteration=iteration, max_iterations=max_iter ) reflection = await reflect_on_facts(agent_state) has_enough = reflection.get("has_enough_info", True) confidence = reflection.get("confidence", 0.8) gaps = reflection.get("gaps", []) events.append({ "type": "reflection", "has_enough_info": has_enough, "confidence": confidence, "gaps": gaps, "reasoning": reflection.get("reasoning", ""), "iteration": iteration }) should_continue = not has_enough and iteration < max_iter if has_enough: print(f"[LANGGRAPH] Reflection: Have enough info (confidence: {confidence})") else: print(f"[LANGGRAPH] Reflection: Need more info. Gaps: {gaps}") return { "reflection_gaps": gaps if should_continue else [], "should_continue": should_continue, "events": events } async def synthesize_node(state: GraphState) -> Dict: """SYNTHESIZE: Mark ready for synthesis (actual streaming happens in runner).""" print(f"[LANGGRAPH] synthesize_node - marking ready") # Don't buffer tokens here - the runner will handle streaming directly return { "final_answer": "__READY_FOR_SYNTHESIS__", "events": [{"type": "status", "message": "Generating answer..."}] } # ============================================================================= # Conditional Edge Functions # ============================================================================= def should_execute(state: GraphState) -> Literal["execute", "synthesize"]: """After plan: execute tools or skip to synthesis.""" if state.get("planned_tools"): return "execute" return "synthesize" def should_loop_or_finish(state: GraphState) -> Literal["plan", "synthesize"]: """After reflect: loop back to plan or proceed to synthesis.""" if state.get("should_continue", False): return "plan" return "synthesize" # ============================================================================= # Build the Graph # ============================================================================= def create_med_agent_graph(): """ Create and compile the MedGemma agent graph. Graph: START -> discover -> skin_analysis -> plan -> [execute or synthesize] ^ | | v | reflect -> [plan or synthesize] |___________| """ graph = StateGraph(GraphState) # Add nodes graph.add_node("discover", discover_node) graph.add_node("skin_analysis", skin_analysis_node) graph.add_node("plan", plan_node) graph.add_node("execute", execute_node) graph.add_node("reflect", reflect_node) graph.add_node("synthesize", synthesize_node) # Edges: START -> discover -> skin_analysis -> plan graph.add_edge(START, "discover") graph.add_edge("discover", "skin_analysis") graph.add_edge("skin_analysis", "plan") # Conditional: plan -> execute (if tools) or synthesize (if none) graph.add_conditional_edges( "plan", should_execute, {"execute": "execute", "synthesize": "synthesize"} ) # execute -> reflect graph.add_edge("execute", "reflect") # Conditional: reflect -> plan (loop back) or synthesize (done) graph.add_conditional_edges( "reflect", should_loop_or_finish, {"plan": "plan", "synthesize": "synthesize"} ) # synthesize -> END graph.add_edge("synthesize", END) # Compile with checkpointer checkpointer = MemorySaver() compiled = graph.compile(checkpointer=checkpointer) print("[LANGGRAPH] Graph compiled successfully") return compiled # Global graph instance (lazy init) _GRAPH = None def get_graph(): """Get or create the compiled graph.""" global _GRAPH if _GRAPH is None: _GRAPH = create_med_agent_graph() return _GRAPH # ============================================================================= # Main Runner - Streams events to the UI # ============================================================================= async def run_agent_langgraph( patient_id: str, question: str, skin_image_data: Optional[str] = None, conversation_history: Optional[List[Dict]] = None, thread_id: str = None ) -> AsyncGenerator[Dict, None]: """ Run the LangGraph agent and yield SSE events for the UI. The graph handles discover→skin→plan→execute→reflect loop. Synthesis is streamed directly (not buffered) for real-time typing effect. """ graph = get_graph() # Initial state initial_state = { "patient_id": patient_id, "question": question, "skin_image_data": skin_image_data, "conversation_history": conversation_history or [], "manifest": {}, "planned_tools": [], "collected_facts": [], "executed_tools": set(), "chart_data": None, "skin_analysis_result": None, "skin_llm_prompt": None, "reflection_gaps": [], "should_continue": True, "iteration": 0, "max_iterations": 3, "events": [], "final_answer": "" } config = {"configurable": {"thread_id": thread_id or f"thread_{patient_id}_{id(question)}"}} final_state = None try: # Stream through graph - yields events from each node in real-time async for step in graph.astream(initial_state, config, stream_mode="updates"): for node_name, state_update in step.items(): node_events = state_update.get("events", []) for event in node_events: yield event # Capture state for synthesis if node_name == "synthesize": final_state = state_update # Now get the full final state from the graph for synthesis # We need the accumulated state (collected_facts, manifest, etc.) graph_state = graph.get_state(config) full_state = graph_state.values if graph_state else initial_state # Stream synthesis tokens directly (not buffered!) yield {"type": "answer_start", "content": ""} agent_state = AgentState( patient_id=full_state.get("patient_id", patient_id), question=full_state.get("question", question), conversation_history=full_state.get("conversation_history", []), manifest=full_state.get("manifest", {}), collected_facts=full_state.get("collected_facts", []), skin_image_data=full_state.get("skin_image_data"), skin_analysis_result=full_state.get("skin_analysis_result"), skin_llm_prompt=full_state.get("skin_llm_prompt"), iteration=full_state.get("iteration", 1) ) async for token in synthesize_answer(agent_state): yield {"type": "token", "content": token} yield {"type": "answer_end", "content": ""} yield { "type": "workflow_complete", "iterations": full_state.get("iteration", 1), "tools_executed": list(full_state.get("executed_tools", set())), } except Exception as e: import traceback traceback.print_exc() yield {"type": "error", "message": f"Agent error: {str(e)}"} # ============================================================================= # Simple Interface for Testing # ============================================================================= async def run_agent_langgraph_simple(patient_id: str, question: str) -> str: """Simple interface - returns just the final answer.""" answer = "" async for event in run_agent_langgraph(patient_id, question): if event["type"] == "token": answer += event["content"] elif event["type"] == "error": answer = f"Error: {event['message']}" return answer # ============================================================================= # Visualization # ============================================================================= def get_graph_visualization(): """Get Mermaid diagram of the workflow graph.""" graph = get_graph() try: return graph.get_graph().draw_mermaid() except Exception: return """ graph TD START --> discover discover --> skin_analysis skin_analysis --> plan plan -->|has tools| execute plan -->|no tools| synthesize execute --> reflect reflect -->|gaps found| plan reflect -->|enough info| synthesize synthesize --> END """ if __name__ == "__main__": print("MedGemma LangGraph Agent") print("=" * 50) print("\nWorkflow Graph (Mermaid):\n") print(get_graph_visualization())