Update app.py
Browse files
app.py
CHANGED
|
@@ -7,31 +7,40 @@ import os
|
|
| 7 |
import time
|
| 8 |
import uuid
|
| 9 |
from contextlib import asynccontextmanager
|
| 10 |
-
from typing import
|
| 11 |
|
| 12 |
-
import numpy as np
|
| 13 |
-
import onnxruntime as ort
|
| 14 |
from fastapi import FastAPI, HTTPException, Request
|
| 15 |
from fastapi.middleware.cors import CORSMiddleware
|
| 16 |
from fastapi.responses import JSONResponse, StreamingResponse
|
| 17 |
-
from huggingface_hub import
|
| 18 |
from pydantic import BaseModel, Field, ValidationError
|
| 19 |
-
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# ---------- Configuration ----------
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 27 |
-
LOCAL_MODEL_DIR = os.getenv("LOCAL_MODEL_DIR", "/data/
|
| 28 |
MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS_DEFAULT", "256"))
|
| 29 |
API_KEY = os.getenv("API_KEY", None)
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
logging.basicConfig(level=logging.INFO)
|
| 32 |
logger = logging.getLogger("uvicorn.error")
|
| 33 |
|
| 34 |
-
# ---------- Pydantic Models ----------
|
| 35 |
class Message(BaseModel):
|
| 36 |
role: str = Field(..., pattern="^(system|user|assistant)$")
|
| 37 |
content: str
|
|
@@ -65,30 +74,20 @@ class ChatCompletionResponse(BaseModel):
|
|
| 65 |
|
| 66 |
class ModelInfo(BaseModel):
|
| 67 |
model_id: str
|
| 68 |
-
|
| 69 |
-
onnx_model_file: str
|
| 70 |
device: str
|
|
|
|
|
|
|
| 71 |
|
| 72 |
class ErrorResponse(BaseModel):
|
| 73 |
error: str
|
| 74 |
detail: Optional[str] = None
|
| 75 |
|
| 76 |
# ---------- Global State ----------
|
| 77 |
-
|
| 78 |
-
ort_session = None
|
| 79 |
model_load_error = None
|
| 80 |
MODEL_LOCK = asyncio.Lock()
|
| 81 |
|
| 82 |
-
# Cached model metadata
|
| 83 |
-
past_input_names = []
|
| 84 |
-
past_output_names = []
|
| 85 |
-
num_layers = 0
|
| 86 |
-
num_kv_heads = 0
|
| 87 |
-
head_dim = 0
|
| 88 |
-
kv_cache_dtype = np.float32
|
| 89 |
-
has_num_logits_input = False
|
| 90 |
-
has_position_ids = False
|
| 91 |
-
|
| 92 |
# ---------- Helper Functions ----------
|
| 93 |
def _verify_api_key(request: Request) -> None:
|
| 94 |
if API_KEY is None:
|
|
@@ -97,366 +96,105 @@ def _verify_api_key(request: Request) -> None:
|
|
| 97 |
if not auth or auth != API_KEY:
|
| 98 |
raise HTTPException(status_code=401, detail="Invalid or missing API key")
|
| 99 |
|
| 100 |
-
def
|
| 101 |
-
return "cuda" if ort.get_device().lower() == "gpu" else "cpu"
|
| 102 |
-
|
| 103 |
-
def _download_model_snapshot() -> str:
|
| 104 |
os.makedirs(LOCAL_MODEL_DIR, exist_ok=True)
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
]
|
| 113 |
try:
|
| 114 |
-
|
| 115 |
repo_id=MODEL_ID,
|
|
|
|
| 116 |
local_dir=LOCAL_MODEL_DIR,
|
| 117 |
-
local_dir_use_symlinks=False,
|
| 118 |
-
allow_patterns=allow_patterns,
|
| 119 |
token=HF_TOKEN,
|
| 120 |
)
|
|
|
|
|
|
|
| 121 |
except Exception as e:
|
| 122 |
logger.error(f"Model download failed: {e}")
|
| 123 |
raise RuntimeError(f"Failed to download model: {str(e)}")
|
| 124 |
-
return LOCAL_MODEL_DIR
|
| 125 |
-
|
| 126 |
-
def _create_ort_session(model_path: str) -> ort.InferenceSession:
|
| 127 |
-
# --- OPTIMIZATION 1: Configure Session Options for LLMs ---
|
| 128 |
-
so = ort.SessionOptions()
|
| 129 |
-
# Disable all graph optimizations; they can be counter-productive for LLMs.
|
| 130 |
-
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
|
| 131 |
-
# Set intra-op threads to 1 to reduce thread pool overhead for small batch sizes.
|
| 132 |
-
so.intra_op_num_threads = 1
|
| 133 |
-
so.inter_op_num_threads = 1
|
| 134 |
-
so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
| 135 |
-
# Disable memory pattern optimization; it can increase memory fragmentation.
|
| 136 |
-
so.enable_mem_pattern = False
|
| 137 |
-
# Enable CPU memory arena for faster allocations.
|
| 138 |
-
so.enable_cpu_mem_arena = True
|
| 139 |
-
# Add an optimized execution provider for Intel CPUs.
|
| 140 |
-
# This provider is specifically designed to accelerate LLM inference.
|
| 141 |
-
providers = [
|
| 142 |
-
('OpenVINOExecutionProvider', {'device_type': 'CPU_FP32'}),
|
| 143 |
-
'CPUExecutionProvider'
|
| 144 |
-
]
|
| 145 |
-
try:
|
| 146 |
-
return ort.InferenceSession(model_path, sess_options=so, providers=providers)
|
| 147 |
-
except Exception as e:
|
| 148 |
-
logger.error(f"Failed to load ONNX session from {model_path}: {e}")
|
| 149 |
-
raise RuntimeError(f"ONNX session creation failed: {str(e)}")
|
| 150 |
|
| 151 |
async def _ensure_loaded():
|
| 152 |
-
global
|
| 153 |
-
global past_input_names, past_output_names, num_layers, num_kv_heads, head_dim, kv_cache_dtype
|
| 154 |
-
global has_num_logits_input, has_position_ids
|
| 155 |
-
|
| 156 |
async with MODEL_LOCK:
|
| 157 |
-
if
|
| 158 |
return
|
| 159 |
if model_load_error:
|
| 160 |
raise HTTPException(status_code=503, detail=f"Model failed to load: {model_load_error}")
|
| 161 |
try:
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
head_dim = config.get("head_dim", 128)
|
| 174 |
-
|
| 175 |
-
# Identify input/output names and special inputs
|
| 176 |
-
inputs = ort_session.get_inputs()
|
| 177 |
-
outputs = ort_session.get_outputs()
|
| 178 |
-
|
| 179 |
-
past_input_names = [inp.name for inp in inputs if inp.name.startswith("past_key_values")]
|
| 180 |
-
past_output_names = [out.name for out in outputs if out.name.startswith("present")]
|
| 181 |
-
|
| 182 |
-
for inp in inputs:
|
| 183 |
-
if inp.name.startswith("past_key_values"):
|
| 184 |
-
kv_cache_dtype = np.float16 if inp.type == "tensor(float16)" else np.float32
|
| 185 |
-
break
|
| 186 |
-
|
| 187 |
-
has_num_logits_input = "num_logits_to_keep" in [inp.name for inp in inputs]
|
| 188 |
-
has_position_ids = "position_ids" in [inp.name for inp in inputs]
|
| 189 |
-
|
| 190 |
-
logger.info(f"Model loaded: {MODEL_ID} ({MODEL_QUANTIZATION})")
|
| 191 |
-
logger.info(f"Layers: {num_layers}, KV heads: {num_kv_heads}, head dim: {head_dim}")
|
| 192 |
-
logger.info(f"Past inputs: {len(past_input_names)}, outputs: {len(past_output_names)}")
|
| 193 |
-
logger.info(f"num_logits_to_keep: {has_num_logits_input}, position_ids: {has_position_ids}")
|
| 194 |
except Exception as e:
|
| 195 |
model_load_error = str(e)
|
| 196 |
logger.exception("Model loading failed")
|
| 197 |
raise HTTPException(status_code=503, detail=f"Model unavailable: {model_load_error}")
|
| 198 |
|
| 199 |
def _build_chat_prompt(messages: List[Message]) -> str:
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
add_generation_prompt=True,
|
| 208 |
-
)
|
| 209 |
-
return prompt
|
| 210 |
-
except Exception as e:
|
| 211 |
-
logger.error(f"Chat template error: {e}")
|
| 212 |
-
prompt = ""
|
| 213 |
-
for msg in messages:
|
| 214 |
-
prompt += f"<|{msg.role}|>\n{msg.content}\n"
|
| 215 |
-
prompt += "<|assistant|>\n"
|
| 216 |
-
return prompt
|
| 217 |
-
|
| 218 |
-
def _count_tokens(text: str) -> int:
|
| 219 |
-
if tokenizer is None:
|
| 220 |
-
return len(text.split())
|
| 221 |
-
return len(tokenizer.encode(text))
|
| 222 |
-
|
| 223 |
-
def _softmax(x: np.ndarray) -> np.ndarray:
|
| 224 |
-
e_x = np.exp(x - np.max(x))
|
| 225 |
-
return e_x / e_x.sum(axis=-1, keepdims=True)
|
| 226 |
-
|
| 227 |
-
def _top_p_sampling(logits: np.ndarray, top_p: float) -> int:
|
| 228 |
-
sorted_indices = np.argsort(logits)[::-1]
|
| 229 |
-
sorted_logits = logits[sorted_indices]
|
| 230 |
-
probs = _softmax(sorted_logits)
|
| 231 |
-
cum_probs = np.cumsum(probs)
|
| 232 |
-
cutoff_index = np.searchsorted(cum_probs, top_p) + 1
|
| 233 |
-
top_indices = sorted_indices[:cutoff_index]
|
| 234 |
-
top_probs = probs[:cutoff_index]
|
| 235 |
-
top_probs /= top_probs.sum()
|
| 236 |
-
return int(np.random.choice(top_indices, p=top_probs))
|
| 237 |
-
|
| 238 |
-
def _sample_token(logits: np.ndarray, temperature: float, top_p: float) -> int:
|
| 239 |
-
if temperature <= 0:
|
| 240 |
-
return int(np.argmax(logits))
|
| 241 |
-
logits = logits / temperature
|
| 242 |
-
if top_p < 1.0:
|
| 243 |
-
return _top_p_sampling(logits, top_p)
|
| 244 |
-
probs = _softmax(logits)
|
| 245 |
-
return int(np.random.choice(len(probs), p=probs))
|
| 246 |
-
|
| 247 |
-
def _init_past_key_values(batch_size: int = 1) -> Dict[str, np.ndarray]:
|
| 248 |
-
"""Create zero-filled KV cache tensors with correct shape and dtype."""
|
| 249 |
-
past = {}
|
| 250 |
-
empty_shape = (batch_size, num_kv_heads, 0, head_dim)
|
| 251 |
-
empty_tensor = np.zeros(empty_shape, dtype=kv_cache_dtype)
|
| 252 |
-
for name in past_input_names:
|
| 253 |
-
past[name] = empty_tensor.copy()
|
| 254 |
-
return past
|
| 255 |
-
|
| 256 |
-
def _prepare_inputs(
|
| 257 |
-
input_ids: np.ndarray,
|
| 258 |
-
attention_mask: np.ndarray,
|
| 259 |
-
past_kv: Dict[str, np.ndarray],
|
| 260 |
-
position_ids: Optional[np.ndarray] = None,
|
| 261 |
-
) -> Dict[str, np.ndarray]:
|
| 262 |
-
"""Build input feed dictionary with all required tensors."""
|
| 263 |
-
feed = {
|
| 264 |
-
"input_ids": input_ids.astype(np.int64),
|
| 265 |
-
"attention_mask": attention_mask.astype(np.int64),
|
| 266 |
-
}
|
| 267 |
-
for name, tensor in past_kv.items():
|
| 268 |
-
feed[name] = tensor
|
| 269 |
-
|
| 270 |
-
if has_position_ids and position_ids is not None:
|
| 271 |
-
feed["position_ids"] = position_ids.astype(np.int64)
|
| 272 |
-
|
| 273 |
-
if has_num_logits_input:
|
| 274 |
-
feed["num_logits_to_keep"] = np.array(1, dtype=np.int64)
|
| 275 |
-
|
| 276 |
-
return feed
|
| 277 |
|
| 278 |
-
def
|
| 279 |
-
|
| 280 |
-
max_new_tokens: int,
|
| 281 |
-
temperature: float,
|
| 282 |
-
top_p: float,
|
| 283 |
-
stop_sequences: Optional[List[str]] = None,
|
| 284 |
-
) -> str:
|
| 285 |
-
if ort_session is None or tokenizer is None:
|
| 286 |
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 287 |
|
| 288 |
-
|
| 289 |
-
attention_mask = np.ones_like(input_ids, dtype=np.int64)
|
| 290 |
-
seq_len = input_ids.shape[1]
|
| 291 |
-
|
| 292 |
-
past_kv = _init_past_key_values(batch_size=1)
|
| 293 |
-
generated_tokens = []
|
| 294 |
-
stop_sequences = stop_sequences or []
|
| 295 |
-
eos_token_id = tokenizer.eos_token_id
|
| 296 |
-
|
| 297 |
-
# --- OPTIMIZATION 2: Use IOBinding to avoid copying KV cache ---
|
| 298 |
-
# Create an IOBinding object to bind inputs and outputs to device memory
|
| 299 |
-
io_binding = ort_session.io_binding()
|
| 300 |
-
|
| 301 |
-
# Prefill step
|
| 302 |
-
position_ids = np.arange(seq_len, dtype=np.int64).reshape(1, -1) if has_position_ids else None
|
| 303 |
-
feed = _prepare_inputs(input_ids, attention_mask, past_kv, position_ids)
|
| 304 |
-
|
| 305 |
-
# Bind inputs
|
| 306 |
-
for name, tensor in feed.items():
|
| 307 |
-
io_binding.bind_cpu_input(name, tensor)
|
| 308 |
-
|
| 309 |
-
# Bind outputs
|
| 310 |
-
for output in ort_session.get_outputs():
|
| 311 |
-
io_binding.bind_output(output.name)
|
| 312 |
-
|
| 313 |
-
ort_session.run_with_iobinding(io_binding)
|
| 314 |
-
outputs = io_binding.copy_outputs_to_cpu()
|
| 315 |
-
|
| 316 |
-
logits = outputs[0][:, -1, :]
|
| 317 |
-
next_token = _sample_token(logits[0], temperature, top_p)
|
| 318 |
-
generated_tokens.append(next_token)
|
| 319 |
-
|
| 320 |
-
past_kv_outputs = outputs[1:]
|
| 321 |
-
past_kv = dict(zip(past_input_names, past_kv_outputs))
|
| 322 |
-
|
| 323 |
-
for _ in range(1, max_new_tokens):
|
| 324 |
-
last_token = np.array([[next_token]], dtype=np.int64)
|
| 325 |
-
attention_mask = np.ones((1, seq_len + 1), dtype=np.int64)
|
| 326 |
-
position_ids = np.array([[seq_len]], dtype=np.int64) if has_position_ids else None
|
| 327 |
-
seq_len += 1
|
| 328 |
-
|
| 329 |
-
feed = _prepare_inputs(last_token, attention_mask, past_kv, position_ids)
|
| 330 |
-
|
| 331 |
-
# Bind inputs for the next token
|
| 332 |
-
io_binding.clear_binding_inputs()
|
| 333 |
-
for name, tensor in feed.items():
|
| 334 |
-
io_binding.bind_cpu_input(name, tensor)
|
| 335 |
-
|
| 336 |
-
# Bind outputs
|
| 337 |
-
io_binding.clear_binding_outputs()
|
| 338 |
-
for output in ort_session.get_outputs():
|
| 339 |
-
io_binding.bind_output(output.name)
|
| 340 |
-
|
| 341 |
-
ort_session.run_with_iobinding(io_binding)
|
| 342 |
-
outputs = io_binding.copy_outputs_to_cpu()
|
| 343 |
-
|
| 344 |
-
logits = outputs[0][:, -1, :]
|
| 345 |
-
next_token = _sample_token(logits[0], temperature, top_p)
|
| 346 |
-
generated_tokens.append(next_token)
|
| 347 |
-
|
| 348 |
-
past_kv_outputs = outputs[1:]
|
| 349 |
-
past_kv = dict(zip(past_input_names, past_kv_outputs))
|
| 350 |
-
|
| 351 |
-
if next_token == eos_token_id:
|
| 352 |
-
break
|
| 353 |
-
partial_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 354 |
-
for stop_seq in stop_sequences:
|
| 355 |
-
if stop_seq in partial_text:
|
| 356 |
-
return partial_text.split(stop_seq)[0].strip()
|
| 357 |
-
|
| 358 |
-
full_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 359 |
-
return full_text.strip()
|
| 360 |
-
|
| 361 |
-
async def _generate_full(
|
| 362 |
-
prompt: str,
|
| 363 |
-
max_new_tokens: int,
|
| 364 |
-
temperature: float,
|
| 365 |
-
top_p: float,
|
| 366 |
-
stop_sequences: Optional[List[str]] = None,
|
| 367 |
-
) -> str:
|
| 368 |
return await asyncio.to_thread(
|
| 369 |
-
|
| 370 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
)
|
| 372 |
|
| 373 |
-
async def _generate_stream(
|
| 374 |
-
|
| 375 |
-
max_new_tokens: int,
|
| 376 |
-
temperature: float,
|
| 377 |
-
top_p: float,
|
| 378 |
-
stop_sequences: Optional[List[str]] = None,
|
| 379 |
-
):
|
| 380 |
-
if ort_session is None or tokenizer is None:
|
| 381 |
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 382 |
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
return io_binding.copy_outputs_to_cpu()
|
| 408 |
-
|
| 409 |
-
outputs = await asyncio.to_thread(prefill_step)
|
| 410 |
-
logits = outputs[0][:, -1, :]
|
| 411 |
-
next_token = _sample_token(logits[0], temperature, top_p)
|
| 412 |
-
generated_tokens.append(next_token)
|
| 413 |
-
|
| 414 |
-
past_kv_outputs = outputs[1:]
|
| 415 |
-
past_kv = dict(zip(past_input_names, past_kv_outputs))
|
| 416 |
-
|
| 417 |
-
new_text = tokenizer.decode([next_token], skip_special_tokens=True)
|
| 418 |
-
if new_text:
|
| 419 |
-
yield new_text
|
| 420 |
-
|
| 421 |
-
for _ in range(1, max_new_tokens):
|
| 422 |
-
last_token = np.array([[next_token]], dtype=np.int64)
|
| 423 |
-
attention_mask = np.ones((1, seq_len + 1), dtype=np.int64)
|
| 424 |
-
position_ids = np.array([[seq_len]], dtype=np.int64) if has_position_ids else None
|
| 425 |
-
seq_len += 1
|
| 426 |
-
|
| 427 |
-
def step():
|
| 428 |
-
feed = _prepare_inputs(last_token, attention_mask, past_kv, position_ids)
|
| 429 |
-
|
| 430 |
-
io_binding.clear_binding_inputs()
|
| 431 |
-
io_binding.clear_binding_outputs()
|
| 432 |
-
for name, tensor in feed.items():
|
| 433 |
-
io_binding.bind_cpu_input(name, tensor)
|
| 434 |
-
for output in ort_session.get_outputs():
|
| 435 |
-
io_binding.bind_output(output.name)
|
| 436 |
-
|
| 437 |
-
ort_session.run_with_iobinding(io_binding)
|
| 438 |
-
return io_binding.copy_outputs_to_cpu()
|
| 439 |
-
|
| 440 |
-
outputs = await asyncio.to_thread(step)
|
| 441 |
-
logits = outputs[0][:, -1, :]
|
| 442 |
-
next_token = _sample_token(logits[0], temperature, top_p)
|
| 443 |
-
generated_tokens.append(next_token)
|
| 444 |
-
|
| 445 |
-
past_kv_outputs = outputs[1:]
|
| 446 |
-
past_kv = dict(zip(past_input_names, past_kv_outputs))
|
| 447 |
-
|
| 448 |
-
new_text = tokenizer.decode([next_token], skip_special_tokens=True)
|
| 449 |
-
if new_text:
|
| 450 |
-
yield new_text
|
| 451 |
-
|
| 452 |
-
full_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 453 |
-
for stop_seq in stop_sequences:
|
| 454 |
-
if stop_seq in full_text:
|
| 455 |
-
return
|
| 456 |
-
if next_token == eos_token_id:
|
| 457 |
-
break
|
| 458 |
-
|
| 459 |
-
# ---------- FastAPI App ----------
|
| 460 |
@asynccontextmanager
|
| 461 |
async def lifespan(app: FastAPI):
|
| 462 |
try:
|
|
@@ -465,14 +203,13 @@ async def lifespan(app: FastAPI):
|
|
| 465 |
except Exception as e:
|
| 466 |
logger.error(f"Startup model load failed: {e}")
|
| 467 |
yield
|
| 468 |
-
global
|
| 469 |
-
|
| 470 |
-
ort_session = None
|
| 471 |
|
| 472 |
app = FastAPI(
|
| 473 |
-
title="Bonsai
|
| 474 |
-
version="
|
| 475 |
-
description="
|
| 476 |
docs_url="/docs",
|
| 477 |
redoc_url="/redoc",
|
| 478 |
lifespan=lifespan,
|
|
@@ -516,17 +253,16 @@ async def generic_exception_handler(request, exc):
|
|
| 516 |
|
| 517 |
@app.get("/", summary="Root")
|
| 518 |
def root():
|
| 519 |
-
return {"message": "Bonsai
|
| 520 |
|
| 521 |
@app.get("/health", summary="Health check")
|
| 522 |
def health():
|
| 523 |
-
loaded =
|
| 524 |
return {
|
| 525 |
"status": "ok" if loaded else "degraded",
|
| 526 |
"model_loaded": loaded,
|
| 527 |
"model_id": MODEL_ID,
|
| 528 |
-
"
|
| 529 |
-
"device": _model_device(),
|
| 530 |
"error": model_load_error if model_load_error else None,
|
| 531 |
}
|
| 532 |
|
|
@@ -534,9 +270,10 @@ def health():
|
|
| 534 |
def model_info():
|
| 535 |
return ModelInfo(
|
| 536 |
model_id=MODEL_ID,
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
|
|
|
| 540 |
)
|
| 541 |
|
| 542 |
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
|
@@ -564,9 +301,9 @@ async def chat_completions(req: ChatCompletionRequest):
|
|
| 564 |
text = await _generate_full(prompt, req.max_tokens, req.temperature, req.top_p, stop_seq)
|
| 565 |
assistant_msg = Message(role="assistant", content=text)
|
| 566 |
usage = Usage(
|
| 567 |
-
prompt_tokens=
|
| 568 |
-
completion_tokens=
|
| 569 |
-
total_tokens=
|
| 570 |
)
|
| 571 |
return ChatCompletionResponse(
|
| 572 |
id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
|
|
|
| 7 |
import time
|
| 8 |
import uuid
|
| 9 |
from contextlib import asynccontextmanager
|
| 10 |
+
from typing import List, Optional, Union
|
| 11 |
|
|
|
|
|
|
|
| 12 |
from fastapi import FastAPI, HTTPException, Request
|
| 13 |
from fastapi.middleware.cors import CORSMiddleware
|
| 14 |
from fastapi.responses import JSONResponse, StreamingResponse
|
| 15 |
+
from huggingface_hub import hf_hub_download
|
| 16 |
from pydantic import BaseModel, Field, ValidationError
|
| 17 |
+
|
| 18 |
+
# NEW: Import llama.cpp
|
| 19 |
+
from llama_cpp import Llama
|
| 20 |
|
| 21 |
# ---------- Configuration ----------
|
| 22 |
+
# You can now use GGUF models for even faster inference!
|
| 23 |
+
# These are specifically optimized by the PrismML team.
|
| 24 |
+
MODEL_ID = os.getenv("MODEL_ID", "prism-ml/Bonsai-1.7B-gguf")
|
| 25 |
+
MODEL_FILENAME = os.getenv("MODEL_FILENAME", "Bonsai-1.7B-v1.0-Q1_0.gguf")
|
| 26 |
+
|
| 27 |
+
# Quantization types in GGUF: Q1_0 is for 1-bit models.
|
| 28 |
+
# For 8B, use MODEL_ID="prism-ml/Bonsai-8B-gguf" and MODEL_FILENAME="Bonsai-8B-v1.0-Q1_0.gguf"
|
| 29 |
|
| 30 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 31 |
+
LOCAL_MODEL_DIR = os.getenv("LOCAL_MODEL_DIR", "/data/models")
|
| 32 |
MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS_DEFAULT", "256"))
|
| 33 |
API_KEY = os.getenv("API_KEY", None)
|
| 34 |
|
| 35 |
+
# Performance settings for CPU inference
|
| 36 |
+
N_CTX = int(os.getenv("N_CTX", "4096")) # Context window
|
| 37 |
+
N_THREADS = int(os.getenv("N_THREADS", "4")) # Number of CPU threads to use
|
| 38 |
+
N_BATCH = int(os.getenv("N_BATCH", "512")) # Batch size for prompt processing
|
| 39 |
+
|
| 40 |
logging.basicConfig(level=logging.INFO)
|
| 41 |
logger = logging.getLogger("uvicorn.error")
|
| 42 |
|
| 43 |
+
# ---------- Pydantic Models (Same as before) ----------
|
| 44 |
class Message(BaseModel):
|
| 45 |
role: str = Field(..., pattern="^(system|user|assistant)$")
|
| 46 |
content: str
|
|
|
|
| 74 |
|
| 75 |
class ModelInfo(BaseModel):
|
| 76 |
model_id: str
|
| 77 |
+
filename: str
|
|
|
|
| 78 |
device: str
|
| 79 |
+
n_ctx: int
|
| 80 |
+
n_threads: int
|
| 81 |
|
| 82 |
class ErrorResponse(BaseModel):
|
| 83 |
error: str
|
| 84 |
detail: Optional[str] = None
|
| 85 |
|
| 86 |
# ---------- Global State ----------
|
| 87 |
+
llm = None
|
|
|
|
| 88 |
model_load_error = None
|
| 89 |
MODEL_LOCK = asyncio.Lock()
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
# ---------- Helper Functions ----------
|
| 92 |
def _verify_api_key(request: Request) -> None:
|
| 93 |
if API_KEY is None:
|
|
|
|
| 96 |
if not auth or auth != API_KEY:
|
| 97 |
raise HTTPException(status_code=401, detail="Invalid or missing API key")
|
| 98 |
|
| 99 |
+
def _download_model() -> str:
|
|
|
|
|
|
|
|
|
|
| 100 |
os.makedirs(LOCAL_MODEL_DIR, exist_ok=True)
|
| 101 |
+
local_path = os.path.join(LOCAL_MODEL_DIR, MODEL_FILENAME)
|
| 102 |
+
|
| 103 |
+
if os.path.exists(local_path):
|
| 104 |
+
logger.info(f"Model already downloaded at {local_path}")
|
| 105 |
+
return local_path
|
| 106 |
+
|
| 107 |
+
logger.info(f"Downloading model {MODEL_ID}/{MODEL_FILENAME}...")
|
|
|
|
| 108 |
try:
|
| 109 |
+
hf_hub_download(
|
| 110 |
repo_id=MODEL_ID,
|
| 111 |
+
filename=MODEL_FILENAME,
|
| 112 |
local_dir=LOCAL_MODEL_DIR,
|
|
|
|
|
|
|
| 113 |
token=HF_TOKEN,
|
| 114 |
)
|
| 115 |
+
logger.info("Model downloaded successfully.")
|
| 116 |
+
return local_path
|
| 117 |
except Exception as e:
|
| 118 |
logger.error(f"Model download failed: {e}")
|
| 119 |
raise RuntimeError(f"Failed to download model: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
async def _ensure_loaded():
|
| 122 |
+
global llm, model_load_error
|
|
|
|
|
|
|
|
|
|
| 123 |
async with MODEL_LOCK:
|
| 124 |
+
if llm is not None:
|
| 125 |
return
|
| 126 |
if model_load_error:
|
| 127 |
raise HTTPException(status_code=503, detail=f"Model failed to load: {model_load_error}")
|
| 128 |
try:
|
| 129 |
+
model_path = _download_model()
|
| 130 |
+
# Load the model with CPU-optimized settings
|
| 131 |
+
llm = Llama(
|
| 132 |
+
model_path=model_path,
|
| 133 |
+
n_ctx=N_CTX, # Context window
|
| 134 |
+
n_threads=N_THREADS, # Number of CPU threads
|
| 135 |
+
n_batch=N_BATCH, # Batch size for prompt processing
|
| 136 |
+
verbose=False,
|
| 137 |
+
)
|
| 138 |
+
logger.info(f"Model loaded successfully: {MODEL_ID} ({MODEL_FILENAME})")
|
| 139 |
+
logger.info(f"Context: {N_CTX}, Threads: {N_THREADS}, Batch: {N_BATCH}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
except Exception as e:
|
| 141 |
model_load_error = str(e)
|
| 142 |
logger.exception("Model loading failed")
|
| 143 |
raise HTTPException(status_code=503, detail=f"Model unavailable: {model_load_error}")
|
| 144 |
|
| 145 |
def _build_chat_prompt(messages: List[Message]) -> str:
|
| 146 |
+
# llama.cpp handles chat templates automatically, so we can just pass the messages directly.
|
| 147 |
+
# This is for compatibility; the actual formatting is done by llama.cpp.
|
| 148 |
+
if llm is None:
|
| 149 |
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 150 |
+
|
| 151 |
+
# The create_chat_completion method expects a list of messages in this format
|
| 152 |
+
return [{"role": msg.role, "content": msg.content} for msg in messages]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
+
async def _generate_full(prompt: list, max_new_tokens: int, temperature: float, top_p: float, stop_sequences: Optional[List[str]] = None) -> str:
|
| 155 |
+
if llm is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 157 |
|
| 158 |
+
# Run the blocking llama.cpp call in a thread
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
return await asyncio.to_thread(
|
| 160 |
+
lambda: llm.create_chat_completion(
|
| 161 |
+
messages=prompt,
|
| 162 |
+
max_tokens=max_new_tokens,
|
| 163 |
+
temperature=temperature,
|
| 164 |
+
top_p=top_p,
|
| 165 |
+
stop=stop_sequences,
|
| 166 |
+
stream=False,
|
| 167 |
+
)["choices"][0]["message"]["content"]
|
| 168 |
)
|
| 169 |
|
| 170 |
+
async def _generate_stream(prompt: list, max_new_tokens: int, temperature: float, top_p: float, stop_sequences: Optional[List[str]] = None):
|
| 171 |
+
if llm is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 173 |
|
| 174 |
+
# llama.cpp can yield a Python generator. We'll run it in a thread and yield the results.
|
| 175 |
+
def generator():
|
| 176 |
+
for chunk in llm.create_chat_completion(
|
| 177 |
+
messages=prompt,
|
| 178 |
+
max_tokens=max_new_tokens,
|
| 179 |
+
temperature=temperature,
|
| 180 |
+
top_p=top_p,
|
| 181 |
+
stop=stop_sequences,
|
| 182 |
+
stream=True,
|
| 183 |
+
):
|
| 184 |
+
if "content" in chunk["choices"][0]["delta"]:
|
| 185 |
+
yield chunk["choices"][0]["delta"]["content"]
|
| 186 |
+
|
| 187 |
+
# We need a helper to bridge the sync generator to an async one
|
| 188 |
+
def sync_generator():
|
| 189 |
+
for item in generator():
|
| 190 |
+
yield item
|
| 191 |
+
|
| 192 |
+
# Run the sync generator in a thread and yield items as they come
|
| 193 |
+
for item in await asyncio.to_thread(list, sync_generator()):
|
| 194 |
+
yield item
|
| 195 |
+
await asyncio.sleep(0) # Yield control to the event loop
|
| 196 |
+
|
| 197 |
+
# ---------- FastAPI App (Same structure) ----------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
@asynccontextmanager
|
| 199 |
async def lifespan(app: FastAPI):
|
| 200 |
try:
|
|
|
|
| 203 |
except Exception as e:
|
| 204 |
logger.error(f"Startup model load failed: {e}")
|
| 205 |
yield
|
| 206 |
+
global llm
|
| 207 |
+
llm = None
|
|
|
|
| 208 |
|
| 209 |
app = FastAPI(
|
| 210 |
+
title="Bonsai CPU-Optimized Inference API",
|
| 211 |
+
version="2.0.0",
|
| 212 |
+
description="Lightning-fast inference for 1-bit Bonsai LLMs using llama.cpp.",
|
| 213 |
docs_url="/docs",
|
| 214 |
redoc_url="/redoc",
|
| 215 |
lifespan=lifespan,
|
|
|
|
| 253 |
|
| 254 |
@app.get("/", summary="Root")
|
| 255 |
def root():
|
| 256 |
+
return {"message": "Bonsai CPU API is running", "docs": "/docs"}
|
| 257 |
|
| 258 |
@app.get("/health", summary="Health check")
|
| 259 |
def health():
|
| 260 |
+
loaded = llm is not None
|
| 261 |
return {
|
| 262 |
"status": "ok" if loaded else "degraded",
|
| 263 |
"model_loaded": loaded,
|
| 264 |
"model_id": MODEL_ID,
|
| 265 |
+
"filename": MODEL_FILENAME,
|
|
|
|
| 266 |
"error": model_load_error if model_load_error else None,
|
| 267 |
}
|
| 268 |
|
|
|
|
| 270 |
def model_info():
|
| 271 |
return ModelInfo(
|
| 272 |
model_id=MODEL_ID,
|
| 273 |
+
filename=MODEL_FILENAME,
|
| 274 |
+
device="CPU",
|
| 275 |
+
n_ctx=N_CTX,
|
| 276 |
+
n_threads=N_THREADS,
|
| 277 |
)
|
| 278 |
|
| 279 |
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
|
|
|
| 301 |
text = await _generate_full(prompt, req.max_tokens, req.temperature, req.top_p, stop_seq)
|
| 302 |
assistant_msg = Message(role="assistant", content=text)
|
| 303 |
usage = Usage(
|
| 304 |
+
prompt_tokens=0, # llama.cpp can return this, but we can omit for simplicity
|
| 305 |
+
completion_tokens=0,
|
| 306 |
+
total_tokens=0,
|
| 307 |
)
|
| 308 |
return ChatCompletionResponse(
|
| 309 |
id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
|