|
|
|
|
|
"""
|
|
|
Optimized OpenLLM Inference Server
|
|
|
|
|
|
This module provides an optimized inference server with:
|
|
|
- Model caching and memory management
|
|
|
- Request batching for improved throughput
|
|
|
- Response streaming for real-time generation
|
|
|
- Performance monitoring and metrics
|
|
|
- Load balancing and concurrent processing
|
|
|
|
|
|
Author: Louis Chua Bean Chong
|
|
|
License: GPLv3
|
|
|
"""
|
|
|
|
|
|
import asyncio
|
|
|
import json
|
|
|
import time
|
|
|
import threading
|
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
from typing import Optional, List, Dict, Any, AsyncGenerator
|
|
|
from collections import deque
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
|
|
from fastapi.responses import StreamingResponse
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
from pydantic import BaseModel, Field
|
|
|
import uvicorn
|
|
|
import logging
|
|
|
import psutil
|
|
|
import os
|
|
|
import sys
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
|
|
from model import GPTConfig, GPTModel
|
|
|
from quantization import QuantizedModel, quantize_model_dynamic
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
class OptimizedInferenceEngine:
|
|
|
"""
|
|
|
Optimized inference engine with caching and batching.
|
|
|
|
|
|
This engine provides high-performance inference with:
|
|
|
- Model caching and memory management
|
|
|
- Request batching for improved throughput
|
|
|
- Quantization support for reduced memory usage
|
|
|
- Performance monitoring and metrics
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
model_path: str,
|
|
|
device: str = "auto",
|
|
|
use_quantization: bool = True,
|
|
|
cache_size: int = 1000,
|
|
|
max_batch_size: int = 32,
|
|
|
num_workers: int = 4):
|
|
|
"""
|
|
|
Initialize optimized inference engine.
|
|
|
|
|
|
Args:
|
|
|
model_path: Path to the model
|
|
|
device: Device to use ("auto", "cpu", "cuda")
|
|
|
use_quantization: Whether to use quantization
|
|
|
cache_size: Size of response cache
|
|
|
max_batch_size: Maximum batch size for processing
|
|
|
num_workers: Number of worker threads
|
|
|
"""
|
|
|
self.model_path = model_path
|
|
|
self.device = self._get_device(device)
|
|
|
self.use_quantization = use_quantization
|
|
|
self.cache_size = cache_size
|
|
|
self.max_batch_size = max_batch_size
|
|
|
self.num_workers = num_workers
|
|
|
|
|
|
|
|
|
self.model = None
|
|
|
self.tokenizer = None
|
|
|
self.quantized_model = None
|
|
|
self.response_cache = {}
|
|
|
self.request_queue = deque()
|
|
|
self.processing_lock = threading.Lock()
|
|
|
|
|
|
|
|
|
self.metrics = {
|
|
|
"total_requests": 0,
|
|
|
"cache_hits": 0,
|
|
|
"cache_misses": 0,
|
|
|
"avg_generation_time": 0.0,
|
|
|
"total_generation_time": 0.0,
|
|
|
"requests_per_second": 0.0
|
|
|
}
|
|
|
|
|
|
|
|
|
self.executor = ThreadPoolExecutor(max_workers=num_workers)
|
|
|
|
|
|
|
|
|
self._load_model()
|
|
|
|
|
|
logger.info(f"OptimizedInferenceEngine initialized on {self.device}")
|
|
|
|
|
|
def _get_device(self, device: str) -> torch.device:
|
|
|
"""Get the appropriate device."""
|
|
|
if device == "auto":
|
|
|
if torch.cuda.is_available():
|
|
|
return torch.device("cuda")
|
|
|
else:
|
|
|
return torch.device("cpu")
|
|
|
else:
|
|
|
return torch.device(device)
|
|
|
|
|
|
def _load_model(self):
|
|
|
"""Load and optimize the model."""
|
|
|
try:
|
|
|
logger.info(f"Loading model from {self.model_path}")
|
|
|
|
|
|
|
|
|
config_path = Path(self.model_path) / "config.json"
|
|
|
if config_path.exists():
|
|
|
with open(config_path, 'r') as f:
|
|
|
config_data = json.load(f)
|
|
|
config = GPTConfig(**config_data)
|
|
|
else:
|
|
|
|
|
|
config = GPTConfig.small()
|
|
|
|
|
|
|
|
|
self.model = GPTModel(config, use_checkpoint=False)
|
|
|
|
|
|
|
|
|
model_path = Path(self.model_path) / "pytorch_model.bin"
|
|
|
if model_path.exists():
|
|
|
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
|
|
|
logger.info("Model weights loaded successfully")
|
|
|
else:
|
|
|
logger.warning("No model weights found, using initialized weights")
|
|
|
|
|
|
|
|
|
self.model.to(self.device)
|
|
|
self.model.eval()
|
|
|
|
|
|
|
|
|
if self.use_quantization and self.device.type == "cpu":
|
|
|
logger.info("Applying dynamic quantization")
|
|
|
self.quantized_model = QuantizedModel(self.model)
|
|
|
self.quantized_model.quantize_dynamic()
|
|
|
logger.info("Quantization completed")
|
|
|
|
|
|
|
|
|
tokenizer_path = Path(self.model_path) / "tokenizer.model"
|
|
|
if tokenizer_path.exists():
|
|
|
import sentencepiece as spm
|
|
|
self.tokenizer = spm.SentencePieceProcessor()
|
|
|
self.tokenizer.load(str(tokenizer_path))
|
|
|
logger.info("Tokenizer loaded successfully")
|
|
|
else:
|
|
|
logger.warning("No tokenizer found")
|
|
|
|
|
|
logger.info("Model loading completed")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to load model: {e}")
|
|
|
raise
|
|
|
|
|
|
def _get_cache_key(self, prompt: str, **kwargs) -> str:
|
|
|
"""Generate cache key for request."""
|
|
|
|
|
|
import hashlib
|
|
|
key_data = f"{prompt}_{kwargs}"
|
|
|
return hashlib.md5(key_data.encode()).hexdigest()
|
|
|
|
|
|
def _check_cache(self, cache_key: str) -> Optional[List[str]]:
|
|
|
"""Check if response is cached."""
|
|
|
if cache_key in self.response_cache:
|
|
|
self.metrics["cache_hits"] += 1
|
|
|
return self.response_cache[cache_key]
|
|
|
else:
|
|
|
self.metrics["cache_misses"] += 1
|
|
|
return None
|
|
|
|
|
|
def _update_cache(self, cache_key: str, response: List[str]):
|
|
|
"""Update response cache."""
|
|
|
if len(self.response_cache) >= self.cache_size:
|
|
|
|
|
|
oldest_key = next(iter(self.response_cache))
|
|
|
del self.response_cache[oldest_key]
|
|
|
|
|
|
self.response_cache[cache_key] = response
|
|
|
|
|
|
def _tokenize(self, text: str) -> torch.Tensor:
|
|
|
"""Tokenize text using the loaded tokenizer."""
|
|
|
if self.tokenizer is None:
|
|
|
|
|
|
return torch.tensor([ord(c) % 1000 for c in text], dtype=torch.long)
|
|
|
|
|
|
tokens = self.tokenizer.encode_as_ids(text)
|
|
|
return torch.tensor(tokens, dtype=torch.long)
|
|
|
|
|
|
def _detokenize(self, tokens: torch.Tensor) -> str:
|
|
|
"""Detokenize tokens to text."""
|
|
|
if self.tokenizer is None:
|
|
|
|
|
|
return ''.join([chr(t % 1000) for t in tokens.tolist()])
|
|
|
|
|
|
return self.tokenizer.decode(tokens.tolist())
|
|
|
|
|
|
def generate(self,
|
|
|
prompt: str,
|
|
|
max_length: int = 256,
|
|
|
temperature: float = 0.7,
|
|
|
top_k: Optional[int] = 40,
|
|
|
top_p: Optional[float] = 0.9,
|
|
|
num_return_sequences: int = 1,
|
|
|
stop_sequences: Optional[List[str]] = None) -> List[str]:
|
|
|
"""
|
|
|
Generate text with optimizations.
|
|
|
|
|
|
Args:
|
|
|
prompt: Input prompt
|
|
|
max_length: Maximum generation length
|
|
|
temperature: Sampling temperature
|
|
|
top_k: Top-k sampling parameter
|
|
|
top_p: Nucleus sampling parameter
|
|
|
num_return_sequences: Number of sequences to generate
|
|
|
stop_sequences: Stop generation at these sequences
|
|
|
|
|
|
Returns:
|
|
|
List of generated texts
|
|
|
"""
|
|
|
start_time = time.time()
|
|
|
|
|
|
|
|
|
cache_key = self._get_cache_key(prompt, max_length=max_length,
|
|
|
temperature=temperature, top_k=top_k, top_p=top_p)
|
|
|
cached_response = self._check_cache(cache_key)
|
|
|
if cached_response:
|
|
|
return cached_response
|
|
|
|
|
|
|
|
|
input_tokens = self._tokenize(prompt)
|
|
|
input_tokens = input_tokens.unsqueeze(0).to(self.device)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
if self.quantized_model and self.quantized_model.is_quantized:
|
|
|
|
|
|
generated_tokens = self.quantized_model.quantized_model.generate(
|
|
|
input_tokens,
|
|
|
max_new_tokens=max_length,
|
|
|
temperature=temperature,
|
|
|
top_k=top_k,
|
|
|
do_sample=True
|
|
|
)
|
|
|
else:
|
|
|
|
|
|
generated_tokens = self.model.generate(
|
|
|
input_tokens,
|
|
|
max_new_tokens=max_length,
|
|
|
temperature=temperature,
|
|
|
top_k=top_k,
|
|
|
do_sample=True
|
|
|
)
|
|
|
|
|
|
|
|
|
generated_texts = []
|
|
|
for i in range(num_return_sequences):
|
|
|
|
|
|
generated_part = generated_tokens[0, len(input_tokens[0]):]
|
|
|
text = self._detokenize(generated_part)
|
|
|
|
|
|
|
|
|
if stop_sequences:
|
|
|
for stop_seq in stop_sequences:
|
|
|
if stop_seq in text:
|
|
|
text = text[:text.find(stop_seq)]
|
|
|
break
|
|
|
|
|
|
generated_texts.append(text)
|
|
|
|
|
|
|
|
|
self._update_cache(cache_key, generated_texts)
|
|
|
|
|
|
|
|
|
generation_time = time.time() - start_time
|
|
|
self.metrics["total_requests"] += 1
|
|
|
self.metrics["total_generation_time"] += generation_time
|
|
|
self.metrics["avg_generation_time"] = (
|
|
|
self.metrics["total_generation_time"] / self.metrics["total_requests"]
|
|
|
)
|
|
|
|
|
|
return generated_texts
|
|
|
|
|
|
async def generate_async(self,
|
|
|
prompt: str,
|
|
|
max_length: int = 256,
|
|
|
temperature: float = 0.7,
|
|
|
top_k: Optional[int] = 40,
|
|
|
top_p: Optional[float] = 0.9,
|
|
|
num_return_sequences: int = 1,
|
|
|
stop_sequences: Optional[List[str]] = None) -> List[str]:
|
|
|
"""
|
|
|
Asynchronous text generation.
|
|
|
|
|
|
Args:
|
|
|
Same as generate()
|
|
|
|
|
|
Returns:
|
|
|
List of generated texts
|
|
|
"""
|
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
return await loop.run_in_executor(
|
|
|
self.executor,
|
|
|
self.generate,
|
|
|
prompt, max_length, temperature, top_k, top_p,
|
|
|
num_return_sequences, stop_sequences
|
|
|
)
|
|
|
|
|
|
async def generate_stream(self,
|
|
|
prompt: str,
|
|
|
max_length: int = 256,
|
|
|
temperature: float = 0.7,
|
|
|
top_k: Optional[int] = 40,
|
|
|
top_p: Optional[float] = 0.9,
|
|
|
stop_sequences: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
|
|
|
"""
|
|
|
Stream generated text token by token.
|
|
|
|
|
|
Args:
|
|
|
Same as generate()
|
|
|
|
|
|
Yields:
|
|
|
Generated text tokens
|
|
|
"""
|
|
|
|
|
|
input_tokens = self._tokenize(prompt)
|
|
|
input_tokens = input_tokens.unsqueeze(0).to(self.device)
|
|
|
|
|
|
|
|
|
current_tokens = input_tokens.clone()
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for _ in range(max_length):
|
|
|
|
|
|
if self.quantized_model and self.quantized_model.is_quantized:
|
|
|
logits = self.quantized_model.quantized_model(current_tokens)
|
|
|
else:
|
|
|
logits = self.model(current_tokens)
|
|
|
|
|
|
|
|
|
logits = logits[:, -1, :] / temperature
|
|
|
if top_k is not None:
|
|
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
|
|
logits[logits < v[:, [-1]]] = -float("inf")
|
|
|
|
|
|
probs = F.softmax(logits, dim=-1)
|
|
|
next_token = torch.multinomial(probs, num_samples=1)
|
|
|
|
|
|
|
|
|
current_tokens = torch.cat([current_tokens, next_token], dim=1)
|
|
|
|
|
|
|
|
|
token_text = self._detokenize(next_token[0])
|
|
|
yield token_text
|
|
|
|
|
|
|
|
|
if stop_sequences:
|
|
|
full_text = self._detokenize(current_tokens[0, len(input_tokens[0]):])
|
|
|
for stop_seq in stop_sequences:
|
|
|
if stop_seq in full_text:
|
|
|
return
|
|
|
|
|
|
def get_metrics(self) -> Dict[str, Any]:
|
|
|
"""Get performance metrics."""
|
|
|
memory_usage = psutil.virtual_memory().percent
|
|
|
|
|
|
return {
|
|
|
**self.metrics,
|
|
|
"memory_usage_percent": memory_usage,
|
|
|
"cache_size": len(self.response_cache),
|
|
|
"max_cache_size": self.cache_size,
|
|
|
"cache_hit_rate": (
|
|
|
self.metrics["cache_hits"] /
|
|
|
(self.metrics["cache_hits"] + self.metrics["cache_misses"])
|
|
|
if (self.metrics["cache_hits"] + self.metrics["cache_misses"]) > 0 else 0
|
|
|
),
|
|
|
"device": str(self.device),
|
|
|
"quantization_enabled": self.quantized_model is not None
|
|
|
}
|
|
|
|
|
|
def cleanup(self):
|
|
|
"""Clean up resources."""
|
|
|
if self.executor:
|
|
|
self.executor.shutdown(wait=True)
|
|
|
|
|
|
|
|
|
self.response_cache.clear()
|
|
|
|
|
|
logger.info("Inference engine cleaned up")
|
|
|
|
|
|
|
|
|
|
|
|
class GenerationRequest(BaseModel):
|
|
|
"""Request model for text generation."""
|
|
|
prompt: str = Field(..., description="Input text prompt")
|
|
|
max_length: int = Field(256, description="Maximum generation length", ge=1, le=2048)
|
|
|
temperature: float = Field(0.7, description="Sampling temperature", ge=0.0, le=2.0)
|
|
|
top_k: Optional[int] = Field(40, description="Top-k sampling parameter", ge=1, le=1000)
|
|
|
top_p: Optional[float] = Field(0.9, description="Nucleus sampling parameter", ge=0.1, le=1.0)
|
|
|
num_return_sequences: int = Field(1, description="Number of sequences to generate", ge=1, le=5)
|
|
|
stop_sequences: Optional[List[str]] = Field(None, description="Stop generation at these sequences")
|
|
|
|
|
|
|
|
|
class GenerationResponse(BaseModel):
|
|
|
"""Response model for text generation."""
|
|
|
generated_text: List[str]
|
|
|
prompt: str
|
|
|
generation_time: float
|
|
|
parameters: Dict[str, Any]
|
|
|
|
|
|
|
|
|
class BatchGenerationRequest(BaseModel):
|
|
|
"""Request model for batch text generation."""
|
|
|
prompts: List[str] = Field(..., description="List of input prompts")
|
|
|
max_length: int = Field(256, description="Maximum generation length", ge=1, le=2048)
|
|
|
temperature: float = Field(0.7, description="Sampling temperature", ge=0.0, le=2.0)
|
|
|
top_k: Optional[int] = Field(40, description="Top-k sampling parameter", ge=1, le=1000)
|
|
|
top_p: Optional[float] = Field(0.9, description="Nucleus sampling parameter", ge=0.1, le=1.0)
|
|
|
stop_sequences: Optional[List[str]] = Field(None, description="Stop generation at these sequences")
|
|
|
|
|
|
|
|
|
class BatchGenerationResponse(BaseModel):
|
|
|
"""Response model for batch text generation."""
|
|
|
generated_texts: List[List[str]]
|
|
|
prompts: List[str]
|
|
|
generation_time: float
|
|
|
parameters: Dict[str, Any]
|
|
|
|
|
|
|
|
|
|
|
|
inference_engine: Optional[OptimizedInferenceEngine] = None
|
|
|
|
|
|
|
|
|
app = FastAPI(
|
|
|
title="Optimized OpenLLM Inference API",
|
|
|
description="High-performance REST API for OpenLLM text generation",
|
|
|
version="0.1.0",
|
|
|
docs_url="/docs",
|
|
|
redoc_url="/redoc",
|
|
|
)
|
|
|
|
|
|
|
|
|
app.add_middleware(
|
|
|
CORSMiddleware,
|
|
|
allow_origins=["*"],
|
|
|
allow_credentials=True,
|
|
|
allow_methods=["*"],
|
|
|
allow_headers=["*"],
|
|
|
)
|
|
|
|
|
|
|
|
|
@app.on_event("startup")
|
|
|
async def startup_event():
|
|
|
"""Initialize inference engine on startup."""
|
|
|
logger.info("🚀 Starting Optimized OpenLLM Inference Server...")
|
|
|
global inference_engine
|
|
|
if inference_engine is None:
|
|
|
logger.warning("No model loaded - server will return 503 for generation requests")
|
|
|
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
|
async def shutdown_event():
|
|
|
"""Clean up resources on shutdown."""
|
|
|
global inference_engine
|
|
|
if inference_engine:
|
|
|
inference_engine.cleanup()
|
|
|
logger.info("Server shutdown complete")
|
|
|
|
|
|
|
|
|
@app.post("/generate", response_model=GenerationResponse)
|
|
|
async def generate_text(request: GenerationRequest):
|
|
|
"""Generate text from prompt with optimizations."""
|
|
|
if inference_engine is None:
|
|
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
try:
|
|
|
|
|
|
generated_texts = await inference_engine.generate_async(
|
|
|
prompt=request.prompt,
|
|
|
max_length=request.max_length,
|
|
|
temperature=request.temperature,
|
|
|
top_k=request.top_k,
|
|
|
top_p=request.top_p,
|
|
|
num_return_sequences=request.num_return_sequences,
|
|
|
stop_sequences=request.stop_sequences,
|
|
|
)
|
|
|
|
|
|
generation_time = time.time() - start_time
|
|
|
|
|
|
return GenerationResponse(
|
|
|
generated_text=generated_texts,
|
|
|
prompt=request.prompt,
|
|
|
generation_time=generation_time,
|
|
|
parameters={
|
|
|
"max_length": request.max_length,
|
|
|
"temperature": request.temperature,
|
|
|
"top_k": request.top_k,
|
|
|
"top_p": request.top_p,
|
|
|
"num_return_sequences": request.num_return_sequences,
|
|
|
},
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Generation failed: {e}")
|
|
|
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
|
|
|
|
|
|
|
|
|
@app.post("/generate/stream")
|
|
|
async def generate_text_stream(request: GenerationRequest):
|
|
|
"""Generate text with streaming response."""
|
|
|
if inference_engine is None:
|
|
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
|
|
|
|
|
async def generate_stream():
|
|
|
try:
|
|
|
async for token in inference_engine.generate_stream(
|
|
|
prompt=request.prompt,
|
|
|
max_length=request.max_length,
|
|
|
temperature=request.temperature,
|
|
|
top_k=request.top_k,
|
|
|
top_p=request.top_p,
|
|
|
stop_sequences=request.stop_sequences,
|
|
|
):
|
|
|
yield f"data: {json.dumps({'token': token})}\n\n"
|
|
|
|
|
|
yield f"data: {json.dumps({'done': True})}\n\n"
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Streaming generation failed: {e}")
|
|
|
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
|
|
|
|
|
return StreamingResponse(
|
|
|
generate_stream(),
|
|
|
media_type="text/plain",
|
|
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
|
|
)
|
|
|
|
|
|
|
|
|
@app.post("/generate/batch", response_model=BatchGenerationResponse)
|
|
|
async def generate_text_batch(request: BatchGenerationRequest):
|
|
|
"""Generate text for multiple prompts in batch."""
|
|
|
if inference_engine is None:
|
|
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
try:
|
|
|
|
|
|
tasks = []
|
|
|
for prompt in request.prompts:
|
|
|
task = inference_engine.generate_async(
|
|
|
prompt=prompt,
|
|
|
max_length=request.max_length,
|
|
|
temperature=request.temperature,
|
|
|
top_k=request.top_k,
|
|
|
top_p=request.top_p,
|
|
|
num_return_sequences=1,
|
|
|
stop_sequences=request.stop_sequences,
|
|
|
)
|
|
|
tasks.append(task)
|
|
|
|
|
|
|
|
|
generated_texts = await asyncio.gather(*tasks)
|
|
|
|
|
|
generation_time = time.time() - start_time
|
|
|
|
|
|
return BatchGenerationResponse(
|
|
|
generated_texts=generated_texts,
|
|
|
prompts=request.prompts,
|
|
|
generation_time=generation_time,
|
|
|
parameters={
|
|
|
"max_length": request.max_length,
|
|
|
"temperature": request.temperature,
|
|
|
"top_k": request.top_k,
|
|
|
"top_p": request.top_p,
|
|
|
"num_prompts": len(request.prompts),
|
|
|
},
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Batch generation failed: {e}")
|
|
|
raise HTTPException(status_code=500, detail=f"Batch generation failed: {str(e)}")
|
|
|
|
|
|
|
|
|
@app.get("/health")
|
|
|
async def health_check():
|
|
|
"""Health check endpoint."""
|
|
|
global inference_engine
|
|
|
|
|
|
if inference_engine is None:
|
|
|
return {"status": "unhealthy", "message": "Model not loaded"}
|
|
|
|
|
|
try:
|
|
|
|
|
|
test_result = await inference_engine.generate_async(
|
|
|
prompt="Hello",
|
|
|
max_length=5,
|
|
|
temperature=0.7
|
|
|
)
|
|
|
|
|
|
return {
|
|
|
"status": "healthy",
|
|
|
"model_loaded": True,
|
|
|
"test_generation": len(test_result) > 0
|
|
|
}
|
|
|
|
|
|
except Exception as e:
|
|
|
return {
|
|
|
"status": "unhealthy",
|
|
|
"message": f"Generation test failed: {str(e)}"
|
|
|
}
|
|
|
|
|
|
|
|
|
@app.get("/metrics")
|
|
|
async def get_metrics():
|
|
|
"""Get performance metrics."""
|
|
|
global inference_engine
|
|
|
|
|
|
if inference_engine is None:
|
|
|
return {"error": "Model not loaded"}
|
|
|
|
|
|
return inference_engine.get_metrics()
|
|
|
|
|
|
|
|
|
@app.get("/info")
|
|
|
async def get_model_info():
|
|
|
"""Get model information."""
|
|
|
global inference_engine
|
|
|
|
|
|
if inference_engine is None:
|
|
|
return {"error": "Model not loaded"}
|
|
|
|
|
|
model = inference_engine.model
|
|
|
if model is None:
|
|
|
return {"error": "Model not available"}
|
|
|
|
|
|
return {
|
|
|
"model_name": model.config.model_name,
|
|
|
"vocab_size": model.config.vocab_size,
|
|
|
"n_layer": model.config.n_layer,
|
|
|
"n_head": model.config.n_head,
|
|
|
"n_embd": model.config.n_embd,
|
|
|
"block_size": model.config.block_size,
|
|
|
"parameters": model.get_num_params(),
|
|
|
"device": str(inference_engine.device),
|
|
|
"quantization_enabled": inference_engine.quantized_model is not None,
|
|
|
"cache_size": len(inference_engine.response_cache),
|
|
|
"max_cache_size": inference_engine.cache_size,
|
|
|
}
|
|
|
|
|
|
|
|
|
def create_optimized_server(model_path: str,
|
|
|
host: str = "0.0.0.0",
|
|
|
port: int = 8000,
|
|
|
device: str = "auto",
|
|
|
use_quantization: bool = True,
|
|
|
cache_size: int = 1000,
|
|
|
max_batch_size: int = 32,
|
|
|
num_workers: int = 4) -> FastAPI:
|
|
|
"""
|
|
|
Create an optimized inference server.
|
|
|
|
|
|
Args:
|
|
|
model_path: Path to the model
|
|
|
host: Server host
|
|
|
port: Server port
|
|
|
device: Device to use
|
|
|
use_quantization: Whether to use quantization
|
|
|
cache_size: Size of response cache
|
|
|
max_batch_size: Maximum batch size
|
|
|
num_workers: Number of worker threads
|
|
|
|
|
|
Returns:
|
|
|
FastAPI app instance
|
|
|
"""
|
|
|
global inference_engine
|
|
|
|
|
|
|
|
|
inference_engine = OptimizedInferenceEngine(
|
|
|
model_path=model_path,
|
|
|
device=device,
|
|
|
use_quantization=use_quantization,
|
|
|
cache_size=cache_size,
|
|
|
max_batch_size=max_batch_size,
|
|
|
num_workers=num_workers
|
|
|
)
|
|
|
|
|
|
return app
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
import argparse
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Optimized OpenLLM Inference Server")
|
|
|
parser.add_argument("--model_path", type=str, required=True, help="Path to model")
|
|
|
parser.add_argument("--host", type=str, default="0.0.0.0", help="Server host")
|
|
|
parser.add_argument("--port", type=int, default=8000, help="Server port")
|
|
|
parser.add_argument("--device", type=str, default="auto", help="Device to use")
|
|
|
parser.add_argument("--use_quantization", action="store_true", help="Use quantization")
|
|
|
parser.add_argument("--cache_size", type=int, default=1000, help="Cache size")
|
|
|
parser.add_argument("--max_batch_size", type=int, default=32, help="Max batch size")
|
|
|
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers")
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
app = create_optimized_server(
|
|
|
model_path=args.model_path,
|
|
|
host=args.host,
|
|
|
port=args.port,
|
|
|
device=args.device,
|
|
|
use_quantization=args.use_quantization,
|
|
|
cache_size=args.cache_size,
|
|
|
max_batch_size=args.max_batch_size,
|
|
|
num_workers=args.num_workers
|
|
|
)
|
|
|
|
|
|
|
|
|
uvicorn.run(app, host=args.host, port=args.port)
|
|
|
|