sidmaz666 commited on
Commit
28340f8
·
verified ·
1 Parent(s): 4862316

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -57
app.py CHANGED
@@ -14,33 +14,25 @@ from fastapi.middleware.cors import CORSMiddleware
14
  from fastapi.responses import JSONResponse, StreamingResponse
15
  from huggingface_hub import hf_hub_download
16
  from pydantic import BaseModel, Field, ValidationError
17
-
18
- # NEW: Import llama.cpp
19
  from llama_cpp import Llama
20
 
21
  # ---------- Configuration ----------
22
- # You can now use GGUF models for even faster inference!
23
- # These are specifically optimized by the PrismML team.
24
  MODEL_ID = os.getenv("MODEL_ID", "prism-ml/Bonsai-1.7B-gguf")
25
- MODEL_FILENAME = os.getenv("MODEL_FILENAME", "Bonsai-1.7B-v1.0-Q1_0.gguf")
26
-
27
- # Quantization types in GGUF: Q1_0 is for 1-bit models.
28
- # For 8B, use MODEL_ID="prism-ml/Bonsai-8B-gguf" and MODEL_FILENAME="Bonsai-8B-v1.0-Q1_0.gguf"
29
-
30
  HF_TOKEN = os.getenv("HF_TOKEN")
31
  LOCAL_MODEL_DIR = os.getenv("LOCAL_MODEL_DIR", "/data/models")
32
  MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS_DEFAULT", "256"))
33
  API_KEY = os.getenv("API_KEY", None)
34
 
35
- # Performance settings for CPU inference
36
- N_CTX = int(os.getenv("N_CTX", "4096")) # Context window
37
- N_THREADS = int(os.getenv("N_THREADS", "4")) # Number of CPU threads to use
38
- N_BATCH = int(os.getenv("N_BATCH", "512")) # Batch size for prompt processing
39
 
40
  logging.basicConfig(level=logging.INFO)
41
  logger = logging.getLogger("uvicorn.error")
42
 
43
- # ---------- Pydantic Models (Same as before) ----------
44
  class Message(BaseModel):
45
  role: str = Field(..., pattern="^(system|user|assistant)$")
46
  content: str
@@ -127,12 +119,11 @@ async def _ensure_loaded():
127
  raise HTTPException(status_code=503, detail=f"Model failed to load: {model_load_error}")
128
  try:
129
  model_path = _download_model()
130
- # Load the model with CPU-optimized settings
131
  llm = Llama(
132
  model_path=model_path,
133
- n_ctx=N_CTX, # Context window
134
- n_threads=N_THREADS, # Number of CPU threads
135
- n_batch=N_BATCH, # Batch size for prompt processing
136
  verbose=False,
137
  )
138
  logger.info(f"Model loaded successfully: {MODEL_ID} ({MODEL_FILENAME})")
@@ -142,21 +133,13 @@ async def _ensure_loaded():
142
  logger.exception("Model loading failed")
143
  raise HTTPException(status_code=503, detail=f"Model unavailable: {model_load_error}")
144
 
145
- def _build_chat_prompt(messages: List[Message]) -> str:
146
- # llama.cpp handles chat templates automatically, so we can just pass the messages directly.
147
- # This is for compatibility; the actual formatting is done by llama.cpp.
148
- if llm is None:
149
- raise HTTPException(status_code=503, detail="Model not loaded")
150
-
151
- # The create_chat_completion method expects a list of messages in this format
152
  return [{"role": msg.role, "content": msg.content} for msg in messages]
153
 
154
  async def _generate_full(prompt: list, max_new_tokens: int, temperature: float, top_p: float, stop_sequences: Optional[List[str]] = None) -> str:
155
  if llm is None:
156
  raise HTTPException(status_code=503, detail="Model not loaded")
