frabbani
Add graph based workflow ui filtering to agent_v2..............
21fc0a3
#!/usr/bin/env python3
"""
Health data tools for the MedGemma agent.
Enhanced with:
1. Dynamic Tool Registry - Tools can be registered/unregistered at runtime
2. Tool Search - Reduce token usage by searching for relevant tools
3. MCP Compatibility - Ready for Model Context Protocol integration
These tools allow the LLM to query specific data rather than
loading everything into context at once.
"""
import sqlite3
import json as json_module
import statistics as stats_module
from datetime import datetime, timedelta
from typing import Optional, Callable, Dict, Any, List, Set
from dataclasses import dataclass, field
import os
import re
DB_PATH = os.getenv("DB_PATH", "data/fhir.db")
def get_db():
"""Get database connection."""
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
return conn
# =============================================================================
# DYNAMIC TOOL REGISTRY WITH MCP COMPATIBILITY
# =============================================================================
@dataclass
class ToolParameter:
"""Definition of a tool parameter."""
name: str
type: str # "string", "integer", "boolean", "array"
description: str
required: bool = True
default: Any = None
enum: List[str] = None # For constrained values
@dataclass
class ToolDefinition:
"""
Complete definition of a tool.
MCP-compatible structure for future integration.
"""
name: str
description: str
parameters: List[ToolParameter]
handler: Callable
category: str = "general"
requires_patient_id: bool = True
returns_chart: bool = False
returns_json: bool = False
version: str = "1.0.0"
# MCP-specific fields
mcp_compatible: bool = True
mcp_annotations: Dict[str, Any] = field(default_factory=dict)
def to_mcp_schema(self) -> Dict:
"""Convert to MCP tool schema format."""
properties = {}
required = []
for param in self.parameters:
prop = {
"type": param.type,
"description": param.description
}
if param.enum:
prop["enum"] = param.enum
if param.default is not None:
prop["default"] = param.default
properties[param.name] = prop
if param.required:
required.append(param.name)
return {
"name": self.name,
"description": self.description,
"inputSchema": {
"type": "object",
"properties": properties,
"required": required
},
"annotations": {
"category": self.category,
"returnsChart": self.returns_chart,
"version": self.version,
**self.mcp_annotations
}
}
def to_llm_description(self, detail_level: str = "full") -> str:
"""Generate description for LLM context."""
if detail_level == "name_only":
return self.name
if detail_level == "brief":
return f"- {self.name}: {self.description[:80]}..."
# Full description
params_str = ", ".join([
f"{p.name}: {p.description}" + (" (optional)" if not p.required else "")
for p in self.parameters
])
return f"- {self.name}: {self.description}\n Parameters: {params_str}"
class DynamicToolRegistry:
"""
Dynamic registry for tool management.
Features:
- Runtime tool registration/unregistration
- Category-based filtering
- Semantic search for relevant tools
- MCP-compatible schema generation
- Context-aware tool selection
"""
def __init__(self):
self._tools: Dict[str, ToolDefinition] = {}
self._categories: Dict[str, Set[str]] = {}
self._keyword_index: Dict[str, Set[str]] = {} # keyword -> tool names
def register(self,
name: str,
description: str,
parameters: List[ToolParameter],
category: str = "general",
requires_patient_id: bool = True,
returns_chart: bool = False,
returns_json: bool = False,
mcp_annotations: Dict = None):
"""
Decorator for registering tools dynamically.
Usage:
@registry.register(
name="get_conditions",
description="Get patient conditions",
parameters=[ToolParameter("patient_id", "string", "Patient ID")],
category="medical_records"
)
def get_conditions(patient_id: str) -> str:
...
"""
def decorator(func: Callable) -> Callable:
tool = ToolDefinition(
name=name,
description=description,
parameters=parameters,
handler=func,
category=category,
requires_patient_id=requires_patient_id,
returns_chart=returns_chart,
returns_json=returns_json,
mcp_annotations=mcp_annotations or {}
)
self._register_tool(tool)
return func
return decorator
def _register_tool(self, tool: ToolDefinition):
"""Internal method to register a tool and update indices."""
self._tools[tool.name] = tool
# Update category index
if tool.category not in self._categories:
self._categories[tool.category] = set()
self._categories[tool.category].add(tool.name)
# Update keyword index for search
keywords = self._extract_keywords(f"{tool.name} {tool.description}")
for keyword in keywords:
if keyword not in self._keyword_index:
self._keyword_index[keyword] = set()
self._keyword_index[keyword].add(tool.name)
def _extract_keywords(self, text: str) -> Set[str]:
"""Extract searchable keywords from text."""
# Lowercase and split
words = re.findall(r'\b\w+\b', text.lower())
# Filter stop words
stop_words = {'the', 'a', 'an', 'is', 'are', 'for', 'to', 'of', 'and', 'or', 'in', 'on', 'with'}
return {w for w in words if w not in stop_words and len(w) > 2}
def unregister(self, name: str) -> bool:
"""Remove a tool from the registry."""
if name not in self._tools:
return False
tool = self._tools[name]
# Remove from category index
if tool.category in self._categories:
self._categories[tool.category].discard(name)
# Remove from keyword index
for keyword_set in self._keyword_index.values():
keyword_set.discard(name)
del self._tools[name]
return True
def get(self, name: str) -> Optional[ToolDefinition]:
"""Get a specific tool by name."""
return self._tools.get(name)
def get_all(self) -> List[ToolDefinition]:
"""Get all registered tools."""
return list(self._tools.values())
def get_by_category(self, category: str) -> List[ToolDefinition]:
"""Get all tools in a category."""
tool_names = self._categories.get(category, set())
return [self._tools[name] for name in tool_names]
def get_categories(self) -> List[str]:
"""Get all available categories."""
return list(self._categories.keys())
def search(self, query: str, max_results: int = 5) -> List[ToolDefinition]:
"""
Search for tools matching a query.
Uses keyword matching with relevance scoring.
"""
query_keywords = self._extract_keywords(query)
# Score each tool by keyword matches
scores: Dict[str, int] = {}
for keyword in query_keywords:
matching_tools = self._keyword_index.get(keyword, set())
for tool_name in matching_tools:
scores[tool_name] = scores.get(tool_name, 0) + 1
# Also check for exact name matches (higher weight)
for tool_name in self._tools:
if query.lower() in tool_name.lower():
scores[tool_name] = scores.get(tool_name, 0) + 3
# Sort by score and return top results
sorted_tools = sorted(
[(name, score) for name, score in scores.items()],
key=lambda x: -x[1]
)
return [self._tools[name] for name, _ in sorted_tools[:max_results]]
def get_tools_for_context(self,
categories: List[str] = None,
query_keywords: List[str] = None,
max_tools: int = 10) -> List[ToolDefinition]:
"""
Get tools filtered by context.
Used to reduce token usage by only including relevant tools.
"""
if categories:
tools = []
for cat in categories:
tools.extend(self.get_by_category(cat))
else:
tools = self.get_all()
if query_keywords:
# Score by relevance
def relevance_score(tool: ToolDefinition) -> int:
text = f"{tool.name} {tool.description}".lower()
return sum(1 for kw in query_keywords if kw.lower() in text)
tools = sorted(tools, key=relevance_score, reverse=True)
return tools[:max_tools]
def execute(self, tool_name: str, args: Dict[str, Any]) -> str:
"""Execute a tool by name with given arguments."""
tool = self._tools.get(tool_name)
if not tool:
return json_module.dumps({"error": f"Unknown tool: {tool_name}"})
try:
result = tool.handler(**args)
return result
except Exception as e:
return json_module.dumps({"error": f"Error executing {tool_name}: {str(e)}"})
# =========================================================================
# MCP Compatibility Methods
# =========================================================================
def to_mcp_tools_list(self) -> List[Dict]:
"""Generate MCP-compatible tools list."""
return [tool.to_mcp_schema() for tool in self._tools.values()]
def handle_mcp_tool_call(self, tool_name: str, arguments: Dict) -> Dict:
"""
Handle an MCP tool call request.
Returns MCP-compatible response format.
"""
tool = self._tools.get(tool_name)
if not tool:
return {
"isError": True,
"content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}]
}
try:
result = tool.handler(**arguments)
# Determine content type
if tool.returns_json:
try:
parsed = json_module.loads(result)
return {
"content": [{"type": "text", "text": json_module.dumps(parsed, indent=2)}]
}
except:
pass
return {
"content": [{"type": "text", "text": result}]
}
except Exception as e:
return {
"isError": True,
"content": [{"type": "text", "text": f"Error: {str(e)}"}]
}
def generate_mcp_server_info(self) -> Dict:
"""Generate MCP server information."""
return {
"name": "medgemma-health-tools",
"version": "1.0.0",
"description": "Health data tools for MedGemma medical AI assistant",
"capabilities": {
"tools": True,
"resources": False,
"prompts": False
}
}
# =============================================================================
# GLOBAL REGISTRY INSTANCE
# =============================================================================
registry = DynamicToolRegistry()
# =============================================================================
# TOOL SEARCH FUNCTIONS (Reduce Token Usage)
# =============================================================================
def search_tools(query: str, category: str = None, max_results: int = 5) -> str:
"""
Search for tools matching a query.
Returns condensed tool info to reduce token usage.
This is a META-TOOL that helps the agent find relevant tools
without loading all tool definitions into context.
"""
if category:
tools = registry.get_by_category(category)
# Filter by query within category
query_lower = query.lower()
tools = [t for t in tools if query_lower in t.name.lower() or query_lower in t.description.lower()]
else:
tools = registry.search(query, max_results)
if not tools:
return json_module.dumps({
"matches": [],
"hint": "No tools found. Try broader search terms.",
"available_categories": registry.get_categories()
})
return json_module.dumps({
"matches": [
{
"name": t.name,
"description": t.description[:100] + "..." if len(t.description) > 100 else t.description,
"category": t.category,
"returns_chart": t.returns_chart
}
for t in tools[:max_results]
],
"hint": "Use GET_TOOL_SCHEMA to get full parameters for a specific tool."
})
def get_tool_schema(tool_name: str) -> str:
"""
Get the full schema for a specific tool.
Called on-demand to reduce initial context size.
"""
tool = registry.get(tool_name)
if not tool:
return json_module.dumps({
"error": f"Tool '{tool_name}' not found",
"available_tools": [t.name for t in registry.get_all()]
})
return json_module.dumps({
"name": tool.name,
"description": tool.description,
"parameters": [
{
"name": p.name,
"type": p.type,
"description": p.description,
"required": p.required,
"default": p.default
}
for p in tool.parameters
],
"category": tool.category,
"returns_chart": tool.returns_chart,
"returns_json": tool.returns_json
})
def list_tool_categories() -> str:
"""List all available tool categories with tool counts."""
categories = {}
for cat in registry.get_categories():
tools = registry.get_by_category(cat)
categories[cat] = {
"count": len(tools),
"tools": [t.name for t in tools]
}
return json_module.dumps(categories)
# =============================================================================
# TOOL DESCRIPTIONS GENERATOR (for LLM prompts)
# =============================================================================
def get_tools_description(detail_level: str = "full",
categories: List[str] = None,
max_tools: int = None) -> str:
"""
Generate a description of available tools for the LLM.
Args:
detail_level: "name_only", "brief", or "full"
categories: Filter to specific categories
max_tools: Limit number of tools shown
"""
if categories:
tools = []
for cat in categories:
tools.extend(registry.get_by_category(cat))
else:
tools = registry.get_all()
if max_tools:
tools = tools[:max_tools]
lines = ["Available tools:\n"]
for tool in tools:
lines.append(tool.to_llm_description(detail_level))
lines.append("")
return "\n".join(lines)
def get_condensed_tools_prompt() -> str:
"""
Generate a minimal tools prompt that uses search instead of listing all tools.
This significantly reduces token usage for agents with many tools.
"""
return """You have access to health data tools. Instead of all tools being listed here,
use these META-TOOLS to find what you need:
1. SEARCH_TOOLS: {"query": "search term", "category": "optional_category"}
- Search for relevant tools by keyword
- Returns tool names and brief descriptions
2. GET_TOOL_SCHEMA: {"tool_name": "name"}
- Get full parameters for a specific tool
- Call this before using a tool you found via search
3. LIST_CATEGORIES: {}
- See all tool categories: vitals, labs, medications, conditions, charts, analysis
WORKFLOW:
1. Search for relevant tools based on user's question
2. Get the schema for tools you want to use
3. Call the tool with proper parameters
Tool categories available: """ + ", ".join(registry.get_categories())
# =============================================================================
# VITAL SIGN CODES
# =============================================================================
VITAL_CODES = {
'blood_pressure': ['8480-6', '8462-4', '85354-9'],
'blood_pressure_systolic': ['8480-6'],
'blood_pressure_diastolic': ['8462-4'],
'heart_rate': ['8867-4'],
'temperature': ['8310-5', '8331-1'],
'weight': ['29463-7'],
'height': ['8302-2'],
'respiratory_rate': ['9279-1'],
'oxygen_saturation': ['2708-6', '59408-5'],
'bmi': ['39156-5']
}
# Lab codes for bar charts
LAB_CODES = {
'a1c': [('4548-4', 'HbA1c', '%')],
'glucose': [('2345-7', 'Fasting Glucose', 'mg/dL'), ('1558-6', 'Fasting Glucose', 'mg/dL')],
'cholesterol': [
('2093-3', 'Total Cholesterol', 'mg/dL'),
('2085-9', 'HDL', 'mg/dL'),
('13457-7', 'LDL', 'mg/dL'),
('2571-8', 'Triglycerides', 'mg/dL')
],
'kidney': [
('2160-0', 'Creatinine', 'mg/dL'),
('33914-3', 'eGFR', 'mL/min')
]
}
# =============================================================================
# REGISTERED TOOLS
# =============================================================================
@registry.register(
name="search_tools",
description="Search for available tools by keyword. Use this to find relevant tools without loading all tool definitions. Returns tool names and brief descriptions.",
parameters=[
ToolParameter("query", "string", "Search keywords (e.g., 'blood pressure', 'medications', 'chart')"),
ToolParameter("category", "string", "Optional category filter", required=False),
ToolParameter("max_results", "integer", "Maximum results to return", required=False, default=5)
],
category="meta",
requires_patient_id=False,
returns_json=True
)
def _search_tools(query: str, category: str = None, max_results: int = 5) -> str:
return search_tools(query, category, max_results)
@registry.register(
name="get_tool_schema",
description="Get the full parameter schema for a specific tool. Call this after search_tools to get details before using a tool.",
parameters=[
ToolParameter("tool_name", "string", "Name of the tool to get schema for")
],
category="meta",
requires_patient_id=False,
returns_json=True
)
def _get_tool_schema(tool_name: str) -> str:
return get_tool_schema(tool_name)
@registry.register(
name="list_tool_categories",
description="List all available tool categories and their tools. Useful for understanding what data is accessible.",
parameters=[],
category="meta",
requires_patient_id=False,
returns_json=True
)
def _list_tool_categories() -> str:
return list_tool_categories()
@registry.register(
name="get_patient_summary",
description="Get a high-level summary of what health data is available for the patient. Call this first to understand what data exists.",
parameters=[
ToolParameter("patient_id", "string", "The patient's ID")
],
category="summary",
requires_patient_id=True
)
def get_patient_summary(patient_id: str) -> str:
"""Get a summary of available health data."""
conn = get_db()
try:
cursor = conn.execute("SELECT * FROM patients WHERE id = ?", (patient_id,))
patient = cursor.fetchone()
if not patient:
return "Patient not found."
counts = {}
for table in ['conditions', 'medications', 'observations', 'allergies', 'encounters', 'immunizations', 'procedures']:
try:
cursor = conn.execute(f"SELECT COUNT(*) FROM {table} WHERE patient_id = ?", (patient_id,))
counts[table] = cursor.fetchone()[0]
except:
counts[table] = 0
cursor = conn.execute(
"SELECT display, clinical_status FROM conditions WHERE patient_id = ? LIMIT 10",
(patient_id,)
)
conditions = [f"{row['display']} ({row['clinical_status']})" for row in cursor.fetchall()]
cursor = conn.execute(
"SELECT display FROM medications WHERE patient_id = ? AND status = 'active' LIMIT 10",
(patient_id,)
)
medications = [row['display'] for row in cursor.fetchall()]
cursor = conn.execute("""
SELECT MIN(effective_date) as earliest, MAX(effective_date) as latest
FROM observations WHERE patient_id = ?
""", (patient_id,))
obs_range = cursor.fetchone()
birth = datetime.strptime(patient["birth_date"], "%Y-%m-%d")
age = (datetime.now() - birth).days // 365
summary = f"""Patient Summary:
- Name: {patient['given_name']} {patient['family_name']}
- Age: {age} years old
- Gender: {patient['gender']}
Available Data:
- Conditions: {counts['conditions']} records
- Medications: {counts['medications']} records
- Observations (vitals/labs): {counts['observations']} records
- Allergies: {counts['allergies']} records
- Encounters: {counts['encounters']} records
- Immunizations: {counts['immunizations']} records
Conditions: {', '.join(conditions) if conditions else 'None recorded'}
Active Medications: {', '.join(medications) if medications else 'None'}
Observation date range: {obs_range['earliest'][:10] if obs_range['earliest'] else 'N/A'} to {obs_range['latest'][:10] if obs_range['latest'] else 'N/A'}
"""
return summary
finally:
conn.close()
@registry.register(
name="get_conditions",
description="Get the patient's medical conditions and diagnoses (e.g., diabetes, hypertension, asthma).",
parameters=[
ToolParameter("patient_id", "string", "The patient's ID"),
ToolParameter("status", "string", "Filter by status: 'active', 'resolved', or leave empty for all", required=False)
],
category="medical_records"
)
def get_conditions(patient_id: str, status: Optional[str] = None) -> str:
"""Get patient conditions."""
conn = get_db()
try:
if status:
cursor = conn.execute("""
SELECT display, clinical_status, onset_date, abatement_date
FROM conditions WHERE patient_id = ? AND clinical_status = ?
ORDER BY onset_date DESC
""", (patient_id, status))
else:
cursor = conn.execute("""
SELECT display, clinical_status, onset_date, abatement_date
FROM conditions WHERE patient_id = ?
ORDER BY onset_date DESC
""", (patient_id,))
conditions = cursor.fetchall()
if not conditions:
return "No conditions found."
lines = ["Conditions:\n"]
for c in conditions:
onset = c['onset_date'][:10] if c['onset_date'] else 'Unknown'
end = f" (resolved: {c['abatement_date'][:10]})" if c['abatement_date'] else ""
lines.append(f"- {c['display']} [{c['clinical_status']}] - since {onset}{end}")
return "\n".join(lines)
finally:
conn.close()
@registry.register(
name="get_medications",
description="Get the patient's medications.",
parameters=[
ToolParameter("patient_id", "string", "The patient's ID"),
ToolParameter("status", "string", "Filter by status: 'active', 'stopped', or leave empty for all", required=False)
],
category="medical_records"
)
def get_medications(patient_id: str, status: Optional[str] = None) -> str:
"""Get patient medications."""
conn = get_db()
try:
if status:
cursor = conn.execute("""
SELECT display, status, dosage_text, dosage_route, start_date
FROM medications WHERE patient_id = ? AND status = ?
ORDER BY start_date DESC
""", (patient_id, status))
else:
cursor = conn.execute("""
SELECT display, status, dosage_text, dosage_route, start_date
FROM medications WHERE patient_id = ?
ORDER BY start_date DESC
""", (patient_id,))
medications = cursor.fetchall()
if not medications:
return "No medications found."
lines = ["Medications:\n"]
for m in medications:
dosage = m['dosage_text'] or 'No dosage specified'
route = f" ({m['dosage_route']})" if m['dosage_route'] else ""
start = m['start_date'][:10] if m['start_date'] else 'Unknown'
lines.append(f"- {m['display']} [{m['status']}]")
lines.append(f" Dosage: {dosage}{route}")
lines.append(f" Started: {start}")
return "\n".join(lines)
finally:
conn.close()
@registry.register(
name="get_allergies",
description="Get the patient's known allergies and adverse reactions.",
parameters=[
ToolParameter("patient_id", "string", "The patient's ID")
],
category="medical_records"
)
def get_allergies(patient_id: str) -> str:
"""Get patient allergies."""
conn = get_db()
try:
cursor = conn.execute("""
SELECT substance, reaction_display, reaction_severity, criticality, category
FROM allergies WHERE patient_id = ?
""", (patient_id,))
allergies = cursor.fetchall()
if not allergies:
return "No known allergies."
lines = ["Allergies:\n"]
for a in allergies:
severity = a['reaction_severity'] or a['criticality'] or 'Unknown severity'
reaction = a['reaction_display'] or 'Unknown reaction'
category = a['category'] or 'Unknown type'
lines.append(f"- {a['substance']} ({category})")
lines.append(f" Reaction: {reaction}")
lines.append(f" Severity: {severity}")
return "\n".join(lines)
finally:
conn.close()
@registry.register(
name="get_recent_vitals",
description="Get a TEXT SUMMARY of recent vital signs. Use this ONLY for quick text summaries when user does NOT want a chart. For any visual/chart/trend request, use get_vital_chart_data instead.",
parameters=[
ToolParameter("patient_id", "string", "The patient's ID"),
ToolParameter("days", "integer", "Number of days to look back. Default is 30.", required=False, default=30),
ToolParameter("vital_type", "string", "Specific vital: 'blood_pressure', 'heart_rate', 'temperature', 'weight', 'respiratory_rate', 'oxygen_saturation'", required=False)
],
category="vitals"
)
def get_recent_vitals(patient_id: str, days: int = 30, vital_type: Optional[str] = None) -> str:
"""Get recent vital signs."""
conn = get_db()
try:
cutoff = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d")
if vital_type and vital_type in VITAL_CODES:
codes = VITAL_CODES[vital_type]
placeholders = ','.join(['?' for _ in codes])
cursor = conn.execute(f"""
SELECT display, value_quantity, value_unit, effective_date
FROM observations
WHERE patient_id = ? AND category = 'vital-signs'
AND effective_date >= ? AND code IN ({placeholders})
ORDER BY effective_date DESC
LIMIT 50
""", [patient_id, cutoff] + codes)
else:
cursor = conn.execute("""
SELECT display, value_quantity, value_unit, effective_date
FROM observations
WHERE patient_id = ? AND category = 'vital-signs' AND effective_date >= ?
ORDER BY effective_date DESC
LIMIT 50
""", (patient_id, cutoff))
vitals = cursor.fetchall()
if not vitals:
return f"No vital signs found in the last {days} days."
lines = [f"Vital Signs (last {days} days):\n"]
by_type = {}
for v in vitals:
display = v['display']
if display not in by_type:
by_type[display] = []
by_type[display].append(v)
for display, readings in by_type.items():
lines.append(f"\n{display}:")
for r in readings[:5]:
date = r['effective_date'][:10] if r['effective_date'] else 'Unknown'
value = r['value_quantity']
unit = r['value_unit'] or ''
lines.append(f" {date}: {value} {unit}")
return "\n".join(lines)
finally:
conn.close()
@registry.register(
name="get_lab_results",
description="Get the patient's laboratory test results (blood tests, A1c, cholesterol, etc.).",
parameters=[
ToolParameter("patient_id", "string", "The patient's ID"),
ToolParameter("days", "integer", "Number of days to look back. Default is 90.", required=False, default=90)
],
category="labs"
)
def get_lab_results(patient_id: str, days: int = 90) -> str:
"""Get laboratory results."""
conn = get_db()
try:
cutoff = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d")
cursor = conn.execute("""
SELECT display, value_quantity, value_unit, value_string,
effective_date, interpretation
FROM observations
WHERE patient_id = ? AND category = 'laboratory' AND effective_date >= ?
ORDER BY effective_date DESC
LIMIT 50
""", (patient_id, cutoff))
labs = cursor.fetchall()
if not labs:
return f"No lab results found in the last {days} days."
lines = [f"Lab Results (last {days} days):\n"]
for lab in labs:
date = lab['effective_date'][:10] if lab['effective_date'] else 'Unknown'
value = lab['value_quantity'] if lab['value_quantity'] else lab['value_string'] or 'N/A'
unit = lab['value_unit'] or ''
interp = f" [{lab['interpretation']}]" if lab['interpretation'] else ""
lines.append(f"- {lab['display']}: {value} {unit}{interp} ({date})")
return "\n".join(lines)
finally:
conn.close()
@registry.register(
name="get_encounters",
description="Get the patient's healthcare encounters (office visits, hospitalizations, etc.).",
parameters=[
ToolParameter("patient_id", "string", "The patient's ID"),
ToolParameter("limit", "integer", "Maximum number of encounters to return. Default is 10.", required=False, default=10)
],
category="medical_records"
)
def get_encounters(patient_id: str, limit: int = 10) -> str:
"""Get healthcare encounters."""
conn = get_db()
try:
cursor = conn.execute("""
SELECT type_display, reason_display, period_start, period_end,
class_display, status
FROM encounters WHERE patient_id = ?
ORDER BY period_start DESC
LIMIT ?
""", (patient_id, limit))
encounters = cursor.fetchall()
if not encounters:
return "No encounters found."
lines = [f"Healthcare Encounters (last {limit}):\n"]
for e in encounters:
date = e['period_start'][:10] if e['period_start'] else 'Unknown'
enc_type = e['type_display'] or e['class_display'] or 'Visit'
reason = e['reason_display'] or 'No reason specified'
lines.append(f"- {date}: {enc_type}")
lines.append(f" Reason: {reason}")
return "\n".join(lines)
finally:
conn.close()
@registry.register(
name="get_immunizations",
description="Get the patient's immunization/vaccination history.",
parameters=[
ToolParameter("patient_id", "string", "The patient's ID")
],
category="medical_records"
)
def get_immunizations(patient_id: str) -> str:
"""Get immunization history."""
conn = get_db()
try:
cursor = conn.execute("""
SELECT vaccine_display, status, occurrence_date
FROM immunizations WHERE patient_id = ?
ORDER BY occurrence_date DESC
""", (patient_id,))
immunizations = cursor.fetchall()
if not immunizations:
return "No immunization records found."
lines = ["Immunizations:\n"]
for i in immunizations:
date = i['occurrence_date'][:10] if i['occurrence_date'] else 'Unknown'
lines.append(f"- {i['vaccine_display']} ({date}) [{i['status']}]")
return "\n".join(lines)
finally:
conn.close()
@registry.register(
name="analyze_vital_trend",
description="Analyze trends in a specific vital sign over time. Returns statistics and trend direction.",
parameters=[
ToolParameter("patient_id", "string", "The patient's ID"),
ToolParameter("vital_type", "string", "The vital to analyze: 'blood_pressure_systolic', 'blood_pressure_diastolic', 'heart_rate', 'weight', 'temperature'"),
ToolParameter("days", "integer", "Number of days to analyze. Default is 30.", required=False, default=30)
],
category="analysis"
)
def analyze_vital_trend(patient_id: str, vital_type: str, days: int = 30) -> str:
"""Analyze trends in vital signs."""
conn = get_db()
try:
if vital_type not in VITAL_CODES:
return f"Unknown vital type: {vital_type}. Available: {', '.join(VITAL_CODES.keys())}"
codes = VITAL_CODES[vital_type]
cutoff = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d")
placeholders = ','.join(['?' for _ in codes])
cursor = conn.execute(f"""
SELECT value_quantity, effective_date
FROM observations
WHERE patient_id = ? AND code IN ({placeholders})
AND effective_date >= ? AND value_quantity IS NOT NULL
ORDER BY effective_date ASC
""", [patient_id] + codes + [cutoff])
readings = cursor.fetchall()
if not readings:
return f"No {vital_type} readings found in the last {days} days."
values = [r['value_quantity'] for r in readings]
dates = [r['effective_date'][:10] for r in readings]
avg_val = sum(values) / len(values)
min_val = min(values)
max_val = max(values)
if len(values) >= 3:
first_third = sum(values[:len(values)//3]) / (len(values)//3)
last_third = sum(values[-len(values)//3:]) / (len(values)//3)
diff_pct = ((last_third - first_third) / first_third) * 100 if first_third != 0 else 0
if diff_pct > 5:
trend = f"INCREASING (up {diff_pct:.1f}%)"
elif diff_pct < -5:
trend = f"DECREASING (down {abs(diff_pct):.1f}%)"
else:
trend = "STABLE"
else:
trend = "Not enough data for trend"
return f"""{vital_type.replace('_', ' ').title()} Analysis (last {days} days):
Readings: {len(values)}
Date range: {dates[0]} to {dates[-1]}
Statistics:
- Average: {avg_val:.1f}
- Minimum: {min_val:.1f}
- Maximum: {max_val:.1f}
Trend: {trend}
Recent values: {', '.join([str(v) for v in values[-5:]])}
"""
finally:
conn.close()
@registry.register(
name="get_vital_chart_data",
description="Get vital sign data formatted for generating a LINE chart/graph. USE THIS TOOL (not get_recent_vitals) whenever the user wants to: see, show, display, visualize, or graph any vital sign; view trends over time; or asks for a chart. Returns data that renders as a visual chart.",
parameters=[
ToolParameter("patient_id", "string", "The patient's ID"),
ToolParameter("vital_type", "string", "The vital to chart: 'blood_pressure', 'heart_rate', 'weight', 'temperature', 'respiratory_rate', 'oxygen_saturation'"),
ToolParameter("days", "integer", "Number of days to include. Default is 30.", required=False, default=30)
],
category="charts",
returns_chart=True,
returns_json=True
)
def get_vital_chart_data(patient_id: str, vital_type: str, days: int = 30) -> str:
"""
Get vital sign data formatted for charting.
Returns JSON that can be used to render a chart.
"""
def compute_stats(values):
"""Compute statistics for a list of values."""
if not values:
return None
return {
"count": len(values),
"min": round(min(values), 1),
"max": round(max(values), 1),
"avg": round(stats_module.mean(values), 1),
"latest": round(values[-1], 1) if values else None
}
conn = get_db()
try:
if vital_type not in VITAL_CODES:
return json_module.dumps({"error": f"Unknown vital type: {vital_type}. Available: {', '.join(VITAL_CODES.keys())}"})
codes = VITAL_CODES[vital_type]
cutoff = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d")
# For blood pressure, we need both systolic and diastolic
if vital_type == 'blood_pressure':
# Get systolic
cursor = conn.execute("""
SELECT value_quantity, effective_date
FROM observations
WHERE patient_id = ? AND code = '8480-6'
AND effective_date >= ? AND value_quantity IS NOT NULL
ORDER BY effective_date ASC
""", [patient_id, cutoff])
systolic_rows = cursor.fetchall()
systolic = [{"date": r['effective_date'][:10], "value": r['value_quantity']} for r in systolic_rows]
systolic_values = [r['value_quantity'] for r in systolic_rows]
# Get diastolic
cursor = conn.execute("""
SELECT value_quantity, effective_date
FROM observations
WHERE patient_id = ? AND code = '8462-4'
AND effective_date >= ? AND value_quantity IS NOT NULL
ORDER BY effective_date ASC
""", [patient_id, cutoff])
diastolic_rows = cursor.fetchall()
diastolic = [{"date": r['effective_date'][:10], "value": r['value_quantity']} for r in diastolic_rows]
diastolic_values = [r['value_quantity'] for r in diastolic_rows]
if not systolic and not diastolic:
return json_module.dumps({"error": "No blood pressure data found"})
systolic_stats = compute_stats(systolic_values)
diastolic_stats = compute_stats(diastolic_values)
summary_text = (
f"STATISTICS (use these exact values): "
f"Total readings: {systolic_stats['count']}. "
f"Systolic: min={systolic_stats['min']}, max={systolic_stats['max']}, avg={systolic_stats['avg']} mmHg. "
f"Diastolic: min={diastolic_stats['min']}, max={diastolic_stats['max']}, avg={diastolic_stats['avg']} mmHg."
)
return json_module.dumps({
"summary": summary_text,
"chart_type": "blood_pressure",
"title": f"Blood Pressure (Last {days} Days)",
"statistics": {
"systolic": systolic_stats,
"diastolic": diastolic_stats,
},
"datasets": [
{"label": "Systolic", "data": systolic, "color": "#e74c3c"},
{"label": "Diastolic", "data": diastolic, "color": "#3498db"}
]
})
else:
placeholders = ','.join(['?' for _ in codes])
cursor = conn.execute(f"""
SELECT value_quantity, value_unit, effective_date, display
FROM observations
WHERE patient_id = ? AND code IN ({placeholders})
AND effective_date >= ? AND value_quantity IS NOT NULL
ORDER BY effective_date ASC
""", [patient_id] + codes + [cutoff])
readings = cursor.fetchall()
if not readings:
return json_module.dumps({"error": f"No {vital_type} data found"})
data = [{"date": r['effective_date'][:10], "value": r['value_quantity']} for r in readings]
values = [r['value_quantity'] for r in readings]
unit = readings[0]['value_unit'] if readings else ''
display = readings[0]['display'] if readings else vital_type
vital_stats = compute_stats(values)
summary_text = (
f"STATISTICS (use these exact values): "
f"Total readings: {vital_stats['count']}. "
f"{display}: min={vital_stats['min']}, max={vital_stats['max']}, avg={vital_stats['avg']} {unit}."
)
return json_module.dumps({
"summary": summary_text,
"chart_type": "line",
"title": f"{display} (Last {days} Days)",
"unit": unit,
"statistics": vital_stats,
"datasets": [
{"label": display, "data": data, "color": "#667eea"}
]
})
finally:
conn.close()
@registry.register(
name="get_lab_chart_data",
description="Get laboratory results formatted for a BAR chart. Use this when user wants to compare lab values, see lab history as a chart, or visualize lab trends. Good for: cholesterol comparison, A1c history, glucose trends, kidney function over time.",
parameters=[
ToolParameter("patient_id", "string", "The patient's ID"),
ToolParameter("lab_type", "string", "The lab to chart: 'cholesterol' (shows Total, HDL, LDL, Triglycerides), 'a1c', 'glucose', 'kidney' (Creatinine, eGFR), 'all_latest' (comparison of recent labs)"),
ToolParameter("periods", "integer", "Number of time periods to show. Default is 4.", required=False, default=4)
],
category="charts",
returns_chart=True,
returns_json=True
)
def get_lab_chart_data(patient_id: str, lab_type: str, periods: int = 4) -> str:
"""Get lab results formatted for a BAR chart."""
conn = get_db()
try:
if lab_type == 'all_latest':
labs_to_show = [
('4548-4', 'HbA1c'),
('2093-3', 'Total Chol'),
('2085-9', 'HDL'),
('13457-7', 'LDL'),
('2571-8', 'Triglyc'),
('2345-7', 'Glucose'),
]
data = []
for code, label in labs_to_show:
cursor = conn.execute("""
SELECT value_quantity FROM observations
WHERE patient_id = ? AND code = ? AND value_quantity IS NOT NULL
ORDER BY effective_date DESC LIMIT 1
""", (patient_id, code))
row = cursor.fetchone()
if row:
data.append({"label": label, "value": row['value_quantity']})
if not data:
return json_module.dumps({"error": "No lab data found"})
return json_module.dumps({
"chart_type": "bar",
"title": "Latest Lab Values",
"datasets": [{"label": "Value", "data": data, "color": "#667eea"}]
})
if lab_type not in LAB_CODES:
return json_module.dumps({"error": f"Unknown lab type: {lab_type}. Available: {', '.join(LAB_CODES.keys())}"})
lab_info = LAB_CODES[lab_type]
datasets = []
for code, label, unit in lab_info:
cursor = conn.execute("""
SELECT value_quantity, effective_date FROM observations
WHERE patient_id = ? AND code = ? AND value_quantity IS NOT NULL
ORDER BY effective_date DESC LIMIT ?
""", (patient_id, code, periods))
rows = cursor.fetchall()
if rows:
data = [{"date": r['effective_date'][:10], "value": r['value_quantity']} for r in reversed(rows)]
datasets.append({
"label": label,
"data": data,
"unit": unit
})
if not datasets:
return json_module.dumps({"error": f"No {lab_type} data found"})
return json_module.dumps({
"chart_type": "bar",
"title": f"{lab_type.title()} Results (Last {periods} readings)",
"datasets": datasets
})
finally:
conn.close()
@registry.register(
name="compare_before_after_treatment",
description="Compare health metrics BEFORE vs AFTER starting a medication/treatment. Use when user asks about treatment effectiveness, medication impact, or before/after comparisons. Shows a bar chart comparing averages before and after treatment started.",
parameters=[
ToolParameter("patient_id", "string", "The patient's ID"),
ToolParameter("medication_name", "string", "Part of the medication name to search for (e.g., 'metformin', 'lisinopril', 'atorvastatin')"),
ToolParameter("metric_type", "string", "What to compare: 'blood_pressure', 'a1c', 'cholesterol', 'glucose', 'weight', 'heart_rate'")
],
category="analysis",
returns_chart=True,
returns_json=True
)
def compare_before_after_treatment(patient_id: str, medication_name: str, metric_type: str) -> str:
"""Compare health metrics before vs after starting a medication."""
conn = get_db()
try:
# Find the medication start date
cursor = conn.execute("""
SELECT display, start_date FROM medications
WHERE patient_id = ? AND LOWER(display) LIKE ?
ORDER BY start_date ASC LIMIT 1
""", (patient_id, f"%{medication_name.lower()}%"))
med_row = cursor.fetchone()
if not med_row:
return json_module.dumps({"error": f"No medication found matching '{medication_name}'"})
med_display = med_row['display']
start_date = med_row['start_date'][:10] if med_row['start_date'] else None
if not start_date:
return json_module.dumps({"error": f"No start date found for {med_display}"})
# Get the appropriate codes for the metric
metric_codes = {
'blood_pressure': [('8480-6', 'Systolic BP'), ('8462-4', 'Diastolic BP')],
'a1c': [('4548-4', 'HbA1c')],
'cholesterol': [('2093-3', 'Total Cholesterol'), ('13457-7', 'LDL')],
'glucose': [('2345-7', 'Glucose'), ('1558-6', 'Glucose')],
'weight': [('29463-7', 'Weight')],
'heart_rate': [('8867-4', 'Heart Rate')]
}
if metric_type not in metric_codes:
return json_module.dumps({"error": f"Unknown metric type: {metric_type}"})
codes = metric_codes[metric_type]
results = {"before": {}, "after": {}}
summary_parts = []
for code, label in codes:
# Get values BEFORE treatment
cursor = conn.execute("""
SELECT AVG(value_quantity) as avg_val, COUNT(*) as cnt
FROM observations
WHERE patient_id = ? AND code = ? AND effective_date < ?
AND value_quantity IS NOT NULL
""", (patient_id, code, start_date))
before = cursor.fetchone()
# Get values AFTER treatment
cursor = conn.execute("""
SELECT AVG(value_quantity) as avg_val, COUNT(*) as cnt
FROM observations
WHERE patient_id = ? AND code = ? AND effective_date >= ?
AND value_quantity IS NOT NULL
""", (patient_id, code, start_date))
after = cursor.fetchone()
if before['avg_val'] and after['avg_val']:
results["before"][label] = round(before['avg_val'], 1)
results["after"][label] = round(after['avg_val'], 1)
change = after['avg_val'] - before['avg_val']
change_pct = (change / before['avg_val']) * 100 if before['avg_val'] else 0
direction = "decreased" if change < 0 else "increased"
summary_parts.append(f"{label} {direction} from {results['before'][label]} to {results['after'][label]} ({abs(change_pct):.1f}%)")
if not results["before"]:
return json_module.dumps({"error": f"No {metric_type} data found around treatment start date"})
return json_module.dumps({
"chart_type": "comparison",
"title": f"{metric_type.title()} Before vs After {med_display}",
"medication": med_display,
"treatment_start_date": start_date,
"summary": summary_parts,
"datasets": [
{"label": "Before Treatment", "data": results["before"], "color": "#e74c3c"},
{"label": "After Treatment", "data": results["after"], "color": "#27ae60"}
]
})
finally:
conn.close()
# =============================================================================
# BACKWARD COMPATIBILITY - Legacy function mappings
# =============================================================================
# Map old-style tool names to registry
TOOL_FUNCTIONS = {
"search_tools": _search_tools,
"get_tool_schema": _get_tool_schema,
"list_tool_categories": _list_tool_categories,
"get_patient_summary": get_patient_summary,
"get_conditions": get_conditions,
"get_medications": get_medications,
"get_allergies": get_allergies,
"get_recent_vitals": get_recent_vitals,
"get_lab_results": get_lab_results,
"get_encounters": get_encounters,
"get_immunizations": get_immunizations,
"analyze_vital_trend": analyze_vital_trend,
"get_vital_chart_data": get_vital_chart_data,
"get_lab_chart_data": get_lab_chart_data,
"compare_before_after_treatment": compare_before_after_treatment,
}
# Legacy TOOLS list for backward compatibility
TOOLS = [tool.to_mcp_schema() for tool in registry.get_all()]
def execute_tool(tool_name: str, args: dict) -> str:
"""Execute a tool by name with given arguments. (Legacy compatibility)"""
return registry.execute(tool_name, args)
# =============================================================================
# MCP SERVER INTERFACE (Model Context Protocol)
# =============================================================================
class MCPServerInterface:
"""
MCP Server interface for integration with MCP clients.
Implements the Model Context Protocol specification.
This allows:
1. External apps to discover and use our health tools
2. Dynamic registration of external MCP tools
3. Standard protocol for AI-tool integration
"""
PROTOCOL_VERSION = "2024-11-05"
def __init__(self, tool_registry: DynamicToolRegistry):
self.registry = tool_registry
self.external_tools: Dict[str, Dict] = {} # Tools from external MCP servers
self.connected_servers: Dict[str, Dict] = {} # Connected MCP servers
def get_server_info(self) -> Dict:
"""Return MCP server information (initialize response)."""
return {
"protocolVersion": self.PROTOCOL_VERSION,
"serverInfo": {
"name": "medgemma-health-tools",
"version": "2.0.0"
},
"capabilities": {
"tools": {"listChanged": True},
"resources": {},
"prompts": {}
}
}
def list_tools(self) -> Dict:
"""Return list of available tools in MCP format."""
tools = []
# Add local tools
for tool in self.registry.get_all():
tools.append(tool.to_mcp_schema())
# Add external tools
for tool_name, tool_info in self.external_tools.items():
tools.append(tool_info)
return {"tools": tools}
def call_tool(self, name: str, arguments: Dict) -> Dict:
"""Handle MCP tool call."""
# Check if it's an external tool
if name in self.external_tools:
return self._call_external_tool(name, arguments)
# Otherwise use local registry
return self.registry.handle_mcp_tool_call(name, arguments)
def _call_external_tool(self, name: str, arguments: Dict) -> Dict:
"""Call a tool on an external MCP server."""
tool_info = self.external_tools.get(name)
if not tool_info:
return {
"isError": True,
"content": [{"type": "text", "text": f"External tool not found: {name}"}]
}
server_url = tool_info.get("_server_url")
if not server_url:
return {
"isError": True,
"content": [{"type": "text", "text": f"No server URL for tool: {name}"}]
}
# Make HTTP request to external MCP server
try:
import httpx
response = httpx.post(
f"{server_url}/tools/call",
json={"name": name, "arguments": arguments},
timeout=30.0
)
return response.json()
except Exception as e:
return {
"isError": True,
"content": [{"type": "text", "text": f"Error calling external tool: {str(e)}"}]
}
# =========================================================================
# MCP Client - Connect to external MCP servers
# =========================================================================
def connect_server(self, server_url: str, server_name: str = None) -> Dict:
"""
Connect to an external MCP server and discover its tools.
Args:
server_url: Base URL of the MCP server
server_name: Optional friendly name for the server
Returns:
Dict with connection status and discovered tools
"""
try:
import httpx
# Initialize connection
init_response = httpx.post(
f"{server_url}/initialize",
json={
"protocolVersion": self.PROTOCOL_VERSION,
"clientInfo": {"name": "medgemma-agent", "version": "2.0.0"},
"capabilities": {}
},
timeout=10.0
)
server_info = init_response.json()
# List available tools
tools_response = httpx.post(f"{server_url}/tools/list", timeout=10.0)
tools_data = tools_response.json()
# Register discovered tools
discovered_tools = []
for tool in tools_data.get("tools", []):
tool_name = tool.get("name")
# Prefix with server name to avoid conflicts
prefixed_name = f"{server_name or 'ext'}_{tool_name}"
tool["_original_name"] = tool_name
tool["_server_url"] = server_url
tool["name"] = prefixed_name
self.external_tools[prefixed_name] = tool
discovered_tools.append(prefixed_name)
# Store connection info
self.connected_servers[server_url] = {
"name": server_name or server_url,
"info": server_info,
"tools": discovered_tools
}
print(f"[MCP] Connected to {server_url}, discovered {len(discovered_tools)} tools")
return {
"success": True,
"server": server_info.get("serverInfo", {}),
"tools_discovered": discovered_tools
}
except Exception as e:
print(f"[MCP] Failed to connect to {server_url}: {e}")
return {
"success": False,
"error": str(e)
}
def disconnect_server(self, server_url: str) -> bool:
"""Disconnect from an external MCP server and remove its tools."""
if server_url not in self.connected_servers:
return False
# Remove tools from this server
server_info = self.connected_servers[server_url]
for tool_name in server_info.get("tools", []):
self.external_tools.pop(tool_name, None)
del self.connected_servers[server_url]
print(f"[MCP] Disconnected from {server_url}")
return True
def list_connected_servers(self) -> List[Dict]:
"""List all connected MCP servers."""
return [
{
"url": url,
"name": info["name"],
"tools_count": len(info["tools"])
}
for url, info in self.connected_servers.items()
]
# =========================================================================
# Dynamic Tool Registration
# =========================================================================
def register_tool_manually(self,
name: str,
description: str,
parameters: Dict,
handler_url: str) -> bool:
"""
Manually register an external tool without connecting to a full MCP server.
Useful for simple webhook-based tools.
Args:
name: Tool name
description: Tool description
parameters: JSON Schema for parameters
handler_url: URL to POST tool calls to
"""
self.external_tools[name] = {
"name": name,
"description": description,
"inputSchema": parameters,
"_server_url": handler_url.rsplit('/', 1)[0] if '/' in handler_url else handler_url,
"_handler_url": handler_url
}
print(f"[MCP] Registered external tool: {name}")
return True
def get_all_tools(self) -> List[Dict]:
"""Get all tools (local + external) in MCP format."""
tools = [tool.to_mcp_schema() for tool in self.registry.get_all()]
tools.extend(self.external_tools.values())
return tools
# Create global MCP interface
mcp_interface = MCPServerInterface(registry)
# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================
def get_tool_names() -> List[str]:
"""Get list of all registered tool names."""
return [tool.name for tool in registry.get_all()]
def get_chart_tools() -> List[str]:
"""Get tools that return chart data."""
return [tool.name for tool in registry.get_all() if tool.returns_chart]
def get_tools_by_category(category: str) -> List[str]:
"""Get tool names by category."""
return [tool.name for tool in registry.get_by_category(category)]
# =============================================================================
# SKIN ANALYSIS TOOL (calls Health Foundation / SCIN Classifier server)
# =============================================================================
HEALTH_FOUNDATION_URL = os.getenv("HEALTH_FOUNDATION_URL", os.getenv("HEAR_SERVER_URL", "http://localhost:8082"))
def analyze_skin_image(patient_id: str, image_data: str, symptoms: str = "") -> str:
"""
Analyze a skin image using Google's Derm Foundation model + SCIN classifier.
This tool calls the Health Foundation Server which:
1. Extracts embeddings using Derm Foundation
2. Runs the trained SCIN classifier for condition prediction
3. Parses user symptoms from text
4. Combines predictions with user input
5. Returns structured data + LLM prompt for synthesis
Args:
patient_id: The patient's ID
image_data: Base64 encoded image data
symptoms: Optional free-text symptom description (e.g., "itchy red rash on arm")
Returns:
JSON string with conditions, symptoms, body parts, and LLM synthesis prompt
"""
import base64
import httpx
try:
# Decode base64 if needed
if isinstance(image_data, str):
# Remove data URL prefix if present
if ',' in image_data:
image_data = image_data.split(',')[1]
image_bytes = base64.b64decode(image_data)
else:
image_bytes = image_data
# First, try the classifier endpoint
with httpx.Client(timeout=90.0) as client:
files = {"image": ("skin_image.png", image_bytes, "image/png")}
data = {
"symptoms": symptoms or "",
"top_k": "5"
}
response = client.post(
f"{HEALTH_FOUNDATION_URL}/analyze/skin/classify",
files=files,
data=data,
headers={"ngrok-skip-browser-warning": "true"}
)
# If classifier not available, fall back to basic embedding analysis
if response.status_code == 503 or (response.status_code == 200 and "classifier not loaded" in response.text.lower()):
# Fallback to basic embedding analysis
files = {"image": ("skin_image.png", image_bytes, "image/png")}
response = client.post(
f"{HEALTH_FOUNDATION_URL}/analyze/skin",
files=files,
data={"include_embedding": "false"},
headers={"ngrok-skip-browser-warning": "true"}
)
response.raise_for_status()
result = response.json()
if not result.get("success"):
return json_module.dumps({
"error": result.get("error", "Analysis failed"),
"disclaimer": "⚠️ FOR RESEARCH USE ONLY - NOT A DIAGNOSTIC TOOL"
})
# Format basic response
quality = result.get("image_quality", {})
return json_module.dumps({
"status": "success",
"analysis_type": "embedding_only",
"model": "Derm Foundation (classifier unavailable)",
"image_quality_score": quality.get("score", 0),
"image_quality_notes": quality.get("notes", []),
"conditions": [],
"symptoms_detected": [],
"body_parts_detected": [],
"user_reported_symptoms": symptoms,
"llm_synthesis_prompt": (
f"The skin image was analyzed but the condition classifier is unavailable. "
f"Image quality score: {quality.get('score', 0)}/100. "
f"User described: '{symptoms}'. "
f"Please acknowledge their concern, note that specific condition identification isn't available, "
f"and recommend they consult a dermatologist for a proper evaluation."
),
"summary": "Image processed (embedding only - classifier unavailable)",
"disclaimer": "⚠️ FOR RESEARCH USE ONLY - NOT A DIAGNOSTIC TOOL"
})
response.raise_for_status()
result = response.json()
if not result.get("success"):
return json_module.dumps({
"error": result.get("error", "Analysis failed"),
"disclaimer": "⚠️ FOR RESEARCH USE ONLY - NOT A DIAGNOSTIC TOOL"
})
# Format the classifier response for LLM consumption
conditions = result.get("conditions", [])
predicted_symptoms = result.get("predicted_symptoms", [])
predicted_body_parts = result.get("predicted_body_parts", [])
user_reported = result.get("user_reported", {})
fitzpatrick = result.get("fitzpatrick_skin_type", {})
symptom_agreement = result.get("symptom_agreement", {})
llm_prompt = result.get("llm_prompt", "")
summary = result.get("summary", {})
# Build a structured response for the agent
analysis = {
"status": "success",
"analysis_type": "full_classification",
"model": "Derm Foundation + SCIN Classifier",
# Conditions ranked by confidence
"conditions": [
{
"name": c.get("condition", ""),
"confidence": round(c.get("adjusted_confidence", 0) * 100, 1),
"likelihood": "high" if c.get("adjusted_confidence", 0) > 0.6 else "moderate" if c.get("adjusted_confidence", 0) > 0.4 else "possible"
}
for c in conditions[:5]
],
# Symptoms from image analysis
"symptoms_from_image": [
{"name": s.get("label", ""), "confidence": round(s.get("confidence", 0) * 100, 1)}
for s in predicted_symptoms[:5]
],
# Body parts detected
"body_parts_detected": [
{"name": b.get("label", ""), "confidence": round(b.get("confidence", 0) * 100, 1)}
for b in predicted_body_parts[:5]
],
# User reported symptoms (parsed)
"user_reported": {
"symptoms": user_reported.get("symptoms", []),
"body_parts": user_reported.get("body_parts", []),
"duration": user_reported.get("duration"),
"raw_description": user_reported.get("raw_text", symptoms)
},
# Symptom agreement between user and model
"symptom_match": {
"agreed": symptom_agreement.get("agreed", []),
"user_only": symptom_agreement.get("user_reported_only", []),
"model_only": symptom_agreement.get("model_detected_only", [])
},
# Fitzpatrick skin type estimation
"estimated_skin_type": {
"type": fitzpatrick.get("label", "Unknown"),
"confidence": round(fitzpatrick.get("confidence", 0) * 100, 1)
},
# Summary for quick reference
"summary": summary,
# LLM synthesis prompt - USE THIS to generate the patient response
"llm_synthesis_prompt": llm_prompt,
"disclaimer": result.get("disclaimer", "⚠️ FOR RESEARCH USE ONLY - NOT A DIAGNOSTIC TOOL")
}
return json_module.dumps(analysis)
except httpx.ConnectError:
return json_module.dumps({
"error": "Skin analysis service unavailable. The Derm Foundation server may not be running.",
"suggestion": "Ensure the health foundation server is running on the correct port.",
"disclaimer": "⚠️ FOR RESEARCH USE ONLY"
})
except Exception as e:
import traceback
traceback.print_exc()
return json_module.dumps({
"error": f"Skin analysis failed: {str(e)}",
"disclaimer": "⚠️ FOR RESEARCH USE ONLY"
})