TEMPO-BIAS / src /llms /vllm_model.py
moujar's picture
init
5b42a0e
"""
vLLM-based model interface for high-performance LLM serving.
"""
import os
import logging
import subprocess
import time
import signal
import requests
from typing import List, Dict, Any, Optional, Union
from dataclasses import dataclass
from .base_model import BaseModel
from ..constants import SUPPORTED_MODELS, MODEL_METADATA, VLLM_DEFAULT_SETTINGS
logger = logging.getLogger(__name__)
@dataclass
class VLLMServerConfig:
"""Configuration for vLLM server."""
host: str = "localhost"
port: int = 8000
model: str = ""
max_model_len: int = 4096
gpu_memory_utilization: float = 0.9
dtype: str = "auto"
tensor_parallel_size: int = 1
trust_remote_code: bool = True
@property
def api_base(self) -> str:
return f"http://{self.host}:{self.port}/v1"
class VLLMServer:
"""
Manages a vLLM server instance for serving LLMs.
Usage:
server = VLLMServer(model_name="mistral-7b-instruct")
server.start()
# Use the server...
server.stop()
Or as context manager:
with VLLMServer(model_name="mistral-7b-instruct") as server:
# Use the server...
"""
def __init__(
self,
model_name: str,
host: str = "localhost",
port: int = 8000,
max_model_len: int = 4096,
gpu_memory_utilization: float = 0.9,
tensor_parallel_size: int = 1,
**kwargs
):
# Resolve model name to HuggingFace ID
if model_name in SUPPORTED_MODELS:
self.hf_model_id = SUPPORTED_MODELS[model_name]
self.model_name = model_name
else:
self.hf_model_id = model_name
self.model_name = model_name
self.config = VLLMServerConfig(
host=host,
port=port,
model=self.hf_model_id,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
tensor_parallel_size=tensor_parallel_size,
)
self.process = None
self._started = False
def start(self, wait_for_ready: bool = True, timeout: int = 300) -> bool:
"""
Start the vLLM server.
Args:
wait_for_ready: Wait for server to be ready before returning
timeout: Maximum time to wait for server (seconds)
Returns:
True if server started successfully
"""
if self._started:
logger.warning("Server already started")
return True
cmd = [
"python", "-m", "vllm.entrypoints.openai.api_server",
"--model", self.config.model,
"--host", self.config.host,
"--port", str(self.config.port),
"--max-model-len", str(self.config.max_model_len),
"--gpu-memory-utilization", str(self.config.gpu_memory_utilization),
"--tensor-parallel-size", str(self.config.tensor_parallel_size),
]
if self.config.trust_remote_code:
cmd.append("--trust-remote-code")
logger.info(f"Starting vLLM server with command: {' '.join(cmd)}")
try:
self.process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
preexec_fn=os.setsid
)
if wait_for_ready:
return self._wait_for_ready(timeout)
self._started = True
return True
except Exception as e:
logger.error(f"Failed to start vLLM server: {e}")
return False
def _wait_for_ready(self, timeout: int = 300) -> bool:
"""Wait for server to be ready."""
start_time = time.time()
health_url = f"{self.config.api_base}/models"
while time.time() - start_time < timeout:
try:
response = requests.get(health_url, timeout=5)
if response.status_code == 200:
logger.info("vLLM server is ready!")
self._started = True
return True
except requests.exceptions.RequestException:
pass
# Check if process died
if self.process and self.process.poll() is not None:
stderr = self.process.stderr.read().decode() if self.process.stderr else ""
logger.error(f"vLLM server process died: {stderr}")
return False
time.sleep(2)
logger.info("Waiting for vLLM server to start...")
logger.error(f"vLLM server failed to start within {timeout} seconds")
return False
def stop(self):
"""Stop the vLLM server."""
if self.process:
try:
os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)
self.process.wait(timeout=10)
except Exception as e:
logger.warning(f"Error stopping server: {e}")
try:
os.killpg(os.getpgid(self.process.pid), signal.SIGKILL)
except:
pass
finally:
self.process = None
self._started = False
logger.info("vLLM server stopped")
def is_running(self) -> bool:
"""Check if server is running."""
if not self._started:
return False
try:
response = requests.get(f"{self.config.api_base}/models", timeout=5)
return response.status_code == 200
except:
return False
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
class VLLMModel(BaseModel):
"""
vLLM-based model for LLM inference using OpenAI-compatible API.
Can connect to an existing vLLM server or manage its own.
Usage:
# Connect to existing server
model = VLLMModel(model_name="mistral-7b-instruct", api_base="http://localhost:8000/v1")
# Or with managed server
model = VLLMModel(model_name="mistral-7b-instruct", start_server=True)
"""
def __init__(
self,
model_name: str,
api_base: Optional[str] = None,
api_key: str = "EMPTY",
start_server: bool = False,
server_config: Optional[Dict] = None,
**kwargs
):
super().__init__(model_name)
# Resolve model name
if model_name in SUPPORTED_MODELS:
self.hf_model_id = SUPPORTED_MODELS[model_name]
else:
self.hf_model_id = model_name
self.api_key = api_key
self.server = None
# Start server if requested
if start_server:
config = server_config or {}
self.server = VLLMServer(model_name, **config)
self.server.start()
self.api_base = self.server.config.api_base
else:
self.api_base = api_base or "http://localhost:8000/v1"
# Get model metadata
self.metadata = MODEL_METADATA.get(model_name, {})
def generate(
self,
prompt: str,
max_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.95,
stop: Optional[List[str]] = None,
**kwargs
) -> str:
"""Generate a response from the model."""
payload = {
"model": self.hf_model_id,
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
if stop:
payload["stop"] = stop
try:
response = requests.post(
f"{self.api_base}/completions",
json=payload,
headers={"Authorization": f"Bearer {self.api_key}"},
timeout=120
)
response.raise_for_status()
result = response.json()
return result["choices"][0]["text"].strip()
except Exception as e:
logger.error(f"Error generating response: {e}")
return ""
def generate_chat(
self,
messages: List[Dict[str, str]],
max_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.95,
**kwargs
) -> str:
"""Generate a chat response."""
payload = {
"model": self.hf_model_id,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
try:
response = requests.post(
f"{self.api_base}/chat/completions",
json=payload,
headers={"Authorization": f"Bearer {self.api_key}"},
timeout=120
)
response.raise_for_status()
result = response.json()
return result["choices"][0]["message"]["content"].strip()
except Exception as e:
logger.error(f"Error generating chat response: {e}")
return ""
def generate_batch(
self,
prompts: List[str],
max_tokens: int = 512,
temperature: float = 0.7,
**kwargs
) -> List[str]:
"""Generate responses for a batch of prompts."""
# vLLM handles batching internally, but we can also send multiple requests
responses = []
for prompt in prompts:
response = self.generate(prompt, max_tokens, temperature, **kwargs)
responses.append(response)
return responses
def get_response(
self,
idx: int,
stage: str,
messages: List[Dict[str, str]],
langcode: Optional[str] = None
) -> tuple:
"""
Get response compatible with the pipeline interface.
Returns:
Tuple of (response_string, cost)
"""
response = self.generate_chat(messages)
return response, 0.0 # vLLM is local, no cost
def __del__(self):
"""Cleanup server if managed."""
if self.server:
self.server.stop()
class VLLMModelFactory:
"""Factory for creating VLLMModel instances."""
@staticmethod
def create(
model_name: str,
api_base: Optional[str] = None,
**kwargs
) -> VLLMModel:
"""Create a VLLMModel instance."""
return VLLMModel(model_name, api_base=api_base, **kwargs)
@staticmethod
def list_models() -> List[str]:
"""List available models."""
return list(SUPPORTED_MODELS.keys())
@staticmethod
def get_model_info(model_name: str) -> Dict:
"""Get model metadata."""
return MODEL_METADATA.get(model_name, {})