sidmaz666 commited on
Commit
41ee25d
·
verified ·
1 Parent(s): ed4f560

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +251 -164
app.py CHANGED
@@ -4,32 +4,31 @@ import asyncio
4
  import json
5
  import logging
6
  import os
7
- import threading
8
  import time
9
  import uuid
10
  from contextlib import asynccontextmanager
11
- from typing import Any, AsyncGenerator, Dict, List, Optional, Union
12
 
 
 
13
  from fastapi import FastAPI, HTTPException, Request
14
  from fastapi.middleware.cors import CORSMiddleware
15
  from fastapi.responses import JSONResponse, StreamingResponse
16
- from huggingface_hub import hf_hub_download
17
- from llama_cpp import Llama
18
  from pydantic import BaseModel, Field, ValidationError
 
19
 
20
  # ---------- Configuration ----------
21
- MODEL_REPO = os.getenv("MODEL_REPO", "prism-ml/Bonsai-8B-gguf")
22
- MODEL_FILE = os.getenv("MODEL_FILE", "Bonsai-8B-Q1_0_g128.gguf")
23
- MODEL_REVISION = os.getenv("MODEL_REVISION", "main")
24
- HF_TOKEN = os.getenv("HF_TOKEN")
25
-
26
- N_CTX = int(os.getenv("N_CTX", "4096"))
27
- N_THREADS = int(os.getenv("N_THREADS", "4"))
28
- N_GPU_LAYERS = int(os.getenv("N_GPU_LAYERS", "0")) # 0 = CPU only
29
- MAX_TOKENS_DEFAULT = int(os.getenv("MAX_TOKENS_DEFAULT", "512"))
30
- TEMPERATURE_DEFAULT = float(os.getenv("TEMPERATURE_DEFAULT", "0.7"))
31
- TOP_P_DEFAULT = float(os.getenv("TOP_P_DEFAULT", "0.95"))
32
 
 
 
 
33
  API_KEY = os.getenv("API_KEY", None)
34
 
35
  logging.basicConfig(level=logging.INFO)
@@ -37,21 +36,17 @@ logger = logging.getLogger("uvicorn.error")
37
 
38
  # ---------- Pydantic Models ----------
39
  class Message(BaseModel):
40
- role: str = Field(..., pattern="^(system|user|assistant|tool)$")
41
- content: Optional[Union[str, List[Dict[str, Any]]]] = None
42
- tool_calls: Optional[List[Dict[str, Any]]] = None
43
- tool_call_id: Optional[str] = None
44
- name: Optional[str] = None
45
 
46
  class ChatCompletionRequest(BaseModel):
47
  messages: List[Message]
48
- model: Optional[str] = MODEL_REPO
49
- max_tokens: int = Field(default=MAX_TOKENS_DEFAULT, ge=1, le=4096)
50
- temperature: float = Field(default=TEMPERATURE_DEFAULT, ge=0.0, le=2.0)
51
- top_p: float = Field(default=TOP_P_DEFAULT, gt=0.0, le=1.0)
52
  stream: bool = False
53
  stop: Optional[Union[str, List[str]]] = None
54
- user: Optional[str] = None
55
 
56
  class ChatCompletionResponseChoice(BaseModel):
57
  index: int
@@ -71,21 +66,21 @@ class ChatCompletionResponse(BaseModel):
71
  choices: List[ChatCompletionResponseChoice]
72
  usage: Usage
73
 
74
- class TokenCountRequest(BaseModel):
75
- text: str = Field(..., min_length=1)
76
-
77
- class TokenCountResponse(BaseModel):
78
- text: str
79
- token_count: int
80
 
81
  class ErrorResponse(BaseModel):
82
  error: str
83
  detail: Optional[str] = None
84
 
85
  # ---------- Global State ----------
86
- llm: Optional[Llama] = None
87
- MODEL_LOCK = threading.Lock()
88
- model_load_error: Optional[str] = None
 
89
 
90
  # ---------- Helper Functions ----------
91
  def _verify_api_key(request: Request) -> None:
