Update app.py
Browse files
app.py
CHANGED
|
@@ -19,11 +19,8 @@ from pydantic import BaseModel, Field, ValidationError
|
|
| 19 |
from transformers import AutoTokenizer
|
| 20 |
|
| 21 |
# ---------- Configuration ----------
|
| 22 |
-
# Model Selection: Use "onnx-community/Bonsai-1.7B-ONNX" or "onnx-community/Bonsai-8B-ONNX"
|
| 23 |
MODEL_ID = os.getenv("MODEL_ID", "onnx-community/Bonsai-1.7B-ONNX")
|
| 24 |
-
# Quantization: Choose from 'q1', 'q2', 'q4', 'q8' based on the files in the ONNX model repo
|
| 25 |
MODEL_QUANTIZATION = os.getenv("MODEL_QUANTIZATION", "q1")
|
| 26 |
-
# Model file name based on quantization
|
| 27 |
ONNX_MODEL_FILE = f"model_{MODEL_QUANTIZATION}.onnx"
|
| 28 |
|
| 29 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
|
@@ -34,6 +31,12 @@ API_KEY = os.getenv("API_KEY", None)
|
|
| 34 |
logging.basicConfig(level=logging.INFO)
|
| 35 |
logger = logging.getLogger("uvicorn.error")
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
# ---------- Pydantic Models ----------
|
| 38 |
class Message(BaseModel):
|
| 39 |
role: str = Field(..., pattern="^(system|user|assistant)$")
|
|
@@ -151,7 +154,6 @@ def _build_chat_prompt(messages: List[Message]) -> str:
|
|
| 151 |
if tokenizer is None:
|
| 152 |
raise HTTPException(status_code=503, detail="Tokenizer not loaded")
|
| 153 |
try:
|
| 154 |
-
# Use the tokenizer's chat template to format the conversation
|
| 155 |
formatted_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
|
| 156 |
prompt = tokenizer.apply_chat_template(
|
| 157 |
formatted_messages,
|
|
@@ -161,7 +163,7 @@ def _build_chat_prompt(messages: List[Message]) -> str:
|
|
| 161 |
return prompt
|
| 162 |
except Exception as e:
|
| 163 |
logger.error(f"Chat template error: {e}")
|
| 164 |
-
# Fallback to a simple concatenation
|
| 165 |
prompt = ""
|
| 166 |
for msg in messages:
|
| 167 |
prompt += f"<|{msg.role}|>\n{msg.content}\n"
|
|
@@ -197,6 +199,26 @@ def _sample_token(logits: np.ndarray, temperature: float, top_p: float) -> int:
|
|
| 197 |
probs = _softmax(logits)
|
| 198 |
return int(np.random.choice(len(probs), p=probs))
|
| 199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
def _generate_full(
|
| 201 |
prompt: str,
|
| 202 |
max_new_tokens: int,
|
|
@@ -207,40 +229,61 @@ def _generate_full(
|
|
| 207 |
if ort_session is None or tokenizer is None:
|
| 208 |
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 209 |
|
| 210 |
-
input_ids = tokenizer.encode(prompt, return_tensors="np")
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
}
|
| 218 |
-
|
| 219 |
generated_tokens = []
|
| 220 |
stop_sequences = stop_sequences or []
|
| 221 |
eos_token_id = tokenizer.eos_token_id
|
| 222 |
-
|
| 223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
outputs = ort_session.run(None, ort_inputs)
|
| 225 |
logits = outputs[0][:, -1, :]
|
| 226 |
next_token = _sample_token(logits[0], temperature, top_p)
|
| 227 |
generated_tokens.append(next_token)
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
ort_inputs["input_ids"] = np.concatenate([input_ids, next_token_id], axis=1)
|
| 232 |
-
ort_inputs["attention_mask"] = np.concatenate(
|
| 233 |
-
[ort_inputs["attention_mask"], np.ones((1, 1), dtype=np.int64)], axis=1
|
| 234 |
-
)
|
| 235 |
-
|
| 236 |
-
# Check stop conditions
|
| 237 |
if next_token == eos_token_id:
|
| 238 |
break
|
| 239 |
partial_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 240 |
for stop_seq in stop_sequences:
|
| 241 |
if stop_seq in partial_text:
|
| 242 |
return partial_text.split(stop_seq)[0].strip()
|
| 243 |
-
|
| 244 |
full_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 245 |
return full_text.strip()
|
| 246 |
|
|
@@ -255,31 +298,55 @@ async def _generate_stream(
|
|
| 255 |
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 256 |
|
| 257 |
input_ids = tokenizer.encode(prompt, return_tensors="np").astype(np.int64)
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
generated_tokens = []
|
| 264 |
stop_sequences = stop_sequences or []
|
| 265 |
eos_token_id = tokenizer.eos_token_id
|
| 266 |
-
|
| 267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
outputs = ort_session.run(None, ort_inputs)
|
| 269 |
logits = outputs[0][:, -1, :]
|
| 270 |
next_token = _sample_token(logits[0], temperature, top_p)
|
| 271 |
generated_tokens.append(next_token)
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
ort_inputs["attention_mask"] = np.concatenate(
|
| 276 |
-
[ort_inputs["attention_mask"], np.ones((1, 1), dtype=np.int64)], axis=1
|
| 277 |
-
)
|
| 278 |
-
|
| 279 |
new_text = tokenizer.decode([next_token], skip_special_tokens=True)
|
| 280 |
if new_text:
|
| 281 |
yield new_text
|
| 282 |
-
|
| 283 |
full_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 284 |
for stop_seq in stop_sequences:
|
| 285 |
if stop_seq in full_text:
|
|
@@ -375,14 +442,14 @@ def model_info():
|
|
| 375 |
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
| 376 |
async def chat_completions(req: ChatCompletionRequest):
|
| 377 |
await _ensure_loaded()
|
| 378 |
-
|
| 379 |
try:
|
| 380 |
prompt = _build_chat_prompt(req.messages)
|
| 381 |
except Exception as e:
|
| 382 |
raise HTTPException(status_code=400, detail=f"Prompt formatting error: {str(e)}")
|
| 383 |
-
|
| 384 |
stop_seq = req.stop if isinstance(req.stop, list) else ([req.stop] if req.stop else None)
|
| 385 |
-
|
| 386 |
if req.stream:
|
| 387 |
async def stream_generator():
|
| 388 |
yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': req.model or MODEL_ID, 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n"
|
|
@@ -392,7 +459,7 @@ async def chat_completions(req: ChatCompletionRequest):
|
|
| 392 |
yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': req.model or MODEL_ID, 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n"
|
| 393 |
yield "data: [DONE]\n\n"
|
| 394 |
return StreamingResponse(stream_generator(), media_type="text/event-stream")
|
| 395 |
-
|
| 396 |
else:
|
| 397 |
text = await asyncio.to_thread(
|
| 398 |
_generate_full,
|
|
|
|
| 19 |
from transformers import AutoTokenizer
|
| 20 |
|
| 21 |
# ---------- Configuration ----------
|
|
|
|
| 22 |
MODEL_ID = os.getenv("MODEL_ID", "onnx-community/Bonsai-1.7B-ONNX")
|
|
|
|
| 23 |
MODEL_QUANTIZATION = os.getenv("MODEL_QUANTIZATION", "q1")
|
|
|
|
| 24 |
ONNX_MODEL_FILE = f"model_{MODEL_QUANTIZATION}.onnx"
|
| 25 |
|
| 26 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
|
|
|
| 31 |
logging.basicConfig(level=logging.INFO)
|
| 32 |
logger = logging.getLogger("uvicorn.error")
|
| 33 |
|
| 34 |
+
# Bonsai architecture constants (from config.json)
|
| 35 |
+
NUM_LAYERS = 28
|
| 36 |
+
NUM_KV_HEADS = 8
|
| 37 |
+
HEAD_DIM = 128
|
| 38 |
+
DTYPE = np.float32
|
| 39 |
+
|
| 40 |
# ---------- Pydantic Models ----------
|
| 41 |
class Message(BaseModel):
|
| 42 |
role: str = Field(..., pattern="^(system|user|assistant)$")
|
|
|
|
| 154 |
if tokenizer is None:
|
| 155 |
raise HTTPException(status_code=503, detail="Tokenizer not loaded")
|
| 156 |
try:
|
|
|
|
| 157 |
formatted_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
|
| 158 |
prompt = tokenizer.apply_chat_template(
|
| 159 |
formatted_messages,
|
|
|
|
| 163 |
return prompt
|
| 164 |
except Exception as e:
|
| 165 |
logger.error(f"Chat template error: {e}")
|
| 166 |
+
# Fallback to a simple concatenation
|
| 167 |
prompt = ""
|
| 168 |
for msg in messages:
|
| 169 |
prompt += f"<|{msg.role}|>\n{msg.content}\n"
|
|
|
|
| 199 |
probs = _softmax(logits)
|
| 200 |
return int(np.random.choice(len(probs), p=probs))
|
| 201 |
|
| 202 |
+
def _init_past_key_values(batch_size: int = 1) -> Dict[str, np.ndarray]:
|
| 203 |
+
"""Create empty past_key_values tensors for the first inference step."""
|
| 204 |
+
past_kv = {}
|
| 205 |
+
empty_shape = (batch_size, NUM_KV_HEADS, 0, HEAD_DIM)
|
| 206 |
+
empty_tensor = np.zeros(empty_shape, dtype=DTYPE)
|
| 207 |
+
for i in range(NUM_LAYERS):
|
| 208 |
+
past_kv[f"past_key_values.{i}.key"] = empty_tensor.copy()
|
| 209 |
+
past_kv[f"past_key_values.{i}.value"] = empty_tensor.copy()
|
| 210 |
+
return past_kv
|
| 211 |
+
|
| 212 |
+
def _update_past_key_values(outputs: List[np.ndarray], output_names: List[str]) -> Dict[str, np.ndarray]:
|
| 213 |
+
"""Extract present_key_values from ONNX outputs and return as dictionary."""
|
| 214 |
+
new_past = {}
|
| 215 |
+
for name, value in zip(output_names, outputs):
|
| 216 |
+
if name.startswith("present"):
|
| 217 |
+
# Convert "present_key_values.0.key" -> "past_key_values.0.key"
|
| 218 |
+
past_name = name.replace("present", "past")
|
| 219 |
+
new_past[past_name] = value
|
| 220 |
+
return new_past
|
| 221 |
+
|
| 222 |
def _generate_full(
|
| 223 |
prompt: str,
|
| 224 |
max_new_tokens: int,
|
|
|
|
| 229 |
if ort_session is None or tokenizer is None:
|
| 230 |
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 231 |
|
| 232 |
+
input_ids = tokenizer.encode(prompt, return_tensors="np").astype(np.int64)
|
| 233 |
+
attention_mask = np.ones_like(input_ids, dtype=np.int64)
|
| 234 |
+
position_ids = np.arange(input_ids.shape[1], dtype=np.int64).reshape(1, -1)
|
| 235 |
+
|
| 236 |
+
# Initialize KV cache
|
| 237 |
+
past_kv = _init_past_key_values(batch_size=1)
|
| 238 |
+
|
|
|
|
|
|
|
| 239 |
generated_tokens = []
|
| 240 |
stop_sequences = stop_sequences or []
|
| 241 |
eos_token_id = tokenizer.eos_token_id
|
| 242 |
+
|
| 243 |
+
# Prefill step: process full prompt
|
| 244 |
+
ort_inputs = {
|
| 245 |
+
"input_ids": input_ids,
|
| 246 |
+
"attention_mask": attention_mask,
|
| 247 |
+
"position_ids": position_ids,
|
| 248 |
+
"num_logits_to_keep": np.array([1], dtype=np.int64),
|
| 249 |
+
**past_kv,
|
| 250 |
+
}
|
| 251 |
+
outputs = ort_session.run(None, ort_inputs)
|
| 252 |
+
# First output is logits, the rest are present_key_values
|
| 253 |
+
logits = outputs[0][:, -1, :]
|
| 254 |
+
next_token = _sample_token(logits[0], temperature, top_p)
|
| 255 |
+
generated_tokens.append(next_token)
|
| 256 |
+
|
| 257 |
+
# Update past_key_values from outputs
|
| 258 |
+
past_kv = _update_past_key_values(outputs, [out.name for out in ort_session.get_outputs()])
|
| 259 |
+
|
| 260 |
+
for step in range(1, max_new_tokens):
|
| 261 |
+
# Subsequent steps: only the last token
|
| 262 |
+
last_token = np.array([[next_token]], dtype=np.int64)
|
| 263 |
+
attention_mask = np.ones((1, past_kv[f"past_key_values.0.key"].shape[2] + 1), dtype=np.int64)
|
| 264 |
+
position_ids = np.array([[past_kv[f"past_key_values.0.key"].shape[2]]], dtype=np.int64)
|
| 265 |
+
|
| 266 |
+
ort_inputs = {
|
| 267 |
+
"input_ids": last_token,
|
| 268 |
+
"attention_mask": attention_mask,
|
| 269 |
+
"position_ids": position_ids,
|
| 270 |
+
"num_logits_to_keep": np.array([1], dtype=np.int64),
|
| 271 |
+
**past_kv,
|
| 272 |
+
}
|
| 273 |
outputs = ort_session.run(None, ort_inputs)
|
| 274 |
logits = outputs[0][:, -1, :]
|
| 275 |
next_token = _sample_token(logits[0], temperature, top_p)
|
| 276 |
generated_tokens.append(next_token)
|
| 277 |
+
|
| 278 |
+
past_kv = _update_past_key_values(outputs, [out.name for out in ort_session.get_outputs()])
|
| 279 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
if next_token == eos_token_id:
|
| 281 |
break
|
| 282 |
partial_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 283 |
for stop_seq in stop_sequences:
|
| 284 |
if stop_seq in partial_text:
|
| 285 |
return partial_text.split(stop_seq)[0].strip()
|
| 286 |
+
|
| 287 |
full_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 288 |
return full_text.strip()
|
| 289 |
|
|
|
|
| 298 |
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 299 |
|
| 300 |
input_ids = tokenizer.encode(prompt, return_tensors="np").astype(np.int64)
|
| 301 |
+
attention_mask = np.ones_like(input_ids, dtype=np.int64)
|
| 302 |
+
position_ids = np.arange(input_ids.shape[1], dtype=np.int64).reshape(1, -1)
|
| 303 |
+
|
| 304 |
+
past_kv = _init_past_key_values(batch_size=1)
|
|
|
|
| 305 |
generated_tokens = []
|
| 306 |
stop_sequences = stop_sequences or []
|
| 307 |
eos_token_id = tokenizer.eos_token_id
|
| 308 |
+
|
| 309 |
+
# Prefill
|
| 310 |
+
ort_inputs = {
|
| 311 |
+
"input_ids": input_ids,
|
| 312 |
+
"attention_mask": attention_mask,
|
| 313 |
+
"position_ids": position_ids,
|
| 314 |
+
"num_logits_to_keep": np.array([1], dtype=np.int64),
|
| 315 |
+
**past_kv,
|
| 316 |
+
}
|
| 317 |
+
outputs = ort_session.run(None, ort_inputs)
|
| 318 |
+
logits = outputs[0][:, -1, :]
|
| 319 |
+
next_token = _sample_token(logits[0], temperature, top_p)
|
| 320 |
+
generated_tokens.append(next_token)
|
| 321 |
+
past_kv = _update_past_key_values(outputs, [out.name for out in ort_session.get_outputs()])
|
| 322 |
+
|
| 323 |
+
new_text = tokenizer.decode([next_token], skip_special_tokens=True)
|
| 324 |
+
if new_text:
|
| 325 |
+
yield new_text
|
| 326 |
+
|
| 327 |
+
for step in range(1, max_new_tokens):
|
| 328 |
+
last_token = np.array([[next_token]], dtype=np.int64)
|
| 329 |
+
attention_mask = np.ones((1, past_kv[f"past_key_values.0.key"].shape[2] + 1), dtype=np.int64)
|
| 330 |
+
position_ids = np.array([[past_kv[f"past_key_values.0.key"].shape[2]]], dtype=np.int64)
|
| 331 |
+
|
| 332 |
+
ort_inputs = {
|
| 333 |
+
"input_ids": last_token,
|
| 334 |
+
"attention_mask": attention_mask,
|
| 335 |
+
"position_ids": position_ids,
|
| 336 |
+
"num_logits_to_keep": np.array([1], dtype=np.int64),
|
| 337 |
+
**past_kv,
|
| 338 |
+
}
|
| 339 |
outputs = ort_session.run(None, ort_inputs)
|
| 340 |
logits = outputs[0][:, -1, :]
|
| 341 |
next_token = _sample_token(logits[0], temperature, top_p)
|
| 342 |
generated_tokens.append(next_token)
|
| 343 |
+
|
| 344 |
+
past_kv = _update_past_key_values(outputs, [out.name for out in ort_session.get_outputs()])
|
| 345 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
new_text = tokenizer.decode([next_token], skip_special_tokens=True)
|
| 347 |
if new_text:
|
| 348 |
yield new_text
|
| 349 |
+
|
| 350 |
full_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 351 |
for stop_seq in stop_sequences:
|
| 352 |
if stop_seq in full_text:
|
|
|
|
| 442 |
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
| 443 |
async def chat_completions(req: ChatCompletionRequest):
|
| 444 |
await _ensure_loaded()
|
| 445 |
+
|
| 446 |
try:
|
| 447 |
prompt = _build_chat_prompt(req.messages)
|
| 448 |
except Exception as e:
|
| 449 |
raise HTTPException(status_code=400, detail=f"Prompt formatting error: {str(e)}")
|
| 450 |
+
|
| 451 |
stop_seq = req.stop if isinstance(req.stop, list) else ([req.stop] if req.stop else None)
|
| 452 |
+
|
| 453 |
if req.stream:
|
| 454 |
async def stream_generator():
|
| 455 |
yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': req.model or MODEL_ID, 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n"
|
|
|
|
| 459 |
yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': req.model or MODEL_ID, 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n"
|
| 460 |
yield "data: [DONE]\n\n"
|
| 461 |
return StreamingResponse(stream_generator(), media_type="text/event-stream")
|
| 462 |
+
|
| 463 |
else:
|
| 464 |
text = await asyncio.to_thread(
|
| 465 |
_generate_full,
|