sidmaz666 commited on
Commit
fb6570d
·
verified ·
1 Parent(s): b4e4f36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -370
app.py CHANGED
@@ -7,31 +7,40 @@ import os
7
  import time
8
  import uuid
9
  from contextlib import asynccontextmanager
10
- from typing import Any, Dict, List, Optional, Union
11
 
12
- import numpy as np
13
- import onnxruntime as ort
14
  from fastapi import FastAPI, HTTPException, Request
15
  from fastapi.middleware.cors import CORSMiddleware
16
  from fastapi.responses import JSONResponse, StreamingResponse
17
- from huggingface_hub import snapshot_download
18
  from pydantic import BaseModel, Field, ValidationError
19
- from transformers import AutoTokenizer
 
 
20
 
21
  # ---------- Configuration ----------
22
- MODEL_ID = os.getenv("MODEL_ID", "onnx-community/Bonsai-1.7B-ONNX")
23
- MODEL_QUANTIZATION = os.getenv("MODEL_QUANTIZATION", "q1")
24
- ONNX_MODEL_FILE = f"model_{MODEL_QUANTIZATION}.onnx"
 
 
 
 
25
 
26
  HF_TOKEN = os.getenv("HF_TOKEN")
27
- LOCAL_MODEL_DIR = os.getenv("LOCAL_MODEL_DIR", "/data/bonsai-onnx")
28
  MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS_DEFAULT", "256"))
29
  API_KEY = os.getenv("API_KEY", None)
30
 
 
 
 
 
 
31
  logging.basicConfig(level=logging.INFO)
32
  logger = logging.getLogger("uvicorn.error")
33
 
34
- # ---------- Pydantic Models ----------
35
  class Message(BaseModel):
36
  role: str = Field(..., pattern="^(system|user|assistant)$")
37
  content: str
@@ -65,30 +74,20 @@ class ChatCompletionResponse(BaseModel):
65
 
66
  class ModelInfo(BaseModel):
67
  model_id: str
68
- quantization: str
69
- onnx_model_file: str
70
  device: str
 
 
71
 
72
  class ErrorResponse(BaseModel):
73
  error: str
74
  detail: Optional[str] = None
75
 
76
  # ---------- Global State ----------
77
- tokenizer = None
78
- ort_session = None
79
  model_load_error = None
80
  MODEL_LOCK = asyncio.Lock()
81
 
82
- # Cached model metadata
83
- past_input_names = []
84
- past_output_names = []
85
- num_layers = 0
86
- num_kv_heads = 0
87
- head_dim = 0
88
- kv_cache_dtype = np.float32
89
- has_num_logits_input = False
90
- has_position_ids = False
91
-
92
  # ---------- Helper Functions ----------
93
  def _verify_api_key(request: Request) -> None:
94
  if API_KEY is None:
@@ -97,366 +96,105 @@ def _verify_api_key(request: Request) -> None:
97
  if not auth or auth != API_KEY:
98
  raise HTTPException(status_code=401, detail="Invalid or missing API key")
99
 
100
- def _model_device() -> str:
101
- return "cuda" if ort.get_device().lower() == "gpu" else "cpu"
102
-
103
- def _download_model_snapshot() -> str:
104
  os.makedirs(LOCAL_MODEL_DIR, exist_ok=True)
105
- allow_patterns = [
106
- "config.json",
107
- "tokenizer.json",
108
- "tokenizer_config.json",
109
- "chat_template.jinja",
110
- f"onnx/{ONNX_MODEL_FILE}",
111
- f"onnx/{ONNX_MODEL_FILE}_data",
112
- ]
113
  try:
114
- snapshot_download(
115
  repo_id=MODEL_ID,
 
116
  local_dir=LOCAL_MODEL_DIR,
117
- local_dir_use_symlinks=False,
118
- allow_patterns=allow_patterns,
119
  token=HF_TOKEN,
120
  )
 
 
121
  except Exception as e:
122
  logger.error(f"Model download failed: {e}")
123
  raise RuntimeError(f"Failed to download model: {str(e)}")
