Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| import importlib.util | |
| import shutil | |
| from collections.abc import Callable | |
| from dataclasses import asdict, dataclass | |
| from typing import Any | |
| import requests | |
| from models.base import BackendStatus | |
| from models.model_catalog import ModelInfo | |
| from models.server_helpers import host_port_args, request_openai_chat_text | |
| from tracking.trackio_client import TrackingClient | |
| class VLLMConfig: | |
| base_url: str = "http://127.0.0.1:8000" | |
| host: str = "127.0.0.1" | |
| port: int = 8000 | |
| tensor_parallel_size: int = 1 | |
| dtype: str = "auto" | |
| max_model_len: int = 4096 | |
| temperature: float = 0.7 | |
| max_tokens: int = 512 | |
| timeout_seconds: float = 120 | |
| class VLLMRunPlan: | |
| start_command: list[str] | |
| stop_note: str | |
| health_url: str | |
| metrics_url: str | |
| chat_url: str | |
| startup_downloads: bool | |
| auto_model_load: bool | |
| def to_dict(self) -> dict[str, Any]: | |
| return asdict(self) | |
| class VLLMService: | |
| """vLLM OpenAI-compatible client and local server command planner.""" | |
| def __init__( | |
| self, | |
| model: ModelInfo, | |
| config: VLLMConfig | None = None, | |
| get_func: Callable[..., requests.Response] = requests.get, | |
| post_func: Callable[..., requests.Response] = requests.post, | |
| ) -> None: | |
| self.model = model | |
| self.config = config or VLLMConfig() | |
| self.get_func = get_func | |
| self.post_func = post_func | |
| def status( | |
| base_url: str = "http://127.0.0.1:8000", | |
| which_func: Callable[[str], str | None] = shutil.which, | |
| find_spec: Callable[[str], object | None] = importlib.util.find_spec, | |
| get_func: Callable[..., requests.Response] = requests.get, | |
| ) -> BackendStatus: | |
| has_vllm = which_func("vllm") is not None or find_spec("vllm") is not None | |
| try: | |
| response = get_func(f"{base_url.rstrip('/')}/health", timeout=2) | |
| except requests.RequestException as exc: | |
| detail = f"vLLM server is not reachable: {exc}" | |
| if not has_vllm: | |
| detail += " Python package vllm or vllm CLI is also not installed." | |
| return BackendStatus("vllm", False, detail) | |
| if response.ok: | |
| return BackendStatus("vllm", True, "vLLM server is reachable.") | |
| return BackendStatus("vllm", False, f"vLLM responded with HTTP {response.status_code}.") | |
| def run_plan(self) -> VLLMRunPlan: | |
| return build_vllm_run_plan(self.model, self.config) | |
| def chat(self, system_prompt: str, user_prompt: str) -> str: | |
| status = self.status(self.config.base_url, get_func=self.get_func) | |
| if not status.available: | |
| return ( | |
| "[vLLM unavailable]\n\n" | |
| f"{status.detail}\n\n" | |
| "Install vLLM, start the planned local server command, then retry." | |
| ) | |
| cfg = self.config | |
| sampling = (cfg.temperature, cfg.max_tokens, cfg.timeout_seconds) | |
| return request_openai_chat_text( | |
| self.post_func, | |
| cfg.base_url, | |
| self.model.hf_id, | |
| system_prompt, | |
| user_prompt, | |
| *sampling, | |
| "vLLM", | |
| ) | |
| def build_vllm_run_plan(model: ModelInfo, config: VLLMConfig | None = None) -> VLLMRunPlan: | |
| cfg = config or VLLMConfig() | |
| base_url = cfg.base_url.rstrip("/") | |
| return VLLMRunPlan( | |
| start_command=[ | |
| "vllm", | |
| "serve", | |
| model.hf_id, | |
| *host_port_args(cfg.host, cfg.port), | |
| "--tensor-parallel-size", | |
| str(cfg.tensor_parallel_size), | |
| "--dtype", | |
| cfg.dtype, | |
| "--max-model-len", | |
| str(cfg.max_model_len), | |
| ], | |
| stop_note="Stop the foreground vLLM process with Ctrl+C or your process manager.", | |
| health_url=f"{base_url}/health", | |
| metrics_url=f"{base_url}/metrics", | |
| chat_url=f"{base_url}/v1/chat/completions", | |
| startup_downloads=False, | |
| auto_model_load=False, | |
| ) | |
| def parse_vllm_metrics(metrics_text: str) -> dict[str, float]: | |
| parsed: dict[str, float] = {} | |
| for line in metrics_text.splitlines(): | |
| if not line or line.startswith("#"): | |
| continue | |
| name, _, value = line.partition(" ") | |
| metric_name = name.split("{", 1)[0] | |
| try: | |
| parsed[metric_name] = float(value.strip()) | |
| except ValueError: | |
| continue | |
| return parsed | |
| def fetch_vllm_metrics( | |
| base_url: str = "http://127.0.0.1:8000", | |
| get_func: Callable[..., requests.Response] = requests.get, | |
| ) -> dict[str, float]: | |
| response = get_func(f"{base_url.rstrip('/')}/metrics", timeout=5) | |
| response.raise_for_status() | |
| return parse_vllm_metrics(response.text) | |
| def log_vllm_benchmark( | |
| metrics: dict[str, float], | |
| model_id: str, | |
| client: TrackingClient | None = None, | |
| ) -> str: | |
| tracking_client = client or TrackingClient() | |
| path = tracking_client.log( | |
| "vllm_benchmark", | |
| { | |
| "model_id": model_id, | |
| "metrics": metrics, | |
| }, | |
| ) | |
| return str(path) | |