Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| MedGemma Agent with Tool Calling | |
| Simple agent loop that: | |
| 1. Receives a question | |
| 2. Decides which tools to call | |
| 3. Executes tools and gathers results | |
| 4. Synthesizes a final answer | |
| """ | |
| import os | |
| import json | |
| import re | |
| from typing import AsyncGenerator, Optional, Dict | |
| import httpx | |
| from tools import get_tools_description, execute_tool | |
| LLAMA_SERVER_URL = os.getenv("LLAMA_SERVER_URL", "http://localhost:8081") | |
| MAX_STEPS = 5 # Max tool calls per question | |
| # Headers for LLM requests (ngrok requires this to skip browser warning) | |
| LLM_HEADERS = { | |
| "Content-Type": "application/json", | |
| "ngrok-skip-browser-warning": "true" | |
| } | |
| def build_system_prompt(patient_id: str) -> str: | |
| """Build the system prompt with tool descriptions.""" | |
| tools_desc = get_tools_description() | |
| return f"""You are MedGemma, a helpful medical AI assistant with access to a patient's health records. | |
| Patient ID: {patient_id} | |
| {tools_desc} | |
| HOW TO USE TOOLS: | |
| When you need information, respond with a tool call in this format: | |
| TOOL_CALL: {{"tool": "tool_name", "args": {{"param1": "value1"}}}} | |
| WHEN TO USE TOOLS vs ANSWER DIRECTLY: | |
| - USE TOOLS when user asks about THEIR specific data: "show MY blood pressure", "what are MY medications" | |
| - ANSWER DIRECTLY for general health questions: "is walking good?", "what is diabetes?", "how does aspirin work?" | |
| - You can combine both: get patient data THEN provide personalized advice | |
| CHART TOOL GUIDELINES: | |
| - Use get_vital_chart_data for VITALS: blood pressure, heart rate, weight, temperature, oxygen | |
| - Use get_lab_chart_data for LABS: cholesterol, A1c, glucose, kidney function | |
| - Use these chart tools when user asks to "show", "display", "graph", "trend", or "visualize" | |
| EXAMPLES: | |
| - "Show my blood pressure" → get_vital_chart_data with vital_type="blood_pressure" | |
| - "Show my cholesterol" → get_lab_chart_data with lab_type="cholesterol" | |
| - "How is my A1c trending?" → get_lab_chart_data with lab_type="a1c" | |
| - "Is walking good for health?" → ANSWER directly (general knowledge) | |
| - "Is walking good for MY heart given my conditions?" → get_conditions, then synthesize answer | |
| GENERAL GUIDELINES: | |
| 1. Use get_recent_vitals or get_lab_results for TEXT summaries only | |
| 2. Use chart tools for any visual/trend/graph request | |
| 3. Be specific - include numbers, dates, and medication names | |
| 4. For general health questions, you can answer from medical knowledge | |
| 5. Remind users to consult their healthcare provider for medical decisions | |
| When ready to give your final answer, start with "ANSWER:" followed by your response.""" | |
| def build_prompt(system: str, question: str, history: list) -> str: | |
| """Build the full prompt.""" | |
| prompt = f"""<start_of_turn>user | |
| {system} | |
| Question: {question} | |
| <end_of_turn> | |
| """ | |
| for entry in history: | |
| if entry["role"] == "assistant": | |
| prompt += f"<start_of_turn>model\n{entry['content']}\n<end_of_turn>\n" | |
| elif entry["role"] == "tool_result": | |
| prompt += f"<start_of_turn>user\nTool result ({entry['tool']}):\n{entry['content']}\n\nContinue or provide your ANSWER:\n<end_of_turn>\n" | |
| prompt += "<start_of_turn>model\n" | |
| return prompt | |
| def parse_tool_call(text: str) -> Optional[Dict]: | |
| """Extract tool call from response.""" | |
| # Format 1: TOOL_CALL: {...} | |
| match = re.search(r'TOOL_CALL:\s*(\{.*)', text, re.IGNORECASE | re.DOTALL) | |
| if match: | |
| try: | |
| json_str = match.group(1) | |
| brace_count = 0 | |
| end_idx = 0 | |
| for i, char in enumerate(json_str): | |
| if char == '{': | |
| brace_count += 1 | |
| elif char == '}': | |
| brace_count -= 1 | |
| if brace_count == 0: | |
| end_idx = i + 1 | |
| break | |
| if end_idx > 0: | |
| return json.loads(json_str[:end_idx]) | |
| except json.JSONDecodeError: | |
| pass | |
| # Format 2: ```tool_call\n{...}\n``` or ```tool\n{...}\n``` | |
| match = re.search(r'```(?:tool_call|tool)\s*\n?\s*(\{.*?\})\s*\n?```', text, re.IGNORECASE | re.DOTALL) | |
| if match: | |
| try: | |
| return json.loads(match.group(1)) | |
| except json.JSONDecodeError: | |
| pass | |
| # Format 3: ```json\n{"tool":...}\n``` - find last occurrence (in case of thinking) | |
| matches = re.findall(r'```json\s*\n?\s*(\{[^`]*\})\s*\n?```', text, re.IGNORECASE | re.DOTALL) | |
| for m in reversed(matches): # Check from last to first | |
| try: | |
| parsed = json.loads(m) | |
| if "tool" in parsed and "args" in parsed: | |
| return parsed | |
| except json.JSONDecodeError: | |
| pass | |
| # Format 4: Just find any JSON with "tool" and "args" keys | |
| # Use a more flexible pattern | |
| for match in re.finditer(r'\{\s*"tool"\s*:\s*"([^"]+)"\s*,\s*"args"\s*:\s*(\{[^}]*\})\s*\}', text): | |
| try: | |
| return json.loads(match.group(0)) | |
| except json.JSONDecodeError: | |
| pass | |
| return None | |
| def extract_answer(text: str) -> str: | |
| """Extract final answer from response.""" | |
| # Look for ANSWER: prefix | |
| for marker in ["ANSWER:", "Answer:", "FINAL ANSWER:", "Final Answer:"]: | |
| if marker in text: | |
| idx = text.find(marker) | |
| return text[idx + len(marker):].strip() | |
| return text.strip() | |
| def has_answer(text: str) -> bool: | |
| """Check if response contains a final answer.""" | |
| markers = ["ANSWER:", "Answer:", "FINAL ANSWER:", "Final Answer:"] | |
| return any(m in text for m in markers) | |
| def filter_thinking(text: str) -> str: | |
| """Remove thinking blocks from text.""" | |
| # Remove <think>...</think> blocks | |
| text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL) | |
| # Remove "thought ..." at the start (MedGemma sometimes outputs this) | |
| # Keep everything from TOOL_CALL: or ANSWER: onwards | |
| if text.lower().strip().startswith('thought'): | |
| # Find where the actual content starts | |
| tool_match = re.search(r'(TOOL_CALL:.*)', text, re.IGNORECASE | re.DOTALL) | |
| answer_match = re.search(r'(ANSWER:.*)', text, re.IGNORECASE | re.DOTALL) | |
| if tool_match: | |
| text = tool_match.group(1) | |
| elif answer_match: | |
| text = answer_match.group(1) | |
| return text.strip() | |
| async def call_llm(prompt: str) -> str: | |
| """Call LLM and get response (non-streaming).""" | |
| async with httpx.AsyncClient(timeout=300.0) as client: | |
| response = await client.post( | |
| f"{LLAMA_SERVER_URL}/completion", | |
| headers=LLM_HEADERS, | |
| json={ | |
| "prompt": prompt, | |
| "n_predict": 1024, | |
| "temperature": 0.7, | |
| "stop": ["<end_of_turn>", "</s>", "<|im_end|>"], | |
| "stream": False | |
| } | |
| ) | |
| response.raise_for_status() | |
| result = response.json() | |
| return result.get("content", "").strip() | |
| async def stream_llm(prompt: str) -> AsyncGenerator[str, None]: | |
| """Stream LLM response token by token.""" | |
| async with httpx.AsyncClient(timeout=300.0) as client: | |
| async with client.stream( | |
| "POST", | |
| f"{LLAMA_SERVER_URL}/completion", | |
| headers=LLM_HEADERS, | |
| json={ | |
| "prompt": prompt, | |
| "n_predict": 1024, | |
| "temperature": 0.7, | |
| "stop": ["<end_of_turn>", "</s>", "<|im_end|>"], | |
| "stream": True | |
| } | |
| ) as response: | |
| async for line in response.aiter_lines(): | |
| if line.startswith("data: "): | |
| data = line[6:] | |
| if data.strip() == "[DONE]": | |
| break | |
| try: | |
| chunk = json.loads(data) | |
| content = chunk.get("content", "") | |
| if content: | |
| yield content | |
| except json.JSONDecodeError: | |
| pass | |
| async def run_agent(patient_id: str, question: str) -> AsyncGenerator[dict, None]: | |
| """ | |
| Run the agent loop with streaming support. | |
| Yields events: | |
| - {"type": "status", "message": "..."} | |
| - {"type": "tool_call", "tool": "...", "args": {...}} | |
| - {"type": "tool_result", "tool": "...", "result": "..."} | |
| - {"type": "chart_data", "data": {...}} | |
| - {"type": "answer_start"} | |
| - {"type": "token", "content": "..."} | |
| - {"type": "answer_end"} | |
| - {"type": "error", "message": "..."} | |
| """ | |
| system = build_system_prompt(patient_id) | |
| history = [] | |
| yield {"type": "status", "message": "Analyzing your question..."} | |
| for step in range(MAX_STEPS): | |
| prompt = build_prompt(system, question, history) | |
| # Stream the response and detect tool calls vs answers | |
| full_response = "" | |
| is_tool_call = False | |
| is_streaming_answer = False | |
| try: | |
| async for token in stream_llm(prompt): | |
| full_response += token | |
| # Check for tool call patterns anywhere in response | |
| has_tool_json = ('"tool"' in full_response and '"args"' in full_response) | |
| has_tool_marker = ("TOOL_CALL:" in full_response or | |
| "```tool" in full_response.lower() or | |
| "```json" in full_response.lower() or | |
| has_tool_json) | |
| # If we see tool patterns, keep buffering until JSON is complete | |
| if has_tool_marker: | |
| if full_response.count('{') > 0 and full_response.count('{') == full_response.count('}'): | |
| is_tool_call = True | |
| break | |
| continue # Keep buffering | |
| # Check for PARTIAL tool markers - keep buffering | |
| stripped = full_response.strip().upper() | |
| if stripped.startswith("TOOL") or stripped.startswith("`"): | |
| continue # Wait for more tokens | |
| # Check for thinking patterns - keep buffering | |
| thinking_patterns = ["thought", "thinking", "let me", "i need to", "i will", "step 1", "1."] | |
| has_thinking = any(p in full_response.lower()[:200] for p in thinking_patterns) | |
| if has_thinking: | |
| # Model is thinking - keep buffering until we see what it decides | |
| # But set a limit to avoid infinite buffering | |
| if len(full_response) < 2000: | |
| continue | |
| # No tool call or thinking patterns - stream as direct answer | |
| if "ANSWER:" in full_response: | |
| if not is_streaming_answer: | |
| is_streaming_answer = True | |
| yield {"type": "answer_start", "content": ""} | |
| answer_part = full_response.split("ANSWER:", 1)[1] | |
| if answer_part.strip(): | |
| yield {"type": "token", "content": answer_part} | |
| else: | |
| yield {"type": "token", "content": token} | |
| else: | |
| # Direct answer without ANSWER: prefix | |
| if not is_streaming_answer: | |
| is_streaming_answer = True | |
| yield {"type": "answer_start", "content": ""} | |
| yield {"type": "token", "content": full_response} | |
| else: | |
| yield {"type": "token", "content": token} | |
| except Exception as e: | |
| yield {"type": "error", "message": f"LLM error: {str(e)}"} | |
| return | |
| # If we were streaming an answer, we're done | |
| if is_streaming_answer: | |
| yield {"type": "answer_end", "content": ""} | |
| return | |
| # Handle tool call | |
| full_response = filter_thinking(full_response) | |
| tool_call = parse_tool_call(full_response) | |
| if tool_call: | |
| tool_name = tool_call.get("tool", "") | |
| tool_args = tool_call.get("args", {}) | |
| if "patient_id" not in tool_args: | |
| tool_args["patient_id"] = patient_id | |
| yield {"type": "tool_call", "tool": tool_name, "args": tool_args} | |
| # Execute tool | |
| result = execute_tool(tool_name, tool_args) | |
| # For chart tools, return immediately | |
| if tool_name in ["get_vital_chart_data", "get_lab_chart_data", "compare_before_after_treatment"]: | |
| try: | |
| parsed = json.loads(result) | |
| if "chart_type" in parsed and "error" not in parsed: | |
| yield {"type": "chart_data", "data": parsed} | |
| chart_title = parsed.get("title", "chart") | |
| if "summary" in parsed: | |
| summary_text = "\n".join(parsed["summary"]) | |
| yield {"type": "answer_start", "content": ""} | |
| yield {"type": "token", "content": f"Here's your {chart_title.lower()}.\n\n**Changes:** {summary_text}\n\nDiscuss these results with your healthcare provider."} | |
| yield {"type": "answer_end", "content": ""} | |
| else: | |
| yield {"type": "answer_start", "content": ""} | |
| yield {"type": "token", "content": f"Here's your {chart_title.lower()}. If you notice any concerning patterns, please discuss with your healthcare provider."} | |
| yield {"type": "answer_end", "content": ""} | |
| return | |
| except: | |
| pass | |
| # Show tool result | |
| display_result = result[:500] + "..." if len(result) > 500 else result | |
| yield {"type": "tool_result", "tool": tool_name, "result": display_result} | |
| # Add to history | |
| history_result = result[:300] + "\n... [truncated]" if len(result) > 300 else result | |
| history.append({"role": "assistant", "content": full_response}) | |
| history.append({"role": "tool_result", "tool": tool_name, "content": history_result}) | |
| else: | |
| # No tool call detected - treat response as answer | |
| yield {"type": "answer_start", "content": ""} | |
| yield {"type": "token", "content": filter_thinking(full_response)} | |
| yield {"type": "answer_end", "content": ""} | |
| return | |
| # Max steps reached - stream final answer | |
| yield {"type": "status", "message": "Generating final answer..."} | |
| prompt = build_prompt(system, question, history) | |
| prompt += "\nProvide your ANSWER now based on the information gathered:" | |
| try: | |
| yield {"type": "answer_start", "content": ""} | |
| async for token in stream_llm(prompt): | |
| # Skip thinking blocks and ANSWER: prefix | |
| yield {"type": "token", "content": token} | |
| yield {"type": "answer_end", "content": ""} | |
| except Exception as e: | |
| yield {"type": "error", "message": f"Failed to generate answer: {str(e)}"} | |
| async def run_agent_simple(patient_id: str, question: str) -> str: | |
| """Simple interface - returns just the final answer.""" | |
| answer = "" | |
| async for event in run_agent(patient_id, question): | |
| if event["type"] == "answer": | |
| answer = event["content"] | |
| elif event["type"] == "error": | |
| answer = f"Error: {event['message']}" | |
| return answer | |