| | """ |
| | vLLM-based model interface for high-performance LLM serving. |
| | """ |
| |
|
| | import os |
| | import logging |
| | import subprocess |
| | import time |
| | import signal |
| | import requests |
| | from typing import List, Dict, Any, Optional, Union |
| | from dataclasses import dataclass |
| |
|
| | from .base_model import BaseModel |
| | from ..constants import SUPPORTED_MODELS, MODEL_METADATA, VLLM_DEFAULT_SETTINGS |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @dataclass |
| | class VLLMServerConfig: |
| | """Configuration for vLLM server.""" |
| | host: str = "localhost" |
| | port: int = 8000 |
| | model: str = "" |
| | max_model_len: int = 4096 |
| | gpu_memory_utilization: float = 0.9 |
| | dtype: str = "auto" |
| | tensor_parallel_size: int = 1 |
| | trust_remote_code: bool = True |
| | |
| | @property |
| | def api_base(self) -> str: |
| | return f"http://{self.host}:{self.port}/v1" |
| |
|
| |
|
| | class VLLMServer: |
| | """ |
| | Manages a vLLM server instance for serving LLMs. |
| | |
| | Usage: |
| | server = VLLMServer(model_name="mistral-7b-instruct") |
| | server.start() |
| | # Use the server... |
| | server.stop() |
| | |
| | Or as context manager: |
| | with VLLMServer(model_name="mistral-7b-instruct") as server: |
| | # Use the server... |
| | """ |
| | |
| | def __init__( |
| | self, |
| | model_name: str, |
| | host: str = "localhost", |
| | port: int = 8000, |
| | max_model_len: int = 4096, |
| | gpu_memory_utilization: float = 0.9, |
| | tensor_parallel_size: int = 1, |
| | **kwargs |
| | ): |
| | |
| | if model_name in SUPPORTED_MODELS: |
| | self.hf_model_id = SUPPORTED_MODELS[model_name] |
| | self.model_name = model_name |
| | else: |
| | self.hf_model_id = model_name |
| | self.model_name = model_name |
| | |
| | self.config = VLLMServerConfig( |
| | host=host, |
| | port=port, |
| | model=self.hf_model_id, |
| | max_model_len=max_model_len, |
| | gpu_memory_utilization=gpu_memory_utilization, |
| | tensor_parallel_size=tensor_parallel_size, |
| | ) |
| | |
| | self.process = None |
| | self._started = False |
| | |
| | def start(self, wait_for_ready: bool = True, timeout: int = 300) -> bool: |
| | """ |
| | Start the vLLM server. |
| | |
| | Args: |
| | wait_for_ready: Wait for server to be ready before returning |
| | timeout: Maximum time to wait for server (seconds) |
| | |
| | Returns: |
| | True if server started successfully |
| | """ |
| | if self._started: |
| | logger.warning("Server already started") |
| | return True |
| | |
| | cmd = [ |
| | "python", "-m", "vllm.entrypoints.openai.api_server", |
| | "--model", self.config.model, |
| | "--host", self.config.host, |
| | "--port", str(self.config.port), |
| | "--max-model-len", str(self.config.max_model_len), |
| | "--gpu-memory-utilization", str(self.config.gpu_memory_utilization), |
| | "--tensor-parallel-size", str(self.config.tensor_parallel_size), |
| | ] |
| | |
| | if self.config.trust_remote_code: |
| | cmd.append("--trust-remote-code") |
| | |
| | logger.info(f"Starting vLLM server with command: {' '.join(cmd)}") |
| | |
| | try: |
| | self.process = subprocess.Popen( |
| | cmd, |
| | stdout=subprocess.PIPE, |
| | stderr=subprocess.PIPE, |
| | preexec_fn=os.setsid |
| | ) |
| | |
| | if wait_for_ready: |
| | return self._wait_for_ready(timeout) |
| | |
| | self._started = True |
| | return True |
| | |
| | except Exception as e: |
| | logger.error(f"Failed to start vLLM server: {e}") |
| | return False |
| | |
| | def _wait_for_ready(self, timeout: int = 300) -> bool: |
| | """Wait for server to be ready.""" |
| | start_time = time.time() |
| | health_url = f"{self.config.api_base}/models" |
| | |
| | while time.time() - start_time < timeout: |
| | try: |
| | response = requests.get(health_url, timeout=5) |
| | if response.status_code == 200: |
| | logger.info("vLLM server is ready!") |
| | self._started = True |
| | return True |
| | except requests.exceptions.RequestException: |
| | pass |
| | |
| | |
| | if self.process and self.process.poll() is not None: |
| | stderr = self.process.stderr.read().decode() if self.process.stderr else "" |
| | logger.error(f"vLLM server process died: {stderr}") |
| | return False |
| | |
| | time.sleep(2) |
| | logger.info("Waiting for vLLM server to start...") |
| | |
| | logger.error(f"vLLM server failed to start within {timeout} seconds") |
| | return False |
| | |
| | def stop(self): |
| | """Stop the vLLM server.""" |
| | if self.process: |
| | try: |
| | os.killpg(os.getpgid(self.process.pid), signal.SIGTERM) |
| | self.process.wait(timeout=10) |
| | except Exception as e: |
| | logger.warning(f"Error stopping server: {e}") |
| | try: |
| | os.killpg(os.getpgid(self.process.pid), signal.SIGKILL) |
| | except: |
| | pass |
| | finally: |
| | self.process = None |
| | self._started = False |
| | logger.info("vLLM server stopped") |
| | |
| | def is_running(self) -> bool: |
| | """Check if server is running.""" |
| | if not self._started: |
| | return False |
| | try: |
| | response = requests.get(f"{self.config.api_base}/models", timeout=5) |
| | return response.status_code == 200 |
| | except: |
| | return False |
| | |
| | def __enter__(self): |
| | self.start() |
| | return self |
| | |
| | def __exit__(self, exc_type, exc_val, exc_tb): |
| | self.stop() |
| |
|
| |
|
| | class VLLMModel(BaseModel): |
| | """ |
| | vLLM-based model for LLM inference using OpenAI-compatible API. |
| | |
| | Can connect to an existing vLLM server or manage its own. |
| | |
| | Usage: |
| | # Connect to existing server |
| | model = VLLMModel(model_name="mistral-7b-instruct", api_base="http://localhost:8000/v1") |
| | |
| | # Or with managed server |
| | model = VLLMModel(model_name="mistral-7b-instruct", start_server=True) |
| | """ |
| | |
| | def __init__( |
| | self, |
| | model_name: str, |
| | api_base: Optional[str] = None, |
| | api_key: str = "EMPTY", |
| | start_server: bool = False, |
| | server_config: Optional[Dict] = None, |
| | **kwargs |
| | ): |
| | super().__init__(model_name) |
| | |
| | |
| | if model_name in SUPPORTED_MODELS: |
| | self.hf_model_id = SUPPORTED_MODELS[model_name] |
| | else: |
| | self.hf_model_id = model_name |
| | |
| | self.api_key = api_key |
| | self.server = None |
| | |
| | |
| | if start_server: |
| | config = server_config or {} |
| | self.server = VLLMServer(model_name, **config) |
| | self.server.start() |
| | self.api_base = self.server.config.api_base |
| | else: |
| | self.api_base = api_base or "http://localhost:8000/v1" |
| | |
| | |
| | self.metadata = MODEL_METADATA.get(model_name, {}) |
| | |
| | def generate( |
| | self, |
| | prompt: str, |
| | max_tokens: int = 512, |
| | temperature: float = 0.7, |
| | top_p: float = 0.95, |
| | stop: Optional[List[str]] = None, |
| | **kwargs |
| | ) -> str: |
| | """Generate a response from the model.""" |
| | |
| | payload = { |
| | "model": self.hf_model_id, |
| | "prompt": prompt, |
| | "max_tokens": max_tokens, |
| | "temperature": temperature, |
| | "top_p": top_p, |
| | } |
| | |
| | if stop: |
| | payload["stop"] = stop |
| | |
| | try: |
| | response = requests.post( |
| | f"{self.api_base}/completions", |
| | json=payload, |
| | headers={"Authorization": f"Bearer {self.api_key}"}, |
| | timeout=120 |
| | ) |
| | response.raise_for_status() |
| | result = response.json() |
| | return result["choices"][0]["text"].strip() |
| | |
| | except Exception as e: |
| | logger.error(f"Error generating response: {e}") |
| | return "" |
| | |
| | def generate_chat( |
| | self, |
| | messages: List[Dict[str, str]], |
| | max_tokens: int = 512, |
| | temperature: float = 0.7, |
| | top_p: float = 0.95, |
| | **kwargs |
| | ) -> str: |
| | """Generate a chat response.""" |
| | |
| | payload = { |
| | "model": self.hf_model_id, |
| | "messages": messages, |
| | "max_tokens": max_tokens, |
| | "temperature": temperature, |
| | "top_p": top_p, |
| | } |
| | |
| | try: |
| | response = requests.post( |
| | f"{self.api_base}/chat/completions", |
| | json=payload, |
| | headers={"Authorization": f"Bearer {self.api_key}"}, |
| | timeout=120 |
| | ) |
| | response.raise_for_status() |
| | result = response.json() |
| | return result["choices"][0]["message"]["content"].strip() |
| | |
| | except Exception as e: |
| | logger.error(f"Error generating chat response: {e}") |
| | return "" |
| | |
| | def generate_batch( |
| | self, |
| | prompts: List[str], |
| | max_tokens: int = 512, |
| | temperature: float = 0.7, |
| | **kwargs |
| | ) -> List[str]: |
| | """Generate responses for a batch of prompts.""" |
| | |
| | |
| | responses = [] |
| | for prompt in prompts: |
| | response = self.generate(prompt, max_tokens, temperature, **kwargs) |
| | responses.append(response) |
| | return responses |
| | |
| | def get_response( |
| | self, |
| | idx: int, |
| | stage: str, |
| | messages: List[Dict[str, str]], |
| | langcode: Optional[str] = None |
| | ) -> tuple: |
| | """ |
| | Get response compatible with the pipeline interface. |
| | |
| | Returns: |
| | Tuple of (response_string, cost) |
| | """ |
| | response = self.generate_chat(messages) |
| | return response, 0.0 |
| | |
| | def __del__(self): |
| | """Cleanup server if managed.""" |
| | if self.server: |
| | self.server.stop() |
| |
|
| |
|
| | class VLLMModelFactory: |
| | """Factory for creating VLLMModel instances.""" |
| | |
| | @staticmethod |
| | def create( |
| | model_name: str, |
| | api_base: Optional[str] = None, |
| | **kwargs |
| | ) -> VLLMModel: |
| | """Create a VLLMModel instance.""" |
| | return VLLMModel(model_name, api_base=api_base, **kwargs) |
| | |
| | @staticmethod |
| | def list_models() -> List[str]: |
| | """List available models.""" |
| | return list(SUPPORTED_MODELS.keys()) |
| | |
| | @staticmethod |
| | def get_model_info(model_name: str) -> Dict: |
| | """Get model metadata.""" |
| | return MODEL_METADATA.get(model_name, {}) |
| |
|