@@ -95,124 +90,220 @@ def _verify_api_key(request: Request) -> None:
95
  if not auth or auth != API_KEY:
96
  raise HTTPException(status_code=401, detail="Invalid or missing API key")
97
 
98
- def _download_model() -> str:
99
- """Download the GGUF model from Hugging Face Hub."""
100
- os.makedirs("/data", exist_ok=True)
 
 
 
 
 
 
 
 
 
 
101
  try:
102
- local_path = hf_hub_download(
103
- repo_id=MODEL_REPO,
104
- filename=MODEL_FILE,
105
- revision=MODEL_REVISION,
 
106
  token=HF_TOKEN,
107
- cache_dir="/data/.cache/huggingface",
108
  )
109
- logger.info(f"Model downloaded to {local_path}")
110
- return local_path
111
  except Exception as e:
112
  logger.error(f"Model download failed: {e}")
113
  raise RuntimeError(f"Failed to download model: {str(e)}")
114
-
115
- def _ensure_loaded() -> None:
116
- global llm, model_load_error
117
- if llm is not None:
118
- return
119
- if model_load_error:
120
- raise HTTPException(status_code=503, detail=f"Model failed to load: {model_load_error}")
 
 
121
  try:
122
- model_path = _download_model()
123
- llm = Llama(
124
- model_path=model_path,
125
- n_ctx=N_CTX,
126
- n_threads=N_THREADS,
127
- n_gpu_layers=N_GPU_LAYERS,
128
- verbose=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  )
130
- logger.info("Model loaded successfully")
131
  except Exception as e:
132
- model_load_error = str(e)
133
- logger.exception("Model loading failed")
134
- raise HTTPException(status_code=503, detail=f"Model unavailable: {model_load_error}")
135
-
136
- def _format_chat_messages(messages: List[Message]) -> List[Dict[str, str]]:
137
- """Convert Pydantic messages to a list of dicts for llama.cpp."""
138
- formatted = []
139
- for msg in messages:
140
- content = msg.content
141
- if isinstance(content, list):
142
- # For simplicity, we only handle text content here.
143
- # Multimodal not needed for Bonsai.
144
- text_parts = [p["text"] for p in content if p.get("type") == "text"]
145
- content = " ".join(text_parts) if text_parts else ""
146
- formatted.append({"role": msg.role, "content": content or ""})
147
- return formatted
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  def _generate_full(
150
- messages: List[Message],
151
- max_tokens: int,
152
  temperature: float,
153
  top_p: float,
154
- stop: Optional[Union[str, List[str]]],
155
- ) -> tuple[str, Usage]:
156
- _ensure_loaded()
157
- with MODEL_LOCK:
158
- chat_messages = _format_chat_messages(messages)
159
- result = llm.create_chat_completion(
160
- messages=chat_messages,
161
- max_tokens=max_tokens,
162
- temperature=temperature,
163
- top_p=top_p,
164
- stop=stop,
165
- stream=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  )
167
- content = result["choices"][0]["message"]["content"]
168
- usage = Usage(
169
- prompt_tokens=result["usage"]["prompt_tokens"],
170
- completion_tokens=result["usage"]["completion_tokens"],
171
- total_tokens=result["usage"]["total_tokens"],
172
- )
173
- return content, usage
 
 
 
 
174
 
175
  async def _generate_stream(
176
- messages: List[Message],
177
- max_tokens: int,
178
  temperature: float,
179
  top_p: float,
180
- stop: Optional[Union[str, List[str]]],
181
- ) -> AsyncGenerator[str, None]:
182
- _ensure_loaded()
183
- chat_messages = _format_chat_messages(messages)
184
- with MODEL_LOCK:
185
- stream = llm.create_chat_completion(
186
- messages=chat_messages,
187
- max_tokens=max_tokens,
188
- temperature=temperature,
189
- top_p=top_p,
190
- stop=stop,
191
- stream=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  )
193
- for chunk in stream:
194
- delta = chunk["choices"][0]["delta"]
195
- if "content" in delta and delta["content"]:
196
- yield delta["content"]
197
- if chunk["choices"][0].get("finish_reason") == "stop":
198
- break
 
 
 
 
 
