from __future__ import annotations import tempfile import unittest from pathlib import Path from typing import Any, cast import requests from models.model_catalog import load_model_catalog from models.vllm_runner import ( VLLMConfig, VLLMService, build_vllm_run_plan, fetch_vllm_metrics, log_vllm_benchmark, parse_vllm_metrics, ) from tracking.trackio_client import TrackingClient, TrackingConfig, read_trace_rows class FakeResponse: def __init__( self, payload: dict[str, Any] | None = None, text: str = "", status_code: int = 200, ) -> None: self.payload = payload or {} self.text = text self.status_code = status_code self.ok = status_code < 400 def json(self) -> dict[str, Any]: return self.payload def raise_for_status(self) -> None: if not self.ok: raise requests.HTTPError(f"HTTP {self.status_code}") class CapturingPost: def __init__(self) -> None: self.url = "" self.payload: dict[str, Any] = {} def __call__(self, url: str, **kwargs: Any) -> requests.Response: self.url = url self.payload = dict(kwargs["json"]) return cast( requests.Response, FakeResponse({"choices": [{"message": {"content": "vllm answer"}}]}), ) class VLLMRunnerTest(unittest.TestCase): def setUp(self) -> None: self.model = load_model_catalog("config/models.yaml")["minicpm5_1b"] def test_builds_vllm_run_plan(self) -> None: plan = build_vllm_run_plan( self.model, VLLMConfig(port=8100, tensor_parallel_size=2, max_model_len=2048), ) self.assertEqual(plan.health_url, "http://127.0.0.1:8000/health") self.assertIn("serve", plan.start_command) self.assertIn("--tensor-parallel-size", plan.start_command) self.assertFalse(plan.startup_downloads) def test_status_reports_missing_package_and_unreachable_server(self) -> None: def get_health(url: str, **kwargs: Any) -> requests.Response: del url, kwargs raise requests.ConnectionError("offline") status = VLLMService.status( "http://local-vllm", which_func=lambda name: None, find_spec=lambda name: None, get_func=get_health, ) self.assertFalse(status.available) self.assertIn("offline", status.detail) self.assertIn("not installed", status.detail) def test_chat_posts_openai_compatible_payload(self) -> None: def get_health(url: str, **kwargs: Any) -> requests.Response: del url, kwargs return cast(requests.Response, FakeResponse({"status": "ok"})) post_chat = CapturingPost() service = VLLMService( self.model, VLLMConfig(base_url="http://local-vllm"), get_func=get_health, post_func=post_chat, ) answer = service.chat("system", "prompt") self.assertEqual(answer, "vllm answer") self.assertEqual(post_chat.url, "http://local-vllm/v1/chat/completions") self.assertEqual(post_chat.payload["model"], "openbmb/MiniCPM5-1B") def test_parses_prometheus_metrics(self) -> None: metrics = parse_vllm_metrics( "# HELP demo demo\n" "vllm:num_requests_running 2\n" 'vllm:gpu_cache_usage_perc{gpu="0"} 0.42\n' ) self.assertEqual(metrics["vllm:num_requests_running"], 2.0) self.assertEqual(metrics["vllm:gpu_cache_usage_perc"], 0.42) def test_fetches_metrics(self) -> None: def get_metrics(url: str, **kwargs: Any) -> requests.Response: self.assertEqual(url, "http://local-vllm/metrics") self.assertEqual(kwargs["timeout"], 5) return cast(requests.Response, FakeResponse(text="vllm:num_requests_running 1\n")) metrics = fetch_vllm_metrics("http://local-vllm", get_metrics) self.assertEqual(metrics["vllm:num_requests_running"], 1.0) def test_logs_vllm_benchmark_to_tracking(self) -> None: with tempfile.TemporaryDirectory() as tmp: path = Path(tmp) / "traces.jsonl" client = TrackingClient(TrackingConfig(local_path=str(path))) saved = log_vllm_benchmark({"latency": 1.2}, "model", client) self.assertEqual(saved, str(path)) self.assertEqual(read_trace_rows(path)[0]["event"], "vllm_benchmark") if __name__ == "__main__": unittest.main()