bshepp commited on
Commit
c800712
·
1 Parent(s): 13d4b74

feat: add MedGemma readiness gate to prevent cold-start pipeline failures

Browse files
src/backend/app/api/health.py CHANGED
@@ -23,4 +23,18 @@ async def config_check():
23
  "medgemma_model_id": settings.medgemma_model_id,
24
  "hf_token_set": bool(settings.hf_token),
25
  "medgemma_max_tokens": settings.medgemma_max_tokens,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  }
 
23
  "medgemma_model_id": settings.medgemma_model_id,
24
  "hf_token_set": bool(settings.hf_token),
25
  "medgemma_max_tokens": settings.medgemma_max_tokens,
26
+ }
27
+
28
+
29
+ @router.get("/api/health/model")
30
+ async def model_readiness():
31
+ """Check if the MedGemma endpoint is warm and accepting requests."""
32
+ from app.services.medgemma import MedGemmaService
33
+
34
+ service = MedGemmaService()
35
+ ready = await service.check_readiness()
36
+ return {
37
+ "ready": ready,
38
+ "model_id": settings.medgemma_model_id,
39
+ "base_url_set": bool(settings.medgemma_base_url),
40
  }
src/backend/app/api/ws.py CHANGED
@@ -11,12 +11,15 @@ from __future__ import annotations
11
 
12
  import asyncio
13
  import json
 
14
 
15
  from fastapi import APIRouter, WebSocket, WebSocketDisconnect
16
 
17
  from app.agent.orchestrator import Orchestrator
18
  from app.models.schemas import CaseSubmission
 
19
 
 
20
  router = APIRouter()
21
 
22
 
@@ -46,7 +49,37 @@ async def agent_websocket(websocket: WebSocket):
46
  # Send acknowledgment
47
  await websocket.send_json({
48
  "type": "ack",
49
- "message": "Case received. Starting agent pipeline...",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  })
51
 
52
  # Run the orchestrator and stream updates
 
11
 
12
  import asyncio
13
  import json
14
+ import logging
15
 
16
  from fastapi import APIRouter, WebSocket, WebSocketDisconnect
17
 
18
  from app.agent.orchestrator import Orchestrator
19
  from app.models.schemas import CaseSubmission
20
+ from app.services.medgemma import MedGemmaService
21
 
22
+ logger = logging.getLogger(__name__)
23
  router = APIRouter()
24
 
25
 
 
49
  # Send acknowledgment
50
  await websocket.send_json({
51
  "type": "ack",
52
+ "message": "Case received. Checking model readiness...",
53
+ })
54
+
55
+ # ── Readiness gate: wait for MedGemma to be warm ──
56
+ medgemma = MedGemmaService()
57
+
58
+ async def _send_warming(elapsed: float, message: str):
59
+ """Stream warm-up progress to client."""
60
+ try:
61
+ await websocket.send_json({
62
+ "type": "warming_up",
63
+ "message": message,
64
+ "elapsed_seconds": int(elapsed),
65
+ })
66
+ except Exception:
67
+ pass # client may have disconnected
68
+
69
+ ready = await medgemma.wait_until_ready(on_waiting=_send_warming)
70
+ if not ready:
71
+ await websocket.send_json({
72
+ "type": "error",
73
+ "message": (
74
+ "MedGemma model did not become ready within the timeout. "
75
+ "The endpoint may be starting up — please try again in a minute."
76
+ ),
77
+ })
78
+ return
79
+
80
+ await websocket.send_json({
81
+ "type": "model_ready",
82
+ "message": "MedGemma is ready. Starting agent pipeline...",
83
  })
84
 
85
  # Run the orchestrator and stream updates
src/backend/app/services/medgemma.py CHANGED
@@ -27,6 +27,10 @@ T = TypeVar("T", bound=BaseModel)
27
  MAX_API_RETRIES = 3
28
  RETRY_BASE_DELAY = 5.0 # seconds, doubles on each retry
29
 
 
 
 
 
30
 
31
  class MedGemmaService:
