workbench / models /vllm_runner.py
GitHub Actions
Initial ZeroGPU deployment with spaces shim
7f9dfed
Raw
History Blame Contribute Delete
5.16 kB
from __future__ import annotations
import importlib.util
import shutil
from collections.abc import Callable
from dataclasses import asdict, dataclass
from typing import Any
import requests
from models.base import BackendStatus
from models.model_catalog import ModelInfo
from models.server_helpers import host_port_args, request_openai_chat_text
from tracking.trackio_client import TrackingClient
@dataclass(frozen=True)
class VLLMConfig:
base_url: str = "http://127.0.0.1:8000"
host: str = "127.0.0.1"
port: int = 8000
tensor_parallel_size: int = 1
dtype: str = "auto"
max_model_len: int = 4096
temperature: float = 0.7
max_tokens: int = 512
timeout_seconds: float = 120
@dataclass(frozen=True)
class VLLMRunPlan:
start_command: list[str]
stop_note: str
health_url: str
metrics_url: str
chat_url: str
startup_downloads: bool
auto_model_load: bool
def to_dict(self) -> dict[str, Any]:
return asdict(self)
class VLLMService:
"""vLLM OpenAI-compatible client and local server command planner."""
def __init__(
self,
model: ModelInfo,
config: VLLMConfig | None = None,
get_func: Callable[..., requests.Response] = requests.get,
post_func: Callable[..., requests.Response] = requests.post,
) -> None:
self.model = model
self.config = config or VLLMConfig()
self.get_func = get_func
self.post_func = post_func
@staticmethod
def status(
base_url: str = "http://127.0.0.1:8000",
which_func: Callable[[str], str | None] = shutil.which,
find_spec: Callable[[str], object | None] = importlib.util.find_spec,
get_func: Callable[..., requests.Response] = requests.get,
) -> BackendStatus:
has_vllm = which_func("vllm") is not None or find_spec("vllm") is not None
try:
response = get_func(f"{base_url.rstrip('/')}/health", timeout=2)
except requests.RequestException as exc:
detail = f"vLLM server is not reachable: {exc}"
if not has_vllm:
detail += " Python package vllm or vllm CLI is also not installed."
return BackendStatus("vllm", False, detail)
if response.ok:
return BackendStatus("vllm", True, "vLLM server is reachable.")
return BackendStatus("vllm", False, f"vLLM responded with HTTP {response.status_code}.")
def run_plan(self) -> VLLMRunPlan:
return build_vllm_run_plan(self.model, self.config)
def chat(self, system_prompt: str, user_prompt: str) -> str:
status = self.status(self.config.base_url, get_func=self.get_func)
if not status.available:
return (
"[vLLM unavailable]\n\n"
f"{status.detail}\n\n"
"Install vLLM, start the planned local server command, then retry."
)
cfg = self.config
sampling = (cfg.temperature, cfg.max_tokens, cfg.timeout_seconds)
return request_openai_chat_text(
self.post_func,
cfg.base_url,
self.model.hf_id,
system_prompt,
user_prompt,
*sampling,
"vLLM",
)
def build_vllm_run_plan(model: ModelInfo, config: VLLMConfig | None = None) -> VLLMRunPlan:
cfg = config or VLLMConfig()
base_url = cfg.base_url.rstrip("/")
return VLLMRunPlan(
start_command=[
"vllm",
"serve",
model.hf_id,
*host_port_args(cfg.host, cfg.port),
"--tensor-parallel-size",
str(cfg.tensor_parallel_size),
"--dtype",
cfg.dtype,
"--max-model-len",
str(cfg.max_model_len),
],
stop_note="Stop the foreground vLLM process with Ctrl+C or your process manager.",
health_url=f"{base_url}/health",
metrics_url=f"{base_url}/metrics",
chat_url=f"{base_url}/v1/chat/completions",
startup_downloads=False,
auto_model_load=False,
)
def parse_vllm_metrics(metrics_text: str) -> dict[str, float]:
parsed: dict[str, float] = {}
for line in metrics_text.splitlines():
if not line or line.startswith("#"):
continue
name, _, value = line.partition(" ")
metric_name = name.split("{", 1)[0]
try:
parsed[metric_name] = float(value.strip())
except ValueError:
continue
return parsed
def fetch_vllm_metrics(
base_url: str = "http://127.0.0.1:8000",
get_func: Callable[..., requests.Response] = requests.get,
) -> dict[str, float]:
response = get_func(f"{base_url.rstrip('/')}/metrics", timeout=5)
response.raise_for_status()
return parse_vllm_metrics(response.text)
def log_vllm_benchmark(
metrics: dict[str, float],
model_id: str,
client: TrackingClient | None = None,
) -> str:
tracking_client = client or TrackingClient()
path = tracking_client.log(
"vllm_benchmark",
{
"model_id": model_id,
"metrics": metrics,
},
)
return str(path)