Trouter-Library's picture
Create inference/server.py
dccc9c1 verified
raw
history blame
12.5 kB
#!/usr/bin/env python3
"""
Helion-2.5-Rnd Inference Server
High-performance inference server with vLLM backend
"""
import argparse
import asyncio
import json
import logging
import os
import time
from typing import AsyncGenerator, Dict, List, Optional, Union
import torch
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, Field
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
from vllm.utils import random_uuid
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class ChatMessage(BaseModel):
"""Chat message format"""
role: str = Field(..., description="Role: system, user, or assistant")
content: str = Field(..., description="Message content")
class ChatCompletionRequest(BaseModel):
"""Chat completion request format"""
model: str = Field(default="DeepXR/Helion-2.5-Rnd")
messages: List[ChatMessage]
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
top_p: float = Field(default=0.9, ge=0.0, le=1.0)
top_k: int = Field(default=50, ge=0)
max_tokens: int = Field(default=4096, ge=1)
stream: bool = Field(default=False)
stop: Optional[List[str]] = None
presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0)
n: int = Field(default=1, ge=1, le=10)
logprobs: Optional[int] = None
echo: bool = Field(default=False)
class CompletionRequest(BaseModel):
"""Text completion request format"""
model: str = Field(default="DeepXR/Helion-2.5-Rnd")
prompt: Union[str, List[str]]
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
top_p: float = Field(default=0.9, ge=0.0, le=1.0)
max_tokens: int = Field(default=4096, ge=1)
stream: bool = Field(default=False)
stop: Optional[List[str]] = None
n: int = Field(default=1, ge=1, le=10)
class HelionInferenceServer:
"""Main inference server class"""
def __init__(
self,
model_path: str,
tensor_parallel_size: int = 2,
max_model_len: int = 131072,
gpu_memory_utilization: float = 0.95,
dtype: str = "bfloat16"
):
self.model_path = model_path
self.model_name = "DeepXR/Helion-2.5-Rnd"
# Initialize vLLM engine
engine_args = AsyncEngineArgs(
model=model_path,
tensor_parallel_size=tensor_parallel_size,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
dtype=dtype,
trust_remote_code=True,
enforce_eager=False,
disable_log_stats=False,
)
logger.info(f"Initializing Helion-2.5-Rnd from {model_path}")
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
logger.info("Engine initialized successfully")
# Statistics
self.request_count = 0
self.start_time = time.time()
def format_chat_prompt(self, messages: List[ChatMessage]) -> str:
"""Format chat messages into prompt"""
formatted = ""
for msg in messages:
formatted += f"<|im_start|>{msg.role}\n{msg.content}<|im_end|>\n"
formatted += "<|im_start|>assistant\n"
return formatted
async def generate(
self,
prompt: str,
sampling_params: SamplingParams,
request_id: str
) -> AsyncGenerator[str, None]:
"""Generate text streaming"""
results_generator = self.engine.generate(
prompt,
sampling_params,
request_id
)
async for request_output in results_generator:
text = request_output.outputs[0].text
yield text
async def chat_completion(
self,
request: ChatCompletionRequest
) -> Union[Dict, AsyncGenerator]:
"""Handle chat completion request"""
request_id = f"helion-{random_uuid()}"
self.request_count += 1
# Format prompt
prompt = self.format_chat_prompt(request.messages)
# Create sampling parameters
sampling_params = SamplingParams(
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
max_tokens=request.max_tokens,
stop=request.stop or ["<|im_end|>", "<|endoftext|>"],
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
repetition_penalty=request.repetition_penalty,
n=request.n,
logprobs=request.logprobs,
)
if request.stream:
return self._stream_chat_completion(
prompt,
sampling_params,
request_id,
request.model
)
else:
return await self._complete_chat_completion(
prompt,
sampling_params,
request_id,
request.model
)
async def _complete_chat_completion(
self,
prompt: str,
sampling_params: SamplingParams,
request_id: str,
model: str
) -> Dict:
"""Non-streaming chat completion"""
final_output = None
async for request_output in self.engine.generate(
prompt, sampling_params, request_id
):
final_output = request_output
if final_output is None:
raise HTTPException(status_code=500, detail="Generation failed")
choice = {
"index": 0,
"message": {
"role": "assistant",
"content": final_output.outputs[0].text
},
"finish_reason": final_output.outputs[0].finish_reason
}
return {
"id": request_id,
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [choice],
"usage": {
"prompt_tokens": len(final_output.prompt_token_ids),
"completion_tokens": len(final_output.outputs[0].token_ids),
"total_tokens": len(final_output.prompt_token_ids) + len(final_output.outputs[0].token_ids)
}
}
async def _stream_chat_completion(
self,
prompt: str,
sampling_params: SamplingParams,
request_id: str,
model: str
) -> AsyncGenerator:
"""Streaming chat completion"""
async def generate():
previous_text = ""
async for request_output in self.engine.generate(
prompt, sampling_params, request_id
):
text = request_output.outputs[0].text
delta = text[len(previous_text):]
previous_text = text
chunk = {
"id": request_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"delta": {"content": delta},
"finish_reason": None
}]
}
yield f"data: {json.dumps(chunk)}\n\n"
# Final chunk
final_chunk = {
"id": request_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop"
}]
}
yield f"data: {json.dumps(final_chunk)}\n\n"
yield "data: [DONE]\n\n"
return generate()
# Initialize FastAPI app
app = FastAPI(
title="Helion-2.5-Rnd Inference API",
description="Advanced language model inference server",
version="2.5.0-rnd"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global server instance
server: Optional[HelionInferenceServer] = None
@app.on_event("startup")
async def startup_event():
"""Initialize the model on startup"""
global server
model_path = os.getenv("MODEL_PATH", "/models/helion")
tensor_parallel = int(os.getenv("TENSOR_PARALLEL_SIZE", "2"))
max_len = int(os.getenv("MAX_MODEL_LEN", "131072"))
gpu_util = float(os.getenv("GPU_MEMORY_UTILIZATION", "0.95"))
server = HelionInferenceServer(
model_path=model_path,
tensor_parallel_size=tensor_parallel,
max_model_len=max_len,
gpu_memory_utilization=gpu_util
)
logger.info("Helion-2.5-Rnd server started successfully")
@app.get("/")
async def root():
"""Root endpoint"""
return {
"model": "DeepXR/Helion-2.5-Rnd",
"version": "2.5.0-rnd",
"status": "ready",
"type": "research"
}
@app.get("/health")
async def health():
"""Health check endpoint"""
if server is None:
raise HTTPException(status_code=503, detail="Server not initialized")
return {
"status": "healthy",
"model": server.model_name,
"requests_served": server.request_count,
"uptime_seconds": int(time.time() - server.start_time)
}
@app.get("/v1/models")
async def list_models():
"""List available models"""
return {
"object": "list",
"data": [{
"id": "DeepXR/Helion-2.5-Rnd",
"object": "model",
"created": int(time.time()),
"owned_by": "DeepXR"
}]
}
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
"""Chat completion endpoint"""
if server is None:
raise HTTPException(status_code=503, detail="Server not initialized")
try:
result = await server.chat_completion(request)
if request.stream:
return StreamingResponse(
result,
media_type="text/event-stream"
)
else:
return JSONResponse(content=result)
except Exception as e:
logger.error(f"Error in chat completion: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/v1/completions")
async def completions(request: CompletionRequest):
"""Text completion endpoint"""
if server is None:
raise HTTPException(status_code=503, detail="Server not initialized")
# Convert to chat format
messages = [ChatMessage(role="user", content=request.prompt)]
chat_request = ChatCompletionRequest(
model=request.model,
messages=messages,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens,
stream=request.stream,
stop=request.stop,
n=request.n
)
return await chat_completions(chat_request)
def main():
"""Main entry point"""
parser = argparse.ArgumentParser(description="Helion-2.5-Rnd Inference Server")
parser.add_argument("--model", type=str, default="/models/helion")
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--tensor-parallel-size", type=int, default=2)
parser.add_argument("--max-model-len", type=int, default=131072)
parser.add_argument("--gpu-memory-utilization", type=float, default=0.95)
args = parser.parse_args()
# Set environment variables
os.environ["MODEL_PATH"] = args.model
os.environ["TENSOR_PARALLEL_SIZE"] = str(args.tensor_parallel_size)
os.environ["MAX_MODEL_LEN"] = str(args.max_model_len)
os.environ["GPU_MEMORY_UTILIZATION"] = str(args.gpu_memory_utilization)
# Run server
uvicorn.run(
app,
host=args.host,
port=args.port,
log_level="info",
access_log=True
)
if __name__ == "__main__":
main()