""" FastAPI server for Qwen ONNX model inference Run with: uvicorn api_server:app --reload --host 0.0.0.0 --port 8000 """ from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from typing import List, Optional import onnxruntime_genai as og from pathlib import Path import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI(title="Qwen ONNX Model API", version="1.0") # Path to model directory MODEL_DIR = Path(__file__).parent # Global model and tokenizer model = None tokenizer = None @app.on_event("startup") async def startup_event(): """Load model on startup""" global model, tokenizer try: logger.info(f"Loading model from {MODEL_DIR}") model = og.Model(str(MODEL_DIR)) tokenizer = og.Tokenizer(model) logger.info("Model loaded successfully") except Exception as e: logger.error(f"Failed to load model: {e}") raise # Request/Response models class GenerateRequest(BaseModel): """Text generation request""" prompt: str = Field(..., description="Input prompt") max_length: int = Field(100, ge=1, le=2048, description="Maximum output length") temperature: float = Field(0.6, ge=0.0, le=2.0, description="Temperature for sampling") top_p: float = Field(0.95, ge=0.0, le=1.0, description="Top-p for nucleus sampling") top_k: int = Field(20, ge=1, le=100, description="Top-k for sampling") class GenerateResponse(BaseModel): """Text generation response""" prompt: str generated_text: str total_length: int class Message(BaseModel): """Chat message""" role: str = Field(..., description="Message role: system, user, or assistant") content: str = Field(..., description="Message content") class ChatRequest(BaseModel): """Chat inference request""" messages: List[Message] = Field(..., description="Conversation messages") max_length: int = Field(200, ge=1, le=2048, description="Maximum output length") temperature: float = Field(0.6, ge=0.0, le=2.0, description="Temperature for sampling") top_p: float = Field(0.95, ge=0.0, le=1.0, description="Top-p for nucleus sampling") top_k: int = Field(20, ge=1, le=100, description="Top-k for sampling") class ChatResponse(BaseModel): """Chat inference response""" messages: List[Message] assistant_response: str class TokenizeRequest(BaseModel): """Tokenization request""" text: str = Field(..., description="Text to tokenize") class TokenizeResponse(BaseModel): """Tokenization response""" text: str token_ids: List[int] num_tokens: int # Health check @app.get("/health") async def health_check(): """Check if model is loaded""" return { "status": "ok" if model and tokenizer else "error", "model": "Qwen3-ONNX" } # Text generation endpoint @app.post("/generate", response_model=GenerateResponse) async def generate(request: GenerateRequest): """Generate text from a prompt""" if not model or not tokenizer: raise HTTPException(status_code=503, detail="Model not loaded") try: # Encode prompt input_tokens = tokenizer.encode(request.prompt) # Setup generation config config = model.get_default_generation_search_parameters() config.max_length = request.max_length config.temperature = request.temperature config.top_p = request.top_p config.top_k = request.top_k # Generate generator = og.Generator(model, config) generator.append_tokens(input_tokens) while not generator.is_done(): generator.compute_logits() generator.generate_next_token() # Decode output output_tokens = generator.get_sequence(0) output_text = tokenizer.decode(output_tokens) # Remove prompt from output generated_text = output_text if generated_text.startswith(request.prompt): generated_text = generated_text[len(request.prompt):] return GenerateResponse( prompt=request.prompt, generated_text=generated_text.strip(), total_length=len(output_tokens) ) except Exception as e: logger.error(f"Generation error: {e}") raise HTTPException(status_code=500, detail=str(e)) # Chat endpoint @app.post("/chat", response_model=ChatResponse) async def chat(request: ChatRequest): """Chat inference with conversation history""" if not model or not tokenizer: raise HTTPException(status_code=503, detail="Model not loaded") try: # Format conversation prompt_text = "" for msg in request.messages: prompt_text += f"<|im_start|>{msg.role}\n{msg.content}<|im_end|>\n" prompt_text += "<|im_start|>assistant\n" # Encode input_tokens = tokenizer.encode(prompt_text) # Setup generation config config = model.get_default_generation_search_parameters() config.max_length = request.max_length config.temperature = request.temperature config.top_p = request.top_p config.top_k = request.top_k # Generate generator = og.Generator(model, config) generator.append_tokens(input_tokens) while not generator.is_done(): generator.compute_logits() generator.generate_next_token() # Decode output_tokens = generator.get_sequence(0) response_text = tokenizer.decode(output_tokens) # Add assistant response to messages messages = [Message(**msg.dict()) for msg in request.messages] messages.append(Message(role="assistant", content=response_text)) return ChatResponse( messages=messages, assistant_response=response_text.strip() ) except Exception as e: logger.error(f"Chat error: {e}") raise HTTPException(status_code=500, detail=str(e)) # Tokenization endpoint @app.post("/tokenize", response_model=TokenizeResponse) async def tokenize(request: TokenizeRequest): """Tokenize text""" if not tokenizer: raise HTTPException(status_code=503, detail="Tokenizer not loaded") try: token_ids = tokenizer.encode(request.text) return TokenizeResponse( text=request.text, token_ids=token_ids, num_tokens=len(token_ids) ) except Exception as e: logger.error(f"Tokenization error: {e}") raise HTTPException(status_code=500, detail=str(e)) # Model info endpoint @app.get("/info") async def model_info(): """Get model information""" if not model: raise HTTPException(status_code=503, detail="Model not loaded") try: config = model.get_default_generation_search_parameters() return { "model_type": "Qwen3", "model_dir": str(MODEL_DIR), "context_length": 40960, "vocab_size": 151936, "default_max_length": config.max_length, "default_temperature": config.temperature, "default_top_p": config.top_p, "default_top_k": config.top_k, } except Exception as e: logger.error(f"Info error: {e}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run( app, host="0.0.0.0", port=8000, log_level="info" )