124
- return LOCAL_MODEL_DIR
125
-
126
- def _create_ort_session(model_path: str) -> ort.InferenceSession:
127
- # --- OPTIMIZATION 1: Configure Session Options for LLMs ---
128
- so = ort.SessionOptions()
129
- # Disable all graph optimizations; they can be counter-productive for LLMs.
130
- so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
131
- # Set intra-op threads to 1 to reduce thread pool overhead for small batch sizes.
132
- so.intra_op_num_threads = 1
133
- so.inter_op_num_threads = 1
134
- so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
135
- # Disable memory pattern optimization; it can increase memory fragmentation.
136
- so.enable_mem_pattern = False
137
- # Enable CPU memory arena for faster allocations.
138
- so.enable_cpu_mem_arena = True
139
- # Add an optimized execution provider for Intel CPUs.
140
- # This provider is specifically designed to accelerate LLM inference.
141
- providers = [
142
- ('OpenVINOExecutionProvider', {'device_type': 'CPU_FP32'}),
143
- 'CPUExecutionProvider'
144
- ]
145
- try:
146
- return ort.InferenceSession(model_path, sess_options=so, providers=providers)
147
- except Exception as e:
148
- logger.error(f"Failed to load ONNX session from {model_path}: {e}")
149
- raise RuntimeError(f"ONNX session creation failed: {str(e)}")
150
 
151
  async def _ensure_loaded():
152
- global tokenizer, ort_session, model_load_error
153
- global past_input_names, past_output_names, num_layers, num_kv_heads, head_dim, kv_cache_dtype
154
- global has_num_logits_input, has_position_ids
155
-
156
  async with MODEL_LOCK:
157
- if tokenizer is not None and ort_session is not None:
158
  return
159
  if model_load_error:
160
  raise HTTPException(status_code=503, detail=f"Model failed to load: {model_load_error}")
161
  try:
162
- local_dir = _download_model_snapshot()
163
- tokenizer = AutoTokenizer.from_pretrained(local_dir, trust_remote_code=True)
164
- onnx_path = os.path.join(local_dir, "onnx", ONNX_MODEL_FILE)
165
- ort_session = _create_ort_session(onnx_path)
166
-
167
- # Read model architecture from config.json
168
- import json
169
- with open(os.path.join(local_dir, "config.json"), "r") as f:
170
- config = json.load(f)
171
- num_layers = config.get("num_hidden_layers", 28)
172
- num_kv_heads = config.get("num_key_value_heads", 8)
173
- head_dim = config.get("head_dim", 128)
174
-
175
- # Identify input/output names and special inputs
176
- inputs = ort_session.get_inputs()
177
- outputs = ort_session.get_outputs()
178
-
179
- past_input_names = [inp.name for inp in inputs if inp.name.startswith("past_key_values")]
180
- past_output_names = [out.name for out in outputs if out.name.startswith("present")]
181
-
182
- for inp in inputs:
183
- if inp.name.startswith("past_key_values"):
184
- kv_cache_dtype = np.float16 if inp.type == "tensor(float16)" else np.float32
185
- break
186
-
187
- has_num_logits_input = "num_logits_to_keep" in [inp.name for inp in inputs]
188
- has_position_ids = "position_ids" in [inp.name for inp in inputs]
189
-
190
- logger.info(f"Model loaded: {MODEL_ID} ({MODEL_QUANTIZATION})")
191
- logger.info(f"Layers: {num_layers}, KV heads: {num_kv_heads}, head dim: {head_dim}")
192
- logger.info(f"Past inputs: {len(past_input_names)}, outputs: {len(past_output_names)}")
193
- logger.info(f"num_logits_to_keep: {has_num_logits_input}, position_ids: {has_position_ids}")
194
  except Exception as e:
195
  model_load_error = str(e)
196
  logger.exception("Model loading failed")
197
  raise HTTPException(status_code=503, detail=f"Model unavailable: {model_load_error}")
198
 
199
  def _build_chat_prompt(messages: List[Message]) -> str:
