File size: 3,207 Bytes
28f1212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# [Shared: Track Utilities]
"""
Endpoint warm-up check — verifies the MedGemma endpoint is online
before running experiments. Handles scale-to-zero cold starts.

Usage:
    from tracks.shared.endpoint_check import wait_for_endpoint
    await wait_for_endpoint()  # blocks until endpoint responds or gives up
"""
from __future__ import annotations

import asyncio
import logging
import sys
import time
from pathlib import Path

BACKEND_DIR = Path(__file__).resolve().parent.parent.parent
if str(BACKEND_DIR) not in sys.path:
    sys.path.insert(0, str(BACKEND_DIR))

from app.config import settings

logger = logging.getLogger(__name__)


async def check_endpoint_health() -> tuple[bool, str]:
    """
    Send a minimal request to the endpoint.
    Returns (is_healthy, message).
    """
    try:
        from openai import AsyncOpenAI

        client = AsyncOpenAI(
            api_key=settings.medgemma_api_key or "not-needed",
            base_url=settings.medgemma_base_url or "http://localhost:8000/v1",
            timeout=30.0,
        )
        resp = await client.chat.completions.create(
            model=settings.medgemma_model_id or "tgi",
            messages=[{"role": "user", "content": "Say OK"}],
            max_tokens=4,
            temperature=0.0,
        )
        text = resp.choices[0].message.content or ""
        return True, f"Endpoint alive: {text.strip()[:30]}"
    except Exception as e:
        msg = str(e)
        if "503" in msg:
            return False, "Endpoint returned 503 — model loading or scaled to zero"
        if "Connection" in msg or "connect" in msg.lower():
            return False, f"Connection error: {msg[:120]}"
        return False, f"Endpoint error: {msg[:120]}"


async def wait_for_endpoint(
    max_wait_sec: int = 600,
    poll_interval_sec: int = 30,
    quiet: bool = False,
) -> bool:
    """
    Wait for the MedGemma endpoint to become healthy.

    Polls every poll_interval_sec seconds, up to max_wait_sec total.
    Returns True if endpoint is online, False if timed out.

    Prints status messages to stdout unless quiet=True.
    """
    t0 = time.monotonic()

    ok, msg = await check_endpoint_health()
    if ok:
        if not quiet:
            print(f"[endpoint] {msg}")
        return True

    if not quiet:
        print(f"[endpoint] {msg}")
        print(f"[endpoint] Waiting up to {max_wait_sec}s for endpoint to come online...")
        print(f"[endpoint] If endpoint is paused, resume it at: https://ui.endpoints.huggingface.co/")

    attempt = 1
    while (time.monotonic() - t0) < max_wait_sec:
        await asyncio.sleep(poll_interval_sec)
        elapsed = int(time.monotonic() - t0)
        ok, msg = await check_endpoint_health()
        if ok:
            if not quiet:
                print(f"[endpoint] Online after {elapsed}s - {msg}")
            return True
        attempt += 1
        if not quiet:
            print(f"[endpoint] Attempt {attempt} ({elapsed}s elapsed): {msg}")

    if not quiet:
        print(f"[endpoint] TIMEOUT after {max_wait_sec}s — endpoint never came online")
        print(f"[endpoint] Check https://ui.endpoints.huggingface.co/ and resume the endpoint")
    return False