File size: 6,030 Bytes
e48c5e9
 
 
 
 
 
 
 
636b91c
 
 
e48c5e9
 
 
 
 
 
 
 
 
 
636b91c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e48c5e9
636b91c
e48c5e9
 
 
 
636b91c
e48c5e9
 
636b91c
 
 
 
 
e48c5e9
 
 
 
 
 
 
 
 
 
 
636b91c
e48c5e9
 
 
 
 
 
636b91c
 
 
e48c5e9
636b91c
 
 
e48c5e9
636b91c
e48c5e9
 
 
 
 
 
 
 
 
 
636b91c
e48c5e9
 
 
636b91c
 
e48c5e9
636b91c
 
 
 
e48c5e9
636b91c
 
e48c5e9
 
 
 
 
636b91c
 
 
 
 
 
e48c5e9
 
636b91c
 
 
 
 
e48c5e9
636b91c
e48c5e9
636b91c
e48c5e9
 
 
 
 
 
636b91c
e48c5e9
 
 
 
636b91c
 
e48c5e9
 
636b91c
e48c5e9
 
 
 
636b91c
 
 
e48c5e9
 
 
 
 
636b91c
 
e48c5e9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# handler.py
from __future__ import annotations

import json
import os
import socket
import subprocess
import time
from typing import Any, Dict, List, Union
from urllib.request import Request, urlopen
from urllib.error import URLError, HTTPError


def _is_port_open(host: str, port: int, timeout_s: float = 0.5) -> bool:
    try:
        with socket.create_connection((host, port), timeout=timeout_s):
            return True
    except OSError:
        return False


def _http_json(method: str, url: str, payload: Dict[str, Any] | None = None, timeout_s: float = 60.0) -> Dict[str, Any]:
    data = None
    headers = {"Content-Type": "application/json"}
    if payload is not None:
        data = json.dumps(payload).encode("utf-8")
    req = Request(url, data=data, headers=headers, method=method.upper())
    try:
        with urlopen(req, timeout=timeout_s) as resp:
            body = resp.read().decode("utf-8")
            return json.loads(body) if body else {}
    except HTTPError as e:
        body = e.read().decode("utf-8", errors="replace")
        raise RuntimeError(f"HTTP {e.code} from {url}: {body}") from e
    except URLError as e:
        raise RuntimeError(f"Request to {url} failed: {e}") from e


def _wait_for_server(host: str, port: int, health_url: str, timeout_s: int = 600) -> None:
    start = time.time()
    # wait for port
    while time.time() - start < timeout_s:
        if _is_port_open(host, port):
            break
        time.sleep(0.5)
    # wait for health
    while time.time() - start < timeout_s:
        try:
            _http_json("GET", health_url, payload=None, timeout_s=2.0)
            return
        except Exception:
            time.sleep(0.5)
    raise RuntimeError(f"SGLang server not ready within {timeout_s}s (health={health_url})")


def _coerce_messages(inputs: Any) -> List[Dict[str, str]]:
    if isinstance(inputs, str):
        return [{"role": "user", "content": inputs}]
    if isinstance(inputs, list):
        if all(isinstance(x, dict) for x in inputs):
            msgs: List[Dict[str, str]] = []
            for m in inputs:
                role = str(m.get("role", "user"))
                content = m.get("content", "")
                msgs.append({"role": role, "content": "" if content is None else str(content)})
            return msgs
        if all(isinstance(x, str) for x in inputs):
            return [{"role": "user", "content": "\n".join(inputs)}]
    return [{"role": "user", "content": json.dumps(inputs, ensure_ascii=False)}]


def _map_params(hf_params: Dict[str, Any]) -> Dict[str, Any]:
    hf_params = hf_params or {}
    out: Dict[str, Any] = {"stream": False}

    max_new = hf_params.get("max_new_tokens", hf_params.get("max_tokens"))
    if max_new is not None:
        out["max_tokens"] = int(max_new)

    for k in ("temperature", "top_p", "seed", "stop", "presence_penalty", "frequency_penalty"):
        if k in hf_params and hf_params[k] is not None:
            out[k] = hf_params[k]

    return out


class EndpointHandler:
    def __init__(self, model_dir: str, **_: Any) -> None:
        self.model_dir = model_dir

        # Local SGLang server address
        self.host = os.getenv("SGLANG_HOST", "127.0.0.1")
        self.port = int(os.getenv("SGLANG_PORT", "30000"))

        self.health_url = f"http://{self.host}:{self.port}/health"
        self.chat_url = f"http://{self.host}:{self.port}/v1/chat/completions"

        # Model path inside endpoint container (repo is mounted here)
        model_path = os.getenv("SGLANG_MODEL_PATH", model_dir)
        tokenizer_path = os.getenv("SGLANG_TOKENIZER_PATH", model_path)
        tp_size = int(os.getenv("SGLANG_TP_SIZE", "1"))

        # If the endpoint base image already has SGLang installed, this works.
        # If not, you must use an SGLang-based image (recommended) rather than pip-installing it here.
        launch_cmd = os.getenv("SGLANG_LAUNCH_CMD", "").strip()
        if launch_cmd:
            cmd = launch_cmd.split()
        else:
            cmd = [
                "python", "-m", "sglang.launch_server",
                "--model-path", model_path,
                "--tokenizer-path", tokenizer_path,
                "--host", "0.0.0.0",
                "--port", str(self.port),
                "--tp-size", str(tp_size),
            ]

            # Helpful optional knobs
            if os.getenv("SGLANG_CHUNKED_PREFILL_SIZE"):
                cmd += ["--chunked-prefill-size", os.environ["SGLANG_CHUNKED_PREFILL_SIZE"]]
            if os.getenv("SGLANG_MAX_RUNNING_REQUESTS"):
                cmd += ["--max-running-requests", os.environ["SGLANG_MAX_RUNNING_REQUESTS"]]

        self.proc = None
        if not _is_port_open(self.host, self.port):
            self.proc = subprocess.Popen(cmd, env=os.environ.copy())

        _wait_for_server(self.host, self.port, self.health_url, timeout_s=int(os.getenv("SGLANG_STARTUP_TIMEOUT", "600")))

        self.served_model_name = os.getenv("SGLANG_SERVED_MODEL_NAME", "ALIA-40b-instruct-nvfp4")

    def __call__(self, data: Dict[str, Any]) -> Union[str, Dict[str, Any]]:
        inputs = data.get("inputs", data)
        params = data.get("parameters", {}) or {}

        payload: Dict[str, Any] = {
            "model": self.served_model_name,
            "messages": _coerce_messages(inputs),
            **_map_params(params),
        }

        # Optional passthrough for tool calling / response_format if you use it
        for k in ("response_format", "tools", "tool_choice"):
            if k in params and params[k] is not None:
                payload[k] = params[k]

        out = _http_json("POST", self.chat_url, payload=payload, timeout_s=float(os.getenv("SGLANG_REQUEST_TIMEOUT", "300")))

        # Return plain text by default (HF UI friendly)
        try:
            text = out["choices"][0]["message"]["content"]
        except Exception:
            return out

        if bool(params.get("details")):
            return {"generated_text": text, "raw": out}
        return text