llm / core /src /optimized_inference_server.py
lemms's picture
Upload folder using huggingface_hub
ef6446c verified
#!/usr/bin/env python3
"""
Optimized OpenLLM Inference Server
This module provides an optimized inference server with:
- Model caching and memory management
- Request batching for improved throughput
- Response streaming for real-time generation
- Performance monitoring and metrics
- Load balancing and concurrent processing
Author: Louis Chua Bean Chong
License: GPLv3
"""
import asyncio
import json
import time
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Optional, List, Dict, Any, AsyncGenerator
from collections import deque
import torch
import torch.nn.functional as F
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import uvicorn
import logging
import psutil
import os
import sys
from pathlib import Path
# Add current directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from model import GPTConfig, GPTModel
from quantization import QuantizedModel, quantize_model_dynamic
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class OptimizedInferenceEngine:
"""
Optimized inference engine with caching and batching.
This engine provides high-performance inference with:
- Model caching and memory management
- Request batching for improved throughput
- Quantization support for reduced memory usage
- Performance monitoring and metrics
"""
def __init__(self,
model_path: str,
device: str = "auto",
use_quantization: bool = True,
cache_size: int = 1000,
max_batch_size: int = 32,
num_workers: int = 4):
"""
Initialize optimized inference engine.
Args:
model_path: Path to the model
device: Device to use ("auto", "cpu", "cuda")
use_quantization: Whether to use quantization
cache_size: Size of response cache
max_batch_size: Maximum batch size for processing
num_workers: Number of worker threads
"""
self.model_path = model_path
self.device = self._get_device(device)
self.use_quantization = use_quantization
self.cache_size = cache_size
self.max_batch_size = max_batch_size
self.num_workers = num_workers
# Initialize components
self.model = None
self.tokenizer = None
self.quantized_model = None
self.response_cache = {}
self.request_queue = deque()
self.processing_lock = threading.Lock()
# Performance metrics
self.metrics = {
"total_requests": 0,
"cache_hits": 0,
"cache_misses": 0,
"avg_generation_time": 0.0,
"total_generation_time": 0.0,
"requests_per_second": 0.0
}
# Thread pool for concurrent processing
self.executor = ThreadPoolExecutor(max_workers=num_workers)
# Load model
self._load_model()
logger.info(f"OptimizedInferenceEngine initialized on {self.device}")
def _get_device(self, device: str) -> torch.device:
"""Get the appropriate device."""
if device == "auto":
if torch.cuda.is_available():
return torch.device("cuda")
else:
return torch.device("cpu")
else:
return torch.device(device)
def _load_model(self):
"""Load and optimize the model."""
try:
logger.info(f"Loading model from {self.model_path}")
# Load model configuration
config_path = Path(self.model_path) / "config.json"
if config_path.exists():
with open(config_path, 'r') as f:
config_data = json.load(f)
config = GPTConfig(**config_data)
else:
# Use default config
config = GPTConfig.small()
# Create model
self.model = GPTModel(config, use_checkpoint=False) # No checkpointing for inference
# Load model weights
model_path = Path(self.model_path) / "pytorch_model.bin"
if model_path.exists():
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
logger.info("Model weights loaded successfully")
else:
logger.warning("No model weights found, using initialized weights")
# Move model to device
self.model.to(self.device)
self.model.eval()
# Apply quantization if requested
if self.use_quantization and self.device.type == "cpu":
logger.info("Applying dynamic quantization")
self.quantized_model = QuantizedModel(self.model)
self.quantized_model.quantize_dynamic()
logger.info("Quantization completed")
# Load tokenizer
tokenizer_path = Path(self.model_path) / "tokenizer.model"
if tokenizer_path.exists():
import sentencepiece as spm
self.tokenizer = spm.SentencePieceProcessor()
self.tokenizer.load(str(tokenizer_path))
logger.info("Tokenizer loaded successfully")
else:
logger.warning("No tokenizer found")
logger.info("Model loading completed")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
def _get_cache_key(self, prompt: str, **kwargs) -> str:
"""Generate cache key for request."""
# Create a hash of the prompt and parameters
import hashlib
key_data = f"{prompt}_{kwargs}"
return hashlib.md5(key_data.encode()).hexdigest()
def _check_cache(self, cache_key: str) -> Optional[List[str]]:
"""Check if response is cached."""
if cache_key in self.response_cache:
self.metrics["cache_hits"] += 1
return self.response_cache[cache_key]
else:
self.metrics["cache_misses"] += 1
return None
def _update_cache(self, cache_key: str, response: List[str]):
"""Update response cache."""
if len(self.response_cache) >= self.cache_size:
# Remove oldest entry
oldest_key = next(iter(self.response_cache))
del self.response_cache[oldest_key]
self.response_cache[cache_key] = response
def _tokenize(self, text: str) -> torch.Tensor:
"""Tokenize text using the loaded tokenizer."""
if self.tokenizer is None:
# Fallback to simple tokenization
return torch.tensor([ord(c) % 1000 for c in text], dtype=torch.long)
tokens = self.tokenizer.encode_as_ids(text)
return torch.tensor(tokens, dtype=torch.long)
def _detokenize(self, tokens: torch.Tensor) -> str:
"""Detokenize tokens to text."""
if self.tokenizer is None:
# Fallback to simple detokenization
return ''.join([chr(t % 1000) for t in tokens.tolist()])
return self.tokenizer.decode(tokens.tolist())
def generate(self,
prompt: str,
max_length: int = 256,
temperature: float = 0.7,
top_k: Optional[int] = 40,
top_p: Optional[float] = 0.9,
num_return_sequences: int = 1,
stop_sequences: Optional[List[str]] = None) -> List[str]:
"""
Generate text with optimizations.
Args:
prompt: Input prompt
max_length: Maximum generation length
temperature: Sampling temperature
top_k: Top-k sampling parameter
top_p: Nucleus sampling parameter
num_return_sequences: Number of sequences to generate
stop_sequences: Stop generation at these sequences
Returns:
List of generated texts
"""
start_time = time.time()
# Check cache first
cache_key = self._get_cache_key(prompt, max_length=max_length,
temperature=temperature, top_k=top_k, top_p=top_p)
cached_response = self._check_cache(cache_key)
if cached_response:
return cached_response
# Tokenize input
input_tokens = self._tokenize(prompt)
input_tokens = input_tokens.unsqueeze(0).to(self.device) # Add batch dimension
# Generate text
with torch.no_grad():
if self.quantized_model and self.quantized_model.is_quantized:
# Use quantized model
generated_tokens = self.quantized_model.quantized_model.generate(
input_tokens,
max_new_tokens=max_length,
temperature=temperature,
top_k=top_k,
do_sample=True
)
else:
# Use regular model
generated_tokens = self.model.generate(
input_tokens,
max_new_tokens=max_length,
temperature=temperature,
top_k=top_k,
do_sample=True
)
# Detokenize
generated_texts = []
for i in range(num_return_sequences):
# Extract generated part (remove input)
generated_part = generated_tokens[0, len(input_tokens[0]):]
text = self._detokenize(generated_part)
# Apply stop sequences
if stop_sequences:
for stop_seq in stop_sequences:
if stop_seq in text:
text = text[:text.find(stop_seq)]
break
generated_texts.append(text)
# Update cache
self._update_cache(cache_key, generated_texts)
# Update metrics
generation_time = time.time() - start_time
self.metrics["total_requests"] += 1
self.metrics["total_generation_time"] += generation_time
self.metrics["avg_generation_time"] = (
self.metrics["total_generation_time"] / self.metrics["total_requests"]
)
return generated_texts
async def generate_async(self,
prompt: str,
max_length: int = 256,
temperature: float = 0.7,
top_k: Optional[int] = 40,
top_p: Optional[float] = 0.9,
num_return_sequences: int = 1,
stop_sequences: Optional[List[str]] = None) -> List[str]:
"""
Asynchronous text generation.
Args:
Same as generate()
Returns:
List of generated texts
"""
# Run generation in thread pool
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
self.generate,
prompt, max_length, temperature, top_k, top_p,
num_return_sequences, stop_sequences
)
async def generate_stream(self,
prompt: str,
max_length: int = 256,
temperature: float = 0.7,
top_k: Optional[int] = 40,
top_p: Optional[float] = 0.9,
stop_sequences: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
"""
Stream generated text token by token.
Args:
Same as generate()
Yields:
Generated text tokens
"""
# Tokenize input
input_tokens = self._tokenize(prompt)
input_tokens = input_tokens.unsqueeze(0).to(self.device)
# Generate tokens one by one
current_tokens = input_tokens.clone()
with torch.no_grad():
for _ in range(max_length):
# Get next token
if self.quantized_model and self.quantized_model.is_quantized:
logits = self.quantized_model.quantized_model(current_tokens)
else:
logits = self.model(current_tokens)
# Sample next token
logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("inf")
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Add to sequence
current_tokens = torch.cat([current_tokens, next_token], dim=1)
# Convert token to text
token_text = self._detokenize(next_token[0])
yield token_text
# Check for stop sequences
if stop_sequences:
full_text = self._detokenize(current_tokens[0, len(input_tokens[0]):])
for stop_seq in stop_sequences:
if stop_seq in full_text:
return
def get_metrics(self) -> Dict[str, Any]:
"""Get performance metrics."""
memory_usage = psutil.virtual_memory().percent
return {
**self.metrics,
"memory_usage_percent": memory_usage,
"cache_size": len(self.response_cache),
"max_cache_size": self.cache_size,
"cache_hit_rate": (
self.metrics["cache_hits"] /
(self.metrics["cache_hits"] + self.metrics["cache_misses"])
if (self.metrics["cache_hits"] + self.metrics["cache_misses"]) > 0 else 0
),
"device": str(self.device),
"quantization_enabled": self.quantized_model is not None
}
def cleanup(self):
"""Clean up resources."""
if self.executor:
self.executor.shutdown(wait=True)
# Clear cache
self.response_cache.clear()
logger.info("Inference engine cleaned up")
# Request/Response models
class GenerationRequest(BaseModel):
"""Request model for text generation."""
prompt: str = Field(..., description="Input text prompt")
max_length: int = Field(256, description="Maximum generation length", ge=1, le=2048)
temperature: float = Field(0.7, description="Sampling temperature", ge=0.0, le=2.0)
top_k: Optional[int] = Field(40, description="Top-k sampling parameter", ge=1, le=1000)
top_p: Optional[float] = Field(0.9, description="Nucleus sampling parameter", ge=0.1, le=1.0)
num_return_sequences: int = Field(1, description="Number of sequences to generate", ge=1, le=5)
stop_sequences: Optional[List[str]] = Field(None, description="Stop generation at these sequences")
class GenerationResponse(BaseModel):
"""Response model for text generation."""
generated_text: List[str]
prompt: str
generation_time: float
parameters: Dict[str, Any]
class BatchGenerationRequest(BaseModel):
"""Request model for batch text generation."""
prompts: List[str] = Field(..., description="List of input prompts")
max_length: int = Field(256, description="Maximum generation length", ge=1, le=2048)
temperature: float = Field(0.7, description="Sampling temperature", ge=0.0, le=2.0)
top_k: Optional[int] = Field(40, description="Top-k sampling parameter", ge=1, le=1000)
top_p: Optional[float] = Field(0.9, description="Nucleus sampling parameter", ge=0.1, le=1.0)
stop_sequences: Optional[List[str]] = Field(None, description="Stop generation at these sequences")
class BatchGenerationResponse(BaseModel):
"""Response model for batch text generation."""
generated_texts: List[List[str]]
prompts: List[str]
generation_time: float
parameters: Dict[str, Any]
# Global inference engine
inference_engine: Optional[OptimizedInferenceEngine] = None
# FastAPI app
app = FastAPI(
title="Optimized OpenLLM Inference API",
description="High-performance REST API for OpenLLM text generation",
version="0.1.0",
docs_url="/docs",
redoc_url="/redoc",
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.on_event("startup")
async def startup_event():
"""Initialize inference engine on startup."""
logger.info("🚀 Starting Optimized OpenLLM Inference Server...")
global inference_engine
if inference_engine is None:
logger.warning("No model loaded - server will return 503 for generation requests")
@app.on_event("shutdown")
async def shutdown_event():
"""Clean up resources on shutdown."""
global inference_engine
if inference_engine:
inference_engine.cleanup()
logger.info("Server shutdown complete")
@app.post("/generate", response_model=GenerationResponse)
async def generate_text(request: GenerationRequest):
"""Generate text from prompt with optimizations."""
if inference_engine is None:
raise HTTPException(status_code=503, detail="Model not loaded")
start_time = time.time()
try:
# Generate text asynchronously
generated_texts = await inference_engine.generate_async(
prompt=request.prompt,
max_length=request.max_length,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
num_return_sequences=request.num_return_sequences,
stop_sequences=request.stop_sequences,
)
generation_time = time.time() - start_time
return GenerationResponse(
generated_text=generated_texts,
prompt=request.prompt,
generation_time=generation_time,
parameters={
"max_length": request.max_length,
"temperature": request.temperature,
"top_k": request.top_k,
"top_p": request.top_p,
"num_return_sequences": request.num_return_sequences,
},
)
except Exception as e:
logger.error(f"Generation failed: {e}")
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
@app.post("/generate/stream")
async def generate_text_stream(request: GenerationRequest):
"""Generate text with streaming response."""
if inference_engine is None:
raise HTTPException(status_code=503, detail="Model not loaded")
async def generate_stream():
try:
async for token in inference_engine.generate_stream(
prompt=request.prompt,
max_length=request.max_length,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
stop_sequences=request.stop_sequences,
):
yield f"data: {json.dumps({'token': token})}\n\n"
yield f"data: {json.dumps({'done': True})}\n\n"
except Exception as e:
logger.error(f"Streaming generation failed: {e}")
yield f"data: {json.dumps({'error': str(e)})}\n\n"
return StreamingResponse(
generate_stream(),
media_type="text/plain",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
)
@app.post("/generate/batch", response_model=BatchGenerationResponse)
async def generate_text_batch(request: BatchGenerationRequest):
"""Generate text for multiple prompts in batch."""
if inference_engine is None:
raise HTTPException(status_code=503, detail="Model not loaded")
start_time = time.time()
try:
# Process prompts in parallel
tasks = []
for prompt in request.prompts:
task = inference_engine.generate_async(
prompt=prompt,
max_length=request.max_length,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
num_return_sequences=1,
stop_sequences=request.stop_sequences,
)
tasks.append(task)
# Wait for all tasks to complete
generated_texts = await asyncio.gather(*tasks)
generation_time = time.time() - start_time
return BatchGenerationResponse(
generated_texts=generated_texts,
prompts=request.prompts,
generation_time=generation_time,
parameters={
"max_length": request.max_length,
"temperature": request.temperature,
"top_k": request.top_k,
"top_p": request.top_p,
"num_prompts": len(request.prompts),
},
)
except Exception as e:
logger.error(f"Batch generation failed: {e}")
raise HTTPException(status_code=500, detail=f"Batch generation failed: {str(e)}")
@app.get("/health")
async def health_check():
"""Health check endpoint."""
global inference_engine
if inference_engine is None:
return {"status": "unhealthy", "message": "Model not loaded"}
try:
# Quick generation test
test_result = await inference_engine.generate_async(
prompt="Hello",
max_length=5,
temperature=0.7
)
return {
"status": "healthy",
"model_loaded": True,
"test_generation": len(test_result) > 0
}
except Exception as e:
return {
"status": "unhealthy",
"message": f"Generation test failed: {str(e)}"
}
@app.get("/metrics")
async def get_metrics():
"""Get performance metrics."""
global inference_engine
if inference_engine is None:
return {"error": "Model not loaded"}
return inference_engine.get_metrics()
@app.get("/info")
async def get_model_info():
"""Get model information."""
global inference_engine
if inference_engine is None:
return {"error": "Model not loaded"}
model = inference_engine.model
if model is None:
return {"error": "Model not available"}
return {
"model_name": model.config.model_name,
"vocab_size": model.config.vocab_size,
"n_layer": model.config.n_layer,
"n_head": model.config.n_head,
"n_embd": model.config.n_embd,
"block_size": model.config.block_size,
"parameters": model.get_num_params(),
"device": str(inference_engine.device),
"quantization_enabled": inference_engine.quantized_model is not None,
"cache_size": len(inference_engine.response_cache),
"max_cache_size": inference_engine.cache_size,
}
def create_optimized_server(model_path: str,
host: str = "0.0.0.0",
port: int = 8000,
device: str = "auto",
use_quantization: bool = True,
cache_size: int = 1000,
max_batch_size: int = 32,
num_workers: int = 4) -> FastAPI:
"""
Create an optimized inference server.
Args:
model_path: Path to the model
host: Server host
port: Server port
device: Device to use
use_quantization: Whether to use quantization
cache_size: Size of response cache
max_batch_size: Maximum batch size
num_workers: Number of worker threads
Returns:
FastAPI app instance
"""
global inference_engine
# Initialize inference engine
inference_engine = OptimizedInferenceEngine(
model_path=model_path,
device=device,
use_quantization=use_quantization,
cache_size=cache_size,
max_batch_size=max_batch_size,
num_workers=num_workers
)
return app
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Optimized OpenLLM Inference Server")
parser.add_argument("--model_path", type=str, required=True, help="Path to model")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Server host")
parser.add_argument("--port", type=int, default=8000, help="Server port")
parser.add_argument("--device", type=str, default="auto", help="Device to use")
parser.add_argument("--use_quantization", action="store_true", help="Use quantization")
parser.add_argument("--cache_size", type=int, default=1000, help="Cache size")
parser.add_argument("--max_batch_size", type=int, default=32, help="Max batch size")
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers")
args = parser.parse_args()
# Create server
app = create_optimized_server(
model_path=args.model_path,
host=args.host,
port=args.port,
device=args.device,
use_quantization=args.use_quantization,
cache_size=args.cache_size,
max_batch_size=args.max_batch_size,
num_workers=args.num_workers
)
# Run server
uvicorn.run(app, host=args.host, port=args.port)