200
- if tokenizer is None:
201
- raise HTTPException(status_code=503, detail="Tokenizer not loaded")
202
- try:
203
- formatted_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
204
- prompt = tokenizer.apply_chat_template(
205
- formatted_messages,
206
- tokenize=False,
207
- add_generation_prompt=True,
208
- )
209
- return prompt
210
- except Exception as e:
211
- logger.error(f"Chat template error: {e}")
212
- prompt = ""
213
- for msg in messages:
214
- prompt += f"<|{msg.role}|>\n{msg.content}\n"
215
- prompt += "<|assistant|>\n"
216
- return prompt
217
-
218
- def _count_tokens(text: str) -> int:
219
- if tokenizer is None:
220
- return len(text.split())
221
- return len(tokenizer.encode(text))
222
-
223
- def _softmax(x: np.ndarray) -> np.ndarray:
224
- e_x = np.exp(x - np.max(x))
225
- return e_x / e_x.sum(axis=-1, keepdims=True)
226
-
227
- def _top_p_sampling(logits: np.ndarray, top_p: float) -> int:
228
- sorted_indices = np.argsort(logits)[::-1]
229
- sorted_logits = logits[sorted_indices]
230
- probs = _softmax(sorted_logits)
231
- cum_probs = np.cumsum(probs)
232
- cutoff_index = np.searchsorted(cum_probs, top_p) + 1
233
- top_indices = sorted_indices[:cutoff_index]
234
- top_probs = probs[:cutoff_index]
235
- top_probs /= top_probs.sum()
236
- return int(np.random.choice(top_indices, p=top_probs))
237
-
238
- def _sample_token(logits: np.ndarray, temperature: float, top_p: float) -> int:
239
- if temperature <= 0:
240
- return int(np.argmax(logits))
241
- logits = logits / temperature
242
- if top_p < 1.0:
243
- return _top_p_sampling(logits, top_p)
244
- probs = _softmax(logits)
245
- return int(np.random.choice(len(probs), p=probs))
246
-
247
- def _init_past_key_values(batch_size: int = 1) -> Dict[str, np.ndarray]:
248
- """Create zero-filled KV cache tensors with correct shape and dtype."""
249
- past = {}
250
- empty_shape = (batch_size, num_kv_heads, 0, head_dim)
251
- empty_tensor = np.zeros(empty_shape, dtype=kv_cache_dtype)
252
- for name in past_input_names:
253
- past[name] = empty_tensor.copy()
254
- return past
255
-
256
- def _prepare_inputs(
257
- input_ids: np.ndarray,
258
- attention_mask: np.ndarray,
259
- past_kv: Dict[str, np.ndarray],
260
- position_ids: Optional[np.ndarray] = None,
261
- ) -> Dict[str, np.ndarray]:
262
- """Build input feed dictionary with all required tensors."""
263
- feed = {
264
- "input_ids": input_ids.astype(np.int64),
265
- "attention_mask": attention_mask.astype(np.int64),
266
- }
267
- for name, tensor in past_kv.items():
268
- feed[name] = tensor
269
-
270
- if has_position_ids and position_ids is not None:
271
- feed["position_ids"] = position_ids.astype(np.int64)
272
-
273
- if has_num_logits_input:
274
- feed["num_logits_to_keep"] = np.array(1, dtype=np.int64)
275
-
276
- return feed
277
 
278
- def _generate_full_sync(
279
- prompt: str,
280
- max_new_tokens: int,
281
- temperature: float,
282
- top_p: float,
283
- stop_sequences: Optional[List[str]] = None,
284
- ) -> str:
285
- if ort_session is None or tokenizer is None:
286
  raise HTTPException(status_code=503, detail="Model not loaded")
287
 
