File size: 7,560 Bytes
5abb996
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
"""
FastAPI server for Qwen ONNX model inference
Run with: uvicorn api_server:app --reload --host 0.0.0.0 --port 8000
"""

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import List, Optional
import onnxruntime_genai as og
from pathlib import Path
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize FastAPI app
app = FastAPI(title="Qwen ONNX Model API", version="1.0")

# Path to model directory
MODEL_DIR = Path(__file__).parent

# Global model and tokenizer
model = None
tokenizer = None


@app.on_event("startup")
async def startup_event():
    """Load model on startup"""
    global model, tokenizer

    try:
        logger.info(f"Loading model from {MODEL_DIR}")
        model = og.Model(str(MODEL_DIR))
        tokenizer = og.Tokenizer(model)
        logger.info("Model loaded successfully")
    except Exception as e:
        logger.error(f"Failed to load model: {e}")
        raise


# Request/Response models
class GenerateRequest(BaseModel):
    """Text generation request"""
    prompt: str = Field(..., description="Input prompt")
    max_length: int = Field(100, ge=1, le=2048, description="Maximum output length")
    temperature: float = Field(0.6, ge=0.0, le=2.0, description="Temperature for sampling")
    top_p: float = Field(0.95, ge=0.0, le=1.0, description="Top-p for nucleus sampling")
    top_k: int = Field(20, ge=1, le=100, description="Top-k for sampling")


class GenerateResponse(BaseModel):
    """Text generation response"""
    prompt: str
    generated_text: str
    total_length: int


class Message(BaseModel):
    """Chat message"""
    role: str = Field(..., description="Message role: system, user, or assistant")
    content: str = Field(..., description="Message content")


class ChatRequest(BaseModel):
    """Chat inference request"""
    messages: List[Message] = Field(..., description="Conversation messages")
    max_length: int = Field(200, ge=1, le=2048, description="Maximum output length")
    temperature: float = Field(0.6, ge=0.0, le=2.0, description="Temperature for sampling")
    top_p: float = Field(0.95, ge=0.0, le=1.0, description="Top-p for nucleus sampling")
    top_k: int = Field(20, ge=1, le=100, description="Top-k for sampling")


class ChatResponse(BaseModel):
    """Chat inference response"""
    messages: List[Message]
    assistant_response: str


class TokenizeRequest(BaseModel):
    """Tokenization request"""
    text: str = Field(..., description="Text to tokenize")


class TokenizeResponse(BaseModel):
    """Tokenization response"""
    text: str
    token_ids: List[int]
    num_tokens: int


# Health check
@app.get("/health")
async def health_check():
    """Check if model is loaded"""
    return {
        "status": "ok" if model and tokenizer else "error",
        "model": "Qwen3-ONNX"
    }


# Text generation endpoint
@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
    """Generate text from a prompt"""
    if not model or not tokenizer:
        raise HTTPException(status_code=503, detail="Model not loaded")

    try:
        # Encode prompt
        input_tokens = tokenizer.encode(request.prompt)

        # Setup generation config
        config = model.get_default_generation_search_parameters()
        config.max_length = request.max_length
        config.temperature = request.temperature
        config.top_p = request.top_p
        config.top_k = request.top_k

        # Generate
        generator = og.Generator(model, config)
        generator.append_tokens(input_tokens)

        while not generator.is_done():
            generator.compute_logits()
            generator.generate_next_token()

        # Decode output
        output_tokens = generator.get_sequence(0)
        output_text = tokenizer.decode(output_tokens)

        # Remove prompt from output
        generated_text = output_text
        if generated_text.startswith(request.prompt):
            generated_text = generated_text[len(request.prompt):]

        return GenerateResponse(
            prompt=request.prompt,
            generated_text=generated_text.strip(),
            total_length=len(output_tokens)
        )

    except Exception as e:
        logger.error(f"Generation error: {e}")
        raise HTTPException(status_code=500, detail=str(e))


# Chat endpoint
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
    """Chat inference with conversation history"""
    if not model or not tokenizer:
        raise HTTPException(status_code=503, detail="Model not loaded")

    try:
        # Format conversation
        prompt_text = ""
        for msg in request.messages:
            prompt_text += f"<|im_start|>{msg.role}\n{msg.content}<|im_end|>\n"

        prompt_text += "<|im_start|>assistant\n"

        # Encode
        input_tokens = tokenizer.encode(prompt_text)

        # Setup generation config
        config = model.get_default_generation_search_parameters()
        config.max_length = request.max_length
        config.temperature = request.temperature
        config.top_p = request.top_p
        config.top_k = request.top_k

        # Generate
        generator = og.Generator(model, config)
        generator.append_tokens(input_tokens)

        while not generator.is_done():
            generator.compute_logits()
            generator.generate_next_token()

        # Decode
        output_tokens = generator.get_sequence(0)
        response_text = tokenizer.decode(output_tokens)

        # Add assistant response to messages
        messages = [Message(**msg.dict()) for msg in request.messages]
        messages.append(Message(role="assistant", content=response_text))

        return ChatResponse(
            messages=messages,
            assistant_response=response_text.strip()
        )

    except Exception as e:
        logger.error(f"Chat error: {e}")
        raise HTTPException(status_code=500, detail=str(e))


# Tokenization endpoint
@app.post("/tokenize", response_model=TokenizeResponse)
async def tokenize(request: TokenizeRequest):
    """Tokenize text"""
    if not tokenizer:
        raise HTTPException(status_code=503, detail="Tokenizer not loaded")

    try:
        token_ids = tokenizer.encode(request.text)

        return TokenizeResponse(
            text=request.text,
            token_ids=token_ids,
            num_tokens=len(token_ids)
        )

    except Exception as e:
        logger.error(f"Tokenization error: {e}")
        raise HTTPException(status_code=500, detail=str(e))


# Model info endpoint
@app.get("/info")
async def model_info():
    """Get model information"""
    if not model:
        raise HTTPException(status_code=503, detail="Model not loaded")

    try:
        config = model.get_default_generation_search_parameters()

        return {
            "model_type": "Qwen3",
            "model_dir": str(MODEL_DIR),
            "context_length": 40960,
            "vocab_size": 151936,
            "default_max_length": config.max_length,
            "default_temperature": config.temperature,
            "default_top_p": config.top_p,
            "default_top_k": config.top_k,
        }

    except Exception as e:
        logger.error(f"Info error: {e}")
        raise HTTPException(status_code=500, detail=str(e))


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(
        app,
        host="0.0.0.0",
        port=8000,
        log_level="info"
    )