157
-
158
- # Run the blocking llama.cpp call in a thread
159
- return await asyncio.to_thread(
160
  lambda: llm.create_chat_completion(
161
  messages=prompt,
162
  max_tokens=max_new_tokens,
@@ -164,15 +147,14 @@ async def _generate_full(prompt: list, max_new_tokens: int, temperature: float,
164
  top_p=top_p,
165
  stop=stop_sequences,
166
  stream=False,
167
- )["choices"][0]["message"]["content"]
168
  )
 
169
 
170
  async def _generate_stream(prompt: list, max_new_tokens: int, temperature: float, top_p: float, stop_sequences: Optional[List[str]] = None):
171
  if llm is None:
172
  raise HTTPException(status_code=503, detail="Model not loaded")
173
-
174
- # llama.cpp can yield a Python generator. We'll run it in a thread and yield the results.
175
- def generator():
176
  for chunk in llm.create_chat_completion(
177
  messages=prompt,
178
  max_tokens=max_new_tokens,
@@ -183,18 +165,12 @@ async def _generate_stream(prompt: list, max_new_tokens: int, temperature: float
183
  ):
184
  if "content" in chunk["choices"][0]["delta"]:
185
  yield chunk["choices"][0]["delta"]["content"]
 
 
 
 
186
 
187
- # We need a helper to bridge the sync generator to an async one
188
- def sync_generator():
189
- for item in generator():
190
- yield item
191
-
192
- # Run the sync generator in a thread and yield items as they come
193
- for item in await asyncio.to_thread(list, sync_generator()):
194
- yield item
195
- await asyncio.sleep(0) # Yield control to the event loop
196
-
197
- # ---------- FastAPI App (Same structure) ----------
198
  @asynccontextmanager
199
  async def lifespan(app: FastAPI):
200
  try:
@@ -233,14 +209,14 @@ async def auth_middleware(request: Request, call_next):
233
  async def http_exception_handler(request, exc):
234
  return JSONResponse(
235
  status_code=exc.status_code,
236
- content=ErrorResponse(error=exc.detail, detail=str(exc.detail)).dict(),
237
  )
238
 
239
  @app.exception_handler(ValidationError)
240
  async def validation_exception_handler(request, exc):
241
  return JSONResponse(
242
  status_code=422,
243
- content=ErrorResponse(error="Validation error", detail=str(exc)).dict(),
244
  )
245
 
246
  @app.exception_handler(Exception)
@@ -248,7 +224,7 @@ async def generic_exception_handler(request, exc):
248
  logger.exception("Unhandled exception")
249
  return JSONResponse(
250
  status_code=500,
251
- content=ErrorResponse(error="Internal server error", detail=str(exc)).dict(),
252
  )
253
 
254
  @app.get("/", summary="Root")
@@ -279,12 +255,7 @@ def model_info():
279
  @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
280
  async def chat_completions(req: ChatCompletionRequest):
281
  await _ensure_loaded()
282
-
283
- try:
284
- prompt = _build_chat_prompt(req.messages)
285
- except Exception as e:
286
- raise HTTPException(status_code=400, detail=f"Prompt formatting error: {str(e)}")
287
-
288
  stop_seq = req.stop if isinstance(req.stop, list) else ([req.stop] if req.stop else None)
289
 
290
  if req.stream:
@@ -300,11 +271,7 @@ async def chat_completions(req: ChatCompletionRequest):
300
  else:
301
  text = await _generate_full(prompt, req.max_tokens, req.temperature, req.top_p, stop_seq)
302
  assistant_msg = Message(role="assistant", content=text)
303
- usage = Usage(
304
- prompt_tokens=0, # llama.cpp can return this, but we can omit for simplicity
305
- completion_tokens=0,
306
- total_tokens=0,
307
- )
308
  return ChatCompletionResponse(
309
  id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
310
  created=int(time.time()),
 
14
  from fastapi.responses import JSONResponse, StreamingResponse
15
  from huggingface_hub import hf_hub_download
16
  from pydantic import BaseModel, Field, ValidationError
 
 
17
  from llama_cpp import Llama
18
 
19
  # ---------- Configuration ----------
 
 
20
  MODEL_ID = os.getenv("MODEL_ID", "prism-ml/Bonsai-1.7B-gguf")
21
+ MODEL_FILENAME = os.getenv("MODEL_FILENAME", "Bonsai-1.7B-Q1_0.gguf")
 
 
 
 
22
  HF_TOKEN = os.getenv("HF_TOKEN")
23
  LOCAL_MODEL_DIR = os.getenv("LOCAL_MODEL_DIR", "/data/models")
24
  MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS_DEFAULT", "256"))
25
  API_KEY = os.getenv("API_KEY", None)
26
 
27
+ # Performance settings
28
+ N_CTX = int(os.getenv("N_CTX", "4096"))
29
+ N_THREADS = int(os.getenv("N_THREADS", "4"))
30
+ N_BATCH = int(os.getenv("N_BATCH", "512"))
31
 
32
  logging.basicConfig(level=logging.INFO)
33
  logger = logging.getLogger("uvicorn.error")
34
 
35
+ # ---------- Pydantic Models ----------
36
  class Message(BaseModel):
37
  role: str = Field(..., pattern="^(system|user|assistant)$")
38
  content: str
 
119
  raise HTTPException(status_code=503, detail=f"Model failed to load: {model_load_error}")
120
  try:
121
  model_path = _download_model()
 
122
  llm = Llama(
123
  model_path=model_path,
124
+ n_ctx=N_CTX,
125
+ n_threads=N_THREADS,
126
+ n_batch=N_BATCH,
127
  verbose=False,
128
  )
129
  logger.info(f"Model loaded successfully: {MODEL_ID} ({MODEL_FILENAME})")
 
133
  logger.exception("Model loading failed")
134
  raise HTTPException(status_code=503, detail=f"Model unavailable: {model_load_error}")
135
 
136
+ def _build_chat_prompt(messages: List[Message]) -> list:
 
 
 
 
 
 
137
  return [{"role": msg.role, "content": msg.content} for msg in messages]
138
 
139
  async def _generate_full(prompt: list, max_new_tokens: int, temperature: float, top_p: float, stop_sequences: Optional[List[str]] = None) -> str:
140
  if llm is None:
141
  raise HTTPException(status_code=503, detail="Model not loaded")
142
+ result = await asyncio.to_thread(
 
 
143
  lambda: llm.create_chat_completion(
144
  messages=prompt,
145
  max_tokens=max_new_tokens,
 
147
  top_p=top_p,
148
  stop=stop_sequences,
149
  stream=False,
150
+ )
151
  )
152
+ return result["choices"][0]["message"]["content"]
153
 
154
  async def _generate_stream(prompt: list, max_new_tokens: int, temperature: float, top_p: float, stop_sequences: Optional[List[str]] = None):
155
  if llm is None:
156
  raise HTTPException(status_code=503, detail="Model not loaded")
157
+ def sync_gen():
 
 
158
  for chunk in llm.create_chat_completion(
159
  messages=prompt,
160
  max_tokens=max_new_tokens,
 
165
  ):
166
  if "content" in chunk["choices"][0]["delta"]:
167
  yield chunk["choices"][0]["delta"]["content"]
168
+ # Convert sync generator to async
169
+ for token in await asyncio.to_thread(list, sync_gen()):
170
+ yield token
171
+ await asyncio.sleep(0)
172
 
173
+ # ---------- FastAPI App ----------
 
 
 
 
 
 
 
 
 
 
174
  @asynccontextmanager
175
  async def lifespan(app: FastAPI):
176
  try:
 
209
  async def http_exception_handler(request, exc):
210
  return JSONResponse(
211
  status_code=exc.status_code,
212
+ content=ErrorResponse(error=exc.detail, detail=str(exc.detail)).model_dump(),
213
  )
214
 
215
  @app.exception_handler(ValidationError)
216
  async def validation_exception_handler(request, exc):
217
  return JSONResponse(
218
  status_code=422,
219
+ content=ErrorResponse(error="Validation error", detail=str(exc)).model_dump(),
220
  )
221
 
222
  @app.exception_handler(Exception)
 
224
  logger.exception("Unhandled exception")
225
  return JSONResponse(
226
  status_code=500,
227
+ content=ErrorResponse(error="Internal server error", detail=str(exc)).model_dump(),
228
  )
229
 
230
  @app.get("/", summary="Root")
 
255
  @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
256
  async def chat_completions(req: ChatCompletionRequest):
257
  await _ensure_loaded()
258
+ prompt = _build_chat_prompt(req.messages)
 
 
 
 
 
259
  stop_seq = req.stop if isinstance(req.stop, list) else ([req.stop] if req.stop else None)
260
 
261
  if req.stream:
 
271
  else:
272
  text = await _generate_full(prompt, req.max_tokens, req.temperature, req.top_p, stop_seq)
273
  assistant_msg = Message(role="assistant", content=text)
274
+ usage = Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
 
 
 
 
275
  return ChatCompletionResponse(
276
  id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
277
  created=int(time.time()),