multimodal_previsit / agent_v3.py
frabbani
Fix fact extraction - pass raw data for simple tools..................,,m,
1173c31
#!/usr/bin/env python3
"""
MedGemma Agent v3 - Enhanced Agentic Architecture
Key improvements for Agentic Workflow Prize:
1. VISIBLE REASONING: Stream thought process to user
2. MULTI-STEP PLANNING: Explicit plan with dependencies
3. SELF-CORRECTION: Verify outputs and retry if needed
4. REFLECTION: Check if answer is complete before responding
Architecture:
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ User Query: "Prepare me for my appointment" β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ PHASE 1: DISCOVER β”‚
β”‚ β€’ Get patient manifest (what data exists) β”‚
β”‚ β€’ Identify patient context β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ PHASE 2: PLAN β”‚
β”‚ β€’ Analyze query intent β”‚
β”‚ β€’ Create ordered tool execution plan β”‚
β”‚ β€’ Identify dependencies between tools β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ PHASE 3: EXECUTE (with self-correction) β”‚
β”‚ β€’ Execute tools in planned order β”‚
β”‚ β€’ Verify each result is valid β”‚
β”‚ β€’ Retry with modified params if needed β”‚
β”‚ β€’ Extract relevant facts β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ PHASE 4: REFLECT β”‚
β”‚ β€’ Check if collected facts answer the question β”‚
β”‚ β€’ Identify any gaps β”‚
β”‚ β€’ Plan additional queries if needed β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ PHASE 5: SYNTHESIZE β”‚
β”‚ β€’ Combine all facts into coherent answer β”‚
β”‚ β€’ Add clinical context and recommendations β”‚
β”‚ β€’ Stream response to user β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
import os
import json
import re
import asyncio
from typing import AsyncGenerator, Optional, Dict, List, Any
from dataclasses import dataclass, field
from datetime import datetime
import httpx
from tools import get_tools_description, execute_tool, TOOLS, get_db
LLAMA_SERVER_URL = os.getenv("LLAMA_SERVER_URL", "http://localhost:8081")
LLM_HEADERS = {
"Content-Type": "application/json",
"ngrok-skip-browser-warning": "true"
}
# =============================================================================
# Agent State with Enhanced Tracking
# =============================================================================
@dataclass
class ReasoningStep:
"""A single step in the agent's reasoning chain."""
phase: str # discover, plan, execute, reflect, synthesize
action: str # What the agent is doing
result: Optional[str] = None # Outcome of this step
duration_ms: int = 0
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
@dataclass
class ToolExecution:
"""Record of a single tool execution."""
tool_name: str
args: Dict
reason: str
result: Optional[str] = None
facts_extracted: Optional[str] = None
success: bool = True
retry_count: int = 0
error: Optional[str] = None
@dataclass
class AgentState:
"""Enhanced agent state with full reasoning trace."""
patient_id: str
question: str
# Discovery
manifest: Dict = field(default_factory=dict)
# Planning
query_analysis: str = "" # LLM's understanding of the query
execution_plan: List[Dict] = field(default_factory=list)
# Execution
tool_executions: List[ToolExecution] = field(default_factory=list)
collected_facts: List[str] = field(default_factory=list)
chart_data: Optional[Dict] = None
# Reflection
completeness_check: str = ""
gaps_identified: List[str] = field(default_factory=list)
# Reasoning trace (for UI)
reasoning_trace: List[ReasoningStep] = field(default_factory=list)
# Output
final_answer: str = ""
error: Optional[str] = None
def add_reasoning(self, phase: str, action: str, result: str = None):
"""Add a step to the reasoning trace."""
self.reasoning_trace.append(ReasoningStep(
phase=phase,
action=action,
result=result
))
# =============================================================================
# LLM Helpers
# =============================================================================
async def call_llm(prompt: str, max_tokens: int = 1024, temperature: float = 0.3) -> str:
"""Call LLM and get response."""
async with httpx.AsyncClient(timeout=300.0) as client:
response = await client.post(
f"{LLAMA_SERVER_URL}/completion",
headers=LLM_HEADERS,
json={
"prompt": prompt,
"n_predict": max_tokens,
"temperature": temperature,
"stop": ["<end_of_turn>", "</s>", "<|im_end|>"],
"stream": False
}
)
response.raise_for_status()
result = response.json()
return result.get("content", "").strip()
async def stream_llm(prompt: str, max_tokens: int = 1024) -> AsyncGenerator[str, None]:
"""Stream LLM response token by token."""
async with httpx.AsyncClient(timeout=300.0) as client:
async with client.stream(
"POST",
f"{LLAMA_SERVER_URL}/completion",
headers=LLM_HEADERS,
json={
"prompt": prompt,
"n_predict": max_tokens,
"temperature": 0.7,
"stop": ["<end_of_turn>", "</s>", "<|im_end|>"],
"stream": True
}
) as response:
async for line in response.aiter_lines():
if line.startswith("data: "):
data = line[6:]
if data.strip() == "[DONE]":
break
try:
chunk = json.loads(data)
content = chunk.get("content", "")
if content:
yield content
except json.JSONDecodeError:
pass
def extract_json(text: str) -> Optional[Dict]:
"""Extract JSON from text."""
patterns = [
r'```json\s*\n?(.*?)\n?```',
r'```\s*\n?(.*?)\n?```',
r'\[.*\]',
r'\{[^{}]*\}'
]
for pattern in patterns:
matches = re.findall(pattern, text, re.DOTALL)
for match in matches:
try:
return json.loads(match if isinstance(match, str) else match)
except json.JSONDecodeError:
continue
try:
return json.loads(text.strip())
except json.JSONDecodeError:
return None
# =============================================================================
# PHASE 1: DISCOVER
# =============================================================================
def discover_patient_data(patient_id: str) -> Dict:
"""
Discover what data is available for this patient.
Returns a manifest of available data types and counts.
"""
conn = get_db()
manifest = {
"patient_info": {},
"available_data": {},
"sample_values": {},
"data_range": {}
}
try:
# Get patient info
cursor = conn.execute("SELECT * FROM patients WHERE id = ?", (patient_id,))
patient = cursor.fetchone()
if patient:
birth = datetime.strptime(patient["birth_date"], "%Y-%m-%d")
age = (datetime.now() - birth).days // 365
manifest["patient_info"] = {
"name": f"{patient['given_name']} {patient['family_name']}",
"age": age,
"gender": patient['gender']
}
# Count records in each category
tables = {
'conditions': 'Medical conditions/diagnoses',
'medications': 'Medications (active and historical)',
'observations': 'Vital signs and lab results',
'allergies': 'Known allergies',
'encounters': 'Healthcare visits',
'immunizations': 'Vaccinations',
'procedures': 'Medical procedures'
}
for table, description in tables.items():
try:
cursor = conn.execute(
f"SELECT COUNT(*) FROM {table} WHERE patient_id = ?",
(patient_id,)
)
count = cursor.fetchone()[0]
if count > 0:
manifest["available_data"][table] = {
"description": description,
"count": count
}
except:
pass
# Get sample conditions
cursor = conn.execute(
"SELECT display, clinical_status FROM conditions WHERE patient_id = ? LIMIT 5",
(patient_id,)
)
conditions = [f"{row['display']}" for row in cursor.fetchall()]
if conditions:
manifest["sample_values"]["conditions"] = conditions
# Get sample medications
cursor = conn.execute(
"SELECT display, status FROM medications WHERE patient_id = ? AND status = 'active' LIMIT 5",
(patient_id,)
)
medications = [row['display'] for row in cursor.fetchall()]
if medications:
manifest["sample_values"]["active_medications"] = medications
# Get available vital types
cursor = conn.execute("""
SELECT DISTINCT code_display FROM observations
WHERE patient_id = ? AND category = 'vital-signs'
""", (patient_id,))
vitals = [row['code_display'] for row in cursor.fetchall()]
if vitals:
manifest["sample_values"]["vital_types"] = vitals
# Get observation date range
cursor = conn.execute("""
SELECT MIN(effective_date) as earliest, MAX(effective_date) as latest
FROM observations WHERE patient_id = ?
""", (patient_id,))
obs_range = cursor.fetchone()
if obs_range and obs_range['earliest']:
manifest["data_range"] = {
"earliest_record": obs_range['earliest'][:10],
"latest_record": obs_range['latest'][:10]
}
except Exception as e:
manifest["error"] = str(e)
finally:
conn.close()
return manifest
# =============================================================================
# PHASE 2: PLAN
# =============================================================================
async def analyze_query(state: AgentState) -> str:
"""Analyze the user's query to understand intent."""
prompt = f"""<start_of_turn>user
Analyze this patient query and identify:
1. The primary intent (what does the patient want to know?)
2. What type of information is needed (vitals, medications, conditions, labs, etc.)
3. Any time constraints mentioned (last 30 days, recent, historical, etc.)
4. Priority level (urgent concern vs routine question)
Patient: {state.manifest.get('patient_info', {}).get('name', 'Unknown')}
Query: {state.question}
Provide a brief analysis (2-3 sentences).
<end_of_turn>
<start_of_turn>model
"""
analysis = await call_llm(prompt, max_tokens=200, temperature=0.3)
return analysis
async def create_execution_plan(state: AgentState) -> List[Dict]:
"""
Create a detailed execution plan with tool dependencies.
"""
tools_desc = get_tools_description()
manifest_text = json.dumps(state.manifest, indent=2)
prompt = f"""<start_of_turn>user
You are a medical AI planning how to gather information to answer a patient's question.
PATIENT DATA MANIFEST:
{manifest_text}
AVAILABLE TOOLS:
{tools_desc}
QUERY ANALYSIS: {state.query_analysis}
USER QUESTION: {state.question}
Create an execution plan. For each step, specify:
- "step": step number (1, 2, 3...)
- "tool": tool name to call
- "args": arguments for the tool
- "reason": why this step is needed
- "depends_on": list of step numbers this depends on (empty if none)
Rules:
1. Order steps logically (gather context first, then specifics)
2. For comprehensive queries, gather multiple data types
3. Include chart tools if visualization would help
4. For "prepare for appointment" queries, gather: conditions, medications, recent vitals, allergies
Output JSON array only:
<end_of_turn>
<start_of_turn>model
["""
response = await call_llm(prompt, max_tokens=800, temperature=0.3)
# Parse the plan
try:
plan = extract_json("[" + response)
if plan and isinstance(plan, list):
# Ensure patient_id is in all args
for step in plan:
if "args" in step:
step["args"]["patient_id"] = state.patient_id
return plan
except:
pass
# Fallback: simple plan based on query keywords
return create_fallback_plan(state)
def create_fallback_plan(state: AgentState) -> List[Dict]:
"""Create a basic plan if LLM planning fails."""
question_lower = state.question.lower()
plan = []
step = 1
# Comprehensive queries
if any(kw in question_lower for kw in ["appointment", "prepare", "summary", "overview", "everything"]):
plan = [
{"step": 1, "tool": "get_conditions", "args": {"patient_id": state.patient_id},
"reason": "Get all medical conditions", "depends_on": []},
{"step": 2, "tool": "get_medications", "args": {"patient_id": state.patient_id, "status": "active"},
"reason": "Get current medications", "depends_on": []},
{"step": 3, "tool": "get_allergies", "args": {"patient_id": state.patient_id},
"reason": "Check for allergies", "depends_on": []},
{"step": 4, "tool": "get_recent_vitals", "args": {"patient_id": state.patient_id},
"reason": "Get recent vital signs", "depends_on": []},
{"step": 5, "tool": "get_vital_chart_data", "args": {"patient_id": state.patient_id, "vital_type": "blood_pressure"},
"reason": "Visualize BP trends", "depends_on": [4]},
]
return plan
# Specific queries
if "medication" in question_lower or "drug" in question_lower or "prescription" in question_lower:
plan.append({"step": step, "tool": "get_medications",
"args": {"patient_id": state.patient_id, "status": "active"},
"reason": "User asked about medications", "depends_on": []})
step += 1
if "condition" in question_lower or "diagnosis" in question_lower or "disease" in question_lower:
plan.append({"step": step, "tool": "get_conditions",
"args": {"patient_id": state.patient_id},
"reason": "User asked about conditions", "depends_on": []})
step += 1
if "blood pressure" in question_lower or "bp" in question_lower:
plan.append({"step": step, "tool": "get_vital_chart_data",
"args": {"patient_id": state.patient_id, "vital_type": "blood_pressure"},
"reason": "User asked about blood pressure", "depends_on": []})
step += 1
if "heart rate" in question_lower or "pulse" in question_lower:
plan.append({"step": step, "tool": "get_vital_chart_data",
"args": {"patient_id": state.patient_id, "vital_type": "heart_rate"},
"reason": "User asked about heart rate", "depends_on": []})
step += 1
if "weight" in question_lower:
plan.append({"step": step, "tool": "get_vital_chart_data",
"args": {"patient_id": state.patient_id, "vital_type": "weight"},
"reason": "User asked about weight", "depends_on": []})
step += 1
if "lab" in question_lower or "test" in question_lower or "result" in question_lower:
plan.append({"step": step, "tool": "get_recent_labs",
"args": {"patient_id": state.patient_id},
"reason": "User asked about labs", "depends_on": []})
step += 1
if "allerg" in question_lower:
plan.append({"step": step, "tool": "get_allergies",
"args": {"patient_id": state.patient_id},
"reason": "User asked about allergies", "depends_on": []})
step += 1
# Default: get summary if no specific match
if not plan:
plan.append({"step": 1, "tool": "get_patient_summary",
"args": {"patient_id": state.patient_id},
"reason": "Get general patient overview", "depends_on": []})
return plan
# =============================================================================
# PHASE 3: EXECUTE (with self-correction)
# =============================================================================
async def execute_with_retry(
tool_name: str,
args: Dict,
reason: str,
max_retries: int = 2
) -> ToolExecution:
"""Execute a tool with retry logic on failure."""
execution = ToolExecution(
tool_name=tool_name,
args=args,
reason=reason
)
for attempt in range(max_retries + 1):
try:
result = execute_tool(tool_name, args)
execution.result = result
# Validate result
if result and "error" not in result.lower()[:100]:
execution.success = True
return execution
# If error in result, try with modified params
if attempt < max_retries:
execution.retry_count += 1
# Modify args for retry (e.g., expand date range)
if "days" in args:
args["days"] = args.get("days", 30) * 2
except Exception as e:
execution.error = str(e)
if attempt < max_retries:
execution.retry_count += 1
await asyncio.sleep(0.5) # Brief delay before retry
execution.success = False
return execution
async def extract_facts_from_result(
tool_name: str,
result: str,
question: str,
state: AgentState
) -> str:
"""Extract relevant facts from a tool's output."""
# Handle chart tools specially
chart_tools = ["get_vital_chart_data", "get_lab_panel_chart", "get_lab_trend_chart"]
if tool_name in chart_tools:
try:
parsed = json.loads(result)
if "chart_type" in parsed:
state.chart_data = parsed
# Extract summary
summary = parsed.get("summary", "")
if summary:
return f"Chart: {parsed.get('title', tool_name)}\n{summary}"
return f"Chart generated: {parsed.get('title', tool_name)}"
except:
pass
# For simple data tools, return as-is
simple_tools = [
"get_patient_summary", "get_conditions", "get_medications",
"get_allergies", "get_immunizations", "get_encounters",
"get_procedures", "get_recent_vitals", "get_recent_labs"
]
if tool_name in simple_tools:
return result
# For complex results, use LLM to extract relevant facts
prompt = f"""<start_of_turn>user
Extract the key facts from this tool output that are relevant to: {question}
Tool: {tool_name}
Output: {result[:2000]}
Provide a concise summary of relevant facts (2-4 sentences).
<end_of_turn>
<start_of_turn>model
"""
facts = await call_llm(prompt, max_tokens=300, temperature=0.3)
return facts
async def execute_plan(state: AgentState) -> None:
"""Execute all planned tools in order, respecting dependencies."""
completed_steps = set()
for step in state.execution_plan:
step_num = step.get("step", 0)
depends_on = step.get("depends_on", [])
# Wait for dependencies
for dep in depends_on:
if dep not in completed_steps:
# Dependency not met - in a real system we'd handle this better
pass
tool_name = step["tool"]
args = step["args"]
reason = step.get("reason", "")
state.add_reasoning(
"execute",
f"Calling {tool_name}: {reason}",
)
# Execute with retry
execution = await execute_with_retry(tool_name, args, reason)
state.tool_executions.append(execution)
if execution.success and execution.result:
# Extract facts
facts = await extract_facts_from_result(
tool_name,
execution.result,
state.question,
state
)
execution.facts_extracted = facts
state.collected_facts.append(f"[{tool_name}] {facts}")
state.add_reasoning(
"execute",
f"Extracted facts from {tool_name}",
result=facts[:200] + "..." if len(facts) > 200 else facts
)
else:
state.add_reasoning(
"execute",
f"Tool {tool_name} failed",
result=execution.error or "Unknown error"
)
completed_steps.add(step_num)
# =============================================================================
# PHASE 4: REFLECT
# =============================================================================
async def reflect_on_completeness(state: AgentState) -> bool:
"""
Check if we have enough information to answer the question.
Returns True if complete, False if more data needed.
"""
facts_summary = "\n".join(state.collected_facts[:10])
prompt = f"""<start_of_turn>user
You are checking if enough information has been gathered to answer a patient's question.
QUESTION: {state.question}
COLLECTED FACTS:
{facts_summary}
Evaluate:
1. Do we have enough information to provide a helpful answer? (yes/no)
2. What key information might be missing? (list any gaps)
Format your response as:
COMPLETE: yes/no
GAPS: [list any missing information, or "none" if complete]
<end_of_turn>
<start_of_turn>model
"""
response = await call_llm(prompt, max_tokens=200, temperature=0.3)
# Parse response
state.completeness_check = response
is_complete = "complete: yes" in response.lower()
# Extract gaps
if "gaps:" in response.lower():
gaps_section = response.lower().split("gaps:")[1].strip()
if gaps_section and "none" not in gaps_section:
state.gaps_identified = [g.strip() for g in gaps_section.split(",")]
state.add_reasoning(
"reflect",
"Checking if information is complete",
result=f"Complete: {is_complete}, Gaps: {state.gaps_identified}"
)
return is_complete
# =============================================================================
# PHASE 5: SYNTHESIZE
# =============================================================================
async def synthesize_answer(state: AgentState) -> AsyncGenerator[str, None]:
"""
Synthesize all collected facts into a comprehensive, conversational answer.
Streams the response.
"""
patient_name = state.manifest.get("patient_info", {}).get("name", "there")
facts_text = "\n\n".join(state.collected_facts)
prompt = f"""<start_of_turn>user
You are a helpful medical AI assistant. Using the collected information, provide a clear and helpful response to the patient.
PATIENT: {patient_name}
QUESTION: {state.question}
COLLECTED INFORMATION:
{facts_text}
Guidelines:
1. Address the patient by name
2. Directly answer their question
3. Highlight any concerning trends or important findings
4. Use the exact statistics provided (min, max, avg, count) - do NOT compute your own
5. Suggest follow-up actions if appropriate
6. Keep the tone warm but professional
7. If showing charts, reference what they display
Provide your response:
<end_of_turn>
<start_of_turn>model
"""
async for token in stream_llm(prompt, max_tokens=1500):
yield token
# =============================================================================
# Main Agent Entry Point
# =============================================================================
async def run_agent_v3(
patient_id: str,
question: str,
stream_reasoning: bool = True
) -> AsyncGenerator[Dict, None]:
"""
Run the enhanced agentic workflow.
Yields events that can be:
- {"type": "reasoning", "phase": "...", "action": "...", "result": "..."}
- {"type": "plan", "steps": [...]}
- {"type": "chart", "data": {...}}
- {"type": "token", "content": "..."}
- {"type": "done", "trace": [...]}
"""
state = AgentState(patient_id=patient_id, question=question)
# =========================================================================
# PHASE 1: DISCOVER
# =========================================================================
state.add_reasoning("discover", "Analyzing available patient data...")
if stream_reasoning:
yield {"type": "reasoning", "phase": "discover", "action": "Analyzing available patient data..."}
state.manifest = discover_patient_data(patient_id)
patient_name = state.manifest.get("patient_info", {}).get("name", "Unknown")
data_types = list(state.manifest.get("available_data", {}).keys())
state.add_reasoning(
"discover",
f"Found patient: {patient_name}",
result=f"Available data: {', '.join(data_types)}"
)
if stream_reasoning:
yield {
"type": "reasoning",
"phase": "discover",
"action": f"Found patient: {patient_name}",
"result": f"Available data: {', '.join(data_types)}"
}
# =========================================================================
# PHASE 2: PLAN
# =========================================================================
state.add_reasoning("plan", "Analyzing your question...")
if stream_reasoning:
yield {"type": "reasoning", "phase": "plan", "action": "Analyzing your question..."}
state.query_analysis = await analyze_query(state)
state.add_reasoning("plan", "Creating execution plan...")
if stream_reasoning:
yield {"type": "reasoning", "phase": "plan", "action": "Creating execution plan..."}
state.execution_plan = await create_execution_plan(state)
# Stream the plan
plan_summary = [f"Step {s['step']}: {s['tool']} - {s.get('reason', '')}"
for s in state.execution_plan]
state.add_reasoning(
"plan",
f"Planned {len(state.execution_plan)} tool calls",
result="\n".join(plan_summary)
)
if stream_reasoning:
yield {
"type": "plan",
"steps": state.execution_plan,
"summary": plan_summary
}
# =========================================================================
# PHASE 3: EXECUTE
# =========================================================================
state.add_reasoning("execute", "Gathering information...")
if stream_reasoning:
yield {"type": "reasoning", "phase": "execute", "action": "Gathering information..."}
await execute_plan(state)
# Stream chart if we have one
if state.chart_data:
yield {"type": "chart", "data": state.chart_data}
# Summary of execution
successful = sum(1 for e in state.tool_executions if e.success)
state.add_reasoning(
"execute",
f"Completed {successful}/{len(state.tool_executions)} tool calls",
result=f"Collected {len(state.collected_facts)} fact sets"
)
if stream_reasoning:
yield {
"type": "reasoning",
"phase": "execute",
"action": f"Completed {successful}/{len(state.tool_executions)} tool calls",
"result": f"Collected {len(state.collected_facts)} fact sets"
}
# =========================================================================
# PHASE 4: REFLECT
# =========================================================================
state.add_reasoning("reflect", "Checking completeness...")
if stream_reasoning:
yield {"type": "reasoning", "phase": "reflect", "action": "Checking if information is complete..."}
is_complete = await reflect_on_completeness(state)
# If not complete and we have gaps, we could do additional queries here
# For now, we proceed with what we have
if not is_complete and state.gaps_identified:
state.add_reasoning(
"reflect",
"Some information may be incomplete",
result=f"Gaps: {', '.join(state.gaps_identified)}"
)
if stream_reasoning:
yield {
"type": "reasoning",
"phase": "reflect",
"action": "Note: Some information may be incomplete",
"result": f"Proceeding with available data"
}
# =========================================================================
# PHASE 5: SYNTHESIZE
# =========================================================================
state.add_reasoning("synthesize", "Generating response...")
if stream_reasoning:
yield {"type": "reasoning", "phase": "synthesize", "action": "Generating response..."}
full_response = ""
async for token in synthesize_answer(state):
full_response += token
yield {"type": "token", "content": token}
state.final_answer = full_response
# =========================================================================
# DONE
# =========================================================================
yield {
"type": "done",
"trace": [
{
"phase": step.phase,
"action": step.action,
"result": step.result,
"timestamp": step.timestamp
}
for step in state.reasoning_trace
],
"tool_calls": len(state.tool_executions),
"facts_collected": len(state.collected_facts)
}
# =============================================================================
# Convenience wrapper for backward compatibility
# =============================================================================
async def chat_with_agent_v3(
patient_id: str,
message: str
) -> AsyncGenerator[str, None]:
"""
Simple wrapper that yields SSE-formatted events.
Compatible with existing frontend.
"""
async for event in run_agent_v3(patient_id, message, stream_reasoning=True):
event_type = event.get("type")
if event_type == "reasoning":
# Send as a special reasoning event
yield f"data: {json.dumps({'type': 'reasoning', 'phase': event.get('phase'), 'action': event.get('action'), 'result': event.get('result')})}\n\n"
elif event_type == "plan":
yield f"data: {json.dumps({'type': 'plan', 'steps': event.get('summary', [])})}\n\n"
elif event_type == "chart":
yield f"data: {json.dumps({'type': 'chart', 'data': event.get('data')})}\n\n"
elif event_type == "token":
yield f"data: {json.dumps({'type': 'token', 'content': event.get('content')})}\n\n"
elif event_type == "done":
yield f"data: {json.dumps({'type': 'done', 'tool_calls': event.get('tool_calls'), 'facts': event.get('facts_collected')})}\n\n"
yield "data: [DONE]\n\n"
# =============================================================================
# Test
# =============================================================================
if __name__ == "__main__":
import asyncio
async def test():
patient_id = "test-patient-id"
question = "Prepare me for my appointment tomorrow"
print("=" * 60)
print("Testing Agent v3 with enhanced agentic features")
print("=" * 60)
async for event in run_agent_v3(patient_id, question):
event_type = event.get("type")
if event_type == "reasoning":
print(f"\nπŸ€” [{event['phase'].upper()}] {event['action']}")
if event.get('result'):
print(f" β†’ {event['result'][:100]}...")
elif event_type == "plan":
print(f"\nπŸ“‹ EXECUTION PLAN:")
for step in event.get('summary', []):
print(f" {step}")
elif event_type == "chart":
print(f"\nπŸ“Š Chart generated: {event['data'].get('title', 'Unknown')}")
elif event_type == "token":
print(event['content'], end="", flush=True)
elif event_type == "done":
print(f"\n\nβœ… Done! Tool calls: {event['tool_calls']}, Facts: {event['facts_collected']}")
asyncio.run(test())