32
  """
@@ -58,6 +62,71 @@ class MedGemmaService:
58
  )
59
  return self._client
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  async def generate(
62
  self,
63
  prompt: str,
 
27
  MAX_API_RETRIES = 3
28
  RETRY_BASE_DELAY = 5.0 # seconds, doubles on each retry
29
 
30
+ # Readiness probe configuration
31
+ READINESS_TIMEOUT = 180 # max seconds to wait for model warm-up
32
+ READINESS_POLL_INTERVAL = 5 # seconds between readiness checks
33
+
34
 
35
  class MedGemmaService:
36
  """
 
62
  )
63
  return self._client
64
 
65
+ async def check_readiness(self) -> bool:
66
+ """
67
+ Lightweight probe to check if the MedGemma endpoint is warm and
68
+ accepting requests. Sends a tiny 1-token generate call.
69
+
70
+ Returns True if the model responds, False on any transient error.
71
+ """
72
+ if self._mode != "api":
73
+ return True # local mode is always "ready"
74
+ try:
75
+ client = await self._get_client()
76
+ response = await client.chat.completions.create(
77
+ model=settings.medgemma_model_id,
78
+ messages=[{"role": "user", "content": "ping"}],
79
+ max_tokens=1,
80
+ temperature=0.0,
81
+ )
82
+ return bool(response.choices)
83
+ except Exception as e:
84
+ logger.debug(f"Readiness probe failed: {e}")
85
+ return False
86
+
87
+ async def wait_until_ready(
88
+ self,
89
+ timeout: float = READINESS_TIMEOUT,
90
+ poll_interval: float = READINESS_POLL_INTERVAL,
91
+ on_waiting: Optional[Any] = None,
92
+ ) -> bool:
93
+ """
94
+ Poll check_readiness() until the model is warm or timeout expires.
95
+
96
+ Args:
97
+ timeout: Maximum seconds to wait.
98
+ poll_interval: Seconds between probes.
99
+ on_waiting: Optional async callback(elapsed_seconds, message) invoked
100
+ each time we're still waiting — used to stream status to
101
+ the client.
102
+
103
+ Returns:
104
+ True if the model became ready, False if timeout was reached.
105
+ """
106
+ import time
107
+ start = time.monotonic()
108
+ attempt = 0
109
+ while True:
110
+ attempt += 1
111
+ if await self.check_readiness():
112
+ logger.info("MedGemma readiness probe succeeded (%.1fs)", time.monotonic() - start)
113
+ return True
114
+
115
+ elapsed = time.monotonic() - start
116
+ if elapsed >= timeout:
117
+ logger.error("MedGemma readiness timeout after %.0fs", elapsed)
118
+ return False
119
+
120
+ msg = (
121
+ f"Warming up MedGemma model... "
122
+ f"({int(elapsed)}s elapsed, attempt {attempt})"
123
+ )
124
+ logger.info(msg)
125
+ if on_waiting:
126
+ await on_waiting(elapsed, msg)
127
+
128
+ await asyncio.sleep(poll_interval)
129
+
130
  async def generate(
131
  self,
132
  prompt: str,
src/frontend/src/app/page.tsx CHANGED
@@ -7,7 +7,7 @@ import { CDSReport } from "@/components/CDSReport";
7
  import { useAgentWebSocket } from "@/hooks/useAgentWebSocket";
8
 
9
  export default function Home() {
10
- const { steps, report, isRunning, error, submitCase } = useAgentWebSocket();
11
  const [hasSubmitted, setHasSubmitted] = useState(false);
12
 
13
  const handleSubmit = (patientText: string) => {
@@ -72,6 +72,21 @@ export default function Home() {
72
  <div className="lg:col-span-2">
73
  {report ? (
74
  <CDSReport report={report} />
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  ) : isRunning ? (
76
  <div className="flex items-center justify-center h-64 text-gray-400">
77
  <div className="text-center">
 
7
  import { useAgentWebSocket } from "@/hooks/useAgentWebSocket";
8
 
9
  export default function Home() {
10
+ const { steps, report, isRunning, isWarmingUp, warmUpMessage, error, submitCase } = useAgentWebSocket();
11
  const [hasSubmitted, setHasSubmitted] = useState(false);
12
 
13
  const handleSubmit = (patientText: string) => {
 
72
  <div className="lg:col-span-2">
73
  {report ? (
74
  <CDSReport report={report} />
75
+ ) : isWarmingUp ? (
76
+ <div className="flex items-center justify-center h-64 text-amber-600">
77
+ <div className="text-center">
78
+ <div className="animate-pulse w-10 h-10 rounded-full bg-amber-100 flex items-center justify-center mx-auto mb-4">
79
+ <span className="text-xl">&#9881;</span>
80
+ </div>
81
+ <p className="font-medium">Model Warming Up</p>
82
+ <p className="text-sm text-amber-500 mt-1">
83
+ {warmUpMessage || "Waiting for MedGemma endpoint..."}
84
+ </p>
85
+ <p className="text-xs text-gray-400 mt-2">
86
+ This happens when the model scales from zero. Usually takes 1-2 minutes.
87
+ </p>
88
+ </div>
89
+ </div>
90
  ) : isRunning ? (
91
  <div className="flex items-center justify-center h-64 text-gray-400">
92
  <div className="text-center">
src/frontend/src/hooks/useAgentWebSocket.ts CHANGED
@@ -22,6 +22,8 @@ interface UseAgentWebSocketReturn {
22
  steps: Step[];
23
  report: any | null;
24
  isRunning: boolean;
 
 
25
  error: string | null;
26
  submitCase: (submission: CaseSubmission) => void;
27
  }
@@ -45,6 +47,8 @@ export function useAgentWebSocket(): UseAgentWebSocketReturn {
45
  const [steps, setSteps] = useState<Step[]>([]);
46
  const [report, setReport] = useState<any | null>(null);
47
  const [isRunning, setIsRunning] = useState(false);
 
 
48
  const [error, setError] = useState<string | null>(null);
49
  const wsRef = useRef<WebSocket | null>(null);
50
 
@@ -54,6 +58,8 @@ export function useAgentWebSocket(): UseAgentWebSocketReturn {
54
  setReport(null);
55
  setError(null);
56
  setIsRunning(true);
 
 
57
 
58
  // Close existing connection
59
  if (wsRef.current) {
@@ -75,6 +81,16 @@ export function useAgentWebSocket(): UseAgentWebSocketReturn {
75
  // Pipeline acknowledged
76
  break;
77
 
 
 
 
 
 
 
 
 
 
 
78
  case "step_update":
79
  setSteps((prev) => {
80
  const existing = prev.findIndex(
@@ -114,5 +130,5 @@ export function useAgentWebSocket(): UseAgentWebSocketReturn {
114
  };
115
  }, []);
116
 
117
- return { steps, report, isRunning, error, submitCase };
118
  }
 
22
  steps: Step[];
23
  report: any | null;
24
  isRunning: boolean;
25
+ isWarmingUp: boolean;
26
+ warmUpMessage: string | null;
27
  error: string | null;
28
  submitCase: (submission: CaseSubmission) => void;
29
  }
 
47
  const [steps, setSteps] = useState<Step[]>([]);
48
  const [report, setReport] = useState<any | null>(null);
49
  const [isRunning, setIsRunning] = useState(false);
50
+ const [isWarmingUp, setIsWarmingUp] = useState(false);
51
+ const [warmUpMessage, setWarmUpMessage] = useState<string | null>(null);
52
  const [error, setError] = useState<string | null>(null);
53
  const wsRef = useRef<WebSocket | null>(null);
54
 
 
58
  setReport(null);
59
  setError(null);
60
  setIsRunning(true);
61
+ setIsWarmingUp(false);
62
+ setWarmUpMessage(null);
63
 
64
  // Close existing connection
65
  if (wsRef.current) {
 
81
  // Pipeline acknowledged
82
  break;
83
 
84
+ case "warming_up":
85
+ setIsWarmingUp(true);
86
+ setWarmUpMessage(data.message);
87
+ break;
88
+
89
+ case "model_ready":
90
+ setIsWarmingUp(false);
91
+ setWarmUpMessage(null);
92
+ break;
93
+
94
  case "step_update":
95
  setSteps((prev) => {
96
  const existing = prev.findIndex(
 
130
  };
131
  }, []);
132
 
133
+ return { steps, report, isRunning, isWarmingUp, warmUpMessage, error, submitCase };
134
  }