Spaces:
Sleeping
Sleeping
| """ | |
| MedGemma Agent using LangGraph | |
| ============================== | |
| A proper graph-based workflow using LangGraph's StateGraph. | |
| Workflow Graph: | |
| START -> discover -> skin_analysis -> plan -> execute -> reflect | |
| ^ | | |
| |___(gaps found)_____| | |
| | | |
| v | |
| synthesize -> END | |
| Reuses all working logic from agent_v2.py: | |
| - Dynamic tool registry with category filtering | |
| - Condensed tool descriptions (token reduction) | |
| - Semantic parameter normalization | |
| - LLM-based error recovery | |
| - Skin image analysis with Derm Foundation | |
| """ | |
| import os | |
| import json | |
| import re | |
| from typing import Dict, List, Optional, Set, Any, Literal, AsyncGenerator | |
| from typing_extensions import TypedDict, Annotated | |
| # LangGraph imports | |
| from langgraph.graph import StateGraph, START, END | |
| from langgraph.checkpoint.memory import MemorySaver | |
| # Reuse ALL working logic from agent_v2 | |
| from agent_v2 import ( | |
| call_llm, stream_llm, filter_thinking, | |
| get_patient_manifest, plan_tools, execute_and_extract, | |
| reflect_on_facts, synthesize_answer, | |
| get_relevant_categories, get_filtered_tools_description, | |
| format_conversation_history, | |
| WorkflowPhase, AgentState, | |
| TOOLS | |
| ) | |
| from tools import execute_tool, get_tools_description | |
| LLAMA_SERVER_URL = os.environ.get('LLAMA_SERVER_URL', 'http://localhost:8080') | |
| # ============================================================================= | |
| # State Definition | |
| # ============================================================================= | |
| def _append_events(existing: List[Dict], new: List[Dict]) -> List[Dict]: | |
| """Reducer: append new events to existing list.""" | |
| return existing + new | |
| def _append_facts(existing: List[Dict], new: List[Dict]) -> List[Dict]: | |
| """Reducer: append new facts to existing list.""" | |
| return existing + new | |
| def _merge_tools(existing: Set[str], new: Set[str]) -> Set[str]: | |
| """Reducer: merge executed tools sets.""" | |
| return existing | new | |
| class GraphState(TypedDict): | |
| """State that flows through the LangGraph.""" | |
| # Input (set once) | |
| patient_id: str | |
| question: str | |
| skin_image_data: Optional[str] | |
| conversation_history: List[Dict] # Prior conversation turns | |
| # Discovery | |
| manifest: Dict | |
| # Planning | |
| planned_tools: List[Dict] | |
| # Execution | |
| collected_facts: Annotated[List[Dict], _append_facts] | |
| executed_tools: Annotated[Set[str], _merge_tools] | |
| chart_data: Optional[Dict] | |
| # Skin analysis | |
| skin_analysis_result: Optional[Dict] | |
| skin_llm_prompt: Optional[str] | |
| # Reflection | |
| reflection_gaps: List[str] | |
| should_continue: bool | |
| # Control | |
| iteration: int | |
| max_iterations: int | |
| # Output - events collected by each node for streaming to UI | |
| events: Annotated[List[Dict], _append_events] | |
| final_answer: str | |
| # ============================================================================= | |
| # Node Functions | |
| # ============================================================================= | |
| async def discover_node(state: GraphState) -> Dict: | |
| """DISCOVER: Get patient data manifest.""" | |
| print(f"[LANGGRAPH] discover_node") | |
| events = [] | |
| events.append({"type": "status", "message": "Discovering available data..."}) | |
| manifest = get_patient_manifest(state["patient_id"]) | |
| # Build summary | |
| patient_info = manifest.get("patient_info", {}) | |
| available = manifest.get("available_data", {}) | |
| summary_parts = [] | |
| if patient_info: | |
| summary_parts.append(f"Patient: {patient_info.get('name', 'Unknown')}, {patient_info.get('age', '?')}y, {patient_info.get('gender', '')}") | |
| data_counts = [f"{info['count']} {table}" for table, info in available.items()] | |
| if data_counts: | |
| summary_parts.append(f"Available: {', '.join(data_counts)}") | |
| events.append({"type": "discovery", "manifest": manifest, "summary": " | ".join(summary_parts)}) | |
| return { | |
| "manifest": manifest, | |
| "events": events | |
| } | |
| async def skin_analysis_node(state: GraphState) -> Dict: | |
| """SKIN ANALYSIS: Analyze uploaded skin image (if present).""" | |
| if not state.get("skin_image_data"): | |
| return {"events": []} | |
| print(f"[LANGGRAPH] skin_analysis_node - image: {len(state['skin_image_data'])} chars") | |
| events = [] | |
| events.append({"type": "status", "message": "Analyzing skin image with Derm Foundation + SCIN Classifier..."}) | |
| try: | |
| from tools import analyze_skin_image | |
| skin_result_str = analyze_skin_image( | |
| state["patient_id"], | |
| state["skin_image_data"], | |
| symptoms=state["question"] | |
| ) | |
| skin_result = json.loads(skin_result_str) | |
| events.append({ | |
| "type": "skin_analysis", | |
| "data": skin_result, | |
| "image_data": state["skin_image_data"] | |
| }) | |
| # Extract facts from skin analysis | |
| facts = [] | |
| if skin_result.get("status") == "success": | |
| facts_parts = [] | |
| conditions = skin_result.get("conditions", []) | |
| if conditions: | |
| if isinstance(conditions[0], dict): | |
| cond_strs = [f"{c.get('name', 'Unknown')} ({c.get('confidence', 0)}% - {c.get('likelihood', 'possible')})" for c in conditions[:3]] | |
| else: | |
| cond_strs = [str(c) for c in conditions[:3]] | |
| facts_parts.append(f"Possible conditions: {', '.join(cond_strs)}") | |
| symptoms_img = skin_result.get("symptoms_from_image", []) | |
| if symptoms_img: | |
| if isinstance(symptoms_img[0], dict): | |
| symp_strs = [s.get('name', '') for s in symptoms_img[:3] if s.get('name')] | |
| else: | |
| symp_strs = [str(s) for s in symptoms_img[:3]] | |
| if symp_strs: | |
| facts_parts.append(f"Detected symptoms: {', '.join(symp_strs)}") | |
| facts_str = ". ".join(facts_parts) if facts_parts else f"Skin image analyzed with {skin_result.get('model', 'Derm Foundation')}." | |
| facts_str += "\n" + skin_result.get("disclaimer", "") | |
| facts.append({"tool": "analyze_skin_image", "facts": facts_str}) | |
| else: | |
| facts.append({"tool": "analyze_skin_image", "facts": f"Skin analysis error: {skin_result.get('error', 'Unknown')}"}) | |
| return { | |
| "skin_analysis_result": skin_result, | |
| "skin_llm_prompt": skin_result.get("llm_synthesis_prompt", ""), | |
| "collected_facts": facts, | |
| "executed_tools": {"analyze_skin_image"}, | |
| "events": events | |
| } | |
| except Exception as e: | |
| print(f"[LANGGRAPH] Skin analysis error: {e}") | |
| events.append({"type": "error", "message": f"Skin analysis error: {str(e)}"}) | |
| return {"events": events} | |
| async def plan_node(state: GraphState) -> Dict: | |
| """PLAN: Use LLM to identify which tools are needed.""" | |
| iteration = state.get("iteration", 0) + 1 | |
| print(f"[LANGGRAPH] plan_node - iteration {iteration}") | |
| events = [] | |
| events.append({"type": "status", "message": f"Planning approach (iteration {iteration})..."}) | |
| # Build an AgentState to reuse plan_tools from agent_v2 | |
| agent_state = AgentState( | |
| patient_id=state["patient_id"], | |
| question=state["question"], | |
| conversation_history=state.get("conversation_history", []), | |
| manifest=state["manifest"], | |
| skin_image_data=state.get("skin_image_data"), | |
| executed_tools=state.get("executed_tools", set()), | |
| reflection_gaps=state.get("reflection_gaps", []), | |
| iteration=iteration, | |
| max_iterations=state.get("max_iterations", 3) | |
| ) | |
| # If we have gaps from reflection, add context | |
| if state.get("reflection_gaps"): | |
| gap_context = f"\n\nPrevious iteration found these gaps: {', '.join(state['reflection_gaps'])}" | |
| agent_state.question = state["question"] + gap_context | |
| # Call the working plan_tools from agent_v2 | |
| planned = await plan_tools(agent_state) | |
| # Remove already executed tools | |
| executed = state.get("executed_tools", set()) | |
| planned = [t for t in planned if t.get("tool") not in executed] | |
| # Remove skin analysis if already done | |
| if state.get("skin_image_data"): | |
| planned = [t for t in planned if t.get("tool") != "analyze_skin_image"] | |
| relevant_categories = get_relevant_categories(state["question"], state["manifest"]) | |
| events.append({ | |
| "type": "plan", | |
| "tools": planned, | |
| "iteration": iteration, | |
| "tool_filtering": { | |
| "categories_used": sorted(relevant_categories), | |
| "total_tools": len(TOOLS) | |
| } | |
| }) | |
| if not planned: | |
| print(f"[LANGGRAPH] No new tools to execute") | |
| return { | |
| "planned_tools": planned, | |
| "iteration": iteration, | |
| "reflection_gaps": [], # Clear after use | |
| "events": events | |
| } | |
| async def execute_node(state: GraphState) -> Dict: | |
| """EXECUTE: Run planned tools and extract facts.""" | |
| print(f"[LANGGRAPH] execute_node - {len(state['planned_tools'])} tools") | |
| events = [] | |
| new_facts = [] | |
| new_executed = set() | |
| chart_data = None | |
| # Build AgentState for execute_and_extract | |
| agent_state = AgentState( | |
| patient_id=state["patient_id"], | |
| question=state["question"], | |
| manifest=state["manifest"], | |
| executed_tools=state.get("executed_tools", set()), | |
| iteration=state.get("iteration", 1), | |
| max_iterations=state.get("max_iterations", 3) | |
| ) | |
| for tool_call in state["planned_tools"]: | |
| tool_name = tool_call.get("tool", "unknown") | |
| tool_args = tool_call.get("args", {}) | |
| reason = tool_call.get("reason", "") | |
| if tool_name in state.get("executed_tools", set()): | |
| continue | |
| events.append({"type": "status", "message": f"Retrieving {tool_name}..."}) | |
| events.append({"type": "tool_call", "tool": tool_name, "args": tool_args, "reason": reason}) | |
| try: | |
| fact_result = await execute_and_extract(agent_state, tool_call) | |
| new_facts.append(fact_result) | |
| new_executed.add(tool_name) | |
| events.append({ | |
| "type": "tool_result", | |
| "tool": tool_name, | |
| "facts": fact_result.get("facts", ""), | |
| "raw_preview": fact_result.get("raw_data", "")[:200] | |
| }) | |
| # Check for chart data | |
| if agent_state.chart_data: | |
| chart_data = agent_state.chart_data | |
| events.append({"type": "chart_data", "data": chart_data}) | |
| agent_state.chart_data = None | |
| except Exception as e: | |
| print(f"[LANGGRAPH] Tool error {tool_name}: {e}") | |
| events.append({"type": "tool_error", "tool": tool_name, "error": str(e)}) | |
| new_executed.add(tool_name) | |
| return { | |
| "collected_facts": new_facts, | |
| "executed_tools": new_executed, | |
| "chart_data": chart_data, | |
| "events": events | |
| } | |
| async def reflect_node(state: GraphState) -> Dict: | |
| """REFLECT: Evaluate if collected data is sufficient.""" | |
| iteration = state.get("iteration", 1) | |
| max_iter = state.get("max_iterations", 3) | |
| print(f"[LANGGRAPH] reflect_node - iteration {iteration}/{max_iter}") | |
| events = [] | |
| events.append({"type": "status", "message": "Reflecting on gathered information..."}) | |
| # Build AgentState for reflect_on_facts | |
| agent_state = AgentState( | |
| patient_id=state["patient_id"], | |
| question=state["question"], | |
| collected_facts=state.get("collected_facts", []), | |
| executed_tools=state.get("executed_tools", set()), | |
| iteration=iteration, | |
| max_iterations=max_iter | |
| ) | |
| reflection = await reflect_on_facts(agent_state) | |
| has_enough = reflection.get("has_enough_info", True) | |
| confidence = reflection.get("confidence", 0.8) | |
| gaps = reflection.get("gaps", []) | |
| events.append({ | |
| "type": "reflection", | |
| "has_enough_info": has_enough, | |
| "confidence": confidence, | |
| "gaps": gaps, | |
| "reasoning": reflection.get("reasoning", ""), | |
| "iteration": iteration | |
| }) | |
| should_continue = not has_enough and iteration < max_iter | |
| if has_enough: | |
| print(f"[LANGGRAPH] Reflection: Have enough info (confidence: {confidence})") | |
| else: | |
| print(f"[LANGGRAPH] Reflection: Need more info. Gaps: {gaps}") | |
| return { | |
| "reflection_gaps": gaps if should_continue else [], | |
| "should_continue": should_continue, | |
| "events": events | |
| } | |
| async def synthesize_node(state: GraphState) -> Dict: | |
| """SYNTHESIZE: Mark ready for synthesis (actual streaming happens in runner).""" | |
| print(f"[LANGGRAPH] synthesize_node - marking ready") | |
| # Don't buffer tokens here - the runner will handle streaming directly | |
| return { | |
| "final_answer": "__READY_FOR_SYNTHESIS__", | |
| "events": [{"type": "status", "message": "Generating answer..."}] | |
| } | |
| # ============================================================================= | |
| # Conditional Edge Functions | |
| # ============================================================================= | |
| def should_execute(state: GraphState) -> Literal["execute", "synthesize"]: | |
| """After plan: execute tools or skip to synthesis.""" | |
| if state.get("planned_tools"): | |
| return "execute" | |
| return "synthesize" | |
| def should_loop_or_finish(state: GraphState) -> Literal["plan", "synthesize"]: | |
| """After reflect: loop back to plan or proceed to synthesis.""" | |
| if state.get("should_continue", False): | |
| return "plan" | |
| return "synthesize" | |
| # ============================================================================= | |
| # Build the Graph | |
| # ============================================================================= | |
| def create_med_agent_graph(): | |
| """ | |
| Create and compile the MedGemma agent graph. | |
| Graph: | |
| START -> discover -> skin_analysis -> plan -> [execute or synthesize] | |
| ^ | | |
| | v | |
| | reflect -> [plan or synthesize] | |
| |___________| | |
| """ | |
| graph = StateGraph(GraphState) | |
| # Add nodes | |
| graph.add_node("discover", discover_node) | |
| graph.add_node("skin_analysis", skin_analysis_node) | |
| graph.add_node("plan", plan_node) | |
| graph.add_node("execute", execute_node) | |
| graph.add_node("reflect", reflect_node) | |
| graph.add_node("synthesize", synthesize_node) | |
| # Edges: START -> discover -> skin_analysis -> plan | |
| graph.add_edge(START, "discover") | |
| graph.add_edge("discover", "skin_analysis") | |
| graph.add_edge("skin_analysis", "plan") | |
| # Conditional: plan -> execute (if tools) or synthesize (if none) | |
| graph.add_conditional_edges( | |
| "plan", | |
| should_execute, | |
| {"execute": "execute", "synthesize": "synthesize"} | |
| ) | |
| # execute -> reflect | |
| graph.add_edge("execute", "reflect") | |
| # Conditional: reflect -> plan (loop back) or synthesize (done) | |
| graph.add_conditional_edges( | |
| "reflect", | |
| should_loop_or_finish, | |
| {"plan": "plan", "synthesize": "synthesize"} | |
| ) | |
| # synthesize -> END | |
| graph.add_edge("synthesize", END) | |
| # Compile with checkpointer | |
| checkpointer = MemorySaver() | |
| compiled = graph.compile(checkpointer=checkpointer) | |
| print("[LANGGRAPH] Graph compiled successfully") | |
| return compiled | |
| # Global graph instance (lazy init) | |
| _GRAPH = None | |
| def get_graph(): | |
| """Get or create the compiled graph.""" | |
| global _GRAPH | |
| if _GRAPH is None: | |
| _GRAPH = create_med_agent_graph() | |
| return _GRAPH | |
| # ============================================================================= | |
| # Main Runner - Streams events to the UI | |
| # ============================================================================= | |
| async def run_agent_langgraph( | |
| patient_id: str, | |
| question: str, | |
| skin_image_data: Optional[str] = None, | |
| conversation_history: Optional[List[Dict]] = None, | |
| thread_id: str = None | |
| ) -> AsyncGenerator[Dict, None]: | |
| """ | |
| Run the LangGraph agent and yield SSE events for the UI. | |
| The graph handles discover→skin→plan→execute→reflect loop. | |
| Synthesis is streamed directly (not buffered) for real-time typing effect. | |
| """ | |
| graph = get_graph() | |
| # Initial state | |
| initial_state = { | |
| "patient_id": patient_id, | |
| "question": question, | |
| "skin_image_data": skin_image_data, | |
| "conversation_history": conversation_history or [], | |
| "manifest": {}, | |
| "planned_tools": [], | |
| "collected_facts": [], | |
| "executed_tools": set(), | |
| "chart_data": None, | |
| "skin_analysis_result": None, | |
| "skin_llm_prompt": None, | |
| "reflection_gaps": [], | |
| "should_continue": True, | |
| "iteration": 0, | |
| "max_iterations": 3, | |
| "events": [], | |
| "final_answer": "" | |
| } | |
| config = {"configurable": {"thread_id": thread_id or f"thread_{patient_id}_{id(question)}"}} | |
| final_state = None | |
| try: | |
| # Stream through graph - yields events from each node in real-time | |
| async for step in graph.astream(initial_state, config, stream_mode="updates"): | |
| for node_name, state_update in step.items(): | |
| node_events = state_update.get("events", []) | |
| for event in node_events: | |
| yield event | |
| # Capture state for synthesis | |
| if node_name == "synthesize": | |
| final_state = state_update | |
| # Now get the full final state from the graph for synthesis | |
| # We need the accumulated state (collected_facts, manifest, etc.) | |
| graph_state = graph.get_state(config) | |
| full_state = graph_state.values if graph_state else initial_state | |
| # Stream synthesis tokens directly (not buffered!) | |
| yield {"type": "answer_start", "content": ""} | |
| agent_state = AgentState( | |
| patient_id=full_state.get("patient_id", patient_id), | |
| question=full_state.get("question", question), | |
| conversation_history=full_state.get("conversation_history", []), | |
| manifest=full_state.get("manifest", {}), | |
| collected_facts=full_state.get("collected_facts", []), | |
| skin_image_data=full_state.get("skin_image_data"), | |
| skin_analysis_result=full_state.get("skin_analysis_result"), | |
| skin_llm_prompt=full_state.get("skin_llm_prompt"), | |
| iteration=full_state.get("iteration", 1) | |
| ) | |
| async for token in synthesize_answer(agent_state): | |
| yield {"type": "token", "content": token} | |
| yield {"type": "answer_end", "content": ""} | |
| yield { | |
| "type": "workflow_complete", | |
| "iterations": full_state.get("iteration", 1), | |
| "tools_executed": list(full_state.get("executed_tools", set())), | |
| } | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| yield {"type": "error", "message": f"Agent error: {str(e)}"} | |
| # ============================================================================= | |
| # Simple Interface for Testing | |
| # ============================================================================= | |
| async def run_agent_langgraph_simple(patient_id: str, question: str) -> str: | |
| """Simple interface - returns just the final answer.""" | |
| answer = "" | |
| async for event in run_agent_langgraph(patient_id, question): | |
| if event["type"] == "token": | |
| answer += event["content"] | |
| elif event["type"] == "error": | |
| answer = f"Error: {event['message']}" | |
| return answer | |
| # ============================================================================= | |
| # Visualization | |
| # ============================================================================= | |
| def get_graph_visualization(): | |
| """Get Mermaid diagram of the workflow graph.""" | |
| graph = get_graph() | |
| try: | |
| return graph.get_graph().draw_mermaid() | |
| except Exception: | |
| return """ | |
| graph TD | |
| START --> discover | |
| discover --> skin_analysis | |
| skin_analysis --> plan | |
| plan -->|has tools| execute | |
| plan -->|no tools| synthesize | |
| execute --> reflect | |
| reflect -->|gaps found| plan | |
| reflect -->|enough info| synthesize | |
| synthesize --> END | |
| """ | |
| if __name__ == "__main__": | |
| print("MedGemma LangGraph Agent") | |
| print("=" * 50) | |
| print("\nWorkflow Graph (Mermaid):\n") | |
| print(get_graph_visualization()) |