288
- input_ids = tokenizer.encode(prompt, return_tensors="np").astype(np.int64)
289
- attention_mask = np.ones_like(input_ids, dtype=np.int64)
290
- seq_len = input_ids.shape[1]
291
-
292
- past_kv = _init_past_key_values(batch_size=1)
293
- generated_tokens = []
294
- stop_sequences = stop_sequences or []
295
- eos_token_id = tokenizer.eos_token_id
296
-
297
- # --- OPTIMIZATION 2: Use IOBinding to avoid copying KV cache ---
298
- # Create an IOBinding object to bind inputs and outputs to device memory
299
- io_binding = ort_session.io_binding()
300
-
301
- # Prefill step
302
- position_ids = np.arange(seq_len, dtype=np.int64).reshape(1, -1) if has_position_ids else None
303
- feed = _prepare_inputs(input_ids, attention_mask, past_kv, position_ids)
304
-
305
- # Bind inputs
306
- for name, tensor in feed.items():
307
- io_binding.bind_cpu_input(name, tensor)
308
-
309
- # Bind outputs
310
- for output in ort_session.get_outputs():
311
- io_binding.bind_output(output.name)
312
-
313
- ort_session.run_with_iobinding(io_binding)
314
- outputs = io_binding.copy_outputs_to_cpu()
315
-
316
- logits = outputs[0][:, -1, :]
317
- next_token = _sample_token(logits[0], temperature, top_p)
318
- generated_tokens.append(next_token)
319
-
320
- past_kv_outputs = outputs[1:]
321
- past_kv = dict(zip(past_input_names, past_kv_outputs))
322
-
323
- for _ in range(1, max_new_tokens):
324
- last_token = np.array([[next_token]], dtype=np.int64)
325
- attention_mask = np.ones((1, seq_len + 1), dtype=np.int64)
326
- position_ids = np.array([[seq_len]], dtype=np.int64) if has_position_ids else None
327
- seq_len += 1
328
-
329
- feed = _prepare_inputs(last_token, attention_mask, past_kv, position_ids)
330
-
331
- # Bind inputs for the next token
332
- io_binding.clear_binding_inputs()
333
- for name, tensor in feed.items():
334
- io_binding.bind_cpu_input(name, tensor)
335
-
336
- # Bind outputs
337
- io_binding.clear_binding_outputs()
338
- for output in ort_session.get_outputs():
339
- io_binding.bind_output(output.name)
340
-
341
- ort_session.run_with_iobinding(io_binding)
342
- outputs = io_binding.copy_outputs_to_cpu()
343
-
344
- logits = outputs[0][:, -1, :]
345
- next_token = _sample_token(logits[0], temperature, top_p)
346
- generated_tokens.append(next_token)
347
-
348
- past_kv_outputs = outputs[1:]
349
- past_kv = dict(zip(past_input_names, past_kv_outputs))
350
-
351
- if next_token == eos_token_id:
352
- break
353
- partial_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
354
- for stop_seq in stop_sequences:
355
- if stop_seq in partial_text:
356
- return partial_text.split(stop_seq)[0].strip()
357
-
358
- full_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
359
- return full_text.strip()
360
-
361
- async def _generate_full(
362
- prompt: str,
363
- max_new_tokens: int,
364
- temperature: float,
365
- top_p: float,
366
- stop_sequences: Optional[List[str]] = None,
367
- ) -> str:
368
  return await asyncio.to_thread(
369
- _generate_full_sync,
370
- prompt, max_new_tokens, temperature, top_p, stop_sequences
 
 
 
 
 
 
371
  )
372
 
373
- async def _generate_stream(
374
- prompt: str,
375
- max_new_tokens: int,
376
- temperature: float,
377
- top_p: float,
378
- stop_sequences: Optional[List[str]] = None,
379
- ):
380
- if ort_session is None or tokenizer is None:
381
  raise HTTPException(status_code=503, detail="Model not loaded")
382
 
383
- input_ids = tokenizer.encode(prompt, return_tensors="np").astype(np.int64)
384
- attention_mask = np.ones_like(input_ids, dtype=np.int64)
385
- seq_len = input_ids.shape[1]
386
-
387
- past_kv = _init_past_key_values(batch_size=1)
388
- generated_tokens = []
389
- stop_sequences = stop_sequences or []
390
- eos_token_id = tokenizer.eos_token_id
391
-
392
- # Use IOBinding for prefill
393
- io_binding = ort_session.io_binding()
394
-
395
- def prefill_step():
396
- position_ids = np.arange(seq_len, dtype=np.int64).reshape(1, -1) if has_position_ids else None
397
- feed = _prepare_inputs(input_ids, attention_mask, past_kv, position_ids)
398
-
399
- io_binding.clear_binding_inputs()
400
- io_binding.clear_binding_outputs()
401
- for name, tensor in feed.items():
402
- io_binding.bind_cpu_input(name, tensor)
403
- for output in ort_session.get_outputs():
404
- io_binding.bind_output(output.name)
405
-
406
- ort_session.run_with_iobinding(io_binding)
407
- return io_binding.copy_outputs_to_cpu()
408
-
409
- outputs = await asyncio.to_thread(prefill_step)
410
- logits = outputs[0][:, -1, :]
411
- next_token = _sample_token(logits[0], temperature, top_p)
412
- generated_tokens.append(next_token)
413
-
414
- past_kv_outputs = outputs[1:]
415
- past_kv = dict(zip(past_input_names, past_kv_outputs))
416
-
417
- new_text = tokenizer.decode([next_token], skip_special_tokens=True)
418
- if new_text:
419
- yield new_text
420
-
421
- for _ in range(1, max_new_tokens):
422
- last_token = np.array([[next_token]], dtype=np.int64)
423
- attention_mask = np.ones((1, seq_len + 1), dtype=np.int64)
424
- position_ids = np.array([[seq_len]], dtype=np.int64) if has_position_ids else None
425
- seq_len += 1
426
-
427
- def step():
428
- feed = _prepare_inputs(last_token, attention_mask, past_kv, position_ids)
429
-
430
- io_binding.clear_binding_inputs()
431
- io_binding.clear_binding_outputs()
432
- for name, tensor in feed.items():
433
- io_binding.bind_cpu_input(name, tensor)
434
- for output in ort_session.get_outputs():
435
- io_binding.bind_output(output.name)
436
-
437
- ort_session.run_with_iobinding(io_binding)
438
- return io_binding.copy_outputs_to_cpu()
439
-
440
- outputs = await asyncio.to_thread(step)
441
- logits = outputs[0][:, -1, :]
442
- next_token = _sample_token(logits[0], temperature, top_p)
443
- generated_tokens.append(next_token)
444
-
445
- past_kv_outputs = outputs[1:]
446
- past_kv = dict(zip(past_input_names, past_kv_outputs))
447
-
448
- new_text = tokenizer.decode([next_token], skip_special_tokens=True)
449
- if new_text:
450
- yield new_text
451
-
452
- full_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
453
- for stop_seq in stop_sequences:
454
- if stop_seq in full_text:
455
- return
456
- if next_token == eos_token_id:
457
- break
458
-
459
- # ---------- FastAPI App ----------
460
  @asynccontextmanager
