Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| MedGemma Agent v3 - Enhanced Agentic Architecture | |
| Key improvements for Agentic Workflow Prize: | |
| 1. VISIBLE REASONING: Stream thought process to user | |
| 2. MULTI-STEP PLANNING: Explicit plan with dependencies | |
| 3. SELF-CORRECTION: Verify outputs and retry if needed | |
| 4. REFLECTION: Check if answer is complete before responding | |
| Architecture: | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| β User Query: "Prepare me for my appointment" β | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| β | |
| βΌ | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| β PHASE 1: DISCOVER β | |
| β β’ Get patient manifest (what data exists) β | |
| β β’ Identify patient context β | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| β | |
| βΌ | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| β PHASE 2: PLAN β | |
| β β’ Analyze query intent β | |
| β β’ Create ordered tool execution plan β | |
| β β’ Identify dependencies between tools β | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| β | |
| βΌ | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| β PHASE 3: EXECUTE (with self-correction) β | |
| β β’ Execute tools in planned order β | |
| β β’ Verify each result is valid β | |
| β β’ Retry with modified params if needed β | |
| β β’ Extract relevant facts β | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| β | |
| βΌ | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| β PHASE 4: REFLECT β | |
| β β’ Check if collected facts answer the question β | |
| β β’ Identify any gaps β | |
| β β’ Plan additional queries if needed β | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| β | |
| βΌ | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| β PHASE 5: SYNTHESIZE β | |
| β β’ Combine all facts into coherent answer β | |
| β β’ Add clinical context and recommendations β | |
| β β’ Stream response to user β | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| """ | |
| import os | |
| import json | |
| import re | |
| import asyncio | |
| from typing import AsyncGenerator, Optional, Dict, List, Any | |
| from dataclasses import dataclass, field | |
| from datetime import datetime | |
| import httpx | |
| from tools import get_tools_description, execute_tool, TOOLS, get_db | |
| LLAMA_SERVER_URL = os.getenv("LLAMA_SERVER_URL", "http://localhost:8081") | |
| LLM_HEADERS = { | |
| "Content-Type": "application/json", | |
| "ngrok-skip-browser-warning": "true" | |
| } | |
| # ============================================================================= | |
| # Agent State with Enhanced Tracking | |
| # ============================================================================= | |
| class ReasoningStep: | |
| """A single step in the agent's reasoning chain.""" | |
| phase: str # discover, plan, execute, reflect, synthesize | |
| action: str # What the agent is doing | |
| result: Optional[str] = None # Outcome of this step | |
| duration_ms: int = 0 | |
| timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) | |
| class ToolExecution: | |
| """Record of a single tool execution.""" | |
| tool_name: str | |
| args: Dict | |
| reason: str | |
| result: Optional[str] = None | |
| facts_extracted: Optional[str] = None | |
| success: bool = True | |
| retry_count: int = 0 | |
| error: Optional[str] = None | |
| class AgentState: | |
| """Enhanced agent state with full reasoning trace.""" | |
| patient_id: str | |
| question: str | |
| # Discovery | |
| manifest: Dict = field(default_factory=dict) | |
| # Planning | |
| query_analysis: str = "" # LLM's understanding of the query | |
| execution_plan: List[Dict] = field(default_factory=list) | |
| # Execution | |
| tool_executions: List[ToolExecution] = field(default_factory=list) | |
| collected_facts: List[str] = field(default_factory=list) | |
| chart_data: Optional[Dict] = None | |
| # Reflection | |
| completeness_check: str = "" | |
| gaps_identified: List[str] = field(default_factory=list) | |
| # Reasoning trace (for UI) | |
| reasoning_trace: List[ReasoningStep] = field(default_factory=list) | |
| # Output | |
| final_answer: str = "" | |
| error: Optional[str] = None | |
| def add_reasoning(self, phase: str, action: str, result: str = None): | |
| """Add a step to the reasoning trace.""" | |
| self.reasoning_trace.append(ReasoningStep( | |
| phase=phase, | |
| action=action, | |
| result=result | |
| )) | |
| # ============================================================================= | |
| # LLM Helpers | |
| # ============================================================================= | |
| async def call_llm(prompt: str, max_tokens: int = 1024, temperature: float = 0.3) -> str: | |
| """Call LLM and get response.""" | |
| async with httpx.AsyncClient(timeout=300.0) as client: | |
| response = await client.post( | |
| f"{LLAMA_SERVER_URL}/completion", | |
| headers=LLM_HEADERS, | |
| json={ | |
| "prompt": prompt, | |
| "n_predict": max_tokens, | |
| "temperature": temperature, | |
| "stop": ["<end_of_turn>", "</s>", "<|im_end|>"], | |
| "stream": False | |
| } | |
| ) | |
| response.raise_for_status() | |
| result = response.json() | |
| return result.get("content", "").strip() | |
| async def stream_llm(prompt: str, max_tokens: int = 1024) -> AsyncGenerator[str, None]: | |
| """Stream LLM response token by token.""" | |
| async with httpx.AsyncClient(timeout=300.0) as client: | |
| async with client.stream( | |
| "POST", | |
| f"{LLAMA_SERVER_URL}/completion", | |
| headers=LLM_HEADERS, | |
| json={ | |
| "prompt": prompt, | |
| "n_predict": max_tokens, | |
| "temperature": 0.7, | |
| "stop": ["<end_of_turn>", "</s>", "<|im_end|>"], | |
| "stream": True | |
| } | |
| ) as response: | |
| async for line in response.aiter_lines(): | |
| if line.startswith("data: "): | |
| data = line[6:] | |
| if data.strip() == "[DONE]": | |
| break | |
| try: | |
| chunk = json.loads(data) | |
| content = chunk.get("content", "") | |
| if content: | |
| yield content | |
| except json.JSONDecodeError: | |
| pass | |
| def extract_json(text: str) -> Optional[Dict]: | |
| """Extract JSON from text.""" | |
| patterns = [ | |
| r'```json\s*\n?(.*?)\n?```', | |
| r'```\s*\n?(.*?)\n?```', | |
| r'\[.*\]', | |
| r'\{[^{}]*\}' | |
| ] | |
| for pattern in patterns: | |
| matches = re.findall(pattern, text, re.DOTALL) | |
| for match in matches: | |
| try: | |
| return json.loads(match if isinstance(match, str) else match) | |
| except json.JSONDecodeError: | |
| continue | |
| try: | |
| return json.loads(text.strip()) | |
| except json.JSONDecodeError: | |
| return None | |
| # ============================================================================= | |
| # PHASE 1: DISCOVER | |
| # ============================================================================= | |
| def discover_patient_data(patient_id: str) -> Dict: | |
| """ | |
| Discover what data is available for this patient. | |
| Returns a manifest of available data types and counts. | |
| """ | |
| conn = get_db() | |
| manifest = { | |
| "patient_info": {}, | |
| "available_data": {}, | |
| "sample_values": {}, | |
| "data_range": {} | |
| } | |
| try: | |
| # Get patient info | |
| cursor = conn.execute("SELECT * FROM patients WHERE id = ?", (patient_id,)) | |
| patient = cursor.fetchone() | |
| if patient: | |
| birth = datetime.strptime(patient["birth_date"], "%Y-%m-%d") | |
| age = (datetime.now() - birth).days // 365 | |
| manifest["patient_info"] = { | |
| "name": f"{patient['given_name']} {patient['family_name']}", | |
| "age": age, | |
| "gender": patient['gender'] | |
| } | |
| # Count records in each category | |
| tables = { | |
| 'conditions': 'Medical conditions/diagnoses', | |
| 'medications': 'Medications (active and historical)', | |
| 'observations': 'Vital signs and lab results', | |
| 'allergies': 'Known allergies', | |
| 'encounters': 'Healthcare visits', | |
| 'immunizations': 'Vaccinations', | |
| 'procedures': 'Medical procedures' | |
| } | |
| for table, description in tables.items(): | |
| try: | |
| cursor = conn.execute( | |
| f"SELECT COUNT(*) FROM {table} WHERE patient_id = ?", | |
| (patient_id,) | |
| ) | |
| count = cursor.fetchone()[0] | |
| if count > 0: | |
| manifest["available_data"][table] = { | |
| "description": description, | |
| "count": count | |
| } | |
| except: | |
| pass | |
| # Get sample conditions | |
| cursor = conn.execute( | |
| "SELECT display, clinical_status FROM conditions WHERE patient_id = ? LIMIT 5", | |
| (patient_id,) | |
| ) | |
| conditions = [f"{row['display']}" for row in cursor.fetchall()] | |
| if conditions: | |
| manifest["sample_values"]["conditions"] = conditions | |
| # Get sample medications | |
| cursor = conn.execute( | |
| "SELECT display, status FROM medications WHERE patient_id = ? AND status = 'active' LIMIT 5", | |
| (patient_id,) | |
| ) | |
| medications = [row['display'] for row in cursor.fetchall()] | |
| if medications: | |
| manifest["sample_values"]["active_medications"] = medications | |
| # Get available vital types | |
| cursor = conn.execute(""" | |
| SELECT DISTINCT code_display FROM observations | |
| WHERE patient_id = ? AND category = 'vital-signs' | |
| """, (patient_id,)) | |
| vitals = [row['code_display'] for row in cursor.fetchall()] | |
| if vitals: | |
| manifest["sample_values"]["vital_types"] = vitals | |
| # Get observation date range | |
| cursor = conn.execute(""" | |
| SELECT MIN(effective_date) as earliest, MAX(effective_date) as latest | |
| FROM observations WHERE patient_id = ? | |
| """, (patient_id,)) | |
| obs_range = cursor.fetchone() | |
| if obs_range and obs_range['earliest']: | |
| manifest["data_range"] = { | |
| "earliest_record": obs_range['earliest'][:10], | |
| "latest_record": obs_range['latest'][:10] | |
| } | |
| except Exception as e: | |
| manifest["error"] = str(e) | |
| finally: | |
| conn.close() | |
| return manifest | |
| # ============================================================================= | |
| # PHASE 2: PLAN | |
| # ============================================================================= | |
| async def analyze_query(state: AgentState) -> str: | |
| """Analyze the user's query to understand intent.""" | |
| prompt = f"""<start_of_turn>user | |
| Analyze this patient query and identify: | |
| 1. The primary intent (what does the patient want to know?) | |
| 2. What type of information is needed (vitals, medications, conditions, labs, etc.) | |
| 3. Any time constraints mentioned (last 30 days, recent, historical, etc.) | |
| 4. Priority level (urgent concern vs routine question) | |
| Patient: {state.manifest.get('patient_info', {}).get('name', 'Unknown')} | |
| Query: {state.question} | |
| Provide a brief analysis (2-3 sentences). | |
| <end_of_turn> | |
| <start_of_turn>model | |
| """ | |
| analysis = await call_llm(prompt, max_tokens=200, temperature=0.3) | |
| return analysis | |
| async def create_execution_plan(state: AgentState) -> List[Dict]: | |
| """ | |
| Create a detailed execution plan with tool dependencies. | |
| """ | |
| tools_desc = get_tools_description() | |
| manifest_text = json.dumps(state.manifest, indent=2) | |
| prompt = f"""<start_of_turn>user | |
| You are a medical AI planning how to gather information to answer a patient's question. | |
| PATIENT DATA MANIFEST: | |
| {manifest_text} | |
| AVAILABLE TOOLS: | |
| {tools_desc} | |
| QUERY ANALYSIS: {state.query_analysis} | |
| USER QUESTION: {state.question} | |
| Create an execution plan. For each step, specify: | |
| - "step": step number (1, 2, 3...) | |
| - "tool": tool name to call | |
| - "args": arguments for the tool | |
| - "reason": why this step is needed | |
| - "depends_on": list of step numbers this depends on (empty if none) | |
| Rules: | |
| 1. Order steps logically (gather context first, then specifics) | |
| 2. For comprehensive queries, gather multiple data types | |
| 3. Include chart tools if visualization would help | |
| 4. For "prepare for appointment" queries, gather: conditions, medications, recent vitals, allergies | |
| Output JSON array only: | |
| <end_of_turn> | |
| <start_of_turn>model | |
| [""" | |
| response = await call_llm(prompt, max_tokens=800, temperature=0.3) | |
| # Parse the plan | |
| try: | |
| plan = extract_json("[" + response) | |
| if plan and isinstance(plan, list): | |
| # Ensure patient_id is in all args | |
| for step in plan: | |
| if "args" in step: | |
| step["args"]["patient_id"] = state.patient_id | |
| return plan | |
| except: | |
| pass | |
| # Fallback: simple plan based on query keywords | |
| return create_fallback_plan(state) | |
| def create_fallback_plan(state: AgentState) -> List[Dict]: | |
| """Create a basic plan if LLM planning fails.""" | |
| question_lower = state.question.lower() | |
| plan = [] | |
| step = 1 | |
| # Comprehensive queries | |
| if any(kw in question_lower for kw in ["appointment", "prepare", "summary", "overview", "everything"]): | |
| plan = [ | |
| {"step": 1, "tool": "get_conditions", "args": {"patient_id": state.patient_id}, | |
| "reason": "Get all medical conditions", "depends_on": []}, | |
| {"step": 2, "tool": "get_medications", "args": {"patient_id": state.patient_id, "status": "active"}, | |
| "reason": "Get current medications", "depends_on": []}, | |
| {"step": 3, "tool": "get_allergies", "args": {"patient_id": state.patient_id}, | |
| "reason": "Check for allergies", "depends_on": []}, | |
| {"step": 4, "tool": "get_recent_vitals", "args": {"patient_id": state.patient_id}, | |
| "reason": "Get recent vital signs", "depends_on": []}, | |
| {"step": 5, "tool": "get_vital_chart_data", "args": {"patient_id": state.patient_id, "vital_type": "blood_pressure"}, | |
| "reason": "Visualize BP trends", "depends_on": [4]}, | |
| ] | |
| return plan | |
| # Specific queries | |
| if "medication" in question_lower or "drug" in question_lower or "prescription" in question_lower: | |
| plan.append({"step": step, "tool": "get_medications", | |
| "args": {"patient_id": state.patient_id, "status": "active"}, | |
| "reason": "User asked about medications", "depends_on": []}) | |
| step += 1 | |
| if "condition" in question_lower or "diagnosis" in question_lower or "disease" in question_lower: | |
| plan.append({"step": step, "tool": "get_conditions", | |
| "args": {"patient_id": state.patient_id}, | |
| "reason": "User asked about conditions", "depends_on": []}) | |
| step += 1 | |
| if "blood pressure" in question_lower or "bp" in question_lower: | |
| plan.append({"step": step, "tool": "get_vital_chart_data", | |
| "args": {"patient_id": state.patient_id, "vital_type": "blood_pressure"}, | |
| "reason": "User asked about blood pressure", "depends_on": []}) | |
| step += 1 | |
| if "heart rate" in question_lower or "pulse" in question_lower: | |
| plan.append({"step": step, "tool": "get_vital_chart_data", | |
| "args": {"patient_id": state.patient_id, "vital_type": "heart_rate"}, | |
| "reason": "User asked about heart rate", "depends_on": []}) | |
| step += 1 | |
| if "weight" in question_lower: | |
| plan.append({"step": step, "tool": "get_vital_chart_data", | |
| "args": {"patient_id": state.patient_id, "vital_type": "weight"}, | |
| "reason": "User asked about weight", "depends_on": []}) | |
| step += 1 | |
| if "lab" in question_lower or "test" in question_lower or "result" in question_lower: | |
| plan.append({"step": step, "tool": "get_recent_labs", | |
| "args": {"patient_id": state.patient_id}, | |
| "reason": "User asked about labs", "depends_on": []}) | |
| step += 1 | |
| if "allerg" in question_lower: | |
| plan.append({"step": step, "tool": "get_allergies", | |
| "args": {"patient_id": state.patient_id}, | |
| "reason": "User asked about allergies", "depends_on": []}) | |
| step += 1 | |
| # Default: get summary if no specific match | |
| if not plan: | |
| plan.append({"step": 1, "tool": "get_patient_summary", | |
| "args": {"patient_id": state.patient_id}, | |
| "reason": "Get general patient overview", "depends_on": []}) | |
| return plan | |
| # ============================================================================= | |
| # PHASE 3: EXECUTE (with self-correction) | |
| # ============================================================================= | |
| async def execute_with_retry( | |
| tool_name: str, | |
| args: Dict, | |
| reason: str, | |
| max_retries: int = 2 | |
| ) -> ToolExecution: | |
| """Execute a tool with retry logic on failure.""" | |
| execution = ToolExecution( | |
| tool_name=tool_name, | |
| args=args, | |
| reason=reason | |
| ) | |
| for attempt in range(max_retries + 1): | |
| try: | |
| result = execute_tool(tool_name, args) | |
| execution.result = result | |
| # Validate result | |
| if result and "error" not in result.lower()[:100]: | |
| execution.success = True | |
| return execution | |
| # If error in result, try with modified params | |
| if attempt < max_retries: | |
| execution.retry_count += 1 | |
| # Modify args for retry (e.g., expand date range) | |
| if "days" in args: | |
| args["days"] = args.get("days", 30) * 2 | |
| except Exception as e: | |
| execution.error = str(e) | |
| if attempt < max_retries: | |
| execution.retry_count += 1 | |
| await asyncio.sleep(0.5) # Brief delay before retry | |
| execution.success = False | |
| return execution | |
| async def extract_facts_from_result( | |
| tool_name: str, | |
| result: str, | |
| question: str, | |
| state: AgentState | |
| ) -> str: | |
| """Extract relevant facts from a tool's output.""" | |
| # Handle chart tools specially | |
| chart_tools = ["get_vital_chart_data", "get_lab_panel_chart", "get_lab_trend_chart"] | |
| if tool_name in chart_tools: | |
| try: | |
| parsed = json.loads(result) | |
| if "chart_type" in parsed: | |
| state.chart_data = parsed | |
| # Extract summary | |
| summary = parsed.get("summary", "") | |
| if summary: | |
| return f"Chart: {parsed.get('title', tool_name)}\n{summary}" | |
| return f"Chart generated: {parsed.get('title', tool_name)}" | |
| except: | |
| pass | |
| # For simple data tools, return as-is | |
| simple_tools = [ | |
| "get_patient_summary", "get_conditions", "get_medications", | |
| "get_allergies", "get_immunizations", "get_encounters", | |
| "get_procedures", "get_recent_vitals", "get_recent_labs" | |
| ] | |
| if tool_name in simple_tools: | |
| return result | |
| # For complex results, use LLM to extract relevant facts | |
| prompt = f"""<start_of_turn>user | |
| Extract the key facts from this tool output that are relevant to: {question} | |
| Tool: {tool_name} | |
| Output: {result[:2000]} | |
| Provide a concise summary of relevant facts (2-4 sentences). | |
| <end_of_turn> | |
| <start_of_turn>model | |
| """ | |
| facts = await call_llm(prompt, max_tokens=300, temperature=0.3) | |
| return facts | |
| async def execute_plan(state: AgentState) -> None: | |
| """Execute all planned tools in order, respecting dependencies.""" | |
| completed_steps = set() | |
| for step in state.execution_plan: | |
| step_num = step.get("step", 0) | |
| depends_on = step.get("depends_on", []) | |
| # Wait for dependencies | |
| for dep in depends_on: | |
| if dep not in completed_steps: | |
| # Dependency not met - in a real system we'd handle this better | |
| pass | |
| tool_name = step["tool"] | |
| args = step["args"] | |
| reason = step.get("reason", "") | |
| state.add_reasoning( | |
| "execute", | |
| f"Calling {tool_name}: {reason}", | |
| ) | |
| # Execute with retry | |
| execution = await execute_with_retry(tool_name, args, reason) | |
| state.tool_executions.append(execution) | |
| if execution.success and execution.result: | |
| # Extract facts | |
| facts = await extract_facts_from_result( | |
| tool_name, | |
| execution.result, | |
| state.question, | |
| state | |
| ) | |
| execution.facts_extracted = facts | |
| state.collected_facts.append(f"[{tool_name}] {facts}") | |
| state.add_reasoning( | |
| "execute", | |
| f"Extracted facts from {tool_name}", | |
| result=facts[:200] + "..." if len(facts) > 200 else facts | |
| ) | |
| else: | |
| state.add_reasoning( | |
| "execute", | |
| f"Tool {tool_name} failed", | |
| result=execution.error or "Unknown error" | |
| ) | |
| completed_steps.add(step_num) | |
| # ============================================================================= | |
| # PHASE 4: REFLECT | |
| # ============================================================================= | |
| async def reflect_on_completeness(state: AgentState) -> bool: | |
| """ | |
| Check if we have enough information to answer the question. | |
| Returns True if complete, False if more data needed. | |
| """ | |
| facts_summary = "\n".join(state.collected_facts[:10]) | |
| prompt = f"""<start_of_turn>user | |
| You are checking if enough information has been gathered to answer a patient's question. | |
| QUESTION: {state.question} | |
| COLLECTED FACTS: | |
| {facts_summary} | |
| Evaluate: | |
| 1. Do we have enough information to provide a helpful answer? (yes/no) | |
| 2. What key information might be missing? (list any gaps) | |
| Format your response as: | |
| COMPLETE: yes/no | |
| GAPS: [list any missing information, or "none" if complete] | |
| <end_of_turn> | |
| <start_of_turn>model | |
| """ | |
| response = await call_llm(prompt, max_tokens=200, temperature=0.3) | |
| # Parse response | |
| state.completeness_check = response | |
| is_complete = "complete: yes" in response.lower() | |
| # Extract gaps | |
| if "gaps:" in response.lower(): | |
| gaps_section = response.lower().split("gaps:")[1].strip() | |
| if gaps_section and "none" not in gaps_section: | |
| state.gaps_identified = [g.strip() for g in gaps_section.split(",")] | |
| state.add_reasoning( | |
| "reflect", | |
| "Checking if information is complete", | |
| result=f"Complete: {is_complete}, Gaps: {state.gaps_identified}" | |
| ) | |
| return is_complete | |
| # ============================================================================= | |
| # PHASE 5: SYNTHESIZE | |
| # ============================================================================= | |
| async def synthesize_answer(state: AgentState) -> AsyncGenerator[str, None]: | |
| """ | |
| Synthesize all collected facts into a comprehensive, conversational answer. | |
| Streams the response. | |
| """ | |
| patient_name = state.manifest.get("patient_info", {}).get("name", "there") | |
| facts_text = "\n\n".join(state.collected_facts) | |
| prompt = f"""<start_of_turn>user | |
| You are a helpful medical AI assistant. Using the collected information, provide a clear and helpful response to the patient. | |
| PATIENT: {patient_name} | |
| QUESTION: {state.question} | |
| COLLECTED INFORMATION: | |
| {facts_text} | |
| Guidelines: | |
| 1. Address the patient by name | |
| 2. Directly answer their question | |
| 3. Highlight any concerning trends or important findings | |
| 4. Use the exact statistics provided (min, max, avg, count) - do NOT compute your own | |
| 5. Suggest follow-up actions if appropriate | |
| 6. Keep the tone warm but professional | |
| 7. If showing charts, reference what they display | |
| Provide your response: | |
| <end_of_turn> | |
| <start_of_turn>model | |
| """ | |
| async for token in stream_llm(prompt, max_tokens=1500): | |
| yield token | |
| # ============================================================================= | |
| # Main Agent Entry Point | |
| # ============================================================================= | |
| async def run_agent_v3( | |
| patient_id: str, | |
| question: str, | |
| stream_reasoning: bool = True | |
| ) -> AsyncGenerator[Dict, None]: | |
| """ | |
| Run the enhanced agentic workflow. | |
| Yields events that can be: | |
| - {"type": "reasoning", "phase": "...", "action": "...", "result": "..."} | |
| - {"type": "plan", "steps": [...]} | |
| - {"type": "chart", "data": {...}} | |
| - {"type": "token", "content": "..."} | |
| - {"type": "done", "trace": [...]} | |
| """ | |
| state = AgentState(patient_id=patient_id, question=question) | |
| # ========================================================================= | |
| # PHASE 1: DISCOVER | |
| # ========================================================================= | |
| state.add_reasoning("discover", "Analyzing available patient data...") | |
| if stream_reasoning: | |
| yield {"type": "reasoning", "phase": "discover", "action": "Analyzing available patient data..."} | |
| state.manifest = discover_patient_data(patient_id) | |
| patient_name = state.manifest.get("patient_info", {}).get("name", "Unknown") | |
| data_types = list(state.manifest.get("available_data", {}).keys()) | |
| state.add_reasoning( | |
| "discover", | |
| f"Found patient: {patient_name}", | |
| result=f"Available data: {', '.join(data_types)}" | |
| ) | |
| if stream_reasoning: | |
| yield { | |
| "type": "reasoning", | |
| "phase": "discover", | |
| "action": f"Found patient: {patient_name}", | |
| "result": f"Available data: {', '.join(data_types)}" | |
| } | |
| # ========================================================================= | |
| # PHASE 2: PLAN | |
| # ========================================================================= | |
| state.add_reasoning("plan", "Analyzing your question...") | |
| if stream_reasoning: | |
| yield {"type": "reasoning", "phase": "plan", "action": "Analyzing your question..."} | |
| state.query_analysis = await analyze_query(state) | |
| state.add_reasoning("plan", "Creating execution plan...") | |
| if stream_reasoning: | |
| yield {"type": "reasoning", "phase": "plan", "action": "Creating execution plan..."} | |
| state.execution_plan = await create_execution_plan(state) | |
| # Stream the plan | |
| plan_summary = [f"Step {s['step']}: {s['tool']} - {s.get('reason', '')}" | |
| for s in state.execution_plan] | |
| state.add_reasoning( | |
| "plan", | |
| f"Planned {len(state.execution_plan)} tool calls", | |
| result="\n".join(plan_summary) | |
| ) | |
| if stream_reasoning: | |
| yield { | |
| "type": "plan", | |
| "steps": state.execution_plan, | |
| "summary": plan_summary | |
| } | |
| # ========================================================================= | |
| # PHASE 3: EXECUTE | |
| # ========================================================================= | |
| state.add_reasoning("execute", "Gathering information...") | |
| if stream_reasoning: | |
| yield {"type": "reasoning", "phase": "execute", "action": "Gathering information..."} | |
| await execute_plan(state) | |
| # Stream chart if we have one | |
| if state.chart_data: | |
| yield {"type": "chart", "data": state.chart_data} | |
| # Summary of execution | |
| successful = sum(1 for e in state.tool_executions if e.success) | |
| state.add_reasoning( | |
| "execute", | |
| f"Completed {successful}/{len(state.tool_executions)} tool calls", | |
| result=f"Collected {len(state.collected_facts)} fact sets" | |
| ) | |
| if stream_reasoning: | |
| yield { | |
| "type": "reasoning", | |
| "phase": "execute", | |
| "action": f"Completed {successful}/{len(state.tool_executions)} tool calls", | |
| "result": f"Collected {len(state.collected_facts)} fact sets" | |
| } | |
| # ========================================================================= | |
| # PHASE 4: REFLECT | |
| # ========================================================================= | |
| state.add_reasoning("reflect", "Checking completeness...") | |
| if stream_reasoning: | |
| yield {"type": "reasoning", "phase": "reflect", "action": "Checking if information is complete..."} | |
| is_complete = await reflect_on_completeness(state) | |
| # If not complete and we have gaps, we could do additional queries here | |
| # For now, we proceed with what we have | |
| if not is_complete and state.gaps_identified: | |
| state.add_reasoning( | |
| "reflect", | |
| "Some information may be incomplete", | |
| result=f"Gaps: {', '.join(state.gaps_identified)}" | |
| ) | |
| if stream_reasoning: | |
| yield { | |
| "type": "reasoning", | |
| "phase": "reflect", | |
| "action": "Note: Some information may be incomplete", | |
| "result": f"Proceeding with available data" | |
| } | |
| # ========================================================================= | |
| # PHASE 5: SYNTHESIZE | |
| # ========================================================================= | |
| state.add_reasoning("synthesize", "Generating response...") | |
| if stream_reasoning: | |
| yield {"type": "reasoning", "phase": "synthesize", "action": "Generating response..."} | |
| full_response = "" | |
| async for token in synthesize_answer(state): | |
| full_response += token | |
| yield {"type": "token", "content": token} | |
| state.final_answer = full_response | |
| # ========================================================================= | |
| # DONE | |
| # ========================================================================= | |
| yield { | |
| "type": "done", | |
| "trace": [ | |
| { | |
| "phase": step.phase, | |
| "action": step.action, | |
| "result": step.result, | |
| "timestamp": step.timestamp | |
| } | |
| for step in state.reasoning_trace | |
| ], | |
| "tool_calls": len(state.tool_executions), | |
| "facts_collected": len(state.collected_facts) | |
| } | |
| # ============================================================================= | |
| # Convenience wrapper for backward compatibility | |
| # ============================================================================= | |
| async def chat_with_agent_v3( | |
| patient_id: str, | |
| message: str | |
| ) -> AsyncGenerator[str, None]: | |
| """ | |
| Simple wrapper that yields SSE-formatted events. | |
| Compatible with existing frontend. | |
| """ | |
| async for event in run_agent_v3(patient_id, message, stream_reasoning=True): | |
| event_type = event.get("type") | |
| if event_type == "reasoning": | |
| # Send as a special reasoning event | |
| yield f"data: {json.dumps({'type': 'reasoning', 'phase': event.get('phase'), 'action': event.get('action'), 'result': event.get('result')})}\n\n" | |
| elif event_type == "plan": | |
| yield f"data: {json.dumps({'type': 'plan', 'steps': event.get('summary', [])})}\n\n" | |
| elif event_type == "chart": | |
| yield f"data: {json.dumps({'type': 'chart', 'data': event.get('data')})}\n\n" | |
| elif event_type == "token": | |
| yield f"data: {json.dumps({'type': 'token', 'content': event.get('content')})}\n\n" | |
| elif event_type == "done": | |
| yield f"data: {json.dumps({'type': 'done', 'tool_calls': event.get('tool_calls'), 'facts': event.get('facts_collected')})}\n\n" | |
| yield "data: [DONE]\n\n" | |
| # ============================================================================= | |
| # Test | |
| # ============================================================================= | |
| if __name__ == "__main__": | |
| import asyncio | |
| async def test(): | |
| patient_id = "test-patient-id" | |
| question = "Prepare me for my appointment tomorrow" | |
| print("=" * 60) | |
| print("Testing Agent v3 with enhanced agentic features") | |
| print("=" * 60) | |
| async for event in run_agent_v3(patient_id, question): | |
| event_type = event.get("type") | |
| if event_type == "reasoning": | |
| print(f"\nπ€ [{event['phase'].upper()}] {event['action']}") | |
| if event.get('result'): | |
| print(f" β {event['result'][:100]}...") | |
| elif event_type == "plan": | |
| print(f"\nπ EXECUTION PLAN:") | |
| for step in event.get('summary', []): | |
| print(f" {step}") | |
| elif event_type == "chart": | |
| print(f"\nπ Chart generated: {event['data'].get('title', 'Unknown')}") | |
| elif event_type == "token": | |
| print(event['content'], end="", flush=True) | |
| elif event_type == "done": | |
| print(f"\n\nβ Done! Tool calls: {event['tool_calls']}, Facts: {event['facts_collected']}") | |
| asyncio.run(test()) | |