| | |
| | from __future__ import annotations |
| |
|
| | import json |
| | import os |
| | import socket |
| | import subprocess |
| | import time |
| | from typing import Any, Dict, List, Union |
| | from urllib.request import Request, urlopen |
| | from urllib.error import URLError, HTTPError |
| |
|
| |
|
| | def _is_port_open(host: str, port: int, timeout_s: float = 0.5) -> bool: |
| | try: |
| | with socket.create_connection((host, port), timeout=timeout_s): |
| | return True |
| | except OSError: |
| | return False |
| |
|
| |
|
| | def _http_json(method: str, url: str, payload: Dict[str, Any] | None = None, timeout_s: float = 60.0) -> Dict[str, Any]: |
| | data = None |
| | headers = {"Content-Type": "application/json"} |
| | if payload is not None: |
| | data = json.dumps(payload).encode("utf-8") |
| | req = Request(url, data=data, headers=headers, method=method.upper()) |
| | try: |
| | with urlopen(req, timeout=timeout_s) as resp: |
| | body = resp.read().decode("utf-8") |
| | return json.loads(body) if body else {} |
| | except HTTPError as e: |
| | body = e.read().decode("utf-8", errors="replace") |
| | raise RuntimeError(f"HTTP {e.code} from {url}: {body}") from e |
| | except URLError as e: |
| | raise RuntimeError(f"Request to {url} failed: {e}") from e |
| |
|
| |
|
| | def _wait_for_server(host: str, port: int, health_url: str, timeout_s: int = 600) -> None: |
| | start = time.time() |
| | |
| | while time.time() - start < timeout_s: |
| | if _is_port_open(host, port): |
| | break |
| | time.sleep(0.5) |
| | |
| | while time.time() - start < timeout_s: |
| | try: |
| | _http_json("GET", health_url, payload=None, timeout_s=2.0) |
| | return |
| | except Exception: |
| | time.sleep(0.5) |
| | raise RuntimeError(f"SGLang server not ready within {timeout_s}s (health={health_url})") |
| |
|
| |
|
| | def _coerce_messages(inputs: Any) -> List[Dict[str, str]]: |
| | if isinstance(inputs, str): |
| | return [{"role": "user", "content": inputs}] |
| | if isinstance(inputs, list): |
| | if all(isinstance(x, dict) for x in inputs): |
| | msgs: List[Dict[str, str]] = [] |
| | for m in inputs: |
| | role = str(m.get("role", "user")) |
| | content = m.get("content", "") |
| | msgs.append({"role": role, "content": "" if content is None else str(content)}) |
| | return msgs |
| | if all(isinstance(x, str) for x in inputs): |
| | return [{"role": "user", "content": "\n".join(inputs)}] |
| | return [{"role": "user", "content": json.dumps(inputs, ensure_ascii=False)}] |
| |
|
| |
|
| | def _map_params(hf_params: Dict[str, Any]) -> Dict[str, Any]: |
| | hf_params = hf_params or {} |
| | out: Dict[str, Any] = {"stream": False} |
| |
|
| | max_new = hf_params.get("max_new_tokens", hf_params.get("max_tokens")) |
| | if max_new is not None: |
| | out["max_tokens"] = int(max_new) |
| |
|
| | for k in ("temperature", "top_p", "seed", "stop", "presence_penalty", "frequency_penalty"): |
| | if k in hf_params and hf_params[k] is not None: |
| | out[k] = hf_params[k] |
| |
|
| | return out |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, model_dir: str, **_: Any) -> None: |
| | self.model_dir = model_dir |
| |
|
| | |
| | self.host = os.getenv("SGLANG_HOST", "127.0.0.1") |
| | self.port = int(os.getenv("SGLANG_PORT", "30000")) |
| |
|
| | self.health_url = f"http://{self.host}:{self.port}/health" |
| | self.chat_url = f"http://{self.host}:{self.port}/v1/chat/completions" |
| |
|
| | |
| | model_path = os.getenv("SGLANG_MODEL_PATH", model_dir) |
| | tokenizer_path = os.getenv("SGLANG_TOKENIZER_PATH", model_path) |
| | tp_size = int(os.getenv("SGLANG_TP_SIZE", "1")) |
| |
|
| | |
| | |
| | launch_cmd = os.getenv("SGLANG_LAUNCH_CMD", "").strip() |
| | if launch_cmd: |
| | cmd = launch_cmd.split() |
| | else: |
| | cmd = [ |
| | "python", "-m", "sglang.launch_server", |
| | "--model-path", model_path, |
| | "--tokenizer-path", tokenizer_path, |
| | "--host", "0.0.0.0", |
| | "--port", str(self.port), |
| | "--tp-size", str(tp_size), |
| | ] |
| |
|
| | |
| | if os.getenv("SGLANG_CHUNKED_PREFILL_SIZE"): |
| | cmd += ["--chunked-prefill-size", os.environ["SGLANG_CHUNKED_PREFILL_SIZE"]] |
| | if os.getenv("SGLANG_MAX_RUNNING_REQUESTS"): |
| | cmd += ["--max-running-requests", os.environ["SGLANG_MAX_RUNNING_REQUESTS"]] |
| |
|
| | self.proc = None |
| | if not _is_port_open(self.host, self.port): |
| | self.proc = subprocess.Popen(cmd, env=os.environ.copy()) |
| |
|
| | _wait_for_server(self.host, self.port, self.health_url, timeout_s=int(os.getenv("SGLANG_STARTUP_TIMEOUT", "600"))) |
| |
|
| | self.served_model_name = os.getenv("SGLANG_SERVED_MODEL_NAME", "ALIA-40b-instruct-nvfp4") |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Union[str, Dict[str, Any]]: |
| | inputs = data.get("inputs", data) |
| | params = data.get("parameters", {}) or {} |
| |
|
| | payload: Dict[str, Any] = { |
| | "model": self.served_model_name, |
| | "messages": _coerce_messages(inputs), |
| | **_map_params(params), |
| | } |
| |
|
| | |
| | for k in ("response_format", "tools", "tool_choice"): |
| | if k in params and params[k] is not None: |
| | payload[k] = params[k] |
| |
|
| | out = _http_json("POST", self.chat_url, payload=payload, timeout_s=float(os.getenv("SGLANG_REQUEST_TIMEOUT", "300"))) |
| |
|
| | |
| | try: |
| | text = out["choices"][0]["message"]["content"] |
| | except Exception: |
| | return out |
| |
|
| | if bool(params.get("details")): |
| | return {"generated_text": text, "raw": out} |
| | return text |
| |
|