|
|
|
|
|
""" |
|
|
Endpoint warm-up check — verifies the MedGemma endpoint is online |
|
|
before running experiments. Handles scale-to-zero cold starts. |
|
|
|
|
|
Usage: |
|
|
from tracks.shared.endpoint_check import wait_for_endpoint |
|
|
await wait_for_endpoint() # blocks until endpoint responds or gives up |
|
|
""" |
|
|
from __future__ import annotations |
|
|
|
|
|
import asyncio |
|
|
import logging |
|
|
import sys |
|
|
import time |
|
|
from pathlib import Path |
|
|
|
|
|
BACKEND_DIR = Path(__file__).resolve().parent.parent.parent |
|
|
if str(BACKEND_DIR) not in sys.path: |
|
|
sys.path.insert(0, str(BACKEND_DIR)) |
|
|
|
|
|
from app.config import settings |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
async def check_endpoint_health() -> tuple[bool, str]: |
|
|
""" |
|
|
Send a minimal request to the endpoint. |
|
|
Returns (is_healthy, message). |
|
|
""" |
|
|
try: |
|
|
from openai import AsyncOpenAI |
|
|
|
|
|
client = AsyncOpenAI( |
|
|
api_key=settings.medgemma_api_key or "not-needed", |
|
|
base_url=settings.medgemma_base_url or "http://localhost:8000/v1", |
|
|
timeout=30.0, |
|
|
) |
|
|
resp = await client.chat.completions.create( |
|
|
model=settings.medgemma_model_id or "tgi", |
|
|
messages=[{"role": "user", "content": "Say OK"}], |
|
|
max_tokens=4, |
|
|
temperature=0.0, |
|
|
) |
|
|
text = resp.choices[0].message.content or "" |
|
|
return True, f"Endpoint alive: {text.strip()[:30]}" |
|
|
except Exception as e: |
|
|
msg = str(e) |
|
|
if "503" in msg: |
|
|
return False, "Endpoint returned 503 — model loading or scaled to zero" |
|
|
if "Connection" in msg or "connect" in msg.lower(): |
|
|
return False, f"Connection error: {msg[:120]}" |
|
|
return False, f"Endpoint error: {msg[:120]}" |
|
|
|
|
|
|
|
|
async def wait_for_endpoint( |
|
|
max_wait_sec: int = 600, |
|
|
poll_interval_sec: int = 30, |
|
|
quiet: bool = False, |
|
|
) -> bool: |
|
|
""" |
|
|
Wait for the MedGemma endpoint to become healthy. |
|
|
|
|
|
Polls every poll_interval_sec seconds, up to max_wait_sec total. |
|
|
Returns True if endpoint is online, False if timed out. |
|
|
|
|
|
Prints status messages to stdout unless quiet=True. |
|
|
""" |
|
|
t0 = time.monotonic() |
|
|
|
|
|
ok, msg = await check_endpoint_health() |
|
|
if ok: |
|
|
if not quiet: |
|
|
print(f"[endpoint] {msg}") |
|
|
return True |
|
|
|
|
|
if not quiet: |
|
|
print(f"[endpoint] {msg}") |
|
|
print(f"[endpoint] Waiting up to {max_wait_sec}s for endpoint to come online...") |
|
|
print(f"[endpoint] If endpoint is paused, resume it at: https://ui.endpoints.huggingface.co/") |
|
|
|
|
|
attempt = 1 |
|
|
while (time.monotonic() - t0) < max_wait_sec: |
|
|
await asyncio.sleep(poll_interval_sec) |
|
|
elapsed = int(time.monotonic() - t0) |
|
|
ok, msg = await check_endpoint_health() |
|
|
if ok: |
|
|
if not quiet: |
|
|
print(f"[endpoint] Online after {elapsed}s - {msg}") |
|
|
return True |
|
|
attempt += 1 |
|
|
if not quiet: |
|
|
print(f"[endpoint] Attempt {attempt} ({elapsed}s elapsed): {msg}") |
|
|
|
|
|
if not quiet: |
|
|
print(f"[endpoint] TIMEOUT after {max_wait_sec}s — endpoint never came online") |
|
|
print(f"[endpoint] Check https://ui.endpoints.huggingface.co/ and resume the endpoint") |
|
|
return False |
|
|
|