sidmaz666 commited on
Commit
fdf82b0
·
verified ·
1 Parent(s): 41ee25d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -45
app.py CHANGED
@@ -19,11 +19,8 @@ from pydantic import BaseModel, Field, ValidationError
19
  from transformers import AutoTokenizer
20
 
21
  # ---------- Configuration ----------
22
- # Model Selection: Use "onnx-community/Bonsai-1.7B-ONNX" or "onnx-community/Bonsai-8B-ONNX"
23
  MODEL_ID = os.getenv("MODEL_ID", "onnx-community/Bonsai-1.7B-ONNX")
24
- # Quantization: Choose from 'q1', 'q2', 'q4', 'q8' based on the files in the ONNX model repo
25
  MODEL_QUANTIZATION = os.getenv("MODEL_QUANTIZATION", "q1")
26
- # Model file name based on quantization
27
  ONNX_MODEL_FILE = f"model_{MODEL_QUANTIZATION}.onnx"
28
 
29
  HF_TOKEN = os.getenv("HF_TOKEN")
@@ -34,6 +31,12 @@ API_KEY = os.getenv("API_KEY", None)
34
  logging.basicConfig(level=logging.INFO)
35
  logger = logging.getLogger("uvicorn.error")
36
 
 
 
 
 
 
 
37
  # ---------- Pydantic Models ----------
38
  class Message(BaseModel):
39
  role: str = Field(..., pattern="^(system|user|assistant)$")
@@ -151,7 +154,6 @@ def _build_chat_prompt(messages: List[Message]) -> str:
151
  if tokenizer is None:
152
  raise HTTPException(status_code=503, detail="Tokenizer not loaded")
153
  try:
154
- # Use the tokenizer's chat template to format the conversation
155
  formatted_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
