Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| MedGemma Pre-Visit Assessment Server (HuggingFace Spaces Version) | |
| """ | |
| import os | |
| import json | |
| import sqlite3 | |
| from datetime import datetime | |
| from typing import Optional | |
| from contextlib import asynccontextmanager | |
| import httpx | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import FileResponse, StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| # Configuration | |
| LLAMA_SERVER_URL = os.getenv("LLAMA_SERVER_URL", "http://localhost:8081") | |
| HEAR_SERVER_URL = os.getenv("HEAR_SERVER_URL", "") # Empty = disabled | |
| DB_PATH = os.getenv("DB_PATH", "data/fhir.db") | |
| # Headers for LLM requests (ngrok requires this) | |
| LLM_HEADERS = { | |
| "Content-Type": "application/json", | |
| "ngrok-skip-browser-warning": "true" | |
| } | |
| # Pydantic models | |
| class ChatRequest(BaseModel): | |
| patient_id: str | |
| message: str | |
| include_context: bool = True | |
| skin_image_data: Optional[str] = None # Base64 encoded skin image for analysis | |
| conversation_history: Optional[list] = None # List of {"role": "user"|"assistant", "content": "..."} | |
| class ChatResponse(BaseModel): | |
| response: str | |
| tokens_used: Optional[int] = None | |
| # Database helpers | |
| def get_db(): | |
| conn = sqlite3.connect(DB_PATH) | |
| conn.row_factory = sqlite3.Row | |
| return conn | |
| def dict_from_row(row): | |
| return dict(row) if row else None | |
| # Lifespan | |
| async def lifespan(app: FastAPI): | |
| async with httpx.AsyncClient() as client: | |
| try: | |
| resp = await client.get(f"{LLAMA_SERVER_URL}/health", headers=LLM_HEADERS, timeout=5.0) | |
| if resp.status_code == 200: | |
| print(f"✓ Connected to llama-server at {LLAMA_SERVER_URL}") | |
| else: | |
| print(f"⚠ llama-server returned status {resp.status_code}") | |
| except Exception as e: | |
| print(f"⚠ Could not connect to llama-server at {LLAMA_SERVER_URL}: {e}") | |
| print(" LLM features will not work until server is available") | |
| yield | |
| app = FastAPI(title="MedGemma Pre-Visit Assessment", lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| os.makedirs("static", exist_ok=True) | |
| # ============================================================================ | |
| # API Routes | |
| # ============================================================================ | |
| async def serve_frontend(): | |
| return FileResponse("static/index.html") | |
| async def list_patients(): | |
| conn = get_db() | |
| try: | |
| cursor = conn.execute(""" | |
| SELECT id, given_name, family_name, birth_date, gender | |
| FROM patients ORDER BY family_name, given_name | |
| """) | |
| patients = [dict_from_row(row) for row in cursor.fetchall()] | |
| for p in patients: | |
| birth = datetime.strptime(p["birth_date"], "%Y-%m-%d") | |
| p["age"] = (datetime.now() - birth).days // 365 | |
| p["name"] = f"{p['given_name']} {p['family_name']}" | |
| p["display_name"] = f"{p['given_name']} {p['family_name']}" | |
| return {"patients": patients} | |
| finally: | |
| conn.close() | |
| async def get_patient(patient_id: str): | |
| conn = get_db() | |
| try: | |
| cursor = conn.execute("SELECT * FROM patients WHERE id = ?", (patient_id,)) | |
| patient = dict_from_row(cursor.fetchone()) | |
| if not patient: | |
| raise HTTPException(status_code=404, detail="Patient not found") | |
| birth = datetime.strptime(patient["birth_date"], "%Y-%m-%d") | |
| patient["age"] = (datetime.now() - birth).days // 365 | |
| patient["name"] = f"{patient['given_name']} {patient['family_name']}" | |
| patient["display_name"] = patient["name"] | |
| return patient | |
| finally: | |
| conn.close() | |
| async def get_conditions(patient_id: str): | |
| conn = get_db() | |
| try: | |
| cursor = conn.execute(""" | |
| SELECT id, code, display, clinical_status, onset_date | |
| FROM conditions WHERE patient_id = ? | |
| ORDER BY onset_date DESC | |
| """, (patient_id,)) | |
| conditions = [dict_from_row(row) for row in cursor.fetchall()] | |
| return {"conditions": conditions} | |
| finally: | |
| conn.close() | |
| async def get_medications(patient_id: str, status: Optional[str] = None): | |
| conn = get_db() | |
| try: | |
| if status: | |
| cursor = conn.execute(""" | |
| SELECT id, code, display, status, start_date | |
| FROM medications WHERE patient_id = ? AND status = ? | |
| ORDER BY start_date DESC | |
| """, (patient_id, status)) | |
| else: | |
| cursor = conn.execute(""" | |
| SELECT id, code, display, status, start_date | |
| FROM medications WHERE patient_id = ? | |
| ORDER BY start_date DESC | |
| """, (patient_id,)) | |
| medications = [dict_from_row(row) for row in cursor.fetchall()] | |
| return {"medications": medications} | |
| finally: | |
| conn.close() | |
| async def get_observations(patient_id: str, code: Optional[str] = None, category: Optional[str] = None, limit: int = 100): | |
| conn = get_db() | |
| try: | |
| query = "SELECT * FROM observations WHERE patient_id = ?" | |
| params = [patient_id] | |
| if code: | |
| query += " AND code = ?" | |
| params.append(code) | |
| if category: | |
| query += " AND category = ?" | |
| params.append(category) | |
| query += " ORDER BY effective_date DESC LIMIT ?" | |
| params.append(limit) | |
| cursor = conn.execute(query, params) | |
| observations = [dict_from_row(row) for row in cursor.fetchall()] | |
| return {"observations": observations} | |
| finally: | |
| conn.close() | |
| async def get_allergies(patient_id: str): | |
| conn = get_db() | |
| try: | |
| cursor = conn.execute(""" | |
| SELECT id, substance, reaction_display as reaction, criticality | |
| FROM allergies WHERE patient_id = ? | |
| """, (patient_id,)) | |
| allergies = [dict_from_row(row) for row in cursor.fetchall()] | |
| return {"allergies": allergies} | |
| finally: | |
| conn.close() | |
| async def get_encounters(patient_id: str, limit: int = 10): | |
| conn = get_db() | |
| try: | |
| cursor = conn.execute(""" | |
| SELECT id, status, class_code, class_display, type_code, type_display, | |
| reason_code, reason_display, period_start, period_end | |
| FROM encounters WHERE patient_id = ? | |
| ORDER BY period_start DESC LIMIT ? | |
| """, (patient_id, limit)) | |
| encounters = [dict_from_row(row) for row in cursor.fetchall()] | |
| return {"encounters": encounters} | |
| finally: | |
| conn.close() | |
| async def get_immunizations(patient_id: str): | |
| conn = get_db() | |
| try: | |
| cursor = conn.execute(""" | |
| SELECT id, vaccine_code, vaccine_display, status, occurrence_date | |
| FROM immunizations WHERE patient_id = ? | |
| ORDER BY occurrence_date DESC | |
| """, (patient_id,)) | |
| immunizations = [dict_from_row(row) for row in cursor.fetchall()] | |
| return {"immunizations": immunizations} | |
| finally: | |
| conn.close() | |
| async def get_procedures(patient_id: str): | |
| conn = get_db() | |
| try: | |
| cursor = conn.execute(""" | |
| SELECT id, code, display, status, performed_date | |
| FROM procedures WHERE patient_id = ? | |
| ORDER BY performed_date DESC | |
| """, (patient_id,)) | |
| procedures = [dict_from_row(row) for row in cursor.fetchall()] | |
| return {"procedures": procedures} | |
| finally: | |
| conn.close() | |
| # ============================================================================ | |
| # LLM Integration | |
| # ============================================================================ | |
| def build_patient_context(patient_id: str) -> str: | |
| conn = get_db() | |
| try: | |
| cursor = conn.execute("SELECT * FROM patients WHERE id = ?", (patient_id,)) | |
| patient = dict_from_row(cursor.fetchone()) | |
| if not patient: | |
| return "Patient not found." | |
| birth = datetime.strptime(patient["birth_date"], "%Y-%m-%d") | |
| age = (datetime.now() - birth).days // 365 | |
| context = f"""PATIENT INFORMATION: | |
| Name: {patient['given_name']} {patient['family_name']} | |
| Age: {age} years old | |
| Gender: {patient['gender']} | |
| Birth Date: {patient['birth_date']} | |
| """ | |
| cursor = conn.execute("SELECT display, clinical_status, onset_date FROM conditions WHERE patient_id = ?", (patient_id,)) | |
| conditions = cursor.fetchall() | |
| if conditions: | |
| context += "ACTIVE CONDITIONS:\n" | |
| for c in conditions: | |
| context += f"- {c['display']} (since {c['onset_date'] or 'unknown'})\n" | |
| context += "\n" | |
| cursor = conn.execute("SELECT display, status, start_date FROM medications WHERE patient_id = ? AND status = 'active'", (patient_id,)) | |
| meds = cursor.fetchall() | |
| if meds: | |
| context += "CURRENT MEDICATIONS:\n" | |
| for m in meds: | |
| context += f"- {m['display']} (started {m['start_date'] or 'unknown'})\n" | |
| context += "\n" | |
| cursor = conn.execute(""" | |
| SELECT display, value_quantity, value_unit, effective_date | |
| FROM observations WHERE patient_id = ? AND category = 'vital-signs' | |
| ORDER BY effective_date DESC LIMIT 10 | |
| """, (patient_id,)) | |
| vitals = cursor.fetchall() | |
| if vitals: | |
| context += "RECENT VITAL SIGNS:\n" | |
| for v in vitals: | |
| context += f"- {v['display']}: {v['value_quantity']} {v['value_unit'] or ''} ({v['effective_date']})\n" | |
| context += "\n" | |
| cursor = conn.execute(""" | |
| SELECT display, value_quantity, value_unit, effective_date | |
| FROM observations WHERE patient_id = ? AND category = 'laboratory' | |
| ORDER BY effective_date DESC LIMIT 10 | |
| """, (patient_id,)) | |
| labs = cursor.fetchall() | |
| if labs: | |
| context += "RECENT LAB RESULTS:\n" | |
| for l in labs: | |
| context += f"- {l['display']}: {l['value_quantity']} {l['value_unit'] or ''} ({l['effective_date']})\n" | |
| context += "\n" | |
| cursor = conn.execute("SELECT substance, reaction_display, criticality FROM allergies WHERE patient_id = ?", (patient_id,)) | |
| allergies = cursor.fetchall() | |
| if allergies: | |
| context += "ALLERGIES:\n" | |
| for a in allergies: | |
| context += f"- {a['substance']}" | |
| if a['reaction_display']: | |
| context += f" (reaction: {a['reaction_display']})" | |
| context += "\n" | |
| return context | |
| finally: | |
| conn.close() | |
| async def call_llama_server(prompt: str) -> str: | |
| async with httpx.AsyncClient(timeout=300.0) as client: | |
| try: | |
| response = await client.post( | |
| f"{LLAMA_SERVER_URL}/completion", | |
| headers=LLM_HEADERS, | |
| json={ | |
| "prompt": prompt, | |
| "n_predict": 1024, | |
| "temperature": 0.7, | |
| "stop": ["<end_of_turn>", "</s>", "<|im_end|>"], | |
| "stream": False | |
| } | |
| ) | |
| response.raise_for_status() | |
| result = response.json() | |
| return result.get("content", "").strip() | |
| except httpx.ConnectError: | |
| raise HTTPException(status_code=503, detail="Cannot connect to llama-server") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"LLM error: {str(e)}") | |
| async def stream_llama_server(prompt: str): | |
| async with httpx.AsyncClient(timeout=300.0) as client: | |
| async with client.stream( | |
| "POST", | |
| f"{LLAMA_SERVER_URL}/completion", | |
| headers=LLM_HEADERS, | |
| json={ | |
| "prompt": prompt, | |
| "n_predict": 1024, | |
| "temperature": 0.7, | |
| "stop": ["<end_of_turn>", "</s>", "<|im_end|>"], | |
| "stream": True | |
| } | |
| ) as response: | |
| buffer = "" | |
| in_thinking = False | |
| async for line in response.aiter_lines(): | |
| if line.startswith("data: "): | |
| data = line[6:] | |
| if data.strip() == "[DONE]": | |
| break | |
| try: | |
| chunk = json.loads(data) | |
| content = chunk.get("content", "") | |
| if content: | |
| buffer += content | |
| while True: | |
| if not in_thinking: | |
| think_start = buffer.find("<think>") | |
| if think_start != -1: | |
| if think_start > 0: | |
| yield buffer[:think_start] | |
| buffer = buffer[think_start + 7:] | |
| in_thinking = True | |
| else: | |
| safe_end = len(buffer) - 7 | |
| if safe_end > 0: | |
| yield buffer[:safe_end] | |
| buffer = buffer[safe_end:] | |
| break | |
| else: | |
| think_end = buffer.find("</think>") | |
| if think_end != -1: | |
| buffer = buffer[think_end + 8:] | |
| in_thinking = False | |
| else: | |
| break | |
| except json.JSONDecodeError: | |
| pass | |
| if buffer and not in_thinking: | |
| yield buffer | |
| async def chat_endpoint(request: ChatRequest): | |
| context = "" | |
| if request.include_context: | |
| context = build_patient_context(request.patient_id) | |
| prompt = f"""<start_of_turn>user | |
| {context} | |
| Patient Question: {request.message} | |
| Please provide a helpful, accurate response based on the patient's health information above.<end_of_turn> | |
| <start_of_turn>model | |
| """ | |
| response = await call_llama_server(prompt) | |
| return ChatResponse(response=response) | |
| async def chat_stream_endpoint(request: ChatRequest): | |
| context = "" | |
| if request.include_context: | |
| context = build_patient_context(request.patient_id) | |
| prompt = f"""<start_of_turn>user | |
| {context} | |
| Patient Question: {request.message} | |
| Please provide a helpful, accurate response based on the patient's health information above.<end_of_turn> | |
| <start_of_turn>model | |
| """ | |
| async def generate(): | |
| async for chunk in stream_llama_server(prompt): | |
| yield f"data: {json.dumps({'content': chunk})}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return StreamingResponse(generate(), media_type="text/event-stream") | |
| async def health_check(): | |
| llama_status = "unknown" | |
| async with httpx.AsyncClient(timeout=5.0) as client: | |
| try: | |
| resp = await client.get(f"{LLAMA_SERVER_URL}/health", headers=LLM_HEADERS) | |
| llama_status = "connected" if resp.status_code == 200 else "error" | |
| except: | |
| llama_status = "disconnected" | |
| db_status = "unknown" | |
| try: | |
| conn = get_db() | |
| conn.execute("SELECT 1") | |
| conn.close() | |
| db_status = "connected" | |
| except: | |
| db_status = "error" | |
| return { | |
| "status": "healthy" if llama_status == "connected" and db_status == "connected" else "degraded", | |
| "llama_server": llama_status, | |
| "database": db_status, | |
| "llama_url": LLAMA_SERVER_URL | |
| } | |
| # ============================================================================ | |
| # MCP (Model Context Protocol) Endpoints | |
| # ============================================================================ | |
| from tools import mcp_interface | |
| async def mcp_initialize(request: dict = None): | |
| """MCP Initialize - Return server capabilities.""" | |
| return mcp_interface.get_server_info() | |
| async def mcp_list_tools(): | |
| """MCP List Tools - Return available tools in MCP format.""" | |
| return mcp_interface.list_tools() | |
| async def mcp_call_tool(request: dict): | |
| """MCP Call Tool - Execute a tool and return result.""" | |
| name = request.get("name") | |
| arguments = request.get("arguments", {}) | |
| return mcp_interface.call_tool(name, arguments) | |
| async def mcp_status(): | |
| """Get MCP status including connected external servers.""" | |
| return { | |
| "protocol_version": mcp_interface.PROTOCOL_VERSION, | |
| "local_tools": len(mcp_interface.registry.get_all()), | |
| "external_tools": len(mcp_interface.external_tools), | |
| "connected_servers": mcp_interface.list_connected_servers() | |
| } | |
| async def mcp_connect_server(request: dict): | |
| """Connect to an external MCP server and discover its tools.""" | |
| server_url = request.get("server_url") | |
| server_name = request.get("server_name") | |
| if not server_url: | |
| return {"success": False, "error": "server_url required"} | |
| return mcp_interface.connect_server(server_url, server_name) | |
| async def mcp_disconnect_server(request: dict): | |
| """Disconnect from an external MCP server.""" | |
| server_url = request.get("server_url") | |
| if not server_url: | |
| return {"success": False, "error": "server_url required"} | |
| success = mcp_interface.disconnect_server(server_url) | |
| return {"success": success} | |
| async def mcp_register_tool(request: dict): | |
| """Manually register an external tool without full MCP server.""" | |
| name = request.get("name") | |
| description = request.get("description") | |
| parameters = request.get("parameters", {"type": "object", "properties": {}}) | |
| handler_url = request.get("handler_url") | |
| if not all([name, description, handler_url]): | |
| return {"success": False, "error": "name, description, and handler_url required"} | |
| success = mcp_interface.register_tool_manually(name, description, parameters, handler_url) | |
| return {"success": success, "tool_name": name} | |
| async def mcp_get_all_tools(): | |
| """Get all tools (local + external) in MCP format.""" | |
| return {"tools": mcp_interface.get_all_tools()} | |
| # ============================================================================ | |
| # Agent endpoints (v2 manual graph + LangGraph) | |
| # ============================================================================ | |
| from agent_v2 import run_agent_v2 | |
| # Toggle: set USE_LANGGRAPH=true to use LangGraph agent | |
| USE_LANGGRAPH = os.getenv("USE_LANGGRAPH", "false").lower() == "true" | |
| try: | |
| from agent_langgraph import run_agent_langgraph | |
| LANGGRAPH_AVAILABLE = True | |
| print(f"[SERVER] LangGraph agent available (active: {USE_LANGGRAPH})") | |
| except ImportError as e: | |
| LANGGRAPH_AVAILABLE = False | |
| print(f"[SERVER] LangGraph agent not available: {e}") | |
| async def agent_chat_endpoint(request: ChatRequest): | |
| async def generate(): | |
| try: | |
| has_image = request.skin_image_data is not None and len(request.skin_image_data) > 0 | |
| agent_type = "langgraph" if (USE_LANGGRAPH and LANGGRAPH_AVAILABLE) else "v2" | |
| print(f"[SERVER] Agent chat ({agent_type}) - question: '{request.message[:50]}', has_skin_image: {has_image}") | |
| if USE_LANGGRAPH and LANGGRAPH_AVAILABLE: | |
| async for event in run_agent_langgraph( | |
| request.patient_id, | |
| request.message, | |
| skin_image_data=request.skin_image_data, | |
| conversation_history=request.conversation_history | |
| ): | |
| yield f"data: {json.dumps(event)}\n\n" | |
| else: | |
| async for event in run_agent_v2( | |
| request.patient_id, | |
| request.message, | |
| skin_image_data=request.skin_image_data, | |
| conversation_history=request.conversation_history | |
| ): | |
| yield f"data: {json.dumps(event)}\n\n" | |
| except Exception as e: | |
| yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return StreamingResponse( | |
| generate(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no" | |
| } | |
| ) | |
| # Simple test endpoint to verify SSE streaming works | |
| async def test_stream(): | |
| async def generate(): | |
| for i in range(3): | |
| yield f"data: {{\"count\": {i}}}\n\n" | |
| import asyncio | |
| await asyncio.sleep(0.1) | |
| yield "data: [DONE]\n\n" | |
| return StreamingResponse( | |
| generate(), | |
| media_type="text/event-stream" | |
| ) | |
| # Debug endpoint to test database directly | |
| async def debug_medications(patient_id: str): | |
| conn = get_db() | |
| try: | |
| # Check patient exists | |
| cursor = conn.execute("SELECT id, given_name, family_name FROM patients WHERE id = ?", (patient_id,)) | |
| patient = cursor.fetchone() | |
| # Get all medications for this patient | |
| cursor = conn.execute("SELECT id, display, status FROM medications WHERE patient_id = ?", (patient_id,)) | |
| meds = [dict(row) for row in cursor.fetchall()] | |
| # Get all patient_ids in medications table | |
| cursor = conn.execute("SELECT DISTINCT patient_id FROM medications") | |
| med_patient_ids = [row[0] for row in cursor.fetchall()] | |
| # Get all patient_ids in observations table | |
| cursor = conn.execute("SELECT DISTINCT patient_id FROM observations LIMIT 5") | |
| obs_patient_ids = [row[0] for row in cursor.fetchall()] | |
| return { | |
| "queried_patient_id": patient_id, | |
| "patient_found": patient is not None, | |
| "patient_name": f"{patient['given_name']} {patient['family_name']}" if patient else None, | |
| "medications_count": len(meds), | |
| "medications": meds[:5], | |
| "all_medication_patient_ids": med_patient_ids, | |
| "sample_observation_patient_ids": obs_patient_ids | |
| } | |
| finally: | |
| conn.close() | |
| async def debug_bp(patient_id: str): | |
| conn = get_db() | |
| try: | |
| # Check BP observations | |
| cursor = conn.execute(""" | |
| SELECT code, display, COUNT(*) as count | |
| FROM observations | |
| WHERE patient_id = ? AND code IN ('8480-6', '8462-4') | |
| GROUP BY code, display | |
| """, (patient_id,)) | |
| bp_counts = [dict(row) for row in cursor.fetchall()] | |
| # Get sample BP readings | |
| cursor = conn.execute(""" | |
| SELECT code, value_quantity, effective_date | |
| FROM observations | |
| WHERE patient_id = ? AND code = '8480-6' | |
| ORDER BY effective_date DESC LIMIT 5 | |
| """, (patient_id,)) | |
| sample_systolic = [dict(row) for row in cursor.fetchall()] | |
| # Get all patient_ids that have BP data | |
| cursor = conn.execute(""" | |
| SELECT DISTINCT patient_id FROM observations | |
| WHERE code IN ('8480-6', '8462-4') | |
| """) | |
| bp_patient_ids = [row[0] for row in cursor.fetchall()] | |
| return { | |
| "queried_patient_id": patient_id, | |
| "bp_observation_counts": bp_counts, | |
| "sample_systolic": sample_systolic, | |
| "patient_ids_with_bp_data": bp_patient_ids | |
| } | |
| finally: | |
| conn.close() | |
| return StreamingResponse( | |
| generate(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no" # Disable nginx buffering | |
| } | |
| ) | |
| # ============================================================================ | |
| # Audio Analysis (proxies to remote HeAR server) | |
| # ============================================================================ | |
| from fastapi import File, UploadFile | |
| async def audio_analyzer_status(): | |
| if not HEAR_SERVER_URL: | |
| return { | |
| "available": False, | |
| "model": None, | |
| "message": "Audio analysis not configured. Set HEAR_SERVER_URL.", | |
| "capabilities": [] | |
| } | |
| # Check remote HeAR server | |
| async with httpx.AsyncClient(timeout=5.0) as client: | |
| try: | |
| resp = await client.get(f"{HEAR_SERVER_URL}/status", headers=LLM_HEADERS) | |
| if resp.status_code == 200: | |
| data = resp.json() | |
| return { | |
| "available": data.get("available", True), | |
| "model": "HeAR (Remote)", | |
| "model_type": "HeAR (Health Acoustic Representations)", | |
| "message": "Connected to remote HeAR server", | |
| "capabilities": data.get("capabilities", ["cough_detection", "covid_risk_screening", "tb_risk_screening"]) | |
| } | |
| except Exception as e: | |
| return { | |
| "available": False, | |
| "model": None, | |
| "message": f"Cannot connect to HeAR server: {str(e)}", | |
| "capabilities": [] | |
| } | |
| async def analyze_audio(audio: UploadFile = File(...)): | |
| if not HEAR_SERVER_URL: | |
| return {"success": False, "error": "Audio analysis not configured"} | |
| try: | |
| audio_bytes = await audio.read() | |
| async with httpx.AsyncClient(timeout=60.0) as client: | |
| files = {"audio": ("recording.webm", audio_bytes, "audio/webm")} | |
| resp = await client.post( | |
| f"{HEAR_SERVER_URL}/analyze", | |
| files=files, | |
| headers={"ngrok-skip-browser-warning": "true"} | |
| ) | |
| if resp.status_code == 200: | |
| result = resp.json() | |
| return result | |
| else: | |
| return {"success": False, "error": f"HeAR server error: {resp.status_code}"} | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| # ============================================================================ | |
| # Skin Analysis (proxies to remote Health Foundation server with Derm Foundation) | |
| # ============================================================================ | |
| # Health Foundation URL (same server handles both HeAR audio and Derm skin) | |
| HEALTH_FOUNDATION_URL = os.getenv("HEALTH_FOUNDATION_URL", HEAR_SERVER_URL or "http://localhost:8082") | |
| class SkinAnalysisRequest(BaseModel): | |
| patient_id: str | |
| image_data: str # Base64 encoded image | |
| async def skin_analysis_status(): | |
| """Check if skin analysis is available.""" | |
| if not HEALTH_FOUNDATION_URL: | |
| return { | |
| "available": False, | |
| "model": None, | |
| "message": "Skin analysis not configured. Set HEALTH_FOUNDATION_URL.", | |
| "capabilities": [] | |
| } | |
| async with httpx.AsyncClient(timeout=5.0) as client: | |
| try: | |
| resp = await client.get( | |
| f"{HEALTH_FOUNDATION_URL}/status", | |
| headers={"ngrok-skip-browser-warning": "true"} | |
| ) | |
| if resp.status_code == 200: | |
| data = resp.json() | |
| return { | |
| "available": data.get("derm_available", False), | |
| "model": "Derm Foundation (google/derm-foundation)" if data.get("derm_available") else "Not loaded", | |
| "message": "Connected to Health Foundation server", | |
| "capabilities": ["skin_analysis", "derm_embeddings"] if data.get("derm_available") else [] | |
| } | |
| except Exception as e: | |
| return { | |
| "available": False, | |
| "model": None, | |
| "message": f"Cannot connect to Health Foundation server: {str(e)}", | |
| "capabilities": [] | |
| } | |
| async def analyze_skin_image(request: SkinAnalysisRequest): | |
| """ | |
| Analyze a skin image using Derm Foundation model. | |
| This endpoint is for direct calls from the frontend. | |
| The agent can also call skin analysis via the analyze_skin_image tool. | |
| """ | |
| if not HEALTH_FOUNDATION_URL: | |
| return {"success": False, "error": "Skin analysis not configured"} | |
| try: | |
| import base64 | |
| # Decode base64 image | |
| image_data = request.image_data | |
| if ',' in image_data: | |
| # Remove data URL prefix (e.g., "data:image/png;base64,") | |
| image_data = image_data.split(',')[1] | |
| image_bytes = base64.b64decode(image_data) | |
| # Send to health foundation server | |
| async with httpx.AsyncClient(timeout=60.0) as client: | |
| files = {"image": ("skin_image.png", image_bytes, "image/png")} | |
| data = {"include_embedding": "false"} | |
| resp = await client.post( | |
| f"{HEALTH_FOUNDATION_URL}/analyze/skin", | |
| files=files, | |
| data=data, | |
| headers={"ngrok-skip-browser-warning": "true"} | |
| ) | |
| if resp.status_code != 200: | |
| return { | |
| "success": False, | |
| "error": f"Analysis server returned status {resp.status_code}" | |
| } | |
| result = resp.json() | |
| if not result.get("success"): | |
| return { | |
| "success": False, | |
| "error": result.get("error", "Analysis failed") | |
| } | |
| # Return successful result | |
| return { | |
| "success": True, | |
| "model": result.get("model", "Derm Foundation"), | |
| "image_quality": result.get("image_quality", {}), | |
| "embedding_analysis": result.get("embedding_analysis", {}), | |
| "recommendation": result.get("recommendation", ""), | |
| "disclaimer": "⚠️ FOR RESEARCH USE ONLY - NOT A DIAGNOSTIC TOOL" | |
| } | |
| except httpx.ConnectError: | |
| return { | |
| "success": False, | |
| "error": "Skin analysis service unavailable. Is the health foundation server running?" | |
| } | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return { | |
| "success": False, | |
| "error": str(e) | |
| } | |
| # ============================================================================ | |
| # Pre-Visit Report Generation | |
| # ============================================================================ | |
| from report_generator import generate_report, format_report_html, PreVisitReport | |
| class ReportRequest(BaseModel): | |
| patient_id: str | |
| conversation: list # List of {"role": "user"|"assistant", "content": "..."} | |
| tool_results: list = [] # List of {"tool": "...", "facts": "..."} | |
| attachments: list = [] # List of {"type": "audio"|"chart"|"skin", "title": "...", "summary": "..."} | |
| async def generate_report_endpoint(request: ReportRequest): | |
| """Generate a pre-visit summary report from conversation.""" | |
| try: | |
| # Get patient info | |
| conn = get_db() | |
| cursor = conn.execute("SELECT * FROM patients WHERE id = ?", (request.patient_id,)) | |
| patient = cursor.fetchone() | |
| if not patient: | |
| conn.close() | |
| raise HTTPException(status_code=404, detail="Patient not found") | |
| from datetime import datetime | |
| birth = datetime.strptime(patient["birth_date"], "%Y-%m-%d") | |
| age = (datetime.now() - birth).days // 365 | |
| patient_info = { | |
| "name": f"{patient['given_name']} {patient['family_name']}", | |
| "age": age, | |
| "gender": patient['gender'] | |
| } | |
| # Fetch immunizations | |
| cursor = conn.execute(""" | |
| SELECT id, vaccine_code, vaccine_display, status, occurrence_date | |
| FROM immunizations WHERE patient_id = ? | |
| ORDER BY occurrence_date DESC | |
| """, (request.patient_id,)) | |
| immunizations = [dict_from_row(row) for row in cursor.fetchall()] | |
| # Fetch procedures (surgical history) | |
| cursor = conn.execute(""" | |
| SELECT id, code, display, status, performed_date | |
| FROM procedures WHERE patient_id = ? | |
| ORDER BY performed_date DESC | |
| """, (request.patient_id,)) | |
| procedures = [dict_from_row(row) for row in cursor.fetchall()] | |
| # Fetch recent encounters | |
| cursor = conn.execute(""" | |
| SELECT id, status, class_code, class_display, type_code, type_display, | |
| reason_code, reason_display, period_start, period_end | |
| FROM encounters WHERE patient_id = ? | |
| ORDER BY period_start DESC LIMIT 10 | |
| """, (request.patient_id,)) | |
| encounters = [dict_from_row(row) for row in cursor.fetchall()] | |
| # Fetch allergies | |
| cursor = conn.execute(""" | |
| SELECT id, substance, reaction_display as reaction, criticality, category | |
| FROM allergies WHERE patient_id = ? | |
| """, (request.patient_id,)) | |
| allergies = [dict_from_row(row) for row in cursor.fetchall()] | |
| conn.close() | |
| # Generate report with all data | |
| report = await generate_report( | |
| patient_info=patient_info, | |
| conversation_history=request.conversation, | |
| tool_results=request.tool_results, | |
| attachments=request.attachments, | |
| immunizations=immunizations, | |
| procedures=procedures, | |
| encounters=encounters, | |
| allergies=allergies | |
| ) | |
| # Return both structured data and HTML | |
| return { | |
| "success": True, | |
| "report": report.to_dict(), | |
| "html": format_report_html(report) | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| # ============================================================================= | |
| # EVALUATION ENDPOINT | |
| # ============================================================================= | |
| async def run_evaluation( | |
| patients: int = 5, | |
| mode: str = "direct", | |
| error_rate: float = 0.15 | |
| ): | |
| """ | |
| Run evaluation framework and return results. | |
| Parameters: | |
| - patients: Number of patients to test (default: 5) | |
| - mode: | |
| - 'direct': Perfect baseline (always 100%) | |
| - 'simulated': With fake errors (tests error detection) | |
| - 'agent': Real tool data retrieval (tests tools) | |
| - 'llm': Full LLM response accuracy (tests MedGemma text output) | |
| - error_rate: Error rate for simulated mode (default: 0.15) | |
| Results are printed to logs and returned as JSON. | |
| """ | |
| try: | |
| # Import evaluation modules | |
| from evaluation.test_generator import generate_all_test_cases, get_test_summary | |
| from evaluation.expected_values import compute_expected_values | |
| from evaluation.evaluator import evaluate_case | |
| from evaluation.metrics import aggregate_metrics, format_report | |
| from evaluation.run_evaluation import introduce_errors | |
| print("=" * 60) | |
| print(f"EVALUATION STARTED - Mode: {mode}, Patients: {patients}") | |
| print("=" * 60) | |
| # Generate test cases | |
| test_cases = generate_all_test_cases(num_patients=patients) | |
| summary = get_test_summary(test_cases) | |
| print(f"Generated {summary['total_cases']} test cases") | |
| for qtype, count in sorted(summary["by_type"].items()): | |
| print(f" {qtype}: {count}") | |
| # Run evaluation | |
| evaluations = [] | |
| if mode == "llm": | |
| # Full LLM evaluation - calls actual MedGemma and parses responses | |
| from evaluation.llm_eval import ( | |
| call_agent_endpoint, | |
| extract_numbers_from_chart, | |
| extract_numbers_from_text, | |
| compare_llm_response, | |
| aggregate_llm_results, | |
| LLMComparisonResult, | |
| evaluate_text_query, | |
| aggregate_text_results | |
| ) | |
| print("\nRunning FULL LLM evaluation (this calls actual MedGemma)...") | |
| # === PART 1: NUMERIC EVALUATION (Vitals) === | |
| print("\n--- PART 1: NUMERIC ACCURACY (Vital Charts) ---\n") | |
| vital_cases = [tc for tc in test_cases if tc["query_type"] == "vital_trend"] | |
| llm_results = [] | |
| for i, test_case in enumerate(vital_cases[:4]): # Limit to 4 | |
| patient_id = test_case["patient_id"] | |
| query = test_case["query"] | |
| case_id = test_case["case_id"] | |
| expected = compute_expected_values(test_case) | |
| print(f" [{i+1}/{min(4, len(vital_cases))}] {query[:50]}...") | |
| llm_response = await call_agent_endpoint(patient_id, query, timeout=90.0) | |
| if llm_response.error: | |
| print(f" ERROR: {llm_response.error}") | |
| llm_results.append(LLMComparisonResult( | |
| case_id=case_id, | |
| query=query, | |
| success=False, | |
| errors=[llm_response.error] | |
| )) | |
| else: | |
| chart_nums = extract_numbers_from_chart(llm_response.chart_data) | |
| text_nums = extract_numbers_from_text(llm_response.raw_response) | |
| print(f" Chart numbers: {chart_nums}") | |
| print(f" Text numbers: {text_nums}") | |
| result = compare_llm_response(llm_response, expected) | |
| result.case_id = case_id | |
| llm_results.append(result) | |
| if result.success: | |
| print(f" ✓ PASS ({result.accuracy():.0%} accuracy)") | |
| else: | |
| print(f" ✗ FAIL ({result.accuracy():.0%} accuracy)") | |
| for err in result.errors[:3]: | |
| print(f" - {err}") | |
| # === PART 2: TEXT EVALUATION (Medications, Conditions, Allergies) === | |
| print("\n--- PART 2: TEXT ACCURACY (Medications, Conditions, Allergies) ---\n") | |
| text_cases = [tc for tc in test_cases if tc["query_type"] in ["medication_list", "condition_list", "allergy_list"]] | |
| text_results = [] | |
| for i, test_case in enumerate(text_cases[:4]): # Limit to 4 | |
| patient_id = test_case["patient_id"] | |
| query = test_case["query"] | |
| query_type = test_case["query_type"] | |
| case_id = test_case["case_id"] | |
| expected = compute_expected_values(test_case) | |
| # Get expected items list based on query type | |
| if query_type == "medication_list": | |
| expected_items = expected.get("medication_names", []) | |
| elif query_type == "condition_list": | |
| expected_items = expected.get("condition_names", []) | |
| elif query_type == "allergy_list": | |
| expected_items = expected.get("substances", []) | |
| else: | |
| expected_items = [] | |
| print(f" [{i+1}/{min(4, len(text_cases))}] {query[:50]}...") | |
| print(f" Expected {len(expected_items)} items: {[x[:30] for x in expected_items[:3]]}...") | |
| result = await evaluate_text_query( | |
| patient_id, query, query_type, expected_items, case_id | |
| ) | |
| text_results.append(result) | |
| if result.success: | |
| print(f" ✓ PASS ({result.accuracy:.0%} - found {len(result.found_items)}/{len(expected_items)})") | |
| else: | |
| print(f" ✗ FAIL ({result.accuracy:.0%} - found {len(result.found_items)}/{len(expected_items)})") | |
| if result.missing_items: | |
| print(f" Missing: {result.missing_items[:3]}") | |
| # === AGGREGATE RESULTS === | |
| numeric_summary = aggregate_llm_results(llm_results) | |
| text_summary = aggregate_text_results(text_results) if text_results else {} | |
| print("\n" + "="*60) | |
| print("LLM RESPONSE ACCURACY REPORT") | |
| print("="*60) | |
| print("\n📊 NUMERIC ACCURACY (Vital Charts):") | |
| print(f" Test Cases: {numeric_summary['total_cases']}") | |
| print(f" Success Rate: {numeric_summary['success_rate']}") | |
| print(f" Number Accuracy: {numeric_summary['number_accuracy']}") | |
| if text_summary: | |
| print("\n📝 TEXT ACCURACY (Medications, Conditions, Allergies):") | |
| print(f" Test Cases: {text_summary['total_cases']}") | |
| print(f" Success Rate: {text_summary['success_rate']}") | |
| print(f" Item Recall: {text_summary['item_recall']}") | |
| if text_summary.get('by_type'): | |
| for qtype, stats in text_summary['by_type'].items(): | |
| print(f" {qtype}: {stats['passed']}/{stats['total']} passed ({stats['avg_accuracy']})") | |
| print("="*60) | |
| return { | |
| "success": True, | |
| "mode": "llm", | |
| "patients_tested": patients, | |
| "metrics": { | |
| "numeric": numeric_summary, | |
| "text": text_summary | |
| } | |
| } | |
| elif mode == "agent": | |
| # Real agent evaluation - run actual tool calls | |
| from evaluation.agent_eval import run_agent_sync | |
| print("\nRunning REAL AGENT evaluation...") | |
| for i, test_case in enumerate(test_cases): | |
| expected = compute_expected_values(test_case) | |
| patient_id = test_case["patient_id"] | |
| query = test_case["query"] | |
| query_type = test_case["query_type"] | |
| parameters = test_case.get("parameters", {}) | |
| # Run actual agent with test case info | |
| agent_response = run_agent_sync(patient_id, query, query_type, parameters) | |
| if agent_response.error: | |
| print(f" [WARN] Query failed: {query[:50]}... - {agent_response.error}") | |
| actual_facts = {"error": agent_response.error} | |
| else: | |
| actual_facts = agent_response.extracted_facts | |
| # Debug: show what agent returned vs expected | |
| if (i + 1) <= 3: # Show first 3 for debugging | |
| print(f"\n Case: {test_case['case_id']}") | |
| print(f" Query: {query}") | |
| print(f" Tool: {agent_response.tool_called}") | |
| evaluation = evaluate_case(test_case, expected, actual_facts) | |
| evaluations.append(evaluation) | |
| if (i + 1) % 10 == 0: | |
| print(f" Processed {i + 1}/{len(test_cases)} cases...") | |
| else: | |
| # Direct or simulated mode | |
| for i, test_case in enumerate(test_cases): | |
| expected = compute_expected_values(test_case) | |
| if mode == "simulated": | |
| actual_facts = introduce_errors(expected, error_rate) | |
| else: | |
| actual_facts = expected.copy() | |
| evaluation = evaluate_case(test_case, expected, actual_facts) | |
| evaluations.append(evaluation) | |
| if (i + 1) % 10 == 0: | |
| print(f" Processed {i + 1}/{len(test_cases)} cases...") | |
| # Aggregate and report (for non-LLM modes) | |
| if mode != "llm": | |
| metrics = aggregate_metrics(evaluations) | |
| report_text = format_report(metrics) | |
| # Print full report to logs | |
| print("\n" + report_text) | |
| # Return JSON response | |
| return { | |
| "success": True, | |
| "mode": mode, | |
| "patients_tested": patients, | |
| "metrics": metrics.to_dict() | |
| } | |
| # This shouldn't be reached but just in case | |
| return {"success": True, "mode": mode} | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"Evaluation failed: {str(e)}" | |
| print(error_msg) | |
| print(traceback.format_exc()) | |
| return { | |
| "success": False, | |
| "error": error_msg | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.getenv("PORT", "7860")) | |
| print(f"Starting server on port {port}...") | |
| print(f"LLM Backend: {LLAMA_SERVER_URL}") | |
| print(f"HeAR Backend: {HEAR_SERVER_URL or 'Not configured'}") | |
| uvicorn.run(app, host="0.0.0.0", port=port) |