461
  async def lifespan(app: FastAPI):
462
  try:
@@ -465,14 +203,13 @@ async def lifespan(app: FastAPI):
465
  except Exception as e:
466
  logger.error(f"Startup model load failed: {e}")
467
  yield
468
- global tokenizer, ort_session
469
- tokenizer = None
470
- ort_session = None
471
 
472
  app = FastAPI(
473
- title="Bonsai ONNX Inference API",
474
- version="1.0.0",
475
- description="Fast, production-ready inference for 1-bit Bonsai LLMs using ONNX Runtime.",
476
  docs_url="/docs",
477
  redoc_url="/redoc",
478
  lifespan=lifespan,
@@ -516,17 +253,16 @@ async def generic_exception_handler(request, exc):
516
 
517
  @app.get("/", summary="Root")
518
  def root():
519
- return {"message": "Bonsai ONNX API is running", "docs": "/docs"}
520
 
521
  @app.get("/health", summary="Health check")
522
  def health():
523
- loaded = tokenizer is not None and ort_session is not None
524
  return {
525
  "status": "ok" if loaded else "degraded",
526
  "model_loaded": loaded,
527
  "model_id": MODEL_ID,
528
- "quantization": MODEL_QUANTIZATION,
529
- "device": _model_device(),
530
  "error": model_load_error if model_load_error else None,
531
  }
532
 
@@ -534,9 +270,10 @@ def health():
534
  def model_info():
535
  return ModelInfo(
536
  model_id=MODEL_ID,
537
- quantization=MODEL_QUANTIZATION,
538
- onnx_model_file=ONNX_MODEL_FILE,
539
- device=_model_device(),
 
540
  )
541
 
542
  @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
@@ -564,9 +301,9 @@ async def chat_completions(req: ChatCompletionRequest):
564
  text = await _generate_full(prompt, req.max_tokens, req.temperature, req.top_p, stop_seq)
565
  assistant_msg = Message(role="assistant", content=text)
566
  usage = Usage(
567
- prompt_tokens=_count_tokens(prompt),
568
- completion_tokens=_count_tokens(text),
569
- total_tokens=_count_tokens(prompt) + _count_tokens(text),
570
  )
571
  return ChatCompletionResponse(
572
  id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
 
7
  import time
8
  import uuid
9
  from contextlib import asynccontextmanager
10
+ from typing import List, Optional, Union
11
 
 
 
12
  from fastapi import FastAPI, HTTPException, Request
13
  from fastapi.middleware.cors import CORSMiddleware
14
  from fastapi.responses import JSONResponse, StreamingResponse
15
+ from huggingface_hub import hf_hub_download
16
  from pydantic import BaseModel, Field, ValidationError
17
+
18
+ # NEW: Import llama.cpp
19
+ from llama_cpp import Llama
20
 
21
  # ---------- Configuration ----------
22
+ # You can now use GGUF models for even faster inference!
23
+ # These are specifically optimized by the PrismML team.
24
+ MODEL_ID = os.getenv("MODEL_ID", "prism-ml/Bonsai-1.7B-gguf")
25
+ MODEL_FILENAME = os.getenv("MODEL_FILENAME", "Bonsai-1.7B-v1.0-Q1_0.gguf")
26
+
27
+ # Quantization types in GGUF: Q1_0 is for 1-bit models.
28
+ # For 8B, use MODEL_ID="prism-ml/Bonsai-8B-gguf" and MODEL_FILENAME="Bonsai-8B-v1.0-Q1_0.gguf"
29
 
30
  HF_TOKEN = os.getenv("HF_TOKEN")
31
+ LOCAL_MODEL_DIR = os.getenv("LOCAL_MODEL_DIR", "/data/models")
32
  MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS_DEFAULT", "256"))
33
  API_KEY = os.getenv("API_KEY", None)
34
 
35
+ # Performance settings for CPU inference
36
+ N_CTX = int(os.getenv("N_CTX", "4096")) # Context window
37
+ N_THREADS = int(os.getenv("N_THREADS", "4")) # Number of CPU threads to use
38
+ N_BATCH = int(os.getenv("N_BATCH", "512")) # Batch size for prompt processing
39
+
40
  logging.basicConfig(level=logging.INFO)
41
  logger = logging.getLogger("uvicorn.error")
42
 
43
+ # ---------- Pydantic Models (Same as before) ----------
44
  class Message(BaseModel):
45
  role: str = Field(..., pattern="^(system|user|assistant)$")
46
  content: str
 
74
 
75
  class ModelInfo(BaseModel):
76
  model_id: str
77
+ filename: str
 
78
  device: str
79
+ n_ctx: int
80
+ n_threads: int
81
 
82
  class ErrorResponse(BaseModel):
83
  error: str
84
  detail: Optional[str] = None
85
 
86
  # ---------- Global State ----------
87
+ llm = None
 
88
  model_load_error = None
89
  MODEL_LOCK = asyncio.Lock()
90
 
 
 
 
 
 
 
 
 
 
 
91
  # ---------- Helper Functions ----------
92
  def _verify_api_key(request: Request) -> None:
93
  if API_KEY is None:
 
96
  if not auth or auth != API_KEY:
97
  raise HTTPException(status_code=401, detail="Invalid or missing API key")
98
 
99
+ def _download_model() -> str:
 
 
 
100
  os.makedirs(LOCAL_MODEL_DIR, exist_ok=True)
101
+ local_path = os.path.join(LOCAL_MODEL_DIR, MODEL_FILENAME)
102
+
103
+ if os.path.exists(local_path):
104
+ logger.info(f"Model already downloaded at {local_path}")
105
+ return local_path
106
+
107
+ logger.info(f"Downloading model {MODEL_ID}/{MODEL_FILENAME}...")
 
108
  try:
109
+ hf_hub_download(
110
  repo_id=MODEL_ID,
111
+ filename=MODEL_FILENAME,
112
  local_dir=LOCAL_MODEL_DIR,
 
 
113
  token=HF_TOKEN,
114
  )
115
+ logger.info("Model downloaded successfully.")
116
+ return local_path
117
  except Exception as e:
118
  logger.error(f"Model download failed: {e}")
119
  raise RuntimeError(f"Failed to download model: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  async def _ensure_loaded():
122
+ global llm, model_load_error
 
 
 
123
  async with MODEL_LOCK:
124
+ if llm is not None:
125
  return
126
  if model_load_error:
127
  raise HTTPException(status_code=503, detail=f"Model failed to load: {model_load_error}")
128
  try:
129
+ model_path = _download_model()
130
+ # Load the model with CPU-optimized settings
131
+ llm = Llama(
132
+ model_path=model_path,
133
+ n_ctx=N_CTX, # Context window
134
+ n_threads=N_THREADS, # Number of CPU threads
135
+ n_batch=N_BATCH, # Batch size for prompt processing
136
+ verbose=False,
137
+ )
138
+ logger.info(f"Model loaded successfully: {MODEL_ID} ({MODEL_FILENAME})")
139
+ logger.info(f"Context: {N_CTX}, Threads: {N_THREADS}, Batch: {N_BATCH}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  except Exception as e:
141
  model_load_error = str(e)
142
  logger.exception("Model loading failed")
143
  raise HTTPException(status_code=503, detail=f"Model unavailable: {model_load_error}")
144
 
145
  def _build_chat_prompt(messages: List[Message]) -> str:
146
+ # llama.cpp handles chat templates automatically, so we can just pass the messages directly.
147
+ # This is for compatibility; the actual formatting is done by llama.cpp.
148
+ if llm is None:
149
+ raise HTTPException(status_code=503, detail="Model not loaded")
150
+
151
+ # The create_chat_completion method expects a list of messages in this format
152
+ return [{"role": msg.role, "content": msg.content} for msg in messages]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
+ async def _generate_full(prompt: list, max_new_tokens: int, temperature: float, top_p: float, stop_sequences: Optional[List[str]] = None) -> str:
155
+ if llm is None:
 
 
 
 
 
 
156
  raise HTTPException(status_code=503, detail="Model not loaded")
157
 
158
+ # Run the blocking llama.cpp call in a thread
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  return await asyncio.to_thread(
160
+ lambda: llm.create_chat_completion(
161
+ messages=prompt,
162
+ max_tokens=max_new_tokens,
163
+ temperature=temperature,
164
+ top_p=top_p,
165
+ stop=stop_sequences,
166
+ stream=False,
167
+ )["choices"][0]["message"]["content"]
168
  )
169
 
170
+ async def _generate_stream(prompt: list, max_new_tokens: int, temperature: float, top_p: float, stop_sequences: Optional[List[str]] = None):
171
+ if llm is None:
 
 
 
 
 
 
172
  raise HTTPException(status_code=503, detail="Model not loaded")
173
 
174
+ # llama.cpp can yield a Python generator. We'll run it in a thread and yield the results.
175
+ def generator():
176
+ for chunk in llm.create_chat_completion(
177
+ messages=prompt,
178
+ max_tokens=max_new_tokens,
179
+ temperature=temperature,
180
+ top_p=top_p,
181
+ stop=stop_sequences,
182
+ stream=True,
183
+ ):
184
+ if "content" in chunk["choices"][0]["delta"]:
185
+ yield chunk["choices"][0]["delta"]["content"]
186
+
187
+ # We need a helper to bridge the sync generator to an async one
188
+ def sync_generator():
189
+ for item in generator():
190
+ yield item
191
+
192
+ # Run the sync generator in a thread and yield items as they come
193
+ for item in await asyncio.to_thread(list, sync_generator()):
194
+ yield item
195
+ await asyncio.sleep(0) # Yield control to the event loop
196
+
197
+ # ---------- FastAPI App (Same structure) ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  @asynccontextmanager
199
  async def lifespan(app: FastAPI):
200
  try:
 
203
  except Exception as e:
204
  logger.error(f"Startup model load failed: {e}")
205
  yield
206
+ global llm
207
+ llm = None
 
208
 
209
  app = FastAPI(
210
+ title="Bonsai CPU-Optimized Inference API",
211
+ version="2.0.0",
212
+ description="Lightning-fast inference for 1-bit Bonsai LLMs using llama.cpp.",
213
  docs_url="/docs",
214
  redoc_url="/redoc",
215
  lifespan=lifespan,
 
253
 
254
  @app.get("/", summary="Root")
255
  def root():
256
+ return {"message": "Bonsai CPU API is running", "docs": "/docs"}
257
 
258
  @app.get("/health", summary="Health check")
259
  def health():
260
+ loaded = llm is not None
261
  return {
262
  "status": "ok" if loaded else "degraded",
263
  "model_loaded": loaded,
264
  "model_id": MODEL_ID,
265
+ "filename": MODEL_FILENAME,
 
266
  "error": model_load_error if model_load_error else None,
267
  }
268
 
 
270
  def model_info():
271
  return ModelInfo(
272
  model_id=MODEL_ID,
273
+ filename=MODEL_FILENAME,
274
+ device="CPU",
275
+ n_ctx=N_CTX,
276
+ n_threads=N_THREADS,
277
  )
278
 
279
  @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
 
301
  text = await _generate_full(prompt, req.max_tokens, req.temperature, req.top_p, stop_seq)
302
  assistant_msg = Message(role="assistant", content=text)
303
  usage = Usage(
304
+ prompt_tokens=0, # llama.cpp can return this, but we can omit for simplicity
305
+ completion_tokens=0,
306
+ total_tokens=0,
307
  )
308
  return ChatCompletionResponse(
309
  id=f"chatcmpl-{uuid.uuid4().hex[:12]}",