156
  prompt = tokenizer.apply_chat_template(
157
  formatted_messages,
@@ -161,7 +163,7 @@ def _build_chat_prompt(messages: List[Message]) -> str:
161
  return prompt
162
  except Exception as e:
163
  logger.error(f"Chat template error: {e}")
164
- # Fallback to a simple concatenation if template fails
165
  prompt = ""
166
  for msg in messages:
167
  prompt += f"<|{msg.role}|>\n{msg.content}\n"
@@ -197,6 +199,26 @@ def _sample_token(logits: np.ndarray, temperature: float, top_p: float) -> int:
197
  probs = _softmax(logits)
198
  return int(np.random.choice(len(probs), p=probs))
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  def _generate_full(
201
  prompt: str,
202
  max_new_tokens: int,
@@ -207,40 +229,61 @@ def _generate_full(
207
  if ort_session is None or tokenizer is None:
208
  raise HTTPException(status_code=503, detail="Model not loaded")
209
 
210
- input_ids = tokenizer.encode(prompt, return_tensors="np")
211
- input_ids = input_ids.astype(np.int64)
212
-
213
- # Prepare initial inputs for the ONNX model
214
- ort_inputs = {
215
- "input_ids": input_ids,
216
- "attention_mask": np.ones_like(input_ids, dtype=np.int64),
217
- }
218
-
219
  generated_tokens = []
220
  stop_sequences = stop_sequences or []
221
  eos_token_id = tokenizer.eos_token_id
222
-
223
- for _ in range(max_new_tokens):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  outputs = ort_session.run(None, ort_inputs)
225
  logits = outputs[0][:, -1, :]
226
  next_token = _sample_token(logits[0], temperature, top_p)
227
  generated_tokens.append(next_token)
228
-
229
- # Update inputs for the next step
230
- next_token_id = np.array([[next_token]], dtype=np.int64)
231
- ort_inputs["input_ids"] = np.concatenate([input_ids, next_token_id], axis=1)
232
- ort_inputs["attention_mask"] = np.concatenate(
233
- [ort_inputs["attention_mask"], np.ones((1, 1), dtype=np.int64)], axis=1
234
- )
235
-
236
- # Check stop conditions
237
  if next_token == eos_token_id:
238
  break
239
  partial_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
240
  for stop_seq in stop_sequences:
241
  if stop_seq in partial_text:
242
  return partial_text.split(stop_seq)[0].strip()
243
-
244
  full_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
245
  return full_text.strip()
246
 
@@ -255,31 +298,55 @@ async def _generate_stream(
255
  raise HTTPException(status_code=503, detail="Model not loaded")
256
 
257
  input_ids = tokenizer.encode(prompt, return_tensors="np").astype(np.int64)
258
- ort_inputs = {
259
- "input_ids": input_ids,
260
- "attention_mask": np.ones_like(input_ids, dtype=np.int64),
261
- }
262
-
263
  generated_tokens = []
264
  stop_sequences = stop_sequences or []
265
  eos_token_id = tokenizer.eos_token_id
266
-
267
- for _ in range(max_new_tokens):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  outputs = ort_session.run(None, ort_inputs)
269
  logits = outputs[0][:, -1, :]
270
  next_token = _sample_token(logits[0], temperature, top_p)
271
  generated_tokens.append(next_token)
272
-
273
- next_token_id = np.array([[next_token]], dtype=np.int64)
274
- ort_inputs["input_ids"] = np.concatenate([input_ids, next_token_id], axis=1)
275
- ort_inputs["attention_mask"] = np.concatenate(
276
- [ort_inputs["attention_mask"], np.ones((1, 1), dtype=np.int64)], axis=1
277
- )
278
-
279
  new_text = tokenizer.decode([next_token], skip_special_tokens=True)
280
  if new_text:
281
  yield new_text
282
-
283
  full_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
284
  for stop_seq in stop_sequences:
285
  if stop_seq in full_text:
@@ -375,14 +442,14 @@ def model_info():
375
  @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
376
  async def chat_completions(req: ChatCompletionRequest):
377
  await _ensure_loaded()
378
-
379
  try:
380
  prompt = _build_chat_prompt(req.messages)
381
  except Exception as e:
382
  raise HTTPException(status_code=400, detail=f"Prompt formatting error: {str(e)}")
383
-
384
  stop_seq = req.stop if isinstance(req.stop, list) else ([req.stop] if req.stop else None)
385
-
386
  if req.stream:
387
  async def stream_generator():
388
  yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': req.model or MODEL_ID, 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n"
@@ -392,7 +459,7 @@ async def chat_completions(req: ChatCompletionRequest):
392
  yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': req.model or MODEL_ID, 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n"
393
  yield "data: [DONE]\n\n"
394
  return StreamingResponse(stream_generator(), media_type="text/event-stream")
395
-
396
  else:
397
  text = await asyncio.to_thread(
398
  _generate_full,
 
19
  from transformers import AutoTokenizer
20
 
21
  # ---------- Configuration ----------
 
22
  MODEL_ID = os.getenv("MODEL_ID", "onnx-community/Bonsai-1.7B-ONNX")
 
23
  MODEL_QUANTIZATION = os.getenv("MODEL_QUANTIZATION", "q1")
 
24
  ONNX_MODEL_FILE = f"model_{MODEL_QUANTIZATION}.onnx"
25
 
26
  HF_TOKEN = os.getenv("HF_TOKEN")
 
31
  logging.basicConfig(level=logging.INFO)
32
  logger = logging.getLogger("uvicorn.error")
33
 
34
+ # Bonsai architecture constants (from config.json)
35
+ NUM_LAYERS = 28
36
+ NUM_KV_HEADS = 8
37
+ HEAD_DIM = 128
38
+ DTYPE = np.float32
39
+
40
  # ---------- Pydantic Models ----------
41
  class Message(BaseModel):
42
  role: str = Field(..., pattern="^(system|user|assistant)$")
 
154
  if tokenizer is None:
155
  raise HTTPException(status_code=503, detail="Tokenizer not loaded")
156
  try:
 
157
  formatted_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
158
  prompt = tokenizer.apply_chat_template(
159
  formatted_messages,
 
163
  return prompt
164
  except Exception as e:
165
  logger.error(f"Chat template error: {e}")
166
+ # Fallback to a simple concatenation
167
  prompt = ""
168
  for msg in messages:
169
  prompt += f"<|{msg.role}|>\n{msg.content}\n"
 
199
  probs = _softmax(logits)
200
  return int(np.random.choice(len(probs), p=probs))
201
 
202
+ def _init_past_key_values(batch_size: int = 1) -> Dict[str, np.ndarray]:
203
+ """Create empty past_key_values tensors for the first inference step."""
204
+ past_kv = {}
205
+ empty_shape = (batch_size, NUM_KV_HEADS, 0, HEAD_DIM)
206
+ empty_tensor = np.zeros(empty_shape, dtype=DTYPE)
207
+ for i in range(NUM_LAYERS):
208
+ past_kv[f"past_key_values.{i}.key"] = empty_tensor.copy()
209
+ past_kv[f"past_key_values.{i}.value"] = empty_tensor.copy()
210
+ return past_kv
211
+
212
+ def _update_past_key_values(outputs: List[np.ndarray], output_names: List[str]) -> Dict[str, np.ndarray]:
213
+ """Extract present_key_values from ONNX outputs and return as dictionary."""
214
+ new_past = {}
215
+ for name, value in zip(output_names, outputs):
216
+ if name.startswith("present"):
217
+ # Convert "present_key_values.0.key" -> "past_key_values.0.key"
218
+ past_name = name.replace("present", "past")
219
+ new_past[past_name] = value
220
+ return new_past
221
+
222
  def _generate_full(
223
  prompt: str,
224
  max_new_tokens: int,
 
229
  if ort_session is None or tokenizer is None:
230
  raise HTTPException(status_code=503, detail="Model not loaded")
231
 
232
+ input_ids = tokenizer.encode(prompt, return_tensors="np").astype(np.int64)
233
+ attention_mask = np.ones_like(input_ids, dtype=np.int64)
234
+ position_ids = np.arange(input_ids.shape[1], dtype=np.int64).reshape(1, -1)
235
+
236
+ # Initialize KV cache
237
+ past_kv = _init_past_key_values(batch_size=1)
238
+
 
 
239
  generated_tokens = []
240
  stop_sequences = stop_sequences or []
241
  eos_token_id = tokenizer.eos_token_id
242
+
243
+ # Prefill step: process full prompt
244
+ ort_inputs = {
245
+ "input_ids": input_ids,
246
+ "attention_mask": attention_mask,
247
+ "position_ids": position_ids,
248
+ "num_logits_to_keep": np.array([1], dtype=np.int64),
249
+ **past_kv,
250
+ }
251
+ outputs = ort_session.run(None, ort_inputs)
252
+ # First output is logits, the rest are present_key_values
253
+ logits = outputs[0][:, -1, :]
254
+ next_token = _sample_token(logits[0], temperature, top_p)
255
+ generated_tokens.append(next_token)
256
+
257
+ # Update past_key_values from outputs
258
+ past_kv = _update_past_key_values(outputs, [out.name for out in ort_session.get_outputs()])
259
+
260
+ for step in range(1, max_new_tokens):
261
+ # Subsequent steps: only the last token
262
+ last_token = np.array([[next_token]], dtype=np.int64)
263
+ attention_mask = np.ones((1, past_kv[f"past_key_values.0.key"].shape[2] + 1), dtype=np.int64)
264
+ position_ids = np.array([[past_kv[f"past_key_values.0.key"].shape[2]]], dtype=np.int64)
265
+
266
+ ort_inputs = {
267
+ "input_ids": last_token,
268
+ "attention_mask": attention_mask,
269
+ "position_ids": position_ids,
270
+ "num_logits_to_keep": np.array([1], dtype=np.int64),
271
+ **past_kv,
272
+ }
273
  outputs = ort_session.run(None, ort_inputs)
274
  logits = outputs[0][:, -1, :]
275
  next_token = _sample_token(logits[0], temperature, top_p)
276
  generated_tokens.append(next_token)
277
+
278
+ past_kv = _update_past_key_values(outputs, [out.name for out in ort_session.get_outputs()])
279
+
 
 
 
 
 
 
280
  if next_token == eos_token_id:
281
  break
282
  partial_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
283
  for stop_seq in stop_sequences:
284
  if stop_seq in partial_text:
285
  return partial_text.split(stop_seq)[0].strip()
286
+
287
  full_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
288
  return full_text.strip()
289
 
 
298
  raise HTTPException(status_code=503, detail="Model not loaded")
299
 
300
  input_ids = tokenizer.encode(prompt, return_tensors="np").astype(np.int64)
301
+ attention_mask = np.ones_like(input_ids, dtype=np.int64)
302
+ position_ids = np.arange(input_ids.shape[1], dtype=np.int64).reshape(1, -1)
303
+
304
+ past_kv = _init_past_key_values(batch_size=1)
 
305
  generated_tokens = []
306
  stop_sequences = stop_sequences or []
307
  eos_token_id = tokenizer.eos_token_id
308
+
309
+ # Prefill
310
+ ort_inputs = {
311
+ "input_ids": input_ids,
312
+ "attention_mask": attention_mask,
313
+ "position_ids": position_ids,
314
+ "num_logits_to_keep": np.array([1], dtype=np.int64),
315
+ **past_kv,
316
+ }
317
+ outputs = ort_session.run(None, ort_inputs)
318
+ logits = outputs[0][:, -1, :]
319
+ next_token = _sample_token(logits[0], temperature, top_p)
320
+ generated_tokens.append(next_token)
321
+ past_kv = _update_past_key_values(outputs, [out.name for out in ort_session.get_outputs()])
322
+
323
+ new_text = tokenizer.decode([next_token], skip_special_tokens=True)
324
+ if new_text:
325
+ yield new_text
326
+
327
+ for step in range(1, max_new_tokens):
328
+ last_token = np.array([[next_token]], dtype=np.int64)
329
+ attention_mask = np.ones((1, past_kv[f"past_key_values.0.key"].shape[2] + 1), dtype=np.int64)
330
+ position_ids = np.array([[past_kv[f"past_key_values.0.key"].shape[2]]], dtype=np.int64)
331
+
332
+ ort_inputs = {
333
+ "input_ids": last_token,
334
+ "attention_mask": attention_mask,
335
+ "position_ids": position_ids,
336
+ "num_logits_to_keep": np.array([1], dtype=np.int64),
337
+ **past_kv,
338
+ }
339
  outputs = ort_session.run(None, ort_inputs)
340
  logits = outputs[0][:, -1, :]
341
  next_token = _sample_token(logits[0], temperature, top_p)
342
  generated_tokens.append(next_token)
343
+
344
+ past_kv = _update_past_key_values(outputs, [out.name for out in ort_session.get_outputs()])
345
+
 
 
 
 
346
  new_text = tokenizer.decode([next_token], skip_special_tokens=True)
347
  if new_text:
348
  yield new_text
349
+
350
  full_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
351
  for stop_seq in stop_sequences:
352
  if stop_seq in full_text:
 
442
  @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
443
  async def chat_completions(req: ChatCompletionRequest):
444
  await _ensure_loaded()
445
+
446
  try:
447
  prompt = _build_chat_prompt(req.messages)
448
  except Exception as e:
449
  raise HTTPException(status_code=400, detail=f"Prompt formatting error: {str(e)}")
450
+
451
  stop_seq = req.stop if isinstance(req.stop, list) else ([req.stop] if req.stop else None)
452
+
453
  if req.stream:
454
  async def stream_generator():
455
  yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': req.model or MODEL_ID, 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n"
 
459
  yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': req.model or MODEL_ID, 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n"
460
  yield "data: [DONE]\n\n"
461
  return StreamingResponse(stream_generator(), media_type="text/event-stream")
462
+
463
  else:
464
  text = await asyncio.to_thread(
465
  _generate_full,