File size: 6,030 Bytes
e48c5e9 636b91c e48c5e9 636b91c e48c5e9 636b91c e48c5e9 636b91c e48c5e9 636b91c e48c5e9 636b91c e48c5e9 636b91c e48c5e9 636b91c e48c5e9 636b91c e48c5e9 636b91c e48c5e9 636b91c e48c5e9 636b91c e48c5e9 636b91c e48c5e9 636b91c e48c5e9 636b91c e48c5e9 636b91c e48c5e9 636b91c e48c5e9 636b91c e48c5e9 636b91c e48c5e9 636b91c e48c5e9 636b91c e48c5e9 636b91c e48c5e9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | # handler.py
from __future__ import annotations
import json
import os
import socket
import subprocess
import time
from typing import Any, Dict, List, Union
from urllib.request import Request, urlopen
from urllib.error import URLError, HTTPError
def _is_port_open(host: str, port: int, timeout_s: float = 0.5) -> bool:
try:
with socket.create_connection((host, port), timeout=timeout_s):
return True
except OSError:
return False
def _http_json(method: str, url: str, payload: Dict[str, Any] | None = None, timeout_s: float = 60.0) -> Dict[str, Any]:
data = None
headers = {"Content-Type": "application/json"}
if payload is not None:
data = json.dumps(payload).encode("utf-8")
req = Request(url, data=data, headers=headers, method=method.upper())
try:
with urlopen(req, timeout=timeout_s) as resp:
body = resp.read().decode("utf-8")
return json.loads(body) if body else {}
except HTTPError as e:
body = e.read().decode("utf-8", errors="replace")
raise RuntimeError(f"HTTP {e.code} from {url}: {body}") from e
except URLError as e:
raise RuntimeError(f"Request to {url} failed: {e}") from e
def _wait_for_server(host: str, port: int, health_url: str, timeout_s: int = 600) -> None:
start = time.time()
# wait for port
while time.time() - start < timeout_s:
if _is_port_open(host, port):
break
time.sleep(0.5)
# wait for health
while time.time() - start < timeout_s:
try:
_http_json("GET", health_url, payload=None, timeout_s=2.0)
return
except Exception:
time.sleep(0.5)
raise RuntimeError(f"SGLang server not ready within {timeout_s}s (health={health_url})")
def _coerce_messages(inputs: Any) -> List[Dict[str, str]]:
if isinstance(inputs, str):
return [{"role": "user", "content": inputs}]
if isinstance(inputs, list):
if all(isinstance(x, dict) for x in inputs):
msgs: List[Dict[str, str]] = []
for m in inputs:
role = str(m.get("role", "user"))
content = m.get("content", "")
msgs.append({"role": role, "content": "" if content is None else str(content)})
return msgs
if all(isinstance(x, str) for x in inputs):
return [{"role": "user", "content": "\n".join(inputs)}]
return [{"role": "user", "content": json.dumps(inputs, ensure_ascii=False)}]
def _map_params(hf_params: Dict[str, Any]) -> Dict[str, Any]:
hf_params = hf_params or {}
out: Dict[str, Any] = {"stream": False}
max_new = hf_params.get("max_new_tokens", hf_params.get("max_tokens"))
if max_new is not None:
out["max_tokens"] = int(max_new)
for k in ("temperature", "top_p", "seed", "stop", "presence_penalty", "frequency_penalty"):
if k in hf_params and hf_params[k] is not None:
out[k] = hf_params[k]
return out
class EndpointHandler:
def __init__(self, model_dir: str, **_: Any) -> None:
self.model_dir = model_dir
# Local SGLang server address
self.host = os.getenv("SGLANG_HOST", "127.0.0.1")
self.port = int(os.getenv("SGLANG_PORT", "30000"))
self.health_url = f"http://{self.host}:{self.port}/health"
self.chat_url = f"http://{self.host}:{self.port}/v1/chat/completions"
# Model path inside endpoint container (repo is mounted here)
model_path = os.getenv("SGLANG_MODEL_PATH", model_dir)
tokenizer_path = os.getenv("SGLANG_TOKENIZER_PATH", model_path)
tp_size = int(os.getenv("SGLANG_TP_SIZE", "1"))
# If the endpoint base image already has SGLang installed, this works.
# If not, you must use an SGLang-based image (recommended) rather than pip-installing it here.
launch_cmd = os.getenv("SGLANG_LAUNCH_CMD", "").strip()
if launch_cmd:
cmd = launch_cmd.split()
else:
cmd = [
"python", "-m", "sglang.launch_server",
"--model-path", model_path,
"--tokenizer-path", tokenizer_path,
"--host", "0.0.0.0",
"--port", str(self.port),
"--tp-size", str(tp_size),
]
# Helpful optional knobs
if os.getenv("SGLANG_CHUNKED_PREFILL_SIZE"):
cmd += ["--chunked-prefill-size", os.environ["SGLANG_CHUNKED_PREFILL_SIZE"]]
if os.getenv("SGLANG_MAX_RUNNING_REQUESTS"):
cmd += ["--max-running-requests", os.environ["SGLANG_MAX_RUNNING_REQUESTS"]]
self.proc = None
if not _is_port_open(self.host, self.port):
self.proc = subprocess.Popen(cmd, env=os.environ.copy())
_wait_for_server(self.host, self.port, self.health_url, timeout_s=int(os.getenv("SGLANG_STARTUP_TIMEOUT", "600")))
self.served_model_name = os.getenv("SGLANG_SERVED_MODEL_NAME", "ALIA-40b-instruct-nvfp4")
def __call__(self, data: Dict[str, Any]) -> Union[str, Dict[str, Any]]:
inputs = data.get("inputs", data)
params = data.get("parameters", {}) or {}
payload: Dict[str, Any] = {
"model": self.served_model_name,
"messages": _coerce_messages(inputs),
**_map_params(params),
}
# Optional passthrough for tool calling / response_format if you use it
for k in ("response_format", "tools", "tool_choice"):
if k in params and params[k] is not None:
payload[k] = params[k]
out = _http_json("POST", self.chat_url, payload=payload, timeout_s=float(os.getenv("SGLANG_REQUEST_TIMEOUT", "300")))
# Return plain text by default (HF UI friendly)
try:
text = out["choices"][0]["message"]["content"]
except Exception:
return out
if bool(params.get("details")):
return {"generated_text": text, "raw": out}
return text
|