bshepp commited on
Commit ·
c800712
1
Parent(s): 13d4b74
feat: add MedGemma readiness gate to prevent cold-start pipeline failures
Browse files
src/backend/app/api/health.py
CHANGED
|
@@ -23,4 +23,18 @@ async def config_check():
|
|
| 23 |
"medgemma_model_id": settings.medgemma_model_id,
|
| 24 |
"hf_token_set": bool(settings.hf_token),
|
| 25 |
"medgemma_max_tokens": settings.medgemma_max_tokens,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
}
|
|
|
|
| 23 |
"medgemma_model_id": settings.medgemma_model_id,
|
| 24 |
"hf_token_set": bool(settings.hf_token),
|
| 25 |
"medgemma_max_tokens": settings.medgemma_max_tokens,
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@router.get("/api/health/model")
|
| 30 |
+
async def model_readiness():
|
| 31 |
+
"""Check if the MedGemma endpoint is warm and accepting requests."""
|
| 32 |
+
from app.services.medgemma import MedGemmaService
|
| 33 |
+
|
| 34 |
+
service = MedGemmaService()
|
| 35 |
+
ready = await service.check_readiness()
|
| 36 |
+
return {
|
| 37 |
+
"ready": ready,
|
| 38 |
+
"model_id": settings.medgemma_model_id,
|
| 39 |
+
"base_url_set": bool(settings.medgemma_base_url),
|
| 40 |
}
|
src/backend/app/api/ws.py
CHANGED
|
@@ -11,12 +11,15 @@ from __future__ import annotations
|
|
| 11 |
|
| 12 |
import asyncio
|
| 13 |
import json
|
|
|
|
| 14 |
|
| 15 |
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
| 16 |
|
| 17 |
from app.agent.orchestrator import Orchestrator
|
| 18 |
from app.models.schemas import CaseSubmission
|
|
|
|
| 19 |
|
|
|
|
| 20 |
router = APIRouter()
|
| 21 |
|
| 22 |
|
|
@@ -46,7 +49,37 @@ async def agent_websocket(websocket: WebSocket):
|
|
| 46 |
# Send acknowledgment
|
| 47 |
await websocket.send_json({
|
| 48 |
"type": "ack",
|
| 49 |
-
"message": "Case received.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
})
|
| 51 |
|
| 52 |
# Run the orchestrator and stream updates
|
|
|
|
| 11 |
|
| 12 |
import asyncio
|
| 13 |
import json
|
| 14 |
+
import logging
|
| 15 |
|
| 16 |
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
| 17 |
|
| 18 |
from app.agent.orchestrator import Orchestrator
|
| 19 |
from app.models.schemas import CaseSubmission
|
| 20 |
+
from app.services.medgemma import MedGemmaService
|
| 21 |
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
router = APIRouter()
|
| 24 |
|
| 25 |
|
|
|
|
| 49 |
# Send acknowledgment
|
| 50 |
await websocket.send_json({
|
| 51 |
"type": "ack",
|
| 52 |
+
"message": "Case received. Checking model readiness...",
|
| 53 |
+
})
|
| 54 |
+
|
| 55 |
+
# ── Readiness gate: wait for MedGemma to be warm ──
|
| 56 |
+
medgemma = MedGemmaService()
|
| 57 |
+
|
| 58 |
+
async def _send_warming(elapsed: float, message: str):
|
| 59 |
+
"""Stream warm-up progress to client."""
|
| 60 |
+
try:
|
| 61 |
+
await websocket.send_json({
|
| 62 |
+
"type": "warming_up",
|
| 63 |
+
"message": message,
|
| 64 |
+
"elapsed_seconds": int(elapsed),
|
| 65 |
+
})
|
| 66 |
+
except Exception:
|
| 67 |
+
pass # client may have disconnected
|
| 68 |
+
|
| 69 |
+
ready = await medgemma.wait_until_ready(on_waiting=_send_warming)
|
| 70 |
+
if not ready:
|
| 71 |
+
await websocket.send_json({
|
| 72 |
+
"type": "error",
|
| 73 |
+
"message": (
|
| 74 |
+
"MedGemma model did not become ready within the timeout. "
|
| 75 |
+
"The endpoint may be starting up — please try again in a minute."
|
| 76 |
+
),
|
| 77 |
+
})
|
| 78 |
+
return
|
| 79 |
+
|
| 80 |
+
await websocket.send_json({
|
| 81 |
+
"type": "model_ready",
|
| 82 |
+
"message": "MedGemma is ready. Starting agent pipeline...",
|
| 83 |
})
|
| 84 |
|
| 85 |
# Run the orchestrator and stream updates
|
src/backend/app/services/medgemma.py
CHANGED
|
@@ -27,6 +27,10 @@ T = TypeVar("T", bound=BaseModel)
|
|
| 27 |
MAX_API_RETRIES = 3
|
| 28 |
RETRY_BASE_DELAY = 5.0 # seconds, doubles on each retry
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
class MedGemmaService:
|
| 32 |
"""
|
|
@@ -58,6 +62,71 @@ class MedGemmaService:
|
|
| 58 |
)
|
| 59 |
return self._client
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
async def generate(
|
| 62 |
self,
|
| 63 |
prompt: str,
|
|
|
|
| 27 |
MAX_API_RETRIES = 3
|
| 28 |
RETRY_BASE_DELAY = 5.0 # seconds, doubles on each retry
|
| 29 |
|
| 30 |
+
# Readiness probe configuration
|
| 31 |
+
READINESS_TIMEOUT = 180 # max seconds to wait for model warm-up
|
| 32 |
+
READINESS_POLL_INTERVAL = 5 # seconds between readiness checks
|
| 33 |
+
|
| 34 |
|
| 35 |
class MedGemmaService:
|
| 36 |
"""
|
|
|
|
| 62 |
)
|
| 63 |
return self._client
|
| 64 |
|
| 65 |
+
async def check_readiness(self) -> bool:
|
| 66 |
+
"""
|
| 67 |
+
Lightweight probe to check if the MedGemma endpoint is warm and
|
| 68 |
+
accepting requests. Sends a tiny 1-token generate call.
|
| 69 |
+
|
| 70 |
+
Returns True if the model responds, False on any transient error.
|
| 71 |
+
"""
|
| 72 |
+
if self._mode != "api":
|
| 73 |
+
return True # local mode is always "ready"
|
| 74 |
+
try:
|
| 75 |
+
client = await self._get_client()
|
| 76 |
+
response = await client.chat.completions.create(
|
| 77 |
+
model=settings.medgemma_model_id,
|
| 78 |
+
messages=[{"role": "user", "content": "ping"}],
|
| 79 |
+
max_tokens=1,
|
| 80 |
+
temperature=0.0,
|
| 81 |
+
)
|
| 82 |
+
return bool(response.choices)
|
| 83 |
+
except Exception as e:
|
| 84 |
+
logger.debug(f"Readiness probe failed: {e}")
|
| 85 |
+
return False
|
| 86 |
+
|
| 87 |
+
async def wait_until_ready(
|
| 88 |
+
self,
|
| 89 |
+
timeout: float = READINESS_TIMEOUT,
|
| 90 |
+
poll_interval: float = READINESS_POLL_INTERVAL,
|
| 91 |
+
on_waiting: Optional[Any] = None,
|
| 92 |
+
) -> bool:
|
| 93 |
+
"""
|
| 94 |
+
Poll check_readiness() until the model is warm or timeout expires.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
timeout: Maximum seconds to wait.
|
| 98 |
+
poll_interval: Seconds between probes.
|
| 99 |
+
on_waiting: Optional async callback(elapsed_seconds, message) invoked
|
| 100 |
+
each time we're still waiting — used to stream status to
|
| 101 |
+
the client.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
True if the model became ready, False if timeout was reached.
|
| 105 |
+
"""
|
| 106 |
+
import time
|
| 107 |
+
start = time.monotonic()
|
| 108 |
+
attempt = 0
|
| 109 |
+
while True:
|
| 110 |
+
attempt += 1
|
| 111 |
+
if await self.check_readiness():
|
| 112 |
+
logger.info("MedGemma readiness probe succeeded (%.1fs)", time.monotonic() - start)
|
| 113 |
+
return True
|
| 114 |
+
|
| 115 |
+
elapsed = time.monotonic() - start
|
| 116 |
+
if elapsed >= timeout:
|
| 117 |
+
logger.error("MedGemma readiness timeout after %.0fs", elapsed)
|
| 118 |
+
return False
|
| 119 |
+
|
| 120 |
+
msg = (
|
| 121 |
+
f"Warming up MedGemma model... "
|
| 122 |
+
f"({int(elapsed)}s elapsed, attempt {attempt})"
|
| 123 |
+
)
|
| 124 |
+
logger.info(msg)
|
| 125 |
+
if on_waiting:
|
| 126 |
+
await on_waiting(elapsed, msg)
|
| 127 |
+
|
| 128 |
+
await asyncio.sleep(poll_interval)
|
| 129 |
+
|
| 130 |
async def generate(
|
| 131 |
self,
|
| 132 |
prompt: str,
|
src/frontend/src/app/page.tsx
CHANGED
|
@@ -7,7 +7,7 @@ import { CDSReport } from "@/components/CDSReport";
|
|
| 7 |
import { useAgentWebSocket } from "@/hooks/useAgentWebSocket";
|
| 8 |
|
| 9 |
export default function Home() {
|
| 10 |
-
const { steps, report, isRunning, error, submitCase } = useAgentWebSocket();
|
| 11 |
const [hasSubmitted, setHasSubmitted] = useState(false);
|
| 12 |
|
| 13 |
const handleSubmit = (patientText: string) => {
|
|
@@ -72,6 +72,21 @@ export default function Home() {
|
|
| 72 |
<div className="lg:col-span-2">
|
| 73 |
{report ? (
|
| 74 |
<CDSReport report={report} />
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
) : isRunning ? (
|
| 76 |
<div className="flex items-center justify-center h-64 text-gray-400">
|
| 77 |
<div className="text-center">
|
|
|
|
| 7 |
import { useAgentWebSocket } from "@/hooks/useAgentWebSocket";
|
| 8 |
|
| 9 |
export default function Home() {
|
| 10 |
+
const { steps, report, isRunning, isWarmingUp, warmUpMessage, error, submitCase } = useAgentWebSocket();
|
| 11 |
const [hasSubmitted, setHasSubmitted] = useState(false);
|
| 12 |
|
| 13 |
const handleSubmit = (patientText: string) => {
|
|
|
|
| 72 |
<div className="lg:col-span-2">
|
| 73 |
{report ? (
|
| 74 |
<CDSReport report={report} />
|
| 75 |
+
) : isWarmingUp ? (
|
| 76 |
+
<div className="flex items-center justify-center h-64 text-amber-600">
|
| 77 |
+
<div className="text-center">
|
| 78 |
+
<div className="animate-pulse w-10 h-10 rounded-full bg-amber-100 flex items-center justify-center mx-auto mb-4">
|
| 79 |
+
<span className="text-xl">⚙</span>
|
| 80 |
+
</div>
|
| 81 |
+
<p className="font-medium">Model Warming Up</p>
|
| 82 |
+
<p className="text-sm text-amber-500 mt-1">
|
| 83 |
+
{warmUpMessage || "Waiting for MedGemma endpoint..."}
|
| 84 |
+
</p>
|
| 85 |
+
<p className="text-xs text-gray-400 mt-2">
|
| 86 |
+
This happens when the model scales from zero. Usually takes 1-2 minutes.
|
| 87 |
+
</p>
|
| 88 |
+
</div>
|
| 89 |
+
</div>
|
| 90 |
) : isRunning ? (
|
| 91 |
<div className="flex items-center justify-center h-64 text-gray-400">
|
| 92 |
<div className="text-center">
|
src/frontend/src/hooks/useAgentWebSocket.ts
CHANGED
|
@@ -22,6 +22,8 @@ interface UseAgentWebSocketReturn {
|
|
| 22 |
steps: Step[];
|
| 23 |
report: any | null;
|
| 24 |
isRunning: boolean;
|
|
|
|
|
|
|
| 25 |
error: string | null;
|
| 26 |
submitCase: (submission: CaseSubmission) => void;
|
| 27 |
}
|
|
@@ -45,6 +47,8 @@ export function useAgentWebSocket(): UseAgentWebSocketReturn {
|
|
| 45 |
const [steps, setSteps] = useState<Step[]>([]);
|
| 46 |
const [report, setReport] = useState<any | null>(null);
|
| 47 |
const [isRunning, setIsRunning] = useState(false);
|
|
|
|
|
|
|
| 48 |
const [error, setError] = useState<string | null>(null);
|
| 49 |
const wsRef = useRef<WebSocket | null>(null);
|
| 50 |
|
|
@@ -54,6 +58,8 @@ export function useAgentWebSocket(): UseAgentWebSocketReturn {
|
|
| 54 |
setReport(null);
|
| 55 |
setError(null);
|
| 56 |
setIsRunning(true);
|
|
|
|
|
|
|
| 57 |
|
| 58 |
// Close existing connection
|
| 59 |
if (wsRef.current) {
|
|
@@ -75,6 +81,16 @@ export function useAgentWebSocket(): UseAgentWebSocketReturn {
|
|
| 75 |
// Pipeline acknowledged
|
| 76 |
break;
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
case "step_update":
|
| 79 |
setSteps((prev) => {
|
| 80 |
const existing = prev.findIndex(
|
|
@@ -114,5 +130,5 @@ export function useAgentWebSocket(): UseAgentWebSocketReturn {
|
|
| 114 |
};
|
| 115 |
}, []);
|
| 116 |
|
| 117 |
-
return { steps, report, isRunning, error, submitCase };
|
| 118 |
}
|
|
|
|
| 22 |
steps: Step[];
|
| 23 |
report: any | null;
|
| 24 |
isRunning: boolean;
|
| 25 |
+
isWarmingUp: boolean;
|
| 26 |
+
warmUpMessage: string | null;
|
| 27 |
error: string | null;
|
| 28 |
submitCase: (submission: CaseSubmission) => void;
|
| 29 |
}
|
|
|
|
| 47 |
const [steps, setSteps] = useState<Step[]>([]);
|
| 48 |
const [report, setReport] = useState<any | null>(null);
|
| 49 |
const [isRunning, setIsRunning] = useState(false);
|
| 50 |
+
const [isWarmingUp, setIsWarmingUp] = useState(false);
|
| 51 |
+
const [warmUpMessage, setWarmUpMessage] = useState<string | null>(null);
|
| 52 |
const [error, setError] = useState<string | null>(null);
|
| 53 |
const wsRef = useRef<WebSocket | null>(null);
|
| 54 |
|
|
|
|
| 58 |
setReport(null);
|
| 59 |
setError(null);
|
| 60 |
setIsRunning(true);
|
| 61 |
+
setIsWarmingUp(false);
|
| 62 |
+
setWarmUpMessage(null);
|
| 63 |
|
| 64 |
// Close existing connection
|
| 65 |
if (wsRef.current) {
|
|
|
|
| 81 |
// Pipeline acknowledged
|
| 82 |
break;
|
| 83 |
|
| 84 |
+
case "warming_up":
|
| 85 |
+
setIsWarmingUp(true);
|
| 86 |
+
setWarmUpMessage(data.message);
|
| 87 |
+
break;
|
| 88 |
+
|
| 89 |
+
case "model_ready":
|
| 90 |
+
setIsWarmingUp(false);
|
| 91 |
+
setWarmUpMessage(null);
|
| 92 |
+
break;
|
| 93 |
+
|
| 94 |
case "step_update":
|
| 95 |
setSteps((prev) => {
|
| 96 |
const existing = prev.findIndex(
|
|
|
|
| 130 |
};
|
| 131 |
}, []);
|
| 132 |
|
| 133 |
+
return { steps, report, isRunning, isWarmingUp, warmUpMessage, error, submitCase };
|
| 134 |
}
|