sidmaz666 commited on
Commit
99f0108
·
verified ·
1 Parent(s): 535963d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +230 -92
app.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  import time
8
  import uuid
9
  from contextlib import asynccontextmanager
10
- from typing import List, Optional, Union
11
 
12
  from fastapi import FastAPI, HTTPException, Request
13
  from fastapi.middleware.cors import CORSMiddleware
@@ -17,34 +17,64 @@ from pydantic import BaseModel, Field, ValidationError
17
  from llama_cpp import Llama
18
 
19
  # ---------- Configuration ----------
20
- MODEL_ID = os.getenv("MODEL_ID", "lilyanatia/Bonsai-1.7B-requantized")
21
- MODEL_FILENAME = os.getenv("MODEL_FILENAME", "Bonsai-1.7B-IQ1_S.gguf")
22
- HF_TOKEN = os.getenv("HF_TOKEN")
23
  LOCAL_MODEL_DIR = os.getenv("LOCAL_MODEL_DIR", "/data/models")
24
  MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS_DEFAULT", "256"))
25
  API_KEY = os.getenv("API_KEY", None)
 
26
 
27
  # Performance settings
28
  N_CTX = int(os.getenv("N_CTX", "4096"))
29
  N_THREADS = int(os.getenv("N_THREADS", "4"))
30
  N_BATCH = int(os.getenv("N_BATCH", "512"))
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  logging.basicConfig(level=logging.INFO)
33
  logger = logging.getLogger("uvicorn.error")
34
 
35
  # ---------- Pydantic Models ----------
36
  class Message(BaseModel):
37
- role: str = Field(..., pattern="^(system|user|assistant)$")
38
- content: str
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  class ChatCompletionRequest(BaseModel):
41
  messages: List[Message]
42
- model: Optional[str] = MODEL_ID
43
- max_tokens: int = Field(default=MAX_NEW_TOKENS_DEFAULT, ge=1, le=1024)
44
  temperature: float = Field(default=0.7, ge=0.0, le=2.0)
45
  top_p: float = Field(default=0.95, gt=0.0, le=1.0)
46
  stream: bool = False
47
  stop: Optional[Union[str, List[str]]] = None
 
 
 
48
 
49
  class ChatCompletionResponseChoice(BaseModel):
50
  index: int
@@ -65,20 +95,25 @@ class ChatCompletionResponse(BaseModel):
65
  usage: Usage
66
 
67
  class ModelInfo(BaseModel):
68
- model_id: str
69
- filename: str
70
- device: str
71
- n_ctx: int
72
- n_threads: int
 
 
 
73
 
74
  class ErrorResponse(BaseModel):
75
  error: str
76
  detail: Optional[str] = None
77
 
78
  # ---------- Global State ----------
79
- llm = None
80
- model_load_error = None
 
81
  MODEL_LOCK = asyncio.Lock()
 
82
 
83
  # ---------- Helper Functions ----------
84
  def _verify_api_key(request: Request) -> None:
@@ -88,37 +123,67 @@ def _verify_api_key(request: Request) -> None:
88
  if not auth or auth != API_KEY:
89
  raise HTTPException(status_code=401, detail="Invalid or missing API key")
90
 
91
- def _download_model() -> str:
 
 
 
 
 
 
 
92
  os.makedirs(LOCAL_MODEL_DIR, exist_ok=True)
93
- local_path = os.path.join(LOCAL_MODEL_DIR, MODEL_FILENAME)
94
 
95
  if os.path.exists(local_path):
96
- logger.info(f"Model already downloaded at {local_path}")
97
  return local_path
98
 
99
- logger.info(f"Downloading model {MODEL_ID}/{MODEL_FILENAME}...")
100
  try:
101
  hf_hub_download(
102
- repo_id=MODEL_ID,
103
- filename=MODEL_FILENAME,
104
  local_dir=LOCAL_MODEL_DIR,
105
  token=HF_TOKEN,
106
  )
107
- logger.info("Model downloaded successfully.")
108
  return local_path
109
  except Exception as e:
110
- logger.error(f"Model download failed: {e}")
111
- raise RuntimeError(f"Failed to download model: {str(e)}")
112
-
113
- async def _ensure_loaded():
114
- global llm, model_load_error
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  async with MODEL_LOCK:
116
- if llm is not None:
117
  return
