#!/usr/bin/env python3 """ MedGemma Agent with Tool Calling Simple agent loop that: 1. Receives a question 2. Decides which tools to call 3. Executes tools and gathers results 4. Synthesizes a final answer """ import os import json import re from typing import AsyncGenerator, Optional, Dict import httpx from tools import get_tools_description, execute_tool LLAMA_SERVER_URL = os.getenv("LLAMA_SERVER_URL", "http://localhost:8081") MAX_STEPS = 5 # Max tool calls per question # Headers for LLM requests (ngrok requires this to skip browser warning) LLM_HEADERS = { "Content-Type": "application/json", "ngrok-skip-browser-warning": "true" } def build_system_prompt(patient_id: str) -> str: """Build the system prompt with tool descriptions.""" tools_desc = get_tools_description() return f"""You are MedGemma, a helpful medical AI assistant with access to a patient's health records. Patient ID: {patient_id} {tools_desc} HOW TO USE TOOLS: When you need information, respond with a tool call in this format: TOOL_CALL: {{"tool": "tool_name", "args": {{"param1": "value1"}}}} WHEN TO USE TOOLS vs ANSWER DIRECTLY: - USE TOOLS when user asks about THEIR specific data: "show MY blood pressure", "what are MY medications" - ANSWER DIRECTLY for general health questions: "is walking good?", "what is diabetes?", "how does aspirin work?" - You can combine both: get patient data THEN provide personalized advice CHART TOOL GUIDELINES: - Use get_vital_chart_data for VITALS: blood pressure, heart rate, weight, temperature, oxygen - Use get_lab_chart_data for LABS: cholesterol, A1c, glucose, kidney function - Use these chart tools when user asks to "show", "display", "graph", "trend", or "visualize" EXAMPLES: - "Show my blood pressure" → get_vital_chart_data with vital_type="blood_pressure" - "Show my cholesterol" → get_lab_chart_data with lab_type="cholesterol" - "How is my A1c trending?" → get_lab_chart_data with lab_type="a1c" - "Is walking good for health?" → ANSWER directly (general knowledge) - "Is walking good for MY heart given my conditions?" → get_conditions, then synthesize answer GENERAL GUIDELINES: 1. Use get_recent_vitals or get_lab_results for TEXT summaries only 2. Use chart tools for any visual/trend/graph request 3. Be specific - include numbers, dates, and medication names 4. For general health questions, you can answer from medical knowledge 5. Remind users to consult their healthcare provider for medical decisions When ready to give your final answer, start with "ANSWER:" followed by your response.""" def build_prompt(system: str, question: str, history: list) -> str: """Build the full prompt.""" prompt = f"""user {system} Question: {question} """ for entry in history: if entry["role"] == "assistant": prompt += f"model\n{entry['content']}\n\n" elif entry["role"] == "tool_result": prompt += f"user\nTool result ({entry['tool']}):\n{entry['content']}\n\nContinue or provide your ANSWER:\n\n" prompt += "model\n" return prompt def parse_tool_call(text: str) -> Optional[Dict]: """Extract tool call from response.""" # Format 1: TOOL_CALL: {...} match = re.search(r'TOOL_CALL:\s*(\{.*)', text, re.IGNORECASE | re.DOTALL) if match: try: json_str = match.group(1) brace_count = 0 end_idx = 0 for i, char in enumerate(json_str): if char == '{': brace_count += 1 elif char == '}': brace_count -= 1 if brace_count == 0: end_idx = i + 1 break if end_idx > 0: return json.loads(json_str[:end_idx]) except json.JSONDecodeError: pass # Format 2: ```tool_call\n{...}\n``` or ```tool\n{...}\n``` match = re.search(r'```(?:tool_call|tool)\s*\n?\s*(\{.*?\})\s*\n?```', text, re.IGNORECASE | re.DOTALL) if match: try: return json.loads(match.group(1)) except json.JSONDecodeError: pass # Format 3: ```json\n{"tool":...}\n``` - find last occurrence (in case of thinking) matches = re.findall(r'```json\s*\n?\s*(\{[^`]*\})\s*\n?```', text, re.IGNORECASE | re.DOTALL) for m in reversed(matches): # Check from last to first try: parsed = json.loads(m) if "tool" in parsed and "args" in parsed: return parsed except json.JSONDecodeError: pass # Format 4: Just find any JSON with "tool" and "args" keys # Use a more flexible pattern for match in re.finditer(r'\{\s*"tool"\s*:\s*"([^"]+)"\s*,\s*"args"\s*:\s*(\{[^}]*\})\s*\}', text): try: return json.loads(match.group(0)) except json.JSONDecodeError: pass return None def extract_answer(text: str) -> str: """Extract final answer from response.""" # Look for ANSWER: prefix for marker in ["ANSWER:", "Answer:", "FINAL ANSWER:", "Final Answer:"]: if marker in text: idx = text.find(marker) return text[idx + len(marker):].strip() return text.strip() def has_answer(text: str) -> bool: """Check if response contains a final answer.""" markers = ["ANSWER:", "Answer:", "FINAL ANSWER:", "Final Answer:"] return any(m in text for m in markers) def filter_thinking(text: str) -> str: """Remove thinking blocks from text.""" # Remove ... blocks text = re.sub(r'.*?', '', text, flags=re.DOTALL) # Remove "thought ..." at the start (MedGemma sometimes outputs this) # Keep everything from TOOL_CALL: or ANSWER: onwards if text.lower().strip().startswith('thought'): # Find where the actual content starts tool_match = re.search(r'(TOOL_CALL:.*)', text, re.IGNORECASE | re.DOTALL) answer_match = re.search(r'(ANSWER:.*)', text, re.IGNORECASE | re.DOTALL) if tool_match: text = tool_match.group(1) elif answer_match: text = answer_match.group(1) return text.strip() async def call_llm(prompt: str) -> str: """Call LLM and get response (non-streaming).""" async with httpx.AsyncClient(timeout=300.0) as client: response = await client.post( f"{LLAMA_SERVER_URL}/completion", headers=LLM_HEADERS, json={ "prompt": prompt, "n_predict": 1024, "temperature": 0.7, "stop": ["", "", "<|im_end|>"], "stream": False } ) response.raise_for_status() result = response.json() return result.get("content", "").strip() async def stream_llm(prompt: str) -> AsyncGenerator[str, None]: """Stream LLM response token by token.""" async with httpx.AsyncClient(timeout=300.0) as client: async with client.stream( "POST", f"{LLAMA_SERVER_URL}/completion", headers=LLM_HEADERS, json={ "prompt": prompt, "n_predict": 1024, "temperature": 0.7, "stop": ["", "", "<|im_end|>"], "stream": True } ) as response: async for line in response.aiter_lines(): if line.startswith("data: "): data = line[6:] if data.strip() == "[DONE]": break try: chunk = json.loads(data) content = chunk.get("content", "") if content: yield content except json.JSONDecodeError: pass async def run_agent(patient_id: str, question: str) -> AsyncGenerator[dict, None]: """ Run the agent loop with streaming support. Yields events: - {"type": "status", "message": "..."} - {"type": "tool_call", "tool": "...", "args": {...}} - {"type": "tool_result", "tool": "...", "result": "..."} - {"type": "chart_data", "data": {...}} - {"type": "answer_start"} - {"type": "token", "content": "..."} - {"type": "answer_end"} - {"type": "error", "message": "..."} """ system = build_system_prompt(patient_id) history = [] yield {"type": "status", "message": "Analyzing your question..."} for step in range(MAX_STEPS): prompt = build_prompt(system, question, history) # Stream the response and detect tool calls vs answers full_response = "" is_tool_call = False is_streaming_answer = False try: async for token in stream_llm(prompt): full_response += token # Check for tool call patterns anywhere in response has_tool_json = ('"tool"' in full_response and '"args"' in full_response) has_tool_marker = ("TOOL_CALL:" in full_response or "```tool" in full_response.lower() or "```json" in full_response.lower() or has_tool_json) # If we see tool patterns, keep buffering until JSON is complete if has_tool_marker: if full_response.count('{') > 0 and full_response.count('{') == full_response.count('}'): is_tool_call = True break continue # Keep buffering # Check for PARTIAL tool markers - keep buffering stripped = full_response.strip().upper() if stripped.startswith("TOOL") or stripped.startswith("`"): continue # Wait for more tokens # Check for thinking patterns - keep buffering thinking_patterns = ["thought", "thinking", "let me", "i need to", "i will", "step 1", "1."] has_thinking = any(p in full_response.lower()[:200] for p in thinking_patterns) if has_thinking: # Model is thinking - keep buffering until we see what it decides # But set a limit to avoid infinite buffering if len(full_response) < 2000: continue # No tool call or thinking patterns - stream as direct answer if "ANSWER:" in full_response: if not is_streaming_answer: is_streaming_answer = True yield {"type": "answer_start", "content": ""} answer_part = full_response.split("ANSWER:", 1)[1] if answer_part.strip(): yield {"type": "token", "content": answer_part} else: yield {"type": "token", "content": token} else: # Direct answer without ANSWER: prefix if not is_streaming_answer: is_streaming_answer = True yield {"type": "answer_start", "content": ""} yield {"type": "token", "content": full_response} else: yield {"type": "token", "content": token} except Exception as e: yield {"type": "error", "message": f"LLM error: {str(e)}"} return # If we were streaming an answer, we're done if is_streaming_answer: yield {"type": "answer_end", "content": ""} return # Handle tool call full_response = filter_thinking(full_response) tool_call = parse_tool_call(full_response) if tool_call: tool_name = tool_call.get("tool", "") tool_args = tool_call.get("args", {}) if "patient_id" not in tool_args: tool_args["patient_id"] = patient_id yield {"type": "tool_call", "tool": tool_name, "args": tool_args} # Execute tool result = execute_tool(tool_name, tool_args) # For chart tools, return immediately if tool_name in ["get_vital_chart_data", "get_lab_chart_data", "compare_before_after_treatment"]: try: parsed = json.loads(result) if "chart_type" in parsed and "error" not in parsed: yield {"type": "chart_data", "data": parsed} chart_title = parsed.get("title", "chart") if "summary" in parsed: summary_text = "\n".join(parsed["summary"]) yield {"type": "answer_start", "content": ""} yield {"type": "token", "content": f"Here's your {chart_title.lower()}.\n\n**Changes:** {summary_text}\n\nDiscuss these results with your healthcare provider."} yield {"type": "answer_end", "content": ""} else: yield {"type": "answer_start", "content": ""} yield {"type": "token", "content": f"Here's your {chart_title.lower()}. If you notice any concerning patterns, please discuss with your healthcare provider."} yield {"type": "answer_end", "content": ""} return except: pass # Show tool result display_result = result[:500] + "..." if len(result) > 500 else result yield {"type": "tool_result", "tool": tool_name, "result": display_result} # Add to history history_result = result[:300] + "\n... [truncated]" if len(result) > 300 else result history.append({"role": "assistant", "content": full_response}) history.append({"role": "tool_result", "tool": tool_name, "content": history_result}) else: # No tool call detected - treat response as answer yield {"type": "answer_start", "content": ""} yield {"type": "token", "content": filter_thinking(full_response)} yield {"type": "answer_end", "content": ""} return # Max steps reached - stream final answer yield {"type": "status", "message": "Generating final answer..."} prompt = build_prompt(system, question, history) prompt += "\nProvide your ANSWER now based on the information gathered:" try: yield {"type": "answer_start", "content": ""} async for token in stream_llm(prompt): # Skip thinking blocks and ANSWER: prefix yield {"type": "token", "content": token} yield {"type": "answer_end", "content": ""} except Exception as e: yield {"type": "error", "message": f"Failed to generate answer: {str(e)}"} async def run_agent_simple(patient_id: str, question: str) -> str: """Simple interface - returns just the final answer.""" answer = "" async for event in run_agent(patient_id, question): if event["type"] == "answer": answer = event["content"] elif event["type"] == "error": answer = f"Error: {event['message']}" return answer