crodri's picture
Update handler.py
636b91c verified
# handler.py
from __future__ import annotations
import json
import os
import socket
import subprocess
import time
from typing import Any, Dict, List, Union
from urllib.request import Request, urlopen
from urllib.error import URLError, HTTPError
def _is_port_open(host: str, port: int, timeout_s: float = 0.5) -> bool:
try:
with socket.create_connection((host, port), timeout=timeout_s):
return True
except OSError:
return False
def _http_json(method: str, url: str, payload: Dict[str, Any] | None = None, timeout_s: float = 60.0) -> Dict[str, Any]:
data = None
headers = {"Content-Type": "application/json"}
if payload is not None:
data = json.dumps(payload).encode("utf-8")
req = Request(url, data=data, headers=headers, method=method.upper())
try:
with urlopen(req, timeout=timeout_s) as resp:
body = resp.read().decode("utf-8")
return json.loads(body) if body else {}
except HTTPError as e:
body = e.read().decode("utf-8", errors="replace")
raise RuntimeError(f"HTTP {e.code} from {url}: {body}") from e
except URLError as e:
raise RuntimeError(f"Request to {url} failed: {e}") from e
def _wait_for_server(host: str, port: int, health_url: str, timeout_s: int = 600) -> None:
start = time.time()
# wait for port
while time.time() - start < timeout_s:
if _is_port_open(host, port):
break
time.sleep(0.5)
# wait for health
while time.time() - start < timeout_s:
try:
_http_json("GET", health_url, payload=None, timeout_s=2.0)
return
except Exception:
time.sleep(0.5)
raise RuntimeError(f"SGLang server not ready within {timeout_s}s (health={health_url})")
def _coerce_messages(inputs: Any) -> List[Dict[str, str]]:
if isinstance(inputs, str):
return [{"role": "user", "content": inputs}]
if isinstance(inputs, list):
if all(isinstance(x, dict) for x in inputs):
msgs: List[Dict[str, str]] = []
for m in inputs:
role = str(m.get("role", "user"))
content = m.get("content", "")
msgs.append({"role": role, "content": "" if content is None else str(content)})
return msgs
if all(isinstance(x, str) for x in inputs):
return [{"role": "user", "content": "\n".join(inputs)}]
return [{"role": "user", "content": json.dumps(inputs, ensure_ascii=False)}]
def _map_params(hf_params: Dict[str, Any]) -> Dict[str, Any]:
hf_params = hf_params or {}
out: Dict[str, Any] = {"stream": False}
max_new = hf_params.get("max_new_tokens", hf_params.get("max_tokens"))
if max_new is not None:
out["max_tokens"] = int(max_new)
for k in ("temperature", "top_p", "seed", "stop", "presence_penalty", "frequency_penalty"):
if k in hf_params and hf_params[k] is not None:
out[k] = hf_params[k]
return out
class EndpointHandler:
def __init__(self, model_dir: str, **_: Any) -> None:
self.model_dir = model_dir
# Local SGLang server address
self.host = os.getenv("SGLANG_HOST", "127.0.0.1")
self.port = int(os.getenv("SGLANG_PORT", "30000"))
self.health_url = f"http://{self.host}:{self.port}/health"
self.chat_url = f"http://{self.host}:{self.port}/v1/chat/completions"
# Model path inside endpoint container (repo is mounted here)
model_path = os.getenv("SGLANG_MODEL_PATH", model_dir)
tokenizer_path = os.getenv("SGLANG_TOKENIZER_PATH", model_path)
tp_size = int(os.getenv("SGLANG_TP_SIZE", "1"))
# If the endpoint base image already has SGLang installed, this works.
# If not, you must use an SGLang-based image (recommended) rather than pip-installing it here.
launch_cmd = os.getenv("SGLANG_LAUNCH_CMD", "").strip()
if launch_cmd:
cmd = launch_cmd.split()
else:
cmd = [
"python", "-m", "sglang.launch_server",
"--model-path", model_path,
"--tokenizer-path", tokenizer_path,
"--host", "0.0.0.0",
"--port", str(self.port),
"--tp-size", str(tp_size),
]
# Helpful optional knobs
if os.getenv("SGLANG_CHUNKED_PREFILL_SIZE"):
cmd += ["--chunked-prefill-size", os.environ["SGLANG_CHUNKED_PREFILL_SIZE"]]
if os.getenv("SGLANG_MAX_RUNNING_REQUESTS"):
cmd += ["--max-running-requests", os.environ["SGLANG_MAX_RUNNING_REQUESTS"]]
self.proc = None
if not _is_port_open(self.host, self.port):
self.proc = subprocess.Popen(cmd, env=os.environ.copy())
_wait_for_server(self.host, self.port, self.health_url, timeout_s=int(os.getenv("SGLANG_STARTUP_TIMEOUT", "600")))
self.served_model_name = os.getenv("SGLANG_SERVED_MODEL_NAME", "ALIA-40b-instruct-nvfp4")
def __call__(self, data: Dict[str, Any]) -> Union[str, Dict[str, Any]]:
inputs = data.get("inputs", data)
params = data.get("parameters", {}) or {}
payload: Dict[str, Any] = {
"model": self.served_model_name,
"messages": _coerce_messages(inputs),
**_map_params(params),
}
# Optional passthrough for tool calling / response_format if you use it
for k in ("response_format", "tools", "tool_choice"):
if k in params and params[k] is not None:
payload[k] = params[k]
out = _http_json("POST", self.chat_url, payload=payload, timeout_s=float(os.getenv("SGLANG_REQUEST_TIMEOUT", "300")))
# Return plain text by default (HF UI friendly)
try:
text = out["choices"][0]["message"]["content"]
except Exception:
return out
if bool(params.get("details")):
return {"generated_text": text, "raw": out}
return text