WebWorld-8B-Onnx / api_server.py
Prince-1's picture
Add files using upload-large-folder tool
5abb996 verified
"""
FastAPI server for Qwen ONNX model inference
Run with: uvicorn api_server:app --reload --host 0.0.0.0 --port 8000
"""
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import List, Optional
import onnxruntime_genai as og
from pathlib import Path
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(title="Qwen ONNX Model API", version="1.0")
# Path to model directory
MODEL_DIR = Path(__file__).parent
# Global model and tokenizer
model = None
tokenizer = None
@app.on_event("startup")
async def startup_event():
"""Load model on startup"""
global model, tokenizer
try:
logger.info(f"Loading model from {MODEL_DIR}")
model = og.Model(str(MODEL_DIR))
tokenizer = og.Tokenizer(model)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
# Request/Response models
class GenerateRequest(BaseModel):
"""Text generation request"""
prompt: str = Field(..., description="Input prompt")
max_length: int = Field(100, ge=1, le=2048, description="Maximum output length")
temperature: float = Field(0.6, ge=0.0, le=2.0, description="Temperature for sampling")
top_p: float = Field(0.95, ge=0.0, le=1.0, description="Top-p for nucleus sampling")
top_k: int = Field(20, ge=1, le=100, description="Top-k for sampling")
class GenerateResponse(BaseModel):
"""Text generation response"""
prompt: str
generated_text: str
total_length: int
class Message(BaseModel):
"""Chat message"""
role: str = Field(..., description="Message role: system, user, or assistant")
content: str = Field(..., description="Message content")
class ChatRequest(BaseModel):
"""Chat inference request"""
messages: List[Message] = Field(..., description="Conversation messages")
max_length: int = Field(200, ge=1, le=2048, description="Maximum output length")
temperature: float = Field(0.6, ge=0.0, le=2.0, description="Temperature for sampling")
top_p: float = Field(0.95, ge=0.0, le=1.0, description="Top-p for nucleus sampling")
top_k: int = Field(20, ge=1, le=100, description="Top-k for sampling")
class ChatResponse(BaseModel):
"""Chat inference response"""
messages: List[Message]
assistant_response: str
class TokenizeRequest(BaseModel):
"""Tokenization request"""
text: str = Field(..., description="Text to tokenize")
class TokenizeResponse(BaseModel):
"""Tokenization response"""
text: str
token_ids: List[int]
num_tokens: int
# Health check
@app.get("/health")
async def health_check():
"""Check if model is loaded"""
return {
"status": "ok" if model and tokenizer else "error",
"model": "Qwen3-ONNX"
}
# Text generation endpoint
@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
"""Generate text from a prompt"""
if not model or not tokenizer:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
# Encode prompt
input_tokens = tokenizer.encode(request.prompt)
# Setup generation config
config = model.get_default_generation_search_parameters()
config.max_length = request.max_length
config.temperature = request.temperature
config.top_p = request.top_p
config.top_k = request.top_k
# Generate
generator = og.Generator(model, config)
generator.append_tokens(input_tokens)
while not generator.is_done():
generator.compute_logits()
generator.generate_next_token()
# Decode output
output_tokens = generator.get_sequence(0)
output_text = tokenizer.decode(output_tokens)
# Remove prompt from output
generated_text = output_text
if generated_text.startswith(request.prompt):
generated_text = generated_text[len(request.prompt):]
return GenerateResponse(
prompt=request.prompt,
generated_text=generated_text.strip(),
total_length=len(output_tokens)
)
except Exception as e:
logger.error(f"Generation error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Chat endpoint
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
"""Chat inference with conversation history"""
if not model or not tokenizer:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
# Format conversation
prompt_text = ""
for msg in request.messages:
prompt_text += f"<|im_start|>{msg.role}\n{msg.content}<|im_end|>\n"
prompt_text += "<|im_start|>assistant\n"
# Encode
input_tokens = tokenizer.encode(prompt_text)
# Setup generation config
config = model.get_default_generation_search_parameters()
config.max_length = request.max_length
config.temperature = request.temperature
config.top_p = request.top_p
config.top_k = request.top_k
# Generate
generator = og.Generator(model, config)
generator.append_tokens(input_tokens)
while not generator.is_done():
generator.compute_logits()
generator.generate_next_token()
# Decode
output_tokens = generator.get_sequence(0)
response_text = tokenizer.decode(output_tokens)
# Add assistant response to messages
messages = [Message(**msg.dict()) for msg in request.messages]
messages.append(Message(role="assistant", content=response_text))
return ChatResponse(
messages=messages,
assistant_response=response_text.strip()
)
except Exception as e:
logger.error(f"Chat error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Tokenization endpoint
@app.post("/tokenize", response_model=TokenizeResponse)
async def tokenize(request: TokenizeRequest):
"""Tokenize text"""
if not tokenizer:
raise HTTPException(status_code=503, detail="Tokenizer not loaded")
try:
token_ids = tokenizer.encode(request.text)
return TokenizeResponse(
text=request.text,
token_ids=token_ids,
num_tokens=len(token_ids)
)
except Exception as e:
logger.error(f"Tokenization error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Model info endpoint
@app.get("/info")
async def model_info():
"""Get model information"""
if not model:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
config = model.get_default_generation_search_parameters()
return {
"model_type": "Qwen3",
"model_dir": str(MODEL_DIR),
"context_length": 40960,
"vocab_size": 151936,
"default_max_length": config.max_length,
"default_temperature": config.temperature,
"default_top_p": config.top_p,
"default_top_k": config.top_k,
}
except Exception as e:
logger.error(f"Info error: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=8000,
log_level="info"
)