118
- if model_load_error:
119
- raise HTTPException(status_code=503, detail=f"Model failed to load: {model_load_error}")
 
 
 
 
 
120
  try:
121
- model_path = _download_model()
122
  llm = Llama(
123
  model_path=model_path,
124
  n_ctx=N_CTX,
@@ -126,56 +191,107 @@ async def _ensure_loaded():
126
  n_batch=N_BATCH,
127
  verbose=False,
128
  )
129
- logger.info(f"Model loaded successfully: {MODEL_ID} ({MODEL_FILENAME})")
130
- logger.info(f"Context: {N_CTX}, Threads: {N_THREADS}, Batch: {N_BATCH}")
131
  except Exception as e:
132
  model_load_error = str(e)
133
- logger.exception("Model loading failed")
134
  raise HTTPException(status_code=503, detail=f"Model unavailable: {model_load_error}")
135
 
136
- def _build_chat_prompt(messages: List[Message]) -> list:
137
- return [{"role": msg.role, "content": msg.content} for msg in messages]
138
-
139
- async def _generate_full(prompt: list, max_new_tokens: int, temperature: float, top_p: float, stop_sequences: Optional[List[str]] = None) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  if llm is None:
141
  raise HTTPException(status_code=503, detail="Model not loaded")
142
- result = await asyncio.to_thread(
143
- lambda: llm.create_chat_completion(
144
- messages=prompt,
145
- max_tokens=max_new_tokens,
146
- temperature=temperature,
147
- top_p=top_p,
148
- stop=stop_sequences,
149
- stream=False,
150
- )
151
- )
152
- return result["choices"][0]["message"]["content"]
153
-
154
- async def _generate_stream(prompt: list, max_new_tokens: int, temperature: float, top_p: float, stop_sequences: Optional[List[str]] = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  if llm is None:
156
  raise HTTPException(status_code=503, detail="Model not loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def sync_gen():
158
- for chunk in llm.create_chat_completion(
159
- messages=prompt,
160
- max_tokens=max_new_tokens,
161
- temperature=temperature,
162
- top_p=top_p,
163
- stop=stop_sequences,
164
- stream=True,
165
- ):
166
- if "content" in chunk["choices"][0]["delta"]:
167
- yield chunk["choices"][0]["delta"]["content"]
168
- # Convert sync generator to async
169
- for token in await asyncio.to_thread(list, sync_gen()):
170
- yield token
171
  await asyncio.sleep(0)
172
 
173
  # ---------- FastAPI App ----------
174
  @asynccontextmanager
175
  async def lifespan(app: FastAPI):
176
  try:
177
- await _ensure_loaded()
178
- logger.info("Model loaded successfully")
 
179
  except Exception as e:
180
  logger.error(f"Startup model load failed: {e}")
181
  yield
@@ -183,9 +299,9 @@ async def lifespan(app: FastAPI):
183
  llm = None
184
 
185
  app = FastAPI(
186
- title="Bonsai CPU-Optimized Inference API",
187
- version="2.0.0",
188
- description="Lightning-fast inference for 1-bit Bonsai LLMs using llama.cpp.",
189
  docs_url="/docs",
190
  redoc_url="/redoc",
191
  lifespan=lifespan,
@@ -229,7 +345,7 @@ async def generic_exception_handler(request, exc):
229
 
230
  @app.get("/", summary="Root")
231
  def root():
232
- return {"message": "Bonsai CPU API is running", "docs": "/docs"}
233
 
234
  @app.get("/health", summary="Health check")
235
  def health():
@@ -237,46 +353,68 @@ def health():
237
  return {
238
  "status": "ok" if loaded else "degraded",
239
  "model_loaded": loaded,
240
- "model_id": MODEL_ID,
241
- "filename": MODEL_FILENAME,
242
  "error": model_load_error if model_load_error else None,
243
  }
244
 
245
- @app.get("/v1/model", response_model=ModelInfo, summary="Model information")
246
- def model_info():
247
- return ModelInfo(
248
- model_id=MODEL_ID,
249
- filename=MODEL_FILENAME,
250
- device="CPU",
251
- n_ctx=N_CTX,
252
- n_threads=N_THREADS,
253
- )
 
 
 
254
 
255
  @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
256
  async def chat_completions(req: ChatCompletionRequest):
257
- await _ensure_loaded()
 
 
258
  prompt = _build_chat_prompt(req.messages)
 
259
  stop_seq = req.stop if isinstance(req.stop, list) else ([req.stop] if req.stop else None)
260
 
261
  if req.stream:
262
  async def stream_generator():
263
- yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': req.model or MODEL_ID, 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n"
264
- async for chunk in _generate_stream(prompt, req.max_tokens, req.temperature, req.top_p, stop_seq):
265
- yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': req.model or MODEL_ID, 'choices': [{'index': 0, 'delta': {'content': chunk}, 'finish_reason': None}]})}\n\n"
 
 
 
 
 
266
  await asyncio.sleep(0)
267
- yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': req.model or MODEL_ID, 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n"
268
  yield "data: [DONE]\n\n"
269
  return StreamingResponse(stream_generator(), media_type="text/event-stream")
270
-
271
  else:
272
- text = await _generate_full(prompt, req.max_tokens, req.temperature, req.top_p, stop_seq)
273
- assistant_msg = Message(role="assistant", content=text)
274
- usage = Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
 
 
 
 
 
 
 
 
 
 
 
 
275
  return ChatCompletionResponse(
276
  id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
277
  created=int(time.time()),
278
- model=req.model or MODEL_ID,
279
- choices=[ChatCompletionResponseChoice(index=0, message=assistant_msg, finish_reason="stop")],
280
  usage=usage,
281
  )
282
 
 
7
  import time
8
  import uuid
9
  from contextlib import asynccontextmanager
10
+ from typing import Dict, List, Optional, Union, Any
11
 
12
  from fastapi import FastAPI, HTTPException, Request
13
  from fastapi.middleware.cors import CORSMiddleware
 
17
  from llama_cpp import Llama
18
 
19
  # ---------- Configuration ----------
20
+ DEFAULT_MODEL_NAME = os.getenv("DEFAULT_MODEL_NAME", "bonsai-1.7b")
 
 
21
  LOCAL_MODEL_DIR = os.getenv("LOCAL_MODEL_DIR", "/data/models")
22
  MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS_DEFAULT", "256"))
23
  API_KEY = os.getenv("API_KEY", None)
24
+ HF_TOKEN = os.getenv("HF_TOKEN")
25
 
26
  # Performance settings
27
  N_CTX = int(os.getenv("N_CTX", "4096"))
28
  N_THREADS = int(os.getenv("N_THREADS", "4"))
29
  N_BATCH = int(os.getenv("N_BATCH", "512"))
30
 
31
+ # ---------- Model Registry ----------
32
+ MODEL_REGISTRY: Dict[str, Dict[str, str]] = {
33
+ "bonsai-1.7b": {
34
+ "repo_id": "lilyanatia/Bonsai-1.7B-requantized",
35
+ "filename": "Bonsai-1.7B-IQ1_S.gguf",
36
+ },
37
+ "bonsai-4b": {
38
+ "repo_id": "lilyanatia/Bonsai-4B-requantized",
39
+ "filename": "Bonsai-4B-IQ1_S.gguf",
40
+ },
41
+ "bonsai-8b": {
42
+ "repo_id": "lilyanatia/Bonsai-8B-requantized",
43
+ "filename": "Bonsai-8B-IQ1_S.gguf",
44
+ },
45
+ }
46
+
47
  logging.basicConfig(level=logging.INFO)
48
  logger = logging.getLogger("uvicorn.error")
49
 
50
  # ---------- Pydantic Models ----------
51
  class Message(BaseModel):
52
+ role: str = Field(..., pattern="^(system|user|assistant|tool)$")
53
+ content: Optional[str] = None
54
+ tool_calls: Optional[List[Dict[str, Any]]] = None
55
+ tool_call_id: Optional[str] = None
56
+ name: Optional[str] = None
57
+
58
+ class ToolFunction(BaseModel):
59
+ name: str
60
+ description: Optional[str] = None
61
+ parameters: Optional[Dict[str, Any]] = None
62
+
63
+ class Tool(BaseModel):
64
+ type: str = "function"
65
+ function: ToolFunction
66
 
67
  class ChatCompletionRequest(BaseModel):
68
  messages: List[Message]
69
+ model: str = Field(default=DEFAULT_MODEL_NAME)
70
+ max_tokens: int = Field(default=MAX_NEW_TOKENS_DEFAULT, ge=1, le=2048)
71
  temperature: float = Field(default=0.7, ge=0.0, le=2.0)
72
  top_p: float = Field(default=0.95, gt=0.0, le=1.0)
73
  stream: bool = False
74
  stop: Optional[Union[str, List[str]]] = None
75
+ tools: Optional[List[Tool]] = None
76
+ tool_choice: Optional[Union[str, Dict[str, Any]]] = None
77
+ response_format: Optional[Dict[str, str]] = None
78
 
79
  class ChatCompletionResponseChoice(BaseModel):
80
  index: int
 
95
  usage: Usage
96
 
97
  class ModelInfo(BaseModel):
98
+ id: str
99
+ object: str = "model"
100
+ created: int
101
+ owned_by: str = "lilyanatia"
102
+
103
+ class ModelListResponse(BaseModel):
104
+ object: str = "list"
105
+ data: List[ModelInfo]
106
 
107
  class ErrorResponse(BaseModel):
108
  error: str
109
  detail: Optional[str] = None
110
 
111
  # ---------- Global State ----------
112
+ current_model_name: Optional[str] = None
113
+ llm: Optional[Llama] = None
114
+ model_load_error: Optional[str] = None
115
  MODEL_LOCK = asyncio.Lock()
116
+ DOWNLOADED_MODELS = set()
117
 
118
  # ---------- Helper Functions ----------
119
  def _verify_api_key(request: Request) -> None:
 
123
  if not auth or auth != API_KEY:
124
  raise HTTPException(status_code=401, detail="Invalid or missing API key")
125
 
126
+ def _download_model(model_name: str) -> str:
127
+ """Downloads a model if it's not already present."""
128
+ if model_name not in MODEL_REGISTRY:
129
+ raise HTTPException(status_code=400, detail=f"Model '{model_name}' not found in registry.")
130
+
131
+ model_info = MODEL_REGISTRY[model_name]
132
+ repo_id = model_info["repo_id"]
133
+ filename = model_info["filename"]
134
  os.makedirs(LOCAL_MODEL_DIR, exist_ok=True)
135
+ local_path = os.path.join(LOCAL_MODEL_DIR, filename)
136
 
137
  if os.path.exists(local_path):
138
+ logger.info(f"Model '{model_name}' already downloaded at {local_path}")
139
  return local_path
140
 
141
+ logger.info(f"Downloading model '{model_name}' from {repo_id}/{filename}...")
142
  try:
143
  hf_hub_download(
144
+ repo_id=repo_id,
145
+ filename=filename,
146
  local_dir=LOCAL_MODEL_DIR,
147
  token=HF_TOKEN,
148
  )
149
+ logger.info(f"Model '{model_name}' downloaded successfully.")
150
  return local_path
151
  except Exception as e:
152
+ logger.error(f"Model download failed for '{model_name}': {e}")
153
+ raise HTTPException(status_code=500, detail=f"Failed to download model: {str(e)}")
154
+
155
+ async def _precache_all_models():
156
+ """Downloads all models in the registry at startup."""
157
+ logger.info("Pre-caching all models in registry...")
158
+ download_tasks = []
159
+ for model_name in MODEL_REGISTRY.keys():
160
+ download_tasks.append(asyncio.to_thread(_download_model, model_name))
161
+
162
+ results = await asyncio.gather(*download_tasks, return_exceptions=True)
163
+ for model_name, result in zip(MODEL_REGISTRY.keys(), results):
164
+ if isinstance(result, Exception):
165
+ logger.error(f"Failed to pre-cache model '{model_name}': {result}")
166
+ else:
167
+ DOWNLOADED_MODELS.add(model_name)
168
+ logger.info(f"Model '{model_name}' is ready.")
169
+
170
+ logger.info(f"Pre-caching complete. {len(DOWNLOADED_MODELS)}/{len(MODEL_REGISTRY)} models cached.")
171
+
172
+ async def _ensure_model_loaded(model_name: str):
173
+ """Loads the specified model, downloading it first if necessary."""
174
+ global llm, current_model_name, model_load_error
175
  async with MODEL_LOCK:
176
+ if current_model_name == model_name and llm is not None:
177
  return
178
+
179
+ if llm is not None:
180
+ logger.info(f"Unloading previous model '{current_model_name}'...")
181
+ del llm
182
+ llm = None
183
+ current_model_name = None
184
+
185
  try:
186
+ model_path = _download_model(model_name)
187
  llm = Llama(
188
  model_path=model_path,
189
  n_ctx=N_CTX,
 
191
  n_batch=N_BATCH,
192
  verbose=False,
193
  )
194
+ current_model_name = model_name
195
+ logger.info(f"Model '{model_name}' loaded successfully.")
196
  except Exception as e:
197
  model_load_error = str(e)
198
+ logger.exception(f"Model loading failed for '{model_name}'")
199
  raise HTTPException(status_code=503, detail=f"Model unavailable: {model_load_error}")
200
 
201
+ def _build_chat_prompt(messages: List[Message]) -> List[Dict[str, Any]]:
202
+ """Convert Pydantic messages to dict format for llama.cpp."""
203
+ formatted = []
204
+ for msg in messages:
205
+ msg_dict = {"role": msg.role, "content": msg.content}
206
+ if msg.tool_calls:
207
+ msg_dict["tool_calls"] = msg.tool_calls
208
+ if msg.tool_call_id:
209
+ msg_dict["tool_call_id"] = msg.tool_call_id
210
+ if msg.name:
211
+ msg_dict["name"] = msg.name
212
+ formatted.append(msg_dict)
213
+ return formatted
214
+
215
+ def _convert_tools(tools: Optional[List[Tool]]) -> Optional[List[Dict[str, Any]]]:
216
+ """Convert Pydantic tools to dict format for llama.cpp."""
217
+ if not tools:
218
+ return None
219
+ return [tool.model_dump() for tool in tools]
220
+
221
+ async def _generate_full(
222
+ prompt: List[Dict[str, Any]],
223
+ max_new_tokens: int,
224
+ temperature: float,
225
+ top_p: float,
226
+ stop_sequences: Optional[List[str]] = None,
227
+ tools: Optional[List[Dict[str, Any]]] = None,
228
+ tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
229
+ response_format: Optional[Dict[str, str]] = None,
230
+ ) -> Dict[str, Any]:
231
  if llm is None:
232
  raise HTTPException(status_code=503, detail="Model not loaded")
233
+
234
+ kwargs = {
235
+ "messages": prompt,
236
+ "max_tokens": max_new_tokens,
237
+ "temperature": temperature,
238
+ "top_p": top_p,
239
+ "stop": stop_sequences,
240
+ "stream": False,
241
+ }
242
+ if tools:
243
+ kwargs["tools"] = tools
244
+ if tool_choice:
245
+ kwargs["tool_choice"] = tool_choice
246
+ if response_format:
247
+ kwargs["response_format"] = response_format
248
+
249
+ result = await asyncio.to_thread(lambda: llm.create_chat_completion(**kwargs))
250
+ return result
251
+
252
+ async def _generate_stream(
253
+ prompt: List[Dict[str, Any]],
254
+ max_new_tokens: int,
255
+ temperature: float,
256
+ top_p: float,
257
+ stop_sequences: Optional[List[str]] = None,
258
+ tools: Optional[List[Dict[str, Any]]] = None,
259
+ tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
260
+ response_format: Optional[Dict[str, str]] = None,
261
+ ):
262
  if llm is None:
263
  raise HTTPException(status_code=503, detail="Model not loaded")
264
+
265
+ kwargs = {
266
+ "messages": prompt,
267
+ "max_tokens": max_new_tokens,
268
+ "temperature": temperature,
269
+ "top_p": top_p,
270
+ "stop": stop_sequences,
271
+ "stream": True,
272
+ }
273
+ if tools:
274
+ kwargs["tools"] = tools
275
+ if tool_choice:
276
+ kwargs["tool_choice"] = tool_choice
277
+ if response_format:
278
+ kwargs["response_format"] = response_format
279
+
280
  def sync_gen():
281
+ for chunk in llm.create_chat_completion(**kwargs):
282
+ yield chunk
283
+
284
+ for chunk in await asyncio.to_thread(list, sync_gen()):
285
+ yield chunk
 
 
 
 
 
 
 
 
286
  await asyncio.sleep(0)
287
 
288
  # ---------- FastAPI App ----------
289
  @asynccontextmanager
290
  async def lifespan(app: FastAPI):
291
  try:
292
+ await _precache_all_models()
293
+ await _ensure_model_loaded(DEFAULT_MODEL_NAME)
294
+ logger.info(f"Default model '{DEFAULT_MODEL_NAME}' loaded successfully")
295
  except Exception as e:
296
  logger.error(f"Startup model load failed: {e}")
297
  yield
 
299
  llm = None
300
 
301
  app = FastAPI(
302
+ title="Bonsai Multi-Model Inference API",
303
+ version="3.0.0",
304
+ description="Lightning-fast inference for Bonsai LLMs with tool calling support.",
305
  docs_url="/docs",
306
  redoc_url="/redoc",
307
  lifespan=lifespan,
 
345
 
346
  @app.get("/", summary="Root")
347
  def root():
348
+ return {"message": "Bonsai Multi-Model API is running", "docs": "/docs"}
349
 
350
  @app.get("/health", summary="Health check")
351
  def health():
 
353
  return {
354
  "status": "ok" if loaded else "degraded",
355
  "model_loaded": loaded,
356
+ "current_model": current_model_name,
357
+ "cached_models": list(DOWNLOADED_MODELS),
358
  "error": model_load_error if model_load_error else None,
359
  }
360
 
361
+ @app.get("/v1/models", response_model=ModelListResponse, summary="List available models")
362
+ def list_models():
363
+ models = []
364
+ for name in MODEL_REGISTRY.keys():
365
+ models.append(ModelInfo(id=name, created=int(time.time())))
366
+ return ModelListResponse(data=models)
367
+
368
+ @app.get("/v1/models/{model_name}", response_model=ModelInfo, summary="Get model information")
369
+ def get_model(model_name: str):
370
+ if model_name not in MODEL_REGISTRY:
371
+ raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
372
+ return ModelInfo(id=model_name, created=int(time.time()))
373
 
374
  @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
375
  async def chat_completions(req: ChatCompletionRequest):
376
+ model_name = req.model or DEFAULT_MODEL_NAME
377
+ await _ensure_model_loaded(model_name)
378
+
379
  prompt = _build_chat_prompt(req.messages)
380
+ tools = _convert_tools(req.tools)
381
  stop_seq = req.stop if isinstance(req.stop, list) else ([req.stop] if req.stop else None)
382
 
383
  if req.stream:
384
  async def stream_generator():
385
+ yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': model_name, 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n"
386
+ async for chunk in _generate_stream(prompt, req.max_tokens, req.temperature, req.top_p, stop_seq, tools, req.tool_choice, req.response_format):
387
+ delta = {}
388
+ if "choices" in chunk and len(chunk["choices"]) > 0:
389
+ choice = chunk["choices"][0]
390
+ if "delta" in choice:
391
+ delta = choice["delta"]
392
+ yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': model_name, 'choices': [{'index': 0, 'delta': delta, 'finish_reason': None}]})}\n\n"
393
  await asyncio.sleep(0)
394
+ yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': model_name, 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n"
395
  yield "data: [DONE]\n\n"
396
  return StreamingResponse(stream_generator(), media_type="text/event-stream")
 
397
  else:
398
+ result = await _generate_full(prompt, req.max_tokens, req.temperature, req.top_p, stop_seq, tools, req.tool_choice, req.response_format)
399
+ choice = result["choices"][0]
400
+ message_data = choice.get("message", {})
401
+ assistant_msg = Message(
402
+ role=message_data.get("role", "assistant"),
403
+ content=message_data.get("content"),
404
+ tool_calls=message_data.get("tool_calls"),
405
+ )
406
+ finish_reason = choice.get("finish_reason", "stop")
407
+ usage_data = result.get("usage", {})
408
+ usage = Usage(
409
+ prompt_tokens=usage_data.get("prompt_tokens", 0),
410
+ completion_tokens=usage_data.get("completion_tokens", 0),
411
+ total_tokens=usage_data.get("total_tokens", 0),
412
+ )
413
  return ChatCompletionResponse(
414
  id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
415
  created=int(time.time()),
416
+ model=model_name,
417
+ choices=[ChatCompletionResponseChoice(index=0, message=assistant_msg, finish_reason=finish_reason)],
418
  usage=usage,
419
  )
420