| """ |
| FastAPI server for Qwen ONNX model inference |
| Run with: uvicorn api_server:app --reload --host 0.0.0.0 --port 8000 |
| """ |
|
|
| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel, Field |
| from typing import List, Optional |
| import onnxruntime_genai as og |
| from pathlib import Path |
| import logging |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| app = FastAPI(title="Qwen ONNX Model API", version="1.0") |
|
|
| |
| MODEL_DIR = Path(__file__).parent |
|
|
| |
| model = None |
| tokenizer = None |
|
|
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| """Load model on startup""" |
| global model, tokenizer |
|
|
| try: |
| logger.info(f"Loading model from {MODEL_DIR}") |
| model = og.Model(str(MODEL_DIR)) |
| tokenizer = og.Tokenizer(model) |
| logger.info("Model loaded successfully") |
| except Exception as e: |
| logger.error(f"Failed to load model: {e}") |
| raise |
|
|
|
|
| |
| class GenerateRequest(BaseModel): |
| """Text generation request""" |
| prompt: str = Field(..., description="Input prompt") |
| max_length: int = Field(100, ge=1, le=2048, description="Maximum output length") |
| temperature: float = Field(0.6, ge=0.0, le=2.0, description="Temperature for sampling") |
| top_p: float = Field(0.95, ge=0.0, le=1.0, description="Top-p for nucleus sampling") |
| top_k: int = Field(20, ge=1, le=100, description="Top-k for sampling") |
|
|
|
|
| class GenerateResponse(BaseModel): |
| """Text generation response""" |
| prompt: str |
| generated_text: str |
| total_length: int |
|
|
|
|
| class Message(BaseModel): |
| """Chat message""" |
| role: str = Field(..., description="Message role: system, user, or assistant") |
| content: str = Field(..., description="Message content") |
|
|
|
|
| class ChatRequest(BaseModel): |
| """Chat inference request""" |
| messages: List[Message] = Field(..., description="Conversation messages") |
| max_length: int = Field(200, ge=1, le=2048, description="Maximum output length") |
| temperature: float = Field(0.6, ge=0.0, le=2.0, description="Temperature for sampling") |
| top_p: float = Field(0.95, ge=0.0, le=1.0, description="Top-p for nucleus sampling") |
| top_k: int = Field(20, ge=1, le=100, description="Top-k for sampling") |
|
|
|
|
| class ChatResponse(BaseModel): |
| """Chat inference response""" |
| messages: List[Message] |
| assistant_response: str |
|
|
|
|
| class TokenizeRequest(BaseModel): |
| """Tokenization request""" |
| text: str = Field(..., description="Text to tokenize") |
|
|
|
|
| class TokenizeResponse(BaseModel): |
| """Tokenization response""" |
| text: str |
| token_ids: List[int] |
| num_tokens: int |
|
|
|
|
| |
| @app.get("/health") |
| async def health_check(): |
| """Check if model is loaded""" |
| return { |
| "status": "ok" if model and tokenizer else "error", |
| "model": "Qwen3-ONNX" |
| } |
|
|
|
|
| |
| @app.post("/generate", response_model=GenerateResponse) |
| async def generate(request: GenerateRequest): |
| """Generate text from a prompt""" |
| if not model or not tokenizer: |
| raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
| try: |
| |
| input_tokens = tokenizer.encode(request.prompt) |
|
|
| |
| config = model.get_default_generation_search_parameters() |
| config.max_length = request.max_length |
| config.temperature = request.temperature |
| config.top_p = request.top_p |
| config.top_k = request.top_k |
|
|
| |
| generator = og.Generator(model, config) |
| generator.append_tokens(input_tokens) |
|
|
| while not generator.is_done(): |
| generator.compute_logits() |
| generator.generate_next_token() |
|
|
| |
| output_tokens = generator.get_sequence(0) |
| output_text = tokenizer.decode(output_tokens) |
|
|
| |
| generated_text = output_text |
| if generated_text.startswith(request.prompt): |
| generated_text = generated_text[len(request.prompt):] |
|
|
| return GenerateResponse( |
| prompt=request.prompt, |
| generated_text=generated_text.strip(), |
| total_length=len(output_tokens) |
| ) |
|
|
| except Exception as e: |
| logger.error(f"Generation error: {e}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| |
| @app.post("/chat", response_model=ChatResponse) |
| async def chat(request: ChatRequest): |
| """Chat inference with conversation history""" |
| if not model or not tokenizer: |
| raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
| try: |
| |
| prompt_text = "" |
| for msg in request.messages: |
| prompt_text += f"<|im_start|>{msg.role}\n{msg.content}<|im_end|>\n" |
|
|
| prompt_text += "<|im_start|>assistant\n" |
|
|
| |
| input_tokens = tokenizer.encode(prompt_text) |
|
|
| |
| config = model.get_default_generation_search_parameters() |
| config.max_length = request.max_length |
| config.temperature = request.temperature |
| config.top_p = request.top_p |
| config.top_k = request.top_k |
|
|
| |
| generator = og.Generator(model, config) |
| generator.append_tokens(input_tokens) |
|
|
| while not generator.is_done(): |
| generator.compute_logits() |
| generator.generate_next_token() |
|
|
| |
| output_tokens = generator.get_sequence(0) |
| response_text = tokenizer.decode(output_tokens) |
|
|
| |
| messages = [Message(**msg.dict()) for msg in request.messages] |
| messages.append(Message(role="assistant", content=response_text)) |
|
|
| return ChatResponse( |
| messages=messages, |
| assistant_response=response_text.strip() |
| ) |
|
|
| except Exception as e: |
| logger.error(f"Chat error: {e}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| |
| @app.post("/tokenize", response_model=TokenizeResponse) |
| async def tokenize(request: TokenizeRequest): |
| """Tokenize text""" |
| if not tokenizer: |
| raise HTTPException(status_code=503, detail="Tokenizer not loaded") |
|
|
| try: |
| token_ids = tokenizer.encode(request.text) |
|
|
| return TokenizeResponse( |
| text=request.text, |
| token_ids=token_ids, |
| num_tokens=len(token_ids) |
| ) |
|
|
| except Exception as e: |
| logger.error(f"Tokenization error: {e}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| |
| @app.get("/info") |
| async def model_info(): |
| """Get model information""" |
| if not model: |
| raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
| try: |
| config = model.get_default_generation_search_parameters() |
|
|
| return { |
| "model_type": "Qwen3", |
| "model_dir": str(MODEL_DIR), |
| "context_length": 40960, |
| "vocab_size": 151936, |
| "default_max_length": config.max_length, |
| "default_temperature": config.temperature, |
| "default_top_p": config.top_p, |
| "default_top_k": config.top_k, |
| } |
|
|
| except Exception as e: |
| logger.error(f"Info error: {e}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
|
|
| uvicorn.run( |
| app, |
| host="0.0.0.0", |
| port=8000, |
| log_level="info" |
| ) |
|
|