safe_rag / generator /vllm_server.py
Tairun Meng
Initial commit: SafeRAG project ready for HF Spaces
db06013
from vllm import LLM, SamplingParams
from typing import List, Dict, Any, Optional
import logging
import asyncio
from concurrent.futures import ThreadPoolExecutor
logger = logging.getLogger(__name__)
class VLLMServer:
def __init__(self, model_name: str = "openai/gpt-oss-20b",
tensor_parallel_size: int = 1, gpu_memory_utilization: float = 0.9):
self.model_name = model_name
self.tensor_parallel_size = tensor_parallel_size
self.gpu_memory_utilization = gpu_memory_utilization
self.llm = None
self.executor = ThreadPoolExecutor(max_workers=4)
def initialize(self):
"""Initialize the vLLM model"""
try:
self.llm = LLM(
model=self.model_name,
tensor_parallel_size=self.tensor_parallel_size,
gpu_memory_utilization=self.gpu_memory_utilization,
trust_remote_code=True
)
logger.info(f"Initialized vLLM with model: {self.model_name}")
except Exception as e:
logger.error(f"Failed to initialize vLLM: {e}")
raise
def generate(self, prompts: List[str],
max_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
stop: Optional[List[str]] = None) -> List[Dict[str, Any]]:
"""Generate text for prompts"""
if self.llm is None:
self.initialize()
sampling_params = SamplingParams(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop=stop
)
try:
outputs = self.llm.generate(prompts, sampling_params)
results = []
for output in outputs:
results.append({
'text': output.outputs[0].text,
'prompt': output.prompt,
'finish_reason': output.outputs[0].finish_reason,
'token_ids': output.outputs[0].token_ids,
'logprobs': getattr(output.outputs[0], 'logprobs', None)
})
return results
except Exception as e:
logger.error(f"Generation failed: {e}")
raise
def generate_single(self, prompt: str, **kwargs) -> str:
"""Generate text for a single prompt"""
results = self.generate([prompt], **kwargs)
return results[0]['text'] if results else ""
def generate_batch(self, prompts: List[str], batch_size: int = 8, **kwargs) -> List[str]:
"""Generate text for multiple prompts in batches"""
all_results = []
for i in range(0, len(prompts), batch_size):
batch_prompts = prompts[i:i + batch_size]
batch_results = self.generate(batch_prompts, **kwargs)
all_results.extend([r['text'] for r in batch_results])
return all_results
async def generate_async(self, prompts: List[str], **kwargs) -> List[Dict[str, Any]]:
"""Async generation"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(self.executor, self.generate, prompts, **kwargs)
def get_model_info(self) -> Dict[str, Any]:
"""Get model information"""
if self.llm is None:
return {}
return {
'model_name': self.model_name,
'tensor_parallel_size': self.tensor_parallel_size,
'gpu_memory_utilization': self.gpu_memory_utilization,
'is_initialized': self.llm is not None
}
def cleanup(self):
"""Cleanup resources"""
if self.executor:
self.executor.shutdown(wait=True)