199
 
200
  # ---------- FastAPI App ----------
201
  @asynccontextmanager
202
  async def lifespan(app: FastAPI):
203
  try:
204
- _ensure_loaded()
205
  logger.info("Model loaded successfully")
206
  except Exception as e:
207
  logger.error(f"Startup model load failed: {e}")
208
  yield
209
- global llm
210
- llm = None
 
211
 
212
  app = FastAPI(
213
- title="Bonsai LLM API (llama.cpp)",
214
  version="1.0.0",
215
- description="Production‑ready API for PrismML Bonsai models via GGUF.",
216
  docs_url="/docs",
217
  redoc_url="/redoc",
218
  lifespan=lifespan,
@@ -229,7 +320,8 @@ app.add_middleware(
229
  @app.middleware("http")
230
  async def auth_middleware(request: Request, call_next):
231
  _verify_api_key(request)
232
- return await call_next(request)
 
233
 
234
  # ---------- Error Handlers ----------
235
  @app.exception_handler(HTTPException)
@@ -257,71 +349,66 @@ async def generic_exception_handler(request, exc):
257
  # ---------- Endpoints ----------
258
  @app.get("/", summary="Root")
259
  def root():
260
- return {"message": "Bonsai API is running", "docs": "/docs"}
261
 
262
  @app.get("/health", summary="Health check")
263
  def health():
264
- loaded = llm is not None
265
  return {
266
  "status": "ok" if loaded else "degraded",
267
  "model_loaded": loaded,
268
- "model_id": MODEL_REPO,
 
 
269
  "error": model_load_error if model_load_error else None,
270
  }
271
 
272
- @app.get("/v1/model", summary="Model information")
273
  def model_info():
274
- return {
275
- "model_id": MODEL_REPO,
276
- "model_file": MODEL_FILE,
277
- "revision": MODEL_REVISION,
278
- "context_length": N_CTX,
279
- "gpu_layers": N_GPU_LAYERS,
280
- "cpu_threads": N_THREADS,
281
- }
282
-
283
- @app.post("/v1/token/count", response_model=TokenCountResponse)
284
- def token_count(req: TokenCountRequest):
285
- _ensure_loaded()
286
- tokens = llm.tokenize(req.text.encode("utf-8"))
287
- return TokenCountResponse(text=req.text, token_count=len(tokens))
288
 
289
- @app.post("/v1/chat/completions")
290
  async def chat_completions(req: ChatCompletionRequest):
291
- _ensure_loaded()
292
-
 
 
 
 
 
 
 
293
  if req.stream:
294
  async def stream_generator():
295
- yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': req.model or MODEL_REPO, 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n"
296
- async for chunk in _generate_stream(
297
- req.messages, req.max_tokens, req.temperature, req.top_p, req.stop
298
- ):
299
- yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': req.model or MODEL_REPO, 'choices': [{'index': 0, 'delta': {'content': chunk}, 'finish_reason': None}]})}\n\n"
300
  await asyncio.sleep(0)
301
- yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': req.model or MODEL_REPO, 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n"
302
  yield "data: [DONE]\n\n"
303
  return StreamingResponse(stream_generator(), media_type="text/event-stream")
304
-
305
  else:
306
- content, usage = await asyncio.to_thread(
307
  _generate_full,
308
- req.messages,
309
- req.max_tokens,
310
- req.temperature,
311
- req.top_p,
312
- req.stop,
 
 
313
  )
314
  return ChatCompletionResponse(
315
  id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
316
  created=int(time.time()),
317
- model=req.model or MODEL_REPO,
318
- choices=[
319
- ChatCompletionResponseChoice(
320
- index=0,
321
- message=Message(role="assistant", content=content),
322
- finish_reason="stop",
323
- )
324
- ],
325
  usage=usage,
326
  )
327
 
 
4
  import json
5
  import logging
6
  import os
 
7
  import time
8
  import uuid
9
  from contextlib import asynccontextmanager
10
+ from typing import Any, Dict, List, Optional, Union
11
 
12
+ import numpy as np
13
+ import onnxruntime as ort
14
  from fastapi import FastAPI, HTTPException, Request
15
  from fastapi.middleware.cors import CORSMiddleware
16
  from fastapi.responses import JSONResponse, StreamingResponse
17
+ from huggingface_hub import snapshot_download
 
18
  from pydantic import BaseModel, Field, ValidationError
19
+ from transformers import AutoTokenizer
20
 
21
  # ---------- Configuration ----------
22
+ # Model Selection: Use "onnx-community/Bonsai-1.7B-ONNX" or "onnx-community/Bonsai-8B-ONNX"
23
+ MODEL_ID = os.getenv("MODEL_ID", "onnx-community/Bonsai-1.7B-ONNX")
24
+ # Quantization: Choose from 'q1', 'q2', 'q4', 'q8' based on the files in the ONNX model repo
25
+ MODEL_QUANTIZATION = os.getenv("MODEL_QUANTIZATION", "q1")
26
+ # Model file name based on quantization
27
+ ONNX_MODEL_FILE = f"model_{MODEL_QUANTIZATION}.onnx"
 
 
 
 
 
28
 
29
+ HF_TOKEN = os.getenv("HF_TOKEN")
30
+ LOCAL_MODEL_DIR = os.getenv("LOCAL_MODEL_DIR", "/data/bonsai-onnx")
31
+ MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS_DEFAULT", "256"))
32
  API_KEY = os.getenv("API_KEY", None)
33
 
34
  logging.basicConfig(level=logging.INFO)
 
36
 
37
  # ---------- Pydantic Models ----------
38
  class Message(BaseModel):
39
+ role: str = Field(..., pattern="^(system|user|assistant)$")
40
+ content: str
 
 
 
41
 
42
  class ChatCompletionRequest(BaseModel):
43
  messages: List[Message]
44
+ model: Optional[str] = MODEL_ID
45
+ max_tokens: int = Field(default=MAX_NEW_TOKENS_DEFAULT, ge=1, le=1024)
46
+ temperature: float = Field(default=0.7, ge=0.0, le=2.0)
47
+ top_p: float = Field(default=0.95, gt=0.0, le=1.0)
48
  stream: bool = False
49
  stop: Optional[Union[str, List[str]]] = None
 
50
 
51
  class ChatCompletionResponseChoice(BaseModel):
52
  index: int
 
66
  choices: List[ChatCompletionResponseChoice]
67
  usage: Usage
68
 
69
+ class ModelInfo(BaseModel):
70
+ model_id: str
71
+ quantization: str
72
+ onnx_model_file: str
73
+ device: str
 
74
 
75
  class ErrorResponse(BaseModel):
76
  error: str
77
  detail: Optional[str] = None
78
 
79
  # ---------- Global State ----------
80
+ tokenizer = None
81
+ ort_session = None
82
+ model_load_error = None
83
+ MODEL_LOCK = asyncio.Lock()
84
 
85
  # ---------- Helper Functions ----------
86
  def _verify_api_key(request: Request) -> None:
 
90
  if not auth or auth != API_KEY:
91
  raise HTTPException(status_code=401, detail="Invalid or missing API key")
92
 
93
+ def _model_device() -> str:
94
+ return "cuda" if ort.get_device().lower() == "gpu" else "cpu"
95
+
96
+ def _download_model_snapshot() -> str:
97
+ os.makedirs(LOCAL_MODEL_DIR, exist_ok=True)
98
+ allow_patterns = [
99
+ "config.json",
100
+ "tokenizer.json",
101
+ "tokenizer_config.json",
102
+ "chat_template.jinja",
103
+ f"onnx/{ONNX_MODEL_FILE}",
104
+ f"onnx/{ONNX_MODEL_FILE}_data",
105
+ ]
106
  try:
107
+ snapshot_download(
108
+ repo_id=MODEL_ID,
109
+ local_dir=LOCAL_MODEL_DIR,
110
+ local_dir_use_symlinks=False,
111
+ allow_patterns=allow_patterns,
112
  token=HF_TOKEN,
 
113
  )
 
 
114
  except Exception as e:
115
  logger.error(f"Model download failed: {e}")
116
  raise RuntimeError(f"Failed to download model: {str(e)}")
117
+ return LOCAL_MODEL_DIR
118
+
119
+ def _create_ort_session(model_path: str) -> ort.InferenceSession:
120
+ so = ort.SessionOptions()
121
+ so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
122
+ so.intra_op_num_threads = int(os.getenv("ORT_INTRA_OP_THREADS", "2"))
123
+ so.inter_op_num_threads = 1
124
+ so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
125
+ so.enable_mem_pattern = True
126
  try:
127
+ return ort.InferenceSession(model_path, sess_options=so, providers=["CPUExecutionProvider"])
128
+ except Exception as e:
129
+ logger.error(f"Failed to load ONNX session from {model_path}: {e}")
130
+ raise RuntimeError(f"ONNX session creation failed: {str(e)}")
131
+
132
+ async def _ensure_loaded():
133
+ global tokenizer, ort_session, model_load_error
134
+ async with MODEL_LOCK:
135
+ if tokenizer is not None and ort_session is not None:
136
+ return
137
+ if model_load_error:
138
+ raise HTTPException(status_code=503, detail=f"Model failed to load: {model_load_error}")
139
+ try:
140
+ local_dir = _download_model_snapshot()
141
+ tokenizer = AutoTokenizer.from_pretrained(local_dir, trust_remote_code=True)
142
+ onnx_path = os.path.join(local_dir, "onnx", ONNX_MODEL_FILE)
143
+ ort_session = _create_ort_session(onnx_path)
144
+ logger.info(f"Model loaded successfully: {MODEL_ID} ({MODEL_QUANTIZATION})")
145
+ except Exception as e:
146
+ model_load_error = str(e)
147
+ logger.exception("Model loading failed")
148
+ raise HTTPException(status_code=503, detail=f"Model unavailable: {model_load_error}")
149
+
150
+ def _build_chat_prompt(messages: List[Message]) -> str:
151
+ if tokenizer is None:
152
+ raise HTTPException(status_code=503, detail="Tokenizer not loaded")
153
+ try:
154
+ # Use the tokenizer's chat template to format the conversation
155
+ formatted_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
156
+ prompt = tokenizer.apply_chat_template(
157
+ formatted_messages,
158
+ tokenize=False,
159
+ add_generation_prompt=True,
160
  )
161
+ return prompt
162
  except Exception as e:
163
+ logger.error(f"Chat template error: {e}")
164
+ # Fallback to a simple concatenation if template fails
165
+ prompt = ""
166
+ for msg in messages:
167
+ prompt += f"<|{msg.role}|>\n{msg.content}\n"
168
+ prompt += "<|assistant|>\n"
169
+ return prompt
170
+
171
+ def _count_tokens(text: str) -> int:
172
+ if tokenizer is None:
173
+ return len(text.split())
174
+ return len(tokenizer.encode(text))
175
+
176
+ def _softmax(x: np.ndarray) -> np.ndarray:
177
+ e_x = np.exp(x - np.max(x))
178
+ return e_x / e_x.sum(axis=-1, keepdims=True)
179
+
180
+ def _top_p_sampling(logits: np.ndarray, top_p: float) -> int:
181
+ sorted_indices = np.argsort(logits)[::-1]
182
+ sorted_logits = logits[sorted_indices]
183
+ probs = _softmax(sorted_logits)
184
+ cum_probs = np.cumsum(probs)
185
+ cutoff_index = np.searchsorted(cum_probs, top_p) + 1
186
+ top_indices = sorted_indices[:cutoff_index]
187
+ top_probs = probs[:cutoff_index]
188
+ top_probs /= top_probs.sum()
189
+ return int(np.random.choice(top_indices, p=top_probs))
190
+
191
+ def _sample_token(logits: np.ndarray, temperature: float, top_p: float) -> int:
192
+ if temperature <= 0:
193
+ return int(np.argmax(logits))
194
+ logits = logits / temperature
195
+ if top_p < 1.0:
196
+ return _top_p_sampling(logits, top_p)
197
+ probs = _softmax(logits)
198
+ return int(np.random.choice(len(probs), p=probs))
199
 
200
  def _generate_full(
201
+ prompt: str,
202
+ max_new_tokens: int,
203
  temperature: float,
204
  top_p: float,
205
+ stop_sequences: Optional[List[str]] = None,
206
+ ) -> str:
207
+ if ort_session is None or tokenizer is None:
208
+ raise HTTPException(status_code=503, detail="Model not loaded")
209
+
210
+ input_ids = tokenizer.encode(prompt, return_tensors="np")
211
+ input_ids = input_ids.astype(np.int64)
212
+
213
+ # Prepare initial inputs for the ONNX model
214
+ ort_inputs = {
215
+ "input_ids": input_ids,
216
+ "attention_mask": np.ones_like(input_ids, dtype=np.int64),
217
+ }
218
+
219
+ generated_tokens = []
220
+ stop_sequences = stop_sequences or []
221
+ eos_token_id = tokenizer.eos_token_id
222
+
223
+ for _ in range(max_new_tokens):
224
+ outputs = ort_session.run(None, ort_inputs)
225
+ logits = outputs[0][:, -1, :]
226
+ next_token = _sample_token(logits[0], temperature, top_p)
227
+ generated_tokens.append(next_token)
228
+
229
+ # Update inputs for the next step
230
+ next_token_id = np.array([[next_token]], dtype=np.int64)
231
+ ort_inputs["input_ids"] = np.concatenate([input_ids, next_token_id], axis=1)
232
+ ort_inputs["attention_mask"] = np.concatenate(
233
+ [ort_inputs["attention_mask"], np.ones((1, 1), dtype=np.int64)], axis=1
234
  )
235
+
236
+ # Check stop conditions
237
+ if next_token == eos_token_id:
238
+ break
239
+ partial_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
240
+ for stop_seq in stop_sequences:
241
+ if stop_seq in partial_text:
242
+ return partial_text.split(stop_seq)[0].strip()
243
+
244
+ full_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
245
+ return full_text.strip()
246
 
247
  async def _generate_stream(
248
+ prompt: str,
249
+ max_new_tokens: int,
250
  temperature: float,
251
  top_p: float,
252
+ stop_sequences: Optional[List[str]] = None,
253
+ ):
254
+ if ort_session is None or tokenizer is None:
255
+ raise HTTPException(status_code=503, detail="Model not loaded")
256
+
257
+ input_ids = tokenizer.encode(prompt, return_tensors="np").astype(np.int64)
258
+ ort_inputs = {
259
+ "input_ids": input_ids,
260
+ "attention_mask": np.ones_like(input_ids, dtype=np.int64),
261
+ }
262
+
263
+ generated_tokens = []
264
+ stop_sequences = stop_sequences or []
265
+ eos_token_id = tokenizer.eos_token_id
266
+
267
+ for _ in range(max_new_tokens):
268
+ outputs = ort_session.run(None, ort_inputs)
269
+ logits = outputs[0][:, -1, :]
270
+ next_token = _sample_token(logits[0], temperature, top_p)
271
+ generated_tokens.append(next_token)
272
+
273
+ next_token_id = np.array([[next_token]], dtype=np.int64)
274
+ ort_inputs["input_ids"] = np.concatenate([input_ids, next_token_id], axis=1)
275
+ ort_inputs["attention_mask"] = np.concatenate(
276
+ [ort_inputs["attention_mask"], np.ones((1, 1), dtype=np.int64)], axis=1
277
  )
278
+
279
+ new_text = tokenizer.decode([next_token], skip_special_tokens=True)
280
+ if new_text:
281
+ yield new_text
282
+
283
+ full_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
284
+ for stop_seq in stop_sequences:
285
+ if stop_seq in full_text:
286
+ return
287
+ if next_token == eos_token_id:
288
+ break
289
 
290
  # ---------- FastAPI App ----------
291
  @asynccontextmanager
292
  async def lifespan(app: FastAPI):
293
  try:
294
+ await _ensure_loaded()
295
  logger.info("Model loaded successfully")
296
  except Exception as e:
297
  logger.error(f"Startup model load failed: {e}")
298
  yield
299
+ global tokenizer, ort_session
300
+ tokenizer = None
301
+ ort_session = None
302
 
303
  app = FastAPI(
304
+ title="Bonsai ONNX Inference API",
305
  version="1.0.0",
306
+ description="Fast, production-ready inference for 1-bit Bonsai LLMs using ONNX Runtime.",
307
  docs_url="/docs",
308
  redoc_url="/redoc",
309
  lifespan=lifespan,
 
320
  @app.middleware("http")
321
  async def auth_middleware(request: Request, call_next):
322
  _verify_api_key(request)
323
+ response = await call_next(request)
324
+ return response
325
 
326
  # ---------- Error Handlers ----------
327
  @app.exception_handler(HTTPException)
 
349
  # ---------- Endpoints ----------
350
  @app.get("/", summary="Root")
351
  def root():
352
+ return {"message": "Bonsai ONNX API is running", "docs": "/docs"}
353
 
354
  @app.get("/health", summary="Health check")
355
  def health():
356
+ loaded = tokenizer is not None and ort_session is not None
357
  return {
358
  "status": "ok" if loaded else "degraded",
359
  "model_loaded": loaded,
360
+ "model_id": MODEL_ID,
361
+ "quantization": MODEL_QUANTIZATION,
362
+ "device": _model_device(),
363
  "error": model_load_error if model_load_error else None,
364
  }
365
 
366
+ @app.get("/v1/model", response_model=ModelInfo, summary="Model information")
367
  def model_info():
368
+ return ModelInfo(
369
+ model_id=MODEL_ID,
370
+ quantization=MODEL_QUANTIZATION,
371
+ onnx_model_file=ONNX_MODEL_FILE,
372
+ device=_model_device(),
373
+ )
 
 
 
 
 
 
 
 
374
 
375
+ @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
376
  async def chat_completions(req: ChatCompletionRequest):
377
+ await _ensure_loaded()
378
+
379
+ try:
380
+ prompt = _build_chat_prompt(req.messages)
381
+ except Exception as e:
382
+ raise HTTPException(status_code=400, detail=f"Prompt formatting error: {str(e)}")
383
+
384
+ stop_seq = req.stop if isinstance(req.stop, list) else ([req.stop] if req.stop else None)
385
+
386
  if req.stream:
387
  async def stream_generator():
388
+ yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': req.model or MODEL_ID, 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n"
389
+ async for chunk in _generate_stream(prompt, req.max_tokens, req.temperature, req.top_p, stop_seq):
390
+ yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': req.model or MODEL_ID, 'choices': [{'index': 0, 'delta': {'content': chunk}, 'finish_reason': None}]})}\n\n"
 
 
391
  await asyncio.sleep(0)
392
+ yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': req.model or MODEL_ID, 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n"
393
  yield "data: [DONE]\n\n"
394
  return StreamingResponse(stream_generator(), media_type="text/event-stream")
395
+
396
  else:
397
+ text = await asyncio.to_thread(
398
  _generate_full,
399
+ prompt, req.max_tokens, req.temperature, req.top_p, stop_seq
400
+ )
401
+ assistant_msg = Message(role="assistant", content=text)
402
+ usage = Usage(
403
+ prompt_tokens=_count_tokens(prompt),
404
+ completion_tokens=_count_tokens(text),
405
+ total_tokens=_count_tokens(prompt) + _count_tokens(text),
406
  )
407
  return ChatCompletionResponse(
408
  id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
409
  created=int(time.time()),
410
+ model=req.model or MODEL_ID,
411
+ choices=[ChatCompletionResponseChoice(index=0, message=assistant_msg, finish_reason="stop")],
 
 
 
 
 
 
412
  usage=usage,
413
  )
414