cds-agent / src /backend /tracks /shared /endpoint_check.py
bshepp
Implement validation pipeline fixes (P1-P7) and experimental track system
28f1212
# [Shared: Track Utilities]
"""
Endpoint warm-up check — verifies the MedGemma endpoint is online
before running experiments. Handles scale-to-zero cold starts.
Usage:
from tracks.shared.endpoint_check import wait_for_endpoint
await wait_for_endpoint() # blocks until endpoint responds or gives up
"""
from __future__ import annotations
import asyncio
import logging
import sys
import time
from pathlib import Path
BACKEND_DIR = Path(__file__).resolve().parent.parent.parent
if str(BACKEND_DIR) not in sys.path:
sys.path.insert(0, str(BACKEND_DIR))
from app.config import settings
logger = logging.getLogger(__name__)
async def check_endpoint_health() -> tuple[bool, str]:
"""
Send a minimal request to the endpoint.
Returns (is_healthy, message).
"""
try:
from openai import AsyncOpenAI
client = AsyncOpenAI(
api_key=settings.medgemma_api_key or "not-needed",
base_url=settings.medgemma_base_url or "http://localhost:8000/v1",
timeout=30.0,
)
resp = await client.chat.completions.create(
model=settings.medgemma_model_id or "tgi",
messages=[{"role": "user", "content": "Say OK"}],
max_tokens=4,
temperature=0.0,
)
text = resp.choices[0].message.content or ""
return True, f"Endpoint alive: {text.strip()[:30]}"
except Exception as e:
msg = str(e)
if "503" in msg:
return False, "Endpoint returned 503 — model loading or scaled to zero"
if "Connection" in msg or "connect" in msg.lower():
return False, f"Connection error: {msg[:120]}"
return False, f"Endpoint error: {msg[:120]}"
async def wait_for_endpoint(
max_wait_sec: int = 600,
poll_interval_sec: int = 30,
quiet: bool = False,
) -> bool:
"""
Wait for the MedGemma endpoint to become healthy.
Polls every poll_interval_sec seconds, up to max_wait_sec total.
Returns True if endpoint is online, False if timed out.
Prints status messages to stdout unless quiet=True.
"""
t0 = time.monotonic()
ok, msg = await check_endpoint_health()
if ok:
if not quiet:
print(f"[endpoint] {msg}")
return True
if not quiet:
print(f"[endpoint] {msg}")
print(f"[endpoint] Waiting up to {max_wait_sec}s for endpoint to come online...")
print(f"[endpoint] If endpoint is paused, resume it at: https://ui.endpoints.huggingface.co/")
attempt = 1
while (time.monotonic() - t0) < max_wait_sec:
await asyncio.sleep(poll_interval_sec)
elapsed = int(time.monotonic() - t0)
ok, msg = await check_endpoint_health()
if ok:
if not quiet:
print(f"[endpoint] Online after {elapsed}s - {msg}")
return True
attempt += 1
if not quiet:
print(f"[endpoint] Attempt {attempt} ({elapsed}s elapsed): {msg}")
if not quiet:
print(f"[endpoint] TIMEOUT after {max_wait_sec}s — endpoint never came online")
print(f"[endpoint] Check https://ui.endpoints.huggingface.co/ and resume the endpoint")
return False