Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Health data tools for the MedGemma agent. | |
| Enhanced with: | |
| 1. Dynamic Tool Registry - Tools can be registered/unregistered at runtime | |
| 2. Tool Search - Reduce token usage by searching for relevant tools | |
| 3. MCP Compatibility - Ready for Model Context Protocol integration | |
| These tools allow the LLM to query specific data rather than | |
| loading everything into context at once. | |
| """ | |
| import sqlite3 | |
| import json as json_module | |
| import statistics as stats_module | |
| from datetime import datetime, timedelta | |
| from typing import Optional, Callable, Dict, Any, List, Set | |
| from dataclasses import dataclass, field | |
| import os | |
| import re | |
| DB_PATH = os.getenv("DB_PATH", "data/fhir.db") | |
| def get_db(): | |
| """Get database connection.""" | |
| conn = sqlite3.connect(DB_PATH) | |
| conn.row_factory = sqlite3.Row | |
| return conn | |
| # ============================================================================= | |
| # DYNAMIC TOOL REGISTRY WITH MCP COMPATIBILITY | |
| # ============================================================================= | |
| class ToolParameter: | |
| """Definition of a tool parameter.""" | |
| name: str | |
| type: str # "string", "integer", "boolean", "array" | |
| description: str | |
| required: bool = True | |
| default: Any = None | |
| enum: List[str] = None # For constrained values | |
| class ToolDefinition: | |
| """ | |
| Complete definition of a tool. | |
| MCP-compatible structure for future integration. | |
| """ | |
| name: str | |
| description: str | |
| parameters: List[ToolParameter] | |
| handler: Callable | |
| category: str = "general" | |
| requires_patient_id: bool = True | |
| returns_chart: bool = False | |
| returns_json: bool = False | |
| version: str = "1.0.0" | |
| # MCP-specific fields | |
| mcp_compatible: bool = True | |
| mcp_annotations: Dict[str, Any] = field(default_factory=dict) | |
| def to_mcp_schema(self) -> Dict: | |
| """Convert to MCP tool schema format.""" | |
| properties = {} | |
| required = [] | |
| for param in self.parameters: | |
| prop = { | |
| "type": param.type, | |
| "description": param.description | |
| } | |
| if param.enum: | |
| prop["enum"] = param.enum | |
| if param.default is not None: | |
| prop["default"] = param.default | |
| properties[param.name] = prop | |
| if param.required: | |
| required.append(param.name) | |
| return { | |
| "name": self.name, | |
| "description": self.description, | |
| "inputSchema": { | |
| "type": "object", | |
| "properties": properties, | |
| "required": required | |
| }, | |
| "annotations": { | |
| "category": self.category, | |
| "returnsChart": self.returns_chart, | |
| "version": self.version, | |
| **self.mcp_annotations | |
| } | |
| } | |
| def to_llm_description(self, detail_level: str = "full") -> str: | |
| """Generate description for LLM context.""" | |
| if detail_level == "name_only": | |
| return self.name | |
| if detail_level == "brief": | |
| return f"- {self.name}: {self.description[:80]}..." | |
| # Full description | |
| params_str = ", ".join([ | |
| f"{p.name}: {p.description}" + (" (optional)" if not p.required else "") | |
| for p in self.parameters | |
| ]) | |
| return f"- {self.name}: {self.description}\n Parameters: {params_str}" | |
| class DynamicToolRegistry: | |
| """ | |
| Dynamic registry for tool management. | |
| Features: | |
| - Runtime tool registration/unregistration | |
| - Category-based filtering | |
| - Semantic search for relevant tools | |
| - MCP-compatible schema generation | |
| - Context-aware tool selection | |
| """ | |
| def __init__(self): | |
| self._tools: Dict[str, ToolDefinition] = {} | |
| self._categories: Dict[str, Set[str]] = {} | |
| self._keyword_index: Dict[str, Set[str]] = {} # keyword -> tool names | |
| def register(self, | |
| name: str, | |
| description: str, | |
| parameters: List[ToolParameter], | |
| category: str = "general", | |
| requires_patient_id: bool = True, | |
| returns_chart: bool = False, | |
| returns_json: bool = False, | |
| mcp_annotations: Dict = None): | |
| """ | |
| Decorator for registering tools dynamically. | |
| Usage: | |
| @registry.register( | |
| name="get_conditions", | |
| description="Get patient conditions", | |
| parameters=[ToolParameter("patient_id", "string", "Patient ID")], | |
| category="medical_records" | |
| ) | |
| def get_conditions(patient_id: str) -> str: | |
| ... | |
| """ | |
| def decorator(func: Callable) -> Callable: | |
| tool = ToolDefinition( | |
| name=name, | |
| description=description, | |
| parameters=parameters, | |
| handler=func, | |
| category=category, | |
| requires_patient_id=requires_patient_id, | |
| returns_chart=returns_chart, | |
| returns_json=returns_json, | |
| mcp_annotations=mcp_annotations or {} | |
| ) | |
| self._register_tool(tool) | |
| return func | |
| return decorator | |
| def _register_tool(self, tool: ToolDefinition): | |
| """Internal method to register a tool and update indices.""" | |
| self._tools[tool.name] = tool | |
| # Update category index | |
| if tool.category not in self._categories: | |
| self._categories[tool.category] = set() | |
| self._categories[tool.category].add(tool.name) | |
| # Update keyword index for search | |
| keywords = self._extract_keywords(f"{tool.name} {tool.description}") | |
| for keyword in keywords: | |
| if keyword not in self._keyword_index: | |
| self._keyword_index[keyword] = set() | |
| self._keyword_index[keyword].add(tool.name) | |
| def _extract_keywords(self, text: str) -> Set[str]: | |
| """Extract searchable keywords from text.""" | |
| # Lowercase and split | |
| words = re.findall(r'\b\w+\b', text.lower()) | |
| # Filter stop words | |
| stop_words = {'the', 'a', 'an', 'is', 'are', 'for', 'to', 'of', 'and', 'or', 'in', 'on', 'with'} | |
| return {w for w in words if w not in stop_words and len(w) > 2} | |
| def unregister(self, name: str) -> bool: | |
| """Remove a tool from the registry.""" | |
| if name not in self._tools: | |
| return False | |
| tool = self._tools[name] | |
| # Remove from category index | |
| if tool.category in self._categories: | |
| self._categories[tool.category].discard(name) | |
| # Remove from keyword index | |
| for keyword_set in self._keyword_index.values(): | |
| keyword_set.discard(name) | |
| del self._tools[name] | |
| return True | |
| def get(self, name: str) -> Optional[ToolDefinition]: | |
| """Get a specific tool by name.""" | |
| return self._tools.get(name) | |
| def get_all(self) -> List[ToolDefinition]: | |
| """Get all registered tools.""" | |
| return list(self._tools.values()) | |
| def get_by_category(self, category: str) -> List[ToolDefinition]: | |
| """Get all tools in a category.""" | |
| tool_names = self._categories.get(category, set()) | |
| return [self._tools[name] for name in tool_names] | |
| def get_categories(self) -> List[str]: | |
| """Get all available categories.""" | |
| return list(self._categories.keys()) | |
| def search(self, query: str, max_results: int = 5) -> List[ToolDefinition]: | |
| """ | |
| Search for tools matching a query. | |
| Uses keyword matching with relevance scoring. | |
| """ | |
| query_keywords = self._extract_keywords(query) | |
| # Score each tool by keyword matches | |
| scores: Dict[str, int] = {} | |
| for keyword in query_keywords: | |
| matching_tools = self._keyword_index.get(keyword, set()) | |
| for tool_name in matching_tools: | |
| scores[tool_name] = scores.get(tool_name, 0) + 1 | |
| # Also check for exact name matches (higher weight) | |
| for tool_name in self._tools: | |
| if query.lower() in tool_name.lower(): | |
| scores[tool_name] = scores.get(tool_name, 0) + 3 | |
| # Sort by score and return top results | |
| sorted_tools = sorted( | |
| [(name, score) for name, score in scores.items()], | |
| key=lambda x: -x[1] | |
| ) | |
| return [self._tools[name] for name, _ in sorted_tools[:max_results]] | |
| def get_tools_for_context(self, | |
| categories: List[str] = None, | |
| query_keywords: List[str] = None, | |
| max_tools: int = 10) -> List[ToolDefinition]: | |
| """ | |
| Get tools filtered by context. | |
| Used to reduce token usage by only including relevant tools. | |
| """ | |
| if categories: | |
| tools = [] | |
| for cat in categories: | |
| tools.extend(self.get_by_category(cat)) | |
| else: | |
| tools = self.get_all() | |
| if query_keywords: | |
| # Score by relevance | |
| def relevance_score(tool: ToolDefinition) -> int: | |
| text = f"{tool.name} {tool.description}".lower() | |
| return sum(1 for kw in query_keywords if kw.lower() in text) | |
| tools = sorted(tools, key=relevance_score, reverse=True) | |
| return tools[:max_tools] | |
| def execute(self, tool_name: str, args: Dict[str, Any]) -> str: | |
| """Execute a tool by name with given arguments.""" | |
| tool = self._tools.get(tool_name) | |
| if not tool: | |
| return json_module.dumps({"error": f"Unknown tool: {tool_name}"}) | |
| try: | |
| result = tool.handler(**args) | |
| return result | |
| except Exception as e: | |
| return json_module.dumps({"error": f"Error executing {tool_name}: {str(e)}"}) | |
| # ========================================================================= | |
| # MCP Compatibility Methods | |
| # ========================================================================= | |
| def to_mcp_tools_list(self) -> List[Dict]: | |
| """Generate MCP-compatible tools list.""" | |
| return [tool.to_mcp_schema() for tool in self._tools.values()] | |
| def handle_mcp_tool_call(self, tool_name: str, arguments: Dict) -> Dict: | |
| """ | |
| Handle an MCP tool call request. | |
| Returns MCP-compatible response format. | |
| """ | |
| tool = self._tools.get(tool_name) | |
| if not tool: | |
| return { | |
| "isError": True, | |
| "content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}] | |
| } | |
| try: | |
| result = tool.handler(**arguments) | |
| # Determine content type | |
| if tool.returns_json: | |
| try: | |
| parsed = json_module.loads(result) | |
| return { | |
| "content": [{"type": "text", "text": json_module.dumps(parsed, indent=2)}] | |
| } | |
| except: | |
| pass | |
| return { | |
| "content": [{"type": "text", "text": result}] | |
| } | |
| except Exception as e: | |
| return { | |
| "isError": True, | |
| "content": [{"type": "text", "text": f"Error: {str(e)}"}] | |
| } | |
| def generate_mcp_server_info(self) -> Dict: | |
| """Generate MCP server information.""" | |
| return { | |
| "name": "medgemma-health-tools", | |
| "version": "1.0.0", | |
| "description": "Health data tools for MedGemma medical AI assistant", | |
| "capabilities": { | |
| "tools": True, | |
| "resources": False, | |
| "prompts": False | |
| } | |
| } | |
| # ============================================================================= | |
| # GLOBAL REGISTRY INSTANCE | |
| # ============================================================================= | |
| registry = DynamicToolRegistry() | |
| # ============================================================================= | |
| # TOOL SEARCH FUNCTIONS (Reduce Token Usage) | |
| # ============================================================================= | |
| def search_tools(query: str, category: str = None, max_results: int = 5) -> str: | |
| """ | |
| Search for tools matching a query. | |
| Returns condensed tool info to reduce token usage. | |
| This is a META-TOOL that helps the agent find relevant tools | |
| without loading all tool definitions into context. | |
| """ | |
| if category: | |
| tools = registry.get_by_category(category) | |
| # Filter by query within category | |
| query_lower = query.lower() | |
| tools = [t for t in tools if query_lower in t.name.lower() or query_lower in t.description.lower()] | |
| else: | |
| tools = registry.search(query, max_results) | |
| if not tools: | |
| return json_module.dumps({ | |
| "matches": [], | |
| "hint": "No tools found. Try broader search terms.", | |
| "available_categories": registry.get_categories() | |
| }) | |
| return json_module.dumps({ | |
| "matches": [ | |
| { | |
| "name": t.name, | |
| "description": t.description[:100] + "..." if len(t.description) > 100 else t.description, | |
| "category": t.category, | |
| "returns_chart": t.returns_chart | |
| } | |
| for t in tools[:max_results] | |
| ], | |
| "hint": "Use GET_TOOL_SCHEMA to get full parameters for a specific tool." | |
| }) | |
| def get_tool_schema(tool_name: str) -> str: | |
| """ | |
| Get the full schema for a specific tool. | |
| Called on-demand to reduce initial context size. | |
| """ | |
| tool = registry.get(tool_name) | |
| if not tool: | |
| return json_module.dumps({ | |
| "error": f"Tool '{tool_name}' not found", | |
| "available_tools": [t.name for t in registry.get_all()] | |
| }) | |
| return json_module.dumps({ | |
| "name": tool.name, | |
| "description": tool.description, | |
| "parameters": [ | |
| { | |
| "name": p.name, | |
| "type": p.type, | |
| "description": p.description, | |
| "required": p.required, | |
| "default": p.default | |
| } | |
| for p in tool.parameters | |
| ], | |
| "category": tool.category, | |
| "returns_chart": tool.returns_chart, | |
| "returns_json": tool.returns_json | |
| }) | |
| def list_tool_categories() -> str: | |
| """List all available tool categories with tool counts.""" | |
| categories = {} | |
| for cat in registry.get_categories(): | |
| tools = registry.get_by_category(cat) | |
| categories[cat] = { | |
| "count": len(tools), | |
| "tools": [t.name for t in tools] | |
| } | |
| return json_module.dumps(categories) | |
| # ============================================================================= | |
| # TOOL DESCRIPTIONS GENERATOR (for LLM prompts) | |
| # ============================================================================= | |
| def get_tools_description(detail_level: str = "full", | |
| categories: List[str] = None, | |
| max_tools: int = None) -> str: | |
| """ | |
| Generate a description of available tools for the LLM. | |
| Args: | |
| detail_level: "name_only", "brief", or "full" | |
| categories: Filter to specific categories | |
| max_tools: Limit number of tools shown | |
| """ | |
| if categories: | |
| tools = [] | |
| for cat in categories: | |
| tools.extend(registry.get_by_category(cat)) | |
| else: | |
| tools = registry.get_all() | |
| if max_tools: | |
| tools = tools[:max_tools] | |
| lines = ["Available tools:\n"] | |
| for tool in tools: | |
| lines.append(tool.to_llm_description(detail_level)) | |
| lines.append("") | |
| return "\n".join(lines) | |
| def get_condensed_tools_prompt() -> str: | |
| """ | |
| Generate a minimal tools prompt that uses search instead of listing all tools. | |
| This significantly reduces token usage for agents with many tools. | |
| """ | |
| return """You have access to health data tools. Instead of all tools being listed here, | |
| use these META-TOOLS to find what you need: | |
| 1. SEARCH_TOOLS: {"query": "search term", "category": "optional_category"} | |
| - Search for relevant tools by keyword | |
| - Returns tool names and brief descriptions | |
| 2. GET_TOOL_SCHEMA: {"tool_name": "name"} | |
| - Get full parameters for a specific tool | |
| - Call this before using a tool you found via search | |
| 3. LIST_CATEGORIES: {} | |
| - See all tool categories: vitals, labs, medications, conditions, charts, analysis | |
| WORKFLOW: | |
| 1. Search for relevant tools based on user's question | |
| 2. Get the schema for tools you want to use | |
| 3. Call the tool with proper parameters | |
| Tool categories available: """ + ", ".join(registry.get_categories()) | |
| # ============================================================================= | |
| # VITAL SIGN CODES | |
| # ============================================================================= | |
| VITAL_CODES = { | |
| 'blood_pressure': ['8480-6', '8462-4', '85354-9'], | |
| 'blood_pressure_systolic': ['8480-6'], | |
| 'blood_pressure_diastolic': ['8462-4'], | |
| 'heart_rate': ['8867-4'], | |
| 'temperature': ['8310-5', '8331-1'], | |
| 'weight': ['29463-7'], | |
| 'height': ['8302-2'], | |
| 'respiratory_rate': ['9279-1'], | |
| 'oxygen_saturation': ['2708-6', '59408-5'], | |
| 'bmi': ['39156-5'] | |
| } | |
| # Lab codes for bar charts | |
| LAB_CODES = { | |
| 'a1c': [('4548-4', 'HbA1c', '%')], | |
| 'glucose': [('2345-7', 'Fasting Glucose', 'mg/dL'), ('1558-6', 'Fasting Glucose', 'mg/dL')], | |
| 'cholesterol': [ | |
| ('2093-3', 'Total Cholesterol', 'mg/dL'), | |
| ('2085-9', 'HDL', 'mg/dL'), | |
| ('13457-7', 'LDL', 'mg/dL'), | |
| ('2571-8', 'Triglycerides', 'mg/dL') | |
| ], | |
| 'kidney': [ | |
| ('2160-0', 'Creatinine', 'mg/dL'), | |
| ('33914-3', 'eGFR', 'mL/min') | |
| ] | |
| } | |
| # ============================================================================= | |
| # REGISTERED TOOLS | |
| # ============================================================================= | |
| def _search_tools(query: str, category: str = None, max_results: int = 5) -> str: | |
| return search_tools(query, category, max_results) | |
| def _get_tool_schema(tool_name: str) -> str: | |
| return get_tool_schema(tool_name) | |
| def _list_tool_categories() -> str: | |
| return list_tool_categories() | |
| def get_patient_summary(patient_id: str) -> str: | |
| """Get a summary of available health data.""" | |
| conn = get_db() | |
| try: | |
| cursor = conn.execute("SELECT * FROM patients WHERE id = ?", (patient_id,)) | |
| patient = cursor.fetchone() | |
| if not patient: | |
| return "Patient not found." | |
| counts = {} | |
| for table in ['conditions', 'medications', 'observations', 'allergies', 'encounters', 'immunizations', 'procedures']: | |
| try: | |
| cursor = conn.execute(f"SELECT COUNT(*) FROM {table} WHERE patient_id = ?", (patient_id,)) | |
| counts[table] = cursor.fetchone()[0] | |
| except: | |
| counts[table] = 0 | |
| cursor = conn.execute( | |
| "SELECT display, clinical_status FROM conditions WHERE patient_id = ? LIMIT 10", | |
| (patient_id,) | |
| ) | |
| conditions = [f"{row['display']} ({row['clinical_status']})" for row in cursor.fetchall()] | |
| cursor = conn.execute( | |
| "SELECT display FROM medications WHERE patient_id = ? AND status = 'active' LIMIT 10", | |
| (patient_id,) | |
| ) | |
| medications = [row['display'] for row in cursor.fetchall()] | |
| cursor = conn.execute(""" | |
| SELECT MIN(effective_date) as earliest, MAX(effective_date) as latest | |
| FROM observations WHERE patient_id = ? | |
| """, (patient_id,)) | |
| obs_range = cursor.fetchone() | |
| birth = datetime.strptime(patient["birth_date"], "%Y-%m-%d") | |
| age = (datetime.now() - birth).days // 365 | |
| summary = f"""Patient Summary: | |
| - Name: {patient['given_name']} {patient['family_name']} | |
| - Age: {age} years old | |
| - Gender: {patient['gender']} | |
| Available Data: | |
| - Conditions: {counts['conditions']} records | |
| - Medications: {counts['medications']} records | |
| - Observations (vitals/labs): {counts['observations']} records | |
| - Allergies: {counts['allergies']} records | |
| - Encounters: {counts['encounters']} records | |
| - Immunizations: {counts['immunizations']} records | |
| Conditions: {', '.join(conditions) if conditions else 'None recorded'} | |
| Active Medications: {', '.join(medications) if medications else 'None'} | |
| Observation date range: {obs_range['earliest'][:10] if obs_range['earliest'] else 'N/A'} to {obs_range['latest'][:10] if obs_range['latest'] else 'N/A'} | |
| """ | |
| return summary | |
| finally: | |
| conn.close() | |
| def get_conditions(patient_id: str, status: Optional[str] = None) -> str: | |
| """Get patient conditions.""" | |
| conn = get_db() | |
| try: | |
| if status: | |
| cursor = conn.execute(""" | |
| SELECT display, clinical_status, onset_date, abatement_date | |
| FROM conditions WHERE patient_id = ? AND clinical_status = ? | |
| ORDER BY onset_date DESC | |
| """, (patient_id, status)) | |
| else: | |
| cursor = conn.execute(""" | |
| SELECT display, clinical_status, onset_date, abatement_date | |
| FROM conditions WHERE patient_id = ? | |
| ORDER BY onset_date DESC | |
| """, (patient_id,)) | |
| conditions = cursor.fetchall() | |
| if not conditions: | |
| return "No conditions found." | |
| lines = ["Conditions:\n"] | |
| for c in conditions: | |
| onset = c['onset_date'][:10] if c['onset_date'] else 'Unknown' | |
| end = f" (resolved: {c['abatement_date'][:10]})" if c['abatement_date'] else "" | |
| lines.append(f"- {c['display']} [{c['clinical_status']}] - since {onset}{end}") | |
| return "\n".join(lines) | |
| finally: | |
| conn.close() | |
| def get_medications(patient_id: str, status: Optional[str] = None) -> str: | |
| """Get patient medications.""" | |
| conn = get_db() | |
| try: | |
| if status: | |
| cursor = conn.execute(""" | |
| SELECT display, status, dosage_text, dosage_route, start_date | |
| FROM medications WHERE patient_id = ? AND status = ? | |
| ORDER BY start_date DESC | |
| """, (patient_id, status)) | |
| else: | |
| cursor = conn.execute(""" | |
| SELECT display, status, dosage_text, dosage_route, start_date | |
| FROM medications WHERE patient_id = ? | |
| ORDER BY start_date DESC | |
| """, (patient_id,)) | |
| medications = cursor.fetchall() | |
| if not medications: | |
| return "No medications found." | |
| lines = ["Medications:\n"] | |
| for m in medications: | |
| dosage = m['dosage_text'] or 'No dosage specified' | |
| route = f" ({m['dosage_route']})" if m['dosage_route'] else "" | |
| start = m['start_date'][:10] if m['start_date'] else 'Unknown' | |
| lines.append(f"- {m['display']} [{m['status']}]") | |
| lines.append(f" Dosage: {dosage}{route}") | |
| lines.append(f" Started: {start}") | |
| return "\n".join(lines) | |
| finally: | |
| conn.close() | |
| def get_allergies(patient_id: str) -> str: | |
| """Get patient allergies.""" | |
| conn = get_db() | |
| try: | |
| cursor = conn.execute(""" | |
| SELECT substance, reaction_display, reaction_severity, criticality, category | |
| FROM allergies WHERE patient_id = ? | |
| """, (patient_id,)) | |
| allergies = cursor.fetchall() | |
| if not allergies: | |
| return "No known allergies." | |
| lines = ["Allergies:\n"] | |
| for a in allergies: | |
| severity = a['reaction_severity'] or a['criticality'] or 'Unknown severity' | |
| reaction = a['reaction_display'] or 'Unknown reaction' | |
| category = a['category'] or 'Unknown type' | |
| lines.append(f"- {a['substance']} ({category})") | |
| lines.append(f" Reaction: {reaction}") | |
| lines.append(f" Severity: {severity}") | |
| return "\n".join(lines) | |
| finally: | |
| conn.close() | |
| def get_recent_vitals(patient_id: str, days: int = 30, vital_type: Optional[str] = None) -> str: | |
| """Get recent vital signs.""" | |
| conn = get_db() | |
| try: | |
| cutoff = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d") | |
| if vital_type and vital_type in VITAL_CODES: | |
| codes = VITAL_CODES[vital_type] | |
| placeholders = ','.join(['?' for _ in codes]) | |
| cursor = conn.execute(f""" | |
| SELECT display, value_quantity, value_unit, effective_date | |
| FROM observations | |
| WHERE patient_id = ? AND category = 'vital-signs' | |
| AND effective_date >= ? AND code IN ({placeholders}) | |
| ORDER BY effective_date DESC | |
| LIMIT 50 | |
| """, [patient_id, cutoff] + codes) | |
| else: | |
| cursor = conn.execute(""" | |
| SELECT display, value_quantity, value_unit, effective_date | |
| FROM observations | |
| WHERE patient_id = ? AND category = 'vital-signs' AND effective_date >= ? | |
| ORDER BY effective_date DESC | |
| LIMIT 50 | |
| """, (patient_id, cutoff)) | |
| vitals = cursor.fetchall() | |
| if not vitals: | |
| return f"No vital signs found in the last {days} days." | |
| lines = [f"Vital Signs (last {days} days):\n"] | |
| by_type = {} | |
| for v in vitals: | |
| display = v['display'] | |
| if display not in by_type: | |
| by_type[display] = [] | |
| by_type[display].append(v) | |
| for display, readings in by_type.items(): | |
| lines.append(f"\n{display}:") | |
| for r in readings[:5]: | |
| date = r['effective_date'][:10] if r['effective_date'] else 'Unknown' | |
| value = r['value_quantity'] | |
| unit = r['value_unit'] or '' | |
| lines.append(f" {date}: {value} {unit}") | |
| return "\n".join(lines) | |
| finally: | |
| conn.close() | |
| def get_lab_results(patient_id: str, days: int = 90) -> str: | |
| """Get laboratory results.""" | |
| conn = get_db() | |
| try: | |
| cutoff = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d") | |
| cursor = conn.execute(""" | |
| SELECT display, value_quantity, value_unit, value_string, | |
| effective_date, interpretation | |
| FROM observations | |
| WHERE patient_id = ? AND category = 'laboratory' AND effective_date >= ? | |
| ORDER BY effective_date DESC | |
| LIMIT 50 | |
| """, (patient_id, cutoff)) | |
| labs = cursor.fetchall() | |
| if not labs: | |
| return f"No lab results found in the last {days} days." | |
| lines = [f"Lab Results (last {days} days):\n"] | |
| for lab in labs: | |
| date = lab['effective_date'][:10] if lab['effective_date'] else 'Unknown' | |
| value = lab['value_quantity'] if lab['value_quantity'] else lab['value_string'] or 'N/A' | |
| unit = lab['value_unit'] or '' | |
| interp = f" [{lab['interpretation']}]" if lab['interpretation'] else "" | |
| lines.append(f"- {lab['display']}: {value} {unit}{interp} ({date})") | |
| return "\n".join(lines) | |
| finally: | |
| conn.close() | |
| def get_encounters(patient_id: str, limit: int = 10) -> str: | |
| """Get healthcare encounters.""" | |
| conn = get_db() | |
| try: | |
| cursor = conn.execute(""" | |
| SELECT type_display, reason_display, period_start, period_end, | |
| class_display, status | |
| FROM encounters WHERE patient_id = ? | |
| ORDER BY period_start DESC | |
| LIMIT ? | |
| """, (patient_id, limit)) | |
| encounters = cursor.fetchall() | |
| if not encounters: | |
| return "No encounters found." | |
| lines = [f"Healthcare Encounters (last {limit}):\n"] | |
| for e in encounters: | |
| date = e['period_start'][:10] if e['period_start'] else 'Unknown' | |
| enc_type = e['type_display'] or e['class_display'] or 'Visit' | |
| reason = e['reason_display'] or 'No reason specified' | |
| lines.append(f"- {date}: {enc_type}") | |
| lines.append(f" Reason: {reason}") | |
| return "\n".join(lines) | |
| finally: | |
| conn.close() | |
| def get_immunizations(patient_id: str) -> str: | |
| """Get immunization history.""" | |
| conn = get_db() | |
| try: | |
| cursor = conn.execute(""" | |
| SELECT vaccine_display, status, occurrence_date | |
| FROM immunizations WHERE patient_id = ? | |
| ORDER BY occurrence_date DESC | |
| """, (patient_id,)) | |
| immunizations = cursor.fetchall() | |
| if not immunizations: | |
| return "No immunization records found." | |
| lines = ["Immunizations:\n"] | |
| for i in immunizations: | |
| date = i['occurrence_date'][:10] if i['occurrence_date'] else 'Unknown' | |
| lines.append(f"- {i['vaccine_display']} ({date}) [{i['status']}]") | |
| return "\n".join(lines) | |
| finally: | |
| conn.close() | |
| def analyze_vital_trend(patient_id: str, vital_type: str, days: int = 30) -> str: | |
| """Analyze trends in vital signs.""" | |
| conn = get_db() | |
| try: | |
| if vital_type not in VITAL_CODES: | |
| return f"Unknown vital type: {vital_type}. Available: {', '.join(VITAL_CODES.keys())}" | |
| codes = VITAL_CODES[vital_type] | |
| cutoff = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d") | |
| placeholders = ','.join(['?' for _ in codes]) | |
| cursor = conn.execute(f""" | |
| SELECT value_quantity, effective_date | |
| FROM observations | |
| WHERE patient_id = ? AND code IN ({placeholders}) | |
| AND effective_date >= ? AND value_quantity IS NOT NULL | |
| ORDER BY effective_date ASC | |
| """, [patient_id] + codes + [cutoff]) | |
| readings = cursor.fetchall() | |
| if not readings: | |
| return f"No {vital_type} readings found in the last {days} days." | |
| values = [r['value_quantity'] for r in readings] | |
| dates = [r['effective_date'][:10] for r in readings] | |
| avg_val = sum(values) / len(values) | |
| min_val = min(values) | |
| max_val = max(values) | |
| if len(values) >= 3: | |
| first_third = sum(values[:len(values)//3]) / (len(values)//3) | |
| last_third = sum(values[-len(values)//3:]) / (len(values)//3) | |
| diff_pct = ((last_third - first_third) / first_third) * 100 if first_third != 0 else 0 | |
| if diff_pct > 5: | |
| trend = f"INCREASING (up {diff_pct:.1f}%)" | |
| elif diff_pct < -5: | |
| trend = f"DECREASING (down {abs(diff_pct):.1f}%)" | |
| else: | |
| trend = "STABLE" | |
| else: | |
| trend = "Not enough data for trend" | |
| return f"""{vital_type.replace('_', ' ').title()} Analysis (last {days} days): | |
| Readings: {len(values)} | |
| Date range: {dates[0]} to {dates[-1]} | |
| Statistics: | |
| - Average: {avg_val:.1f} | |
| - Minimum: {min_val:.1f} | |
| - Maximum: {max_val:.1f} | |
| Trend: {trend} | |
| Recent values: {', '.join([str(v) for v in values[-5:]])} | |
| """ | |
| finally: | |
| conn.close() | |
| def get_vital_chart_data(patient_id: str, vital_type: str, days: int = 30) -> str: | |
| """ | |
| Get vital sign data formatted for charting. | |
| Returns JSON that can be used to render a chart. | |
| """ | |
| def compute_stats(values): | |
| """Compute statistics for a list of values.""" | |
| if not values: | |
| return None | |
| return { | |
| "count": len(values), | |
| "min": round(min(values), 1), | |
| "max": round(max(values), 1), | |
| "avg": round(stats_module.mean(values), 1), | |
| "latest": round(values[-1], 1) if values else None | |
| } | |
| conn = get_db() | |
| try: | |
| if vital_type not in VITAL_CODES: | |
| return json_module.dumps({"error": f"Unknown vital type: {vital_type}. Available: {', '.join(VITAL_CODES.keys())}"}) | |
| codes = VITAL_CODES[vital_type] | |
| cutoff = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d") | |
| # For blood pressure, we need both systolic and diastolic | |
| if vital_type == 'blood_pressure': | |
| # Get systolic | |
| cursor = conn.execute(""" | |
| SELECT value_quantity, effective_date | |
| FROM observations | |
| WHERE patient_id = ? AND code = '8480-6' | |
| AND effective_date >= ? AND value_quantity IS NOT NULL | |
| ORDER BY effective_date ASC | |
| """, [patient_id, cutoff]) | |
| systolic_rows = cursor.fetchall() | |
| systolic = [{"date": r['effective_date'][:10], "value": r['value_quantity']} for r in systolic_rows] | |
| systolic_values = [r['value_quantity'] for r in systolic_rows] | |
| # Get diastolic | |
| cursor = conn.execute(""" | |
| SELECT value_quantity, effective_date | |
| FROM observations | |
| WHERE patient_id = ? AND code = '8462-4' | |
| AND effective_date >= ? AND value_quantity IS NOT NULL | |
| ORDER BY effective_date ASC | |
| """, [patient_id, cutoff]) | |
| diastolic_rows = cursor.fetchall() | |
| diastolic = [{"date": r['effective_date'][:10], "value": r['value_quantity']} for r in diastolic_rows] | |
| diastolic_values = [r['value_quantity'] for r in diastolic_rows] | |
| if not systolic and not diastolic: | |
| return json_module.dumps({"error": "No blood pressure data found"}) | |
| systolic_stats = compute_stats(systolic_values) | |
| diastolic_stats = compute_stats(diastolic_values) | |
| summary_text = ( | |
| f"STATISTICS (use these exact values): " | |
| f"Total readings: {systolic_stats['count']}. " | |
| f"Systolic: min={systolic_stats['min']}, max={systolic_stats['max']}, avg={systolic_stats['avg']} mmHg. " | |
| f"Diastolic: min={diastolic_stats['min']}, max={diastolic_stats['max']}, avg={diastolic_stats['avg']} mmHg." | |
| ) | |
| return json_module.dumps({ | |
| "summary": summary_text, | |
| "chart_type": "blood_pressure", | |
| "title": f"Blood Pressure (Last {days} Days)", | |
| "statistics": { | |
| "systolic": systolic_stats, | |
| "diastolic": diastolic_stats, | |
| }, | |
| "datasets": [ | |
| {"label": "Systolic", "data": systolic, "color": "#e74c3c"}, | |
| {"label": "Diastolic", "data": diastolic, "color": "#3498db"} | |
| ] | |
| }) | |
| else: | |
| placeholders = ','.join(['?' for _ in codes]) | |
| cursor = conn.execute(f""" | |
| SELECT value_quantity, value_unit, effective_date, display | |
| FROM observations | |
| WHERE patient_id = ? AND code IN ({placeholders}) | |
| AND effective_date >= ? AND value_quantity IS NOT NULL | |
| ORDER BY effective_date ASC | |
| """, [patient_id] + codes + [cutoff]) | |
| readings = cursor.fetchall() | |
| if not readings: | |
| return json_module.dumps({"error": f"No {vital_type} data found"}) | |
| data = [{"date": r['effective_date'][:10], "value": r['value_quantity']} for r in readings] | |
| values = [r['value_quantity'] for r in readings] | |
| unit = readings[0]['value_unit'] if readings else '' | |
| display = readings[0]['display'] if readings else vital_type | |
| vital_stats = compute_stats(values) | |
| summary_text = ( | |
| f"STATISTICS (use these exact values): " | |
| f"Total readings: {vital_stats['count']}. " | |
| f"{display}: min={vital_stats['min']}, max={vital_stats['max']}, avg={vital_stats['avg']} {unit}." | |
| ) | |
| return json_module.dumps({ | |
| "summary": summary_text, | |
| "chart_type": "line", | |
| "title": f"{display} (Last {days} Days)", | |
| "unit": unit, | |
| "statistics": vital_stats, | |
| "datasets": [ | |
| {"label": display, "data": data, "color": "#667eea"} | |
| ] | |
| }) | |
| finally: | |
| conn.close() | |
| def get_lab_chart_data(patient_id: str, lab_type: str, periods: int = 4) -> str: | |
| """Get lab results formatted for a BAR chart.""" | |
| conn = get_db() | |
| try: | |
| if lab_type == 'all_latest': | |
| labs_to_show = [ | |
| ('4548-4', 'HbA1c'), | |
| ('2093-3', 'Total Chol'), | |
| ('2085-9', 'HDL'), | |
| ('13457-7', 'LDL'), | |
| ('2571-8', 'Triglyc'), | |
| ('2345-7', 'Glucose'), | |
| ] | |
| data = [] | |
| for code, label in labs_to_show: | |
| cursor = conn.execute(""" | |
| SELECT value_quantity FROM observations | |
| WHERE patient_id = ? AND code = ? AND value_quantity IS NOT NULL | |
| ORDER BY effective_date DESC LIMIT 1 | |
| """, (patient_id, code)) | |
| row = cursor.fetchone() | |
| if row: | |
| data.append({"label": label, "value": row['value_quantity']}) | |
| if not data: | |
| return json_module.dumps({"error": "No lab data found"}) | |
| return json_module.dumps({ | |
| "chart_type": "bar", | |
| "title": "Latest Lab Values", | |
| "datasets": [{"label": "Value", "data": data, "color": "#667eea"}] | |
| }) | |
| if lab_type not in LAB_CODES: | |
| return json_module.dumps({"error": f"Unknown lab type: {lab_type}. Available: {', '.join(LAB_CODES.keys())}"}) | |
| lab_info = LAB_CODES[lab_type] | |
| datasets = [] | |
| for code, label, unit in lab_info: | |
| cursor = conn.execute(""" | |
| SELECT value_quantity, effective_date FROM observations | |
| WHERE patient_id = ? AND code = ? AND value_quantity IS NOT NULL | |
| ORDER BY effective_date DESC LIMIT ? | |
| """, (patient_id, code, periods)) | |
| rows = cursor.fetchall() | |
| if rows: | |
| data = [{"date": r['effective_date'][:10], "value": r['value_quantity']} for r in reversed(rows)] | |
| datasets.append({ | |
| "label": label, | |
| "data": data, | |
| "unit": unit | |
| }) | |
| if not datasets: | |
| return json_module.dumps({"error": f"No {lab_type} data found"}) | |
| return json_module.dumps({ | |
| "chart_type": "bar", | |
| "title": f"{lab_type.title()} Results (Last {periods} readings)", | |
| "datasets": datasets | |
| }) | |
| finally: | |
| conn.close() | |
| def compare_before_after_treatment(patient_id: str, medication_name: str, metric_type: str) -> str: | |
| """Compare health metrics before vs after starting a medication.""" | |
| conn = get_db() | |
| try: | |
| # Find the medication start date | |
| cursor = conn.execute(""" | |
| SELECT display, start_date FROM medications | |
| WHERE patient_id = ? AND LOWER(display) LIKE ? | |
| ORDER BY start_date ASC LIMIT 1 | |
| """, (patient_id, f"%{medication_name.lower()}%")) | |
| med_row = cursor.fetchone() | |
| if not med_row: | |
| return json_module.dumps({"error": f"No medication found matching '{medication_name}'"}) | |
| med_display = med_row['display'] | |
| start_date = med_row['start_date'][:10] if med_row['start_date'] else None | |
| if not start_date: | |
| return json_module.dumps({"error": f"No start date found for {med_display}"}) | |
| # Get the appropriate codes for the metric | |
| metric_codes = { | |
| 'blood_pressure': [('8480-6', 'Systolic BP'), ('8462-4', 'Diastolic BP')], | |
| 'a1c': [('4548-4', 'HbA1c')], | |
| 'cholesterol': [('2093-3', 'Total Cholesterol'), ('13457-7', 'LDL')], | |
| 'glucose': [('2345-7', 'Glucose'), ('1558-6', 'Glucose')], | |
| 'weight': [('29463-7', 'Weight')], | |
| 'heart_rate': [('8867-4', 'Heart Rate')] | |
| } | |
| if metric_type not in metric_codes: | |
| return json_module.dumps({"error": f"Unknown metric type: {metric_type}"}) | |
| codes = metric_codes[metric_type] | |
| results = {"before": {}, "after": {}} | |
| summary_parts = [] | |
| for code, label in codes: | |
| # Get values BEFORE treatment | |
| cursor = conn.execute(""" | |
| SELECT AVG(value_quantity) as avg_val, COUNT(*) as cnt | |
| FROM observations | |
| WHERE patient_id = ? AND code = ? AND effective_date < ? | |
| AND value_quantity IS NOT NULL | |
| """, (patient_id, code, start_date)) | |
| before = cursor.fetchone() | |
| # Get values AFTER treatment | |
| cursor = conn.execute(""" | |
| SELECT AVG(value_quantity) as avg_val, COUNT(*) as cnt | |
| FROM observations | |
| WHERE patient_id = ? AND code = ? AND effective_date >= ? | |
| AND value_quantity IS NOT NULL | |
| """, (patient_id, code, start_date)) | |
| after = cursor.fetchone() | |
| if before['avg_val'] and after['avg_val']: | |
| results["before"][label] = round(before['avg_val'], 1) | |
| results["after"][label] = round(after['avg_val'], 1) | |
| change = after['avg_val'] - before['avg_val'] | |
| change_pct = (change / before['avg_val']) * 100 if before['avg_val'] else 0 | |
| direction = "decreased" if change < 0 else "increased" | |
| summary_parts.append(f"{label} {direction} from {results['before'][label]} to {results['after'][label]} ({abs(change_pct):.1f}%)") | |
| if not results["before"]: | |
| return json_module.dumps({"error": f"No {metric_type} data found around treatment start date"}) | |
| return json_module.dumps({ | |
| "chart_type": "comparison", | |
| "title": f"{metric_type.title()} Before vs After {med_display}", | |
| "medication": med_display, | |
| "treatment_start_date": start_date, | |
| "summary": summary_parts, | |
| "datasets": [ | |
| {"label": "Before Treatment", "data": results["before"], "color": "#e74c3c"}, | |
| {"label": "After Treatment", "data": results["after"], "color": "#27ae60"} | |
| ] | |
| }) | |
| finally: | |
| conn.close() | |
| # ============================================================================= | |
| # BACKWARD COMPATIBILITY - Legacy function mappings | |
| # ============================================================================= | |
| # Map old-style tool names to registry | |
| TOOL_FUNCTIONS = { | |
| "search_tools": _search_tools, | |
| "get_tool_schema": _get_tool_schema, | |
| "list_tool_categories": _list_tool_categories, | |
| "get_patient_summary": get_patient_summary, | |
| "get_conditions": get_conditions, | |
| "get_medications": get_medications, | |
| "get_allergies": get_allergies, | |
| "get_recent_vitals": get_recent_vitals, | |
| "get_lab_results": get_lab_results, | |
| "get_encounters": get_encounters, | |
| "get_immunizations": get_immunizations, | |
| "analyze_vital_trend": analyze_vital_trend, | |
| "get_vital_chart_data": get_vital_chart_data, | |
| "get_lab_chart_data": get_lab_chart_data, | |
| "compare_before_after_treatment": compare_before_after_treatment, | |
| } | |
| # Legacy TOOLS list for backward compatibility | |
| TOOLS = [tool.to_mcp_schema() for tool in registry.get_all()] | |
| def execute_tool(tool_name: str, args: dict) -> str: | |
| """Execute a tool by name with given arguments. (Legacy compatibility)""" | |
| return registry.execute(tool_name, args) | |
| # ============================================================================= | |
| # MCP SERVER INTERFACE (Model Context Protocol) | |
| # ============================================================================= | |
| class MCPServerInterface: | |
| """ | |
| MCP Server interface for integration with MCP clients. | |
| Implements the Model Context Protocol specification. | |
| This allows: | |
| 1. External apps to discover and use our health tools | |
| 2. Dynamic registration of external MCP tools | |
| 3. Standard protocol for AI-tool integration | |
| """ | |
| PROTOCOL_VERSION = "2024-11-05" | |
| def __init__(self, tool_registry: DynamicToolRegistry): | |
| self.registry = tool_registry | |
| self.external_tools: Dict[str, Dict] = {} # Tools from external MCP servers | |
| self.connected_servers: Dict[str, Dict] = {} # Connected MCP servers | |
| def get_server_info(self) -> Dict: | |
| """Return MCP server information (initialize response).""" | |
| return { | |
| "protocolVersion": self.PROTOCOL_VERSION, | |
| "serverInfo": { | |
| "name": "medgemma-health-tools", | |
| "version": "2.0.0" | |
| }, | |
| "capabilities": { | |
| "tools": {"listChanged": True}, | |
| "resources": {}, | |
| "prompts": {} | |
| } | |
| } | |
| def list_tools(self) -> Dict: | |
| """Return list of available tools in MCP format.""" | |
| tools = [] | |
| # Add local tools | |
| for tool in self.registry.get_all(): | |
| tools.append(tool.to_mcp_schema()) | |
| # Add external tools | |
| for tool_name, tool_info in self.external_tools.items(): | |
| tools.append(tool_info) | |
| return {"tools": tools} | |
| def call_tool(self, name: str, arguments: Dict) -> Dict: | |
| """Handle MCP tool call.""" | |
| # Check if it's an external tool | |
| if name in self.external_tools: | |
| return self._call_external_tool(name, arguments) | |
| # Otherwise use local registry | |
| return self.registry.handle_mcp_tool_call(name, arguments) | |
| def _call_external_tool(self, name: str, arguments: Dict) -> Dict: | |
| """Call a tool on an external MCP server.""" | |
| tool_info = self.external_tools.get(name) | |
| if not tool_info: | |
| return { | |
| "isError": True, | |
| "content": [{"type": "text", "text": f"External tool not found: {name}"}] | |
| } | |
| server_url = tool_info.get("_server_url") | |
| if not server_url: | |
| return { | |
| "isError": True, | |
| "content": [{"type": "text", "text": f"No server URL for tool: {name}"}] | |
| } | |
| # Make HTTP request to external MCP server | |
| try: | |
| import httpx | |
| response = httpx.post( | |
| f"{server_url}/tools/call", | |
| json={"name": name, "arguments": arguments}, | |
| timeout=30.0 | |
| ) | |
| return response.json() | |
| except Exception as e: | |
| return { | |
| "isError": True, | |
| "content": [{"type": "text", "text": f"Error calling external tool: {str(e)}"}] | |
| } | |
| # ========================================================================= | |
| # MCP Client - Connect to external MCP servers | |
| # ========================================================================= | |
| def connect_server(self, server_url: str, server_name: str = None) -> Dict: | |
| """ | |
| Connect to an external MCP server and discover its tools. | |
| Args: | |
| server_url: Base URL of the MCP server | |
| server_name: Optional friendly name for the server | |
| Returns: | |
| Dict with connection status and discovered tools | |
| """ | |
| try: | |
| import httpx | |
| # Initialize connection | |
| init_response = httpx.post( | |
| f"{server_url}/initialize", | |
| json={ | |
| "protocolVersion": self.PROTOCOL_VERSION, | |
| "clientInfo": {"name": "medgemma-agent", "version": "2.0.0"}, | |
| "capabilities": {} | |
| }, | |
| timeout=10.0 | |
| ) | |
| server_info = init_response.json() | |
| # List available tools | |
| tools_response = httpx.post(f"{server_url}/tools/list", timeout=10.0) | |
| tools_data = tools_response.json() | |
| # Register discovered tools | |
| discovered_tools = [] | |
| for tool in tools_data.get("tools", []): | |
| tool_name = tool.get("name") | |
| # Prefix with server name to avoid conflicts | |
| prefixed_name = f"{server_name or 'ext'}_{tool_name}" | |
| tool["_original_name"] = tool_name | |
| tool["_server_url"] = server_url | |
| tool["name"] = prefixed_name | |
| self.external_tools[prefixed_name] = tool | |
| discovered_tools.append(prefixed_name) | |
| # Store connection info | |
| self.connected_servers[server_url] = { | |
| "name": server_name or server_url, | |
| "info": server_info, | |
| "tools": discovered_tools | |
| } | |
| print(f"[MCP] Connected to {server_url}, discovered {len(discovered_tools)} tools") | |
| return { | |
| "success": True, | |
| "server": server_info.get("serverInfo", {}), | |
| "tools_discovered": discovered_tools | |
| } | |
| except Exception as e: | |
| print(f"[MCP] Failed to connect to {server_url}: {e}") | |
| return { | |
| "success": False, | |
| "error": str(e) | |
| } | |
| def disconnect_server(self, server_url: str) -> bool: | |
| """Disconnect from an external MCP server and remove its tools.""" | |
| if server_url not in self.connected_servers: | |
| return False | |
| # Remove tools from this server | |
| server_info = self.connected_servers[server_url] | |
| for tool_name in server_info.get("tools", []): | |
| self.external_tools.pop(tool_name, None) | |
| del self.connected_servers[server_url] | |
| print(f"[MCP] Disconnected from {server_url}") | |
| return True | |
| def list_connected_servers(self) -> List[Dict]: | |
| """List all connected MCP servers.""" | |
| return [ | |
| { | |
| "url": url, | |
| "name": info["name"], | |
| "tools_count": len(info["tools"]) | |
| } | |
| for url, info in self.connected_servers.items() | |
| ] | |
| # ========================================================================= | |
| # Dynamic Tool Registration | |
| # ========================================================================= | |
| def register_tool_manually(self, | |
| name: str, | |
| description: str, | |
| parameters: Dict, | |
| handler_url: str) -> bool: | |
| """ | |
| Manually register an external tool without connecting to a full MCP server. | |
| Useful for simple webhook-based tools. | |
| Args: | |
| name: Tool name | |
| description: Tool description | |
| parameters: JSON Schema for parameters | |
| handler_url: URL to POST tool calls to | |
| """ | |
| self.external_tools[name] = { | |
| "name": name, | |
| "description": description, | |
| "inputSchema": parameters, | |
| "_server_url": handler_url.rsplit('/', 1)[0] if '/' in handler_url else handler_url, | |
| "_handler_url": handler_url | |
| } | |
| print(f"[MCP] Registered external tool: {name}") | |
| return True | |
| def get_all_tools(self) -> List[Dict]: | |
| """Get all tools (local + external) in MCP format.""" | |
| tools = [tool.to_mcp_schema() for tool in self.registry.get_all()] | |
| tools.extend(self.external_tools.values()) | |
| return tools | |
| # Create global MCP interface | |
| mcp_interface = MCPServerInterface(registry) | |
| # ============================================================================= | |
| # UTILITY FUNCTIONS | |
| # ============================================================================= | |
| def get_tool_names() -> List[str]: | |
| """Get list of all registered tool names.""" | |
| return [tool.name for tool in registry.get_all()] | |
| def get_chart_tools() -> List[str]: | |
| """Get tools that return chart data.""" | |
| return [tool.name for tool in registry.get_all() if tool.returns_chart] | |
| def get_tools_by_category(category: str) -> List[str]: | |
| """Get tool names by category.""" | |
| return [tool.name for tool in registry.get_by_category(category)] | |
| # ============================================================================= | |
| # SKIN ANALYSIS TOOL (calls Health Foundation / SCIN Classifier server) | |
| # ============================================================================= | |
| HEALTH_FOUNDATION_URL = os.getenv("HEALTH_FOUNDATION_URL", os.getenv("HEAR_SERVER_URL", "http://localhost:8082")) | |
| def analyze_skin_image(patient_id: str, image_data: str, symptoms: str = "") -> str: | |
| """ | |
| Analyze a skin image using Google's Derm Foundation model + SCIN classifier. | |
| This tool calls the Health Foundation Server which: | |
| 1. Extracts embeddings using Derm Foundation | |
| 2. Runs the trained SCIN classifier for condition prediction | |
| 3. Parses user symptoms from text | |
| 4. Combines predictions with user input | |
| 5. Returns structured data + LLM prompt for synthesis | |
| Args: | |
| patient_id: The patient's ID | |
| image_data: Base64 encoded image data | |
| symptoms: Optional free-text symptom description (e.g., "itchy red rash on arm") | |
| Returns: | |
| JSON string with conditions, symptoms, body parts, and LLM synthesis prompt | |
| """ | |
| import base64 | |
| import httpx | |
| try: | |
| # Decode base64 if needed | |
| if isinstance(image_data, str): | |
| # Remove data URL prefix if present | |
| if ',' in image_data: | |
| image_data = image_data.split(',')[1] | |
| image_bytes = base64.b64decode(image_data) | |
| else: | |
| image_bytes = image_data | |
| # First, try the classifier endpoint | |
| with httpx.Client(timeout=90.0) as client: | |
| files = {"image": ("skin_image.png", image_bytes, "image/png")} | |
| data = { | |
| "symptoms": symptoms or "", | |
| "top_k": "5" | |
| } | |
| response = client.post( | |
| f"{HEALTH_FOUNDATION_URL}/analyze/skin/classify", | |
| files=files, | |
| data=data, | |
| headers={"ngrok-skip-browser-warning": "true"} | |
| ) | |
| # If classifier not available, fall back to basic embedding analysis | |
| if response.status_code == 503 or (response.status_code == 200 and "classifier not loaded" in response.text.lower()): | |
| # Fallback to basic embedding analysis | |
| files = {"image": ("skin_image.png", image_bytes, "image/png")} | |
| response = client.post( | |
| f"{HEALTH_FOUNDATION_URL}/analyze/skin", | |
| files=files, | |
| data={"include_embedding": "false"}, | |
| headers={"ngrok-skip-browser-warning": "true"} | |
| ) | |
| response.raise_for_status() | |
| result = response.json() | |
| if not result.get("success"): | |
| return json_module.dumps({ | |
| "error": result.get("error", "Analysis failed"), | |
| "disclaimer": "⚠️ FOR RESEARCH USE ONLY - NOT A DIAGNOSTIC TOOL" | |
| }) | |
| # Format basic response | |
| quality = result.get("image_quality", {}) | |
| return json_module.dumps({ | |
| "status": "success", | |
| "analysis_type": "embedding_only", | |
| "model": "Derm Foundation (classifier unavailable)", | |
| "image_quality_score": quality.get("score", 0), | |
| "image_quality_notes": quality.get("notes", []), | |
| "conditions": [], | |
| "symptoms_detected": [], | |
| "body_parts_detected": [], | |
| "user_reported_symptoms": symptoms, | |
| "llm_synthesis_prompt": ( | |
| f"The skin image was analyzed but the condition classifier is unavailable. " | |
| f"Image quality score: {quality.get('score', 0)}/100. " | |
| f"User described: '{symptoms}'. " | |
| f"Please acknowledge their concern, note that specific condition identification isn't available, " | |
| f"and recommend they consult a dermatologist for a proper evaluation." | |
| ), | |
| "summary": "Image processed (embedding only - classifier unavailable)", | |
| "disclaimer": "⚠️ FOR RESEARCH USE ONLY - NOT A DIAGNOSTIC TOOL" | |
| }) | |
| response.raise_for_status() | |
| result = response.json() | |
| if not result.get("success"): | |
| return json_module.dumps({ | |
| "error": result.get("error", "Analysis failed"), | |
| "disclaimer": "⚠️ FOR RESEARCH USE ONLY - NOT A DIAGNOSTIC TOOL" | |
| }) | |
| # Format the classifier response for LLM consumption | |
| conditions = result.get("conditions", []) | |
| predicted_symptoms = result.get("predicted_symptoms", []) | |
| predicted_body_parts = result.get("predicted_body_parts", []) | |
| user_reported = result.get("user_reported", {}) | |
| fitzpatrick = result.get("fitzpatrick_skin_type", {}) | |
| symptom_agreement = result.get("symptom_agreement", {}) | |
| llm_prompt = result.get("llm_prompt", "") | |
| summary = result.get("summary", {}) | |
| # Build a structured response for the agent | |
| analysis = { | |
| "status": "success", | |
| "analysis_type": "full_classification", | |
| "model": "Derm Foundation + SCIN Classifier", | |
| # Conditions ranked by confidence | |
| "conditions": [ | |
| { | |
| "name": c.get("condition", ""), | |
| "confidence": round(c.get("adjusted_confidence", 0) * 100, 1), | |
| "likelihood": "high" if c.get("adjusted_confidence", 0) > 0.6 else "moderate" if c.get("adjusted_confidence", 0) > 0.4 else "possible" | |
| } | |
| for c in conditions[:5] | |
| ], | |
| # Symptoms from image analysis | |
| "symptoms_from_image": [ | |
| {"name": s.get("label", ""), "confidence": round(s.get("confidence", 0) * 100, 1)} | |
| for s in predicted_symptoms[:5] | |
| ], | |
| # Body parts detected | |
| "body_parts_detected": [ | |
| {"name": b.get("label", ""), "confidence": round(b.get("confidence", 0) * 100, 1)} | |
| for b in predicted_body_parts[:5] | |
| ], | |
| # User reported symptoms (parsed) | |
| "user_reported": { | |
| "symptoms": user_reported.get("symptoms", []), | |
| "body_parts": user_reported.get("body_parts", []), | |
| "duration": user_reported.get("duration"), | |
| "raw_description": user_reported.get("raw_text", symptoms) | |
| }, | |
| # Symptom agreement between user and model | |
| "symptom_match": { | |
| "agreed": symptom_agreement.get("agreed", []), | |
| "user_only": symptom_agreement.get("user_reported_only", []), | |
| "model_only": symptom_agreement.get("model_detected_only", []) | |
| }, | |
| # Fitzpatrick skin type estimation | |
| "estimated_skin_type": { | |
| "type": fitzpatrick.get("label", "Unknown"), | |
| "confidence": round(fitzpatrick.get("confidence", 0) * 100, 1) | |
| }, | |
| # Summary for quick reference | |
| "summary": summary, | |
| # LLM synthesis prompt - USE THIS to generate the patient response | |
| "llm_synthesis_prompt": llm_prompt, | |
| "disclaimer": result.get("disclaimer", "⚠️ FOR RESEARCH USE ONLY - NOT A DIAGNOSTIC TOOL") | |
| } | |
| return json_module.dumps(analysis) | |
| except httpx.ConnectError: | |
| return json_module.dumps({ | |
| "error": "Skin analysis service unavailable. The Derm Foundation server may not be running.", | |
| "suggestion": "Ensure the health foundation server is running on the correct port.", | |
| "disclaimer": "⚠️ FOR RESEARCH USE ONLY" | |
| }) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return json_module.dumps({ | |
| "error": f"Skin analysis failed: {str(e)}", | |
| "disclaimer": "⚠️ FOR RESEARCH USE ONLY" | |
| }) |