""" vLLM-based model interface for high-performance LLM serving. """ import os import logging import subprocess import time import signal import requests from typing import List, Dict, Any, Optional, Union from dataclasses import dataclass from .base_model import BaseModel from ..constants import SUPPORTED_MODELS, MODEL_METADATA, VLLM_DEFAULT_SETTINGS logger = logging.getLogger(__name__) @dataclass class VLLMServerConfig: """Configuration for vLLM server.""" host: str = "localhost" port: int = 8000 model: str = "" max_model_len: int = 4096 gpu_memory_utilization: float = 0.9 dtype: str = "auto" tensor_parallel_size: int = 1 trust_remote_code: bool = True @property def api_base(self) -> str: return f"http://{self.host}:{self.port}/v1" class VLLMServer: """ Manages a vLLM server instance for serving LLMs. Usage: server = VLLMServer(model_name="mistral-7b-instruct") server.start() # Use the server... server.stop() Or as context manager: with VLLMServer(model_name="mistral-7b-instruct") as server: # Use the server... """ def __init__( self, model_name: str, host: str = "localhost", port: int = 8000, max_model_len: int = 4096, gpu_memory_utilization: float = 0.9, tensor_parallel_size: int = 1, **kwargs ): # Resolve model name to HuggingFace ID if model_name in SUPPORTED_MODELS: self.hf_model_id = SUPPORTED_MODELS[model_name] self.model_name = model_name else: self.hf_model_id = model_name self.model_name = model_name self.config = VLLMServerConfig( host=host, port=port, model=self.hf_model_id, max_model_len=max_model_len, gpu_memory_utilization=gpu_memory_utilization, tensor_parallel_size=tensor_parallel_size, ) self.process = None self._started = False def start(self, wait_for_ready: bool = True, timeout: int = 300) -> bool: """ Start the vLLM server. Args: wait_for_ready: Wait for server to be ready before returning timeout: Maximum time to wait for server (seconds) Returns: True if server started successfully """ if self._started: logger.warning("Server already started") return True cmd = [ "python", "-m", "vllm.entrypoints.openai.api_server", "--model", self.config.model, "--host", self.config.host, "--port", str(self.config.port), "--max-model-len", str(self.config.max_model_len), "--gpu-memory-utilization", str(self.config.gpu_memory_utilization), "--tensor-parallel-size", str(self.config.tensor_parallel_size), ] if self.config.trust_remote_code: cmd.append("--trust-remote-code") logger.info(f"Starting vLLM server with command: {' '.join(cmd)}") try: self.process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, preexec_fn=os.setsid ) if wait_for_ready: return self._wait_for_ready(timeout) self._started = True return True except Exception as e: logger.error(f"Failed to start vLLM server: {e}") return False def _wait_for_ready(self, timeout: int = 300) -> bool: """Wait for server to be ready.""" start_time = time.time() health_url = f"{self.config.api_base}/models" while time.time() - start_time < timeout: try: response = requests.get(health_url, timeout=5) if response.status_code == 200: logger.info("vLLM server is ready!") self._started = True return True except requests.exceptions.RequestException: pass # Check if process died if self.process and self.process.poll() is not None: stderr = self.process.stderr.read().decode() if self.process.stderr else "" logger.error(f"vLLM server process died: {stderr}") return False time.sleep(2) logger.info("Waiting for vLLM server to start...") logger.error(f"vLLM server failed to start within {timeout} seconds") return False def stop(self): """Stop the vLLM server.""" if self.process: try: os.killpg(os.getpgid(self.process.pid), signal.SIGTERM) self.process.wait(timeout=10) except Exception as e: logger.warning(f"Error stopping server: {e}") try: os.killpg(os.getpgid(self.process.pid), signal.SIGKILL) except: pass finally: self.process = None self._started = False logger.info("vLLM server stopped") def is_running(self) -> bool: """Check if server is running.""" if not self._started: return False try: response = requests.get(f"{self.config.api_base}/models", timeout=5) return response.status_code == 200 except: return False def __enter__(self): self.start() return self def __exit__(self, exc_type, exc_val, exc_tb): self.stop() class VLLMModel(BaseModel): """ vLLM-based model for LLM inference using OpenAI-compatible API. Can connect to an existing vLLM server or manage its own. Usage: # Connect to existing server model = VLLMModel(model_name="mistral-7b-instruct", api_base="http://localhost:8000/v1") # Or with managed server model = VLLMModel(model_name="mistral-7b-instruct", start_server=True) """ def __init__( self, model_name: str, api_base: Optional[str] = None, api_key: str = "EMPTY", start_server: bool = False, server_config: Optional[Dict] = None, **kwargs ): super().__init__(model_name) # Resolve model name if model_name in SUPPORTED_MODELS: self.hf_model_id = SUPPORTED_MODELS[model_name] else: self.hf_model_id = model_name self.api_key = api_key self.server = None # Start server if requested if start_server: config = server_config or {} self.server = VLLMServer(model_name, **config) self.server.start() self.api_base = self.server.config.api_base else: self.api_base = api_base or "http://localhost:8000/v1" # Get model metadata self.metadata = MODEL_METADATA.get(model_name, {}) def generate( self, prompt: str, max_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.95, stop: Optional[List[str]] = None, **kwargs ) -> str: """Generate a response from the model.""" payload = { "model": self.hf_model_id, "prompt": prompt, "max_tokens": max_tokens, "temperature": temperature, "top_p": top_p, } if stop: payload["stop"] = stop try: response = requests.post( f"{self.api_base}/completions", json=payload, headers={"Authorization": f"Bearer {self.api_key}"}, timeout=120 ) response.raise_for_status() result = response.json() return result["choices"][0]["text"].strip() except Exception as e: logger.error(f"Error generating response: {e}") return "" def generate_chat( self, messages: List[Dict[str, str]], max_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.95, **kwargs ) -> str: """Generate a chat response.""" payload = { "model": self.hf_model_id, "messages": messages, "max_tokens": max_tokens, "temperature": temperature, "top_p": top_p, } try: response = requests.post( f"{self.api_base}/chat/completions", json=payload, headers={"Authorization": f"Bearer {self.api_key}"}, timeout=120 ) response.raise_for_status() result = response.json() return result["choices"][0]["message"]["content"].strip() except Exception as e: logger.error(f"Error generating chat response: {e}") return "" def generate_batch( self, prompts: List[str], max_tokens: int = 512, temperature: float = 0.7, **kwargs ) -> List[str]: """Generate responses for a batch of prompts.""" # vLLM handles batching internally, but we can also send multiple requests responses = [] for prompt in prompts: response = self.generate(prompt, max_tokens, temperature, **kwargs) responses.append(response) return responses def get_response( self, idx: int, stage: str, messages: List[Dict[str, str]], langcode: Optional[str] = None ) -> tuple: """ Get response compatible with the pipeline interface. Returns: Tuple of (response_string, cost) """ response = self.generate_chat(messages) return response, 0.0 # vLLM is local, no cost def __del__(self): """Cleanup server if managed.""" if self.server: self.server.stop() class VLLMModelFactory: """Factory for creating VLLMModel instances.""" @staticmethod def create( model_name: str, api_base: Optional[str] = None, **kwargs ) -> VLLMModel: """Create a VLLMModel instance.""" return VLLMModel(model_name, api_base=api_base, **kwargs) @staticmethod def list_models() -> List[str]: """List available models.""" return list(SUPPORTED_MODELS.keys()) @staticmethod def get_model_info(model_name: str) -> Dict: """Get model metadata.""" return MODEL_METADATA.get(model_name, {})