from __future__ import annotations from typing import Any import gradio as gr import requests from models.model_catalog import ModelInfo from models.vllm_runner import ( VLLMConfig, VLLMService, build_vllm_run_plan, fetch_vllm_metrics, log_vllm_benchmark, ) from ui.progress import CLICK_PROGRESS from ui.server_controls import create_serving_controls def build_vllm_tab(catalog: dict[str, ModelInfo]) -> None: gr.Markdown("vLLM serving plans and local metrics checks.") controls = create_vllm_controls(catalog) command = gr.Textbox(label="vLLM command / status", interactive=False) output = gr.JSON(label="vLLM plan, status, or metrics") prepare_inputs = [ controls[key] for key in ( "selected", "base_url", "host", "port", "parallel", "dtype", "max_model_len", ) ] controls["prepare"].click( lambda *args: prepare_vllm(catalog, *args), prepare_inputs, [command, output], show_progress=CLICK_PROGRESS, ) controls["check"].click(check_vllm, controls["base_url"], output, show_progress=CLICK_PROGRESS) controls["metrics"].click( get_metrics, controls["base_url"], output, show_progress=CLICK_PROGRESS, ) controls["log_metrics"].click( log_current_metrics, [controls["selected"], controls["base_url"]], command, show_progress=CLICK_PROGRESS, ) def create_vllm_controls(catalog: dict[str, ModelInfo]) -> dict[str, Any]: controls = create_serving_controls(catalog, "vLLM", "http://127.0.0.1:8000", 8000) controls["dtype"] = gr.Textbox(label="dtype", value="auto") controls["max_model_len"] = gr.Number(label="Max model length", value=4096, precision=0) with gr.Row(): controls["prepare"] = gr.Button("Prepare vLLM command", variant="primary") controls["check"] = gr.Button("Check vLLM") controls["metrics"] = gr.Button("Fetch metrics") controls["log_metrics"] = gr.Button("Log benchmark") return controls def config_from_inputs( url: str, server_host: str, server_port: int | float, parallel_size: int | float, dtype_value: str, model_len: int | float, ) -> VLLMConfig: return VLLMConfig( base_url=url.strip() or "http://127.0.0.1:8000", host=server_host.strip() or "127.0.0.1", port=int(server_port), tensor_parallel_size=int(parallel_size), dtype=dtype_value.strip() or "auto", max_model_len=int(model_len), ) def prepare_vllm( catalog: dict[str, ModelInfo], model_id: str, url: str, server_host: str, server_port: int | float, parallel_size: int | float, dtype_value: str, model_len: int | float, ) -> tuple[str, dict]: config = config_from_inputs( url, server_host, server_port, parallel_size, dtype_value, model_len, ) plan = build_vllm_run_plan(catalog[model_id], config) return " ".join(plan.start_command), plan.to_dict() def check_vllm(url: str) -> dict: status = VLLMService.status(url.strip() or "http://127.0.0.1:8000") return {"backend": status.name, "available": status.available, "detail": status.detail} def get_metrics(url: str) -> dict: try: return fetch_vllm_metrics(url.strip() or "http://127.0.0.1:8000") except (OSError, requests.RequestException) as exc: return {"error": str(exc)} def log_current_metrics(model_id: str, url: str) -> str: parsed = get_metrics(url) if "error" in parsed: return str(parsed["error"]) return f"Logged vLLM benchmark to {log_vllm_benchmark(parsed, model_id)}"