frabbani
Fix fact extraction - pass raw data for simple tools.................,nk,
08c1d46
#!/usr/bin/env python3
"""
MedGemma Agent with Tool Calling
Simple agent loop that:
1. Receives a question
2. Decides which tools to call
3. Executes tools and gathers results
4. Synthesizes a final answer
"""
import os
import json
import re
from typing import AsyncGenerator, Optional, Dict
import httpx
from tools import get_tools_description, execute_tool
LLAMA_SERVER_URL = os.getenv("LLAMA_SERVER_URL", "http://localhost:8081")
MAX_STEPS = 5 # Max tool calls per question
# Headers for LLM requests (ngrok requires this to skip browser warning)
LLM_HEADERS = {
"Content-Type": "application/json",
"ngrok-skip-browser-warning": "true"
}
def build_system_prompt(patient_id: str) -> str:
"""Build the system prompt with tool descriptions."""
tools_desc = get_tools_description()
return f"""You are MedGemma, a helpful medical AI assistant with access to a patient's health records.
Patient ID: {patient_id}
{tools_desc}
HOW TO USE TOOLS:
When you need information, respond with a tool call in this format:
TOOL_CALL: {{"tool": "tool_name", "args": {{"param1": "value1"}}}}
WHEN TO USE TOOLS vs ANSWER DIRECTLY:
- USE TOOLS when user asks about THEIR specific data: "show MY blood pressure", "what are MY medications"
- ANSWER DIRECTLY for general health questions: "is walking good?", "what is diabetes?", "how does aspirin work?"
- You can combine both: get patient data THEN provide personalized advice
CHART TOOL GUIDELINES:
- Use get_vital_chart_data for VITALS: blood pressure, heart rate, weight, temperature, oxygen
- Use get_lab_chart_data for LABS: cholesterol, A1c, glucose, kidney function
- Use these chart tools when user asks to "show", "display", "graph", "trend", or "visualize"
EXAMPLES:
- "Show my blood pressure" → get_vital_chart_data with vital_type="blood_pressure"
- "Show my cholesterol" → get_lab_chart_data with lab_type="cholesterol"
- "How is my A1c trending?" → get_lab_chart_data with lab_type="a1c"
- "Is walking good for health?" → ANSWER directly (general knowledge)
- "Is walking good for MY heart given my conditions?" → get_conditions, then synthesize answer
GENERAL GUIDELINES:
1. Use get_recent_vitals or get_lab_results for TEXT summaries only
2. Use chart tools for any visual/trend/graph request
3. Be specific - include numbers, dates, and medication names
4. For general health questions, you can answer from medical knowledge
5. Remind users to consult their healthcare provider for medical decisions
When ready to give your final answer, start with "ANSWER:" followed by your response."""
def build_prompt(system: str, question: str, history: list) -> str:
"""Build the full prompt."""
prompt = f"""<start_of_turn>user
{system}
Question: {question}
<end_of_turn>
"""
for entry in history:
if entry["role"] == "assistant":
prompt += f"<start_of_turn>model\n{entry['content']}\n<end_of_turn>\n"
elif entry["role"] == "tool_result":
prompt += f"<start_of_turn>user\nTool result ({entry['tool']}):\n{entry['content']}\n\nContinue or provide your ANSWER:\n<end_of_turn>\n"
prompt += "<start_of_turn>model\n"
return prompt
def parse_tool_call(text: str) -> Optional[Dict]:
"""Extract tool call from response."""
# Format 1: TOOL_CALL: {...}
match = re.search(r'TOOL_CALL:\s*(\{.*)', text, re.IGNORECASE | re.DOTALL)
if match:
try:
json_str = match.group(1)
brace_count = 0
end_idx = 0
for i, char in enumerate(json_str):
if char == '{':
brace_count += 1
elif char == '}':
brace_count -= 1
if brace_count == 0:
end_idx = i + 1
break
if end_idx > 0:
return json.loads(json_str[:end_idx])
except json.JSONDecodeError:
pass
# Format 2: ```tool_call\n{...}\n``` or ```tool\n{...}\n```
match = re.search(r'```(?:tool_call|tool)\s*\n?\s*(\{.*?\})\s*\n?```', text, re.IGNORECASE | re.DOTALL)
if match:
try:
return json.loads(match.group(1))
except json.JSONDecodeError:
pass
# Format 3: ```json\n{"tool":...}\n``` - find last occurrence (in case of thinking)
matches = re.findall(r'```json\s*\n?\s*(\{[^`]*\})\s*\n?```', text, re.IGNORECASE | re.DOTALL)
for m in reversed(matches): # Check from last to first
try:
parsed = json.loads(m)
if "tool" in parsed and "args" in parsed:
return parsed
except json.JSONDecodeError:
pass
# Format 4: Just find any JSON with "tool" and "args" keys
# Use a more flexible pattern
for match in re.finditer(r'\{\s*"tool"\s*:\s*"([^"]+)"\s*,\s*"args"\s*:\s*(\{[^}]*\})\s*\}', text):
try:
return json.loads(match.group(0))
except json.JSONDecodeError:
pass
return None
def extract_answer(text: str) -> str:
"""Extract final answer from response."""
# Look for ANSWER: prefix
for marker in ["ANSWER:", "Answer:", "FINAL ANSWER:", "Final Answer:"]:
if marker in text:
idx = text.find(marker)
return text[idx + len(marker):].strip()
return text.strip()
def has_answer(text: str) -> bool:
"""Check if response contains a final answer."""
markers = ["ANSWER:", "Answer:", "FINAL ANSWER:", "Final Answer:"]
return any(m in text for m in markers)
def filter_thinking(text: str) -> str:
"""Remove thinking blocks from text."""
# Remove <think>...</think> blocks
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
# Remove "thought ..." at the start (MedGemma sometimes outputs this)
# Keep everything from TOOL_CALL: or ANSWER: onwards
if text.lower().strip().startswith('thought'):
# Find where the actual content starts
tool_match = re.search(r'(TOOL_CALL:.*)', text, re.IGNORECASE | re.DOTALL)
answer_match = re.search(r'(ANSWER:.*)', text, re.IGNORECASE | re.DOTALL)
if tool_match:
text = tool_match.group(1)
elif answer_match:
text = answer_match.group(1)
return text.strip()
async def call_llm(prompt: str) -> str:
"""Call LLM and get response (non-streaming)."""
async with httpx.AsyncClient(timeout=300.0) as client:
response = await client.post(
f"{LLAMA_SERVER_URL}/completion",
headers=LLM_HEADERS,
json={
"prompt": prompt,
"n_predict": 1024,
"temperature": 0.7,
"stop": ["<end_of_turn>", "</s>", "<|im_end|>"],
"stream": False
}
)
response.raise_for_status()
result = response.json()
return result.get("content", "").strip()
async def stream_llm(prompt: str) -> AsyncGenerator[str, None]:
"""Stream LLM response token by token."""
async with httpx.AsyncClient(timeout=300.0) as client:
async with client.stream(
"POST",
f"{LLAMA_SERVER_URL}/completion",
headers=LLM_HEADERS,
json={
"prompt": prompt,
"n_predict": 1024,
"temperature": 0.7,
"stop": ["<end_of_turn>", "</s>", "<|im_end|>"],
"stream": True
}
) as response:
async for line in response.aiter_lines():
if line.startswith("data: "):
data = line[6:]
if data.strip() == "[DONE]":
break
try:
chunk = json.loads(data)
content = chunk.get("content", "")
if content:
yield content
except json.JSONDecodeError:
pass
async def run_agent(patient_id: str, question: str) -> AsyncGenerator[dict, None]:
"""
Run the agent loop with streaming support.
Yields events:
- {"type": "status", "message": "..."}
- {"type": "tool_call", "tool": "...", "args": {...}}
- {"type": "tool_result", "tool": "...", "result": "..."}
- {"type": "chart_data", "data": {...}}
- {"type": "answer_start"}
- {"type": "token", "content": "..."}
- {"type": "answer_end"}
- {"type": "error", "message": "..."}
"""
system = build_system_prompt(patient_id)
history = []
yield {"type": "status", "message": "Analyzing your question..."}
for step in range(MAX_STEPS):
prompt = build_prompt(system, question, history)
# Stream the response and detect tool calls vs answers
full_response = ""
is_tool_call = False
is_streaming_answer = False
try:
async for token in stream_llm(prompt):
full_response += token
# Check for tool call patterns anywhere in response
has_tool_json = ('"tool"' in full_response and '"args"' in full_response)
has_tool_marker = ("TOOL_CALL:" in full_response or
"```tool" in full_response.lower() or
"```json" in full_response.lower() or
has_tool_json)
# If we see tool patterns, keep buffering until JSON is complete
if has_tool_marker:
if full_response.count('{') > 0 and full_response.count('{') == full_response.count('}'):
is_tool_call = True
break
continue # Keep buffering
# Check for PARTIAL tool markers - keep buffering
stripped = full_response.strip().upper()
if stripped.startswith("TOOL") or stripped.startswith("`"):
continue # Wait for more tokens
# Check for thinking patterns - keep buffering
thinking_patterns = ["thought", "thinking", "let me", "i need to", "i will", "step 1", "1."]
has_thinking = any(p in full_response.lower()[:200] for p in thinking_patterns)
if has_thinking:
# Model is thinking - keep buffering until we see what it decides
# But set a limit to avoid infinite buffering
if len(full_response) < 2000:
continue
# No tool call or thinking patterns - stream as direct answer
if "ANSWER:" in full_response:
if not is_streaming_answer:
is_streaming_answer = True
yield {"type": "answer_start", "content": ""}
answer_part = full_response.split("ANSWER:", 1)[1]
if answer_part.strip():
yield {"type": "token", "content": answer_part}
else:
yield {"type": "token", "content": token}
else:
# Direct answer without ANSWER: prefix
if not is_streaming_answer:
is_streaming_answer = True
yield {"type": "answer_start", "content": ""}
yield {"type": "token", "content": full_response}
else:
yield {"type": "token", "content": token}
except Exception as e:
yield {"type": "error", "message": f"LLM error: {str(e)}"}
return
# If we were streaming an answer, we're done
if is_streaming_answer:
yield {"type": "answer_end", "content": ""}
return
# Handle tool call
full_response = filter_thinking(full_response)
tool_call = parse_tool_call(full_response)
if tool_call:
tool_name = tool_call.get("tool", "")
tool_args = tool_call.get("args", {})
if "patient_id" not in tool_args:
tool_args["patient_id"] = patient_id
yield {"type": "tool_call", "tool": tool_name, "args": tool_args}
# Execute tool
result = execute_tool(tool_name, tool_args)
# For chart tools, return immediately
if tool_name in ["get_vital_chart_data", "get_lab_chart_data", "compare_before_after_treatment"]:
try:
parsed = json.loads(result)
if "chart_type" in parsed and "error" not in parsed:
yield {"type": "chart_data", "data": parsed}
chart_title = parsed.get("title", "chart")
if "summary" in parsed:
summary_text = "\n".join(parsed["summary"])
yield {"type": "answer_start", "content": ""}
yield {"type": "token", "content": f"Here's your {chart_title.lower()}.\n\n**Changes:** {summary_text}\n\nDiscuss these results with your healthcare provider."}
yield {"type": "answer_end", "content": ""}
else:
yield {"type": "answer_start", "content": ""}
yield {"type": "token", "content": f"Here's your {chart_title.lower()}. If you notice any concerning patterns, please discuss with your healthcare provider."}
yield {"type": "answer_end", "content": ""}
return
except:
pass
# Show tool result
display_result = result[:500] + "..." if len(result) > 500 else result
yield {"type": "tool_result", "tool": tool_name, "result": display_result}
# Add to history
history_result = result[:300] + "\n... [truncated]" if len(result) > 300 else result
history.append({"role": "assistant", "content": full_response})
history.append({"role": "tool_result", "tool": tool_name, "content": history_result})
else:
# No tool call detected - treat response as answer
yield {"type": "answer_start", "content": ""}
yield {"type": "token", "content": filter_thinking(full_response)}
yield {"type": "answer_end", "content": ""}
return
# Max steps reached - stream final answer
yield {"type": "status", "message": "Generating final answer..."}
prompt = build_prompt(system, question, history)
prompt += "\nProvide your ANSWER now based on the information gathered:"
try:
yield {"type": "answer_start", "content": ""}
async for token in stream_llm(prompt):
# Skip thinking blocks and ANSWER: prefix
yield {"type": "token", "content": token}
yield {"type": "answer_end", "content": ""}
except Exception as e:
yield {"type": "error", "message": f"Failed to generate answer: {str(e)}"}
async def run_agent_simple(patient_id: str, question: str) -> str:
"""Simple interface - returns just the final answer."""
answer = ""
async for event in run_agent(patient_id, question):
if event["type"] == "answer":
answer = event["content"]
elif event["type"] == "error":
answer = f"Error: {event['message']}"
return answer