Update app.py
Browse files
app.py
CHANGED
|
@@ -14,33 +14,25 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
| 14 |
from fastapi.responses import JSONResponse, StreamingResponse
|
| 15 |
from huggingface_hub import hf_hub_download
|
| 16 |
from pydantic import BaseModel, Field, ValidationError
|
| 17 |
-
|
| 18 |
-
# NEW: Import llama.cpp
|
| 19 |
from llama_cpp import Llama
|
| 20 |
|
| 21 |
# ---------- Configuration ----------
|
| 22 |
-
# You can now use GGUF models for even faster inference!
|
| 23 |
-
# These are specifically optimized by the PrismML team.
|
| 24 |
MODEL_ID = os.getenv("MODEL_ID", "prism-ml/Bonsai-1.7B-gguf")
|
| 25 |
-
MODEL_FILENAME = os.getenv("MODEL_FILENAME", "Bonsai-1.7B-
|
| 26 |
-
|
| 27 |
-
# Quantization types in GGUF: Q1_0 is for 1-bit models.
|
| 28 |
-
# For 8B, use MODEL_ID="prism-ml/Bonsai-8B-gguf" and MODEL_FILENAME="Bonsai-8B-v1.0-Q1_0.gguf"
|
| 29 |
-
|
| 30 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 31 |
LOCAL_MODEL_DIR = os.getenv("LOCAL_MODEL_DIR", "/data/models")
|
| 32 |
MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS_DEFAULT", "256"))
|
| 33 |
API_KEY = os.getenv("API_KEY", None)
|
| 34 |
|
| 35 |
-
# Performance settings
|
| 36 |
-
N_CTX = int(os.getenv("N_CTX", "4096"))
|
| 37 |
-
N_THREADS = int(os.getenv("N_THREADS", "4"))
|
| 38 |
-
N_BATCH = int(os.getenv("N_BATCH", "512"))
|
| 39 |
|
| 40 |
logging.basicConfig(level=logging.INFO)
|
| 41 |
logger = logging.getLogger("uvicorn.error")
|
| 42 |
|
| 43 |
-
# ---------- Pydantic Models
|
| 44 |
class Message(BaseModel):
|
| 45 |
role: str = Field(..., pattern="^(system|user|assistant)$")
|
| 46 |
content: str
|
|
@@ -127,12 +119,11 @@ async def _ensure_loaded():
|
|
| 127 |
raise HTTPException(status_code=503, detail=f"Model failed to load: {model_load_error}")
|
| 128 |
try:
|
| 129 |
model_path = _download_model()
|
| 130 |
-
# Load the model with CPU-optimized settings
|
| 131 |
llm = Llama(
|
| 132 |
model_path=model_path,
|
| 133 |
-
n_ctx=N_CTX,
|
| 134 |
-
n_threads=N_THREADS,
|
| 135 |
-
n_batch=N_BATCH,
|
| 136 |
verbose=False,
|
| 137 |
)
|
| 138 |
logger.info(f"Model loaded successfully: {MODEL_ID} ({MODEL_FILENAME})")
|
|
@@ -142,21 +133,13 @@ async def _ensure_loaded():
|
|
| 142 |
logger.exception("Model loading failed")
|
| 143 |
raise HTTPException(status_code=503, detail=f"Model unavailable: {model_load_error}")
|
| 144 |
|
| 145 |
-
def _build_chat_prompt(messages: List[Message]) ->
|
| 146 |
-
# llama.cpp handles chat templates automatically, so we can just pass the messages directly.
|
| 147 |
-
# This is for compatibility; the actual formatting is done by llama.cpp.
|
| 148 |
-
if llm is None:
|
| 149 |
-
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 150 |
-
|
| 151 |
-
# The create_chat_completion method expects a list of messages in this format
|
| 152 |
return [{"role": msg.role, "content": msg.content} for msg in messages]
|
| 153 |
|
| 154 |
async def _generate_full(prompt: list, max_new_tokens: int, temperature: float, top_p: float, stop_sequences: Optional[List[str]] = None) -> str:
|
| 155 |
if llm is None:
|
| 156 |
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 157 |
-
|
| 158 |
-
# Run the blocking llama.cpp call in a thread
|
| 159 |
-
return await asyncio.to_thread(
|
| 160 |
lambda: llm.create_chat_completion(
|
| 161 |
messages=prompt,
|
| 162 |
max_tokens=max_new_tokens,
|
|
@@ -164,15 +147,14 @@ async def _generate_full(prompt: list, max_new_tokens: int, temperature: float,
|
|
| 164 |
top_p=top_p,
|
| 165 |
stop=stop_sequences,
|
| 166 |
stream=False,
|
| 167 |
-
)
|
| 168 |
)
|
|
|
|
| 169 |
|
| 170 |
async def _generate_stream(prompt: list, max_new_tokens: int, temperature: float, top_p: float, stop_sequences: Optional[List[str]] = None):
|
| 171 |
if llm is None:
|
| 172 |
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 173 |
-
|
| 174 |
-
# llama.cpp can yield a Python generator. We'll run it in a thread and yield the results.
|
| 175 |
-
def generator():
|
| 176 |
for chunk in llm.create_chat_completion(
|
| 177 |
messages=prompt,
|
| 178 |
max_tokens=max_new_tokens,
|
|
@@ -183,18 +165,12 @@ async def _generate_stream(prompt: list, max_new_tokens: int, temperature: float
|
|
| 183 |
):
|
| 184 |
if "content" in chunk["choices"][0]["delta"]:
|
| 185 |
yield chunk["choices"][0]["delta"]["content"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
-
|
| 188 |
-
def sync_generator():
|
| 189 |
-
for item in generator():
|
| 190 |
-
yield item
|
| 191 |
-
|
| 192 |
-
# Run the sync generator in a thread and yield items as they come
|
| 193 |
-
for item in await asyncio.to_thread(list, sync_generator()):
|
| 194 |
-
yield item
|
| 195 |
-
await asyncio.sleep(0) # Yield control to the event loop
|
| 196 |
-
|
| 197 |
-
# ---------- FastAPI App (Same structure) ----------
|
| 198 |
@asynccontextmanager
|
| 199 |
async def lifespan(app: FastAPI):
|
| 200 |
try:
|
|
@@ -233,14 +209,14 @@ async def auth_middleware(request: Request, call_next):
|
|
| 233 |
async def http_exception_handler(request, exc):
|
| 234 |
return JSONResponse(
|
| 235 |
status_code=exc.status_code,
|
| 236 |
-
content=ErrorResponse(error=exc.detail, detail=str(exc.detail)).
|
| 237 |
)
|
| 238 |
|
| 239 |
@app.exception_handler(ValidationError)
|
| 240 |
async def validation_exception_handler(request, exc):
|
| 241 |
return JSONResponse(
|
| 242 |
status_code=422,
|
| 243 |
-
content=ErrorResponse(error="Validation error", detail=str(exc)).
|
| 244 |
)
|
| 245 |
|
| 246 |
@app.exception_handler(Exception)
|
|
@@ -248,7 +224,7 @@ async def generic_exception_handler(request, exc):
|
|
| 248 |
logger.exception("Unhandled exception")
|
| 249 |
return JSONResponse(
|
| 250 |
status_code=500,
|
| 251 |
-
content=ErrorResponse(error="Internal server error", detail=str(exc)).
|
| 252 |
)
|
| 253 |
|
| 254 |
@app.get("/", summary="Root")
|
|
@@ -279,12 +255,7 @@ def model_info():
|
|
| 279 |
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
| 280 |
async def chat_completions(req: ChatCompletionRequest):
|
| 281 |
await _ensure_loaded()
|
| 282 |
-
|
| 283 |
-
try:
|
| 284 |
-
prompt = _build_chat_prompt(req.messages)
|
| 285 |
-
except Exception as e:
|
| 286 |
-
raise HTTPException(status_code=400, detail=f"Prompt formatting error: {str(e)}")
|
| 287 |
-
|
| 288 |
stop_seq = req.stop if isinstance(req.stop, list) else ([req.stop] if req.stop else None)
|
| 289 |
|
| 290 |
if req.stream:
|
|
@@ -300,11 +271,7 @@ async def chat_completions(req: ChatCompletionRequest):
|
|
| 300 |
else:
|
| 301 |
text = await _generate_full(prompt, req.max_tokens, req.temperature, req.top_p, stop_seq)
|
| 302 |
assistant_msg = Message(role="assistant", content=text)
|
| 303 |
-
usage = Usage(
|
| 304 |
-
prompt_tokens=0, # llama.cpp can return this, but we can omit for simplicity
|
| 305 |
-
completion_tokens=0,
|
| 306 |
-
total_tokens=0,
|
| 307 |
-
)
|
| 308 |
return ChatCompletionResponse(
|
| 309 |
id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
| 310 |
created=int(time.time()),
|
|
|
|
| 14 |
from fastapi.responses import JSONResponse, StreamingResponse
|
| 15 |
from huggingface_hub import hf_hub_download
|
| 16 |
from pydantic import BaseModel, Field, ValidationError
|
|
|
|
|
|
|
| 17 |
from llama_cpp import Llama
|
| 18 |
|
| 19 |
# ---------- Configuration ----------
|
|
|
|
|
|
|
| 20 |
MODEL_ID = os.getenv("MODEL_ID", "prism-ml/Bonsai-1.7B-gguf")
|
| 21 |
+
MODEL_FILENAME = os.getenv("MODEL_FILENAME", "Bonsai-1.7B-Q1_0.gguf")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 23 |
LOCAL_MODEL_DIR = os.getenv("LOCAL_MODEL_DIR", "/data/models")
|
| 24 |
MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS_DEFAULT", "256"))
|
| 25 |
API_KEY = os.getenv("API_KEY", None)
|
| 26 |
|
| 27 |
+
# Performance settings
|
| 28 |
+
N_CTX = int(os.getenv("N_CTX", "4096"))
|
| 29 |
+
N_THREADS = int(os.getenv("N_THREADS", "4"))
|
| 30 |
+
N_BATCH = int(os.getenv("N_BATCH", "512"))
|
| 31 |
|
| 32 |
logging.basicConfig(level=logging.INFO)
|
| 33 |
logger = logging.getLogger("uvicorn.error")
|
| 34 |
|
| 35 |
+
# ---------- Pydantic Models ----------
|
| 36 |
class Message(BaseModel):
|
| 37 |
role: str = Field(..., pattern="^(system|user|assistant)$")
|
| 38 |
content: str
|
|
|
|
| 119 |
raise HTTPException(status_code=503, detail=f"Model failed to load: {model_load_error}")
|
| 120 |
try:
|
| 121 |
model_path = _download_model()
|
|
|
|
| 122 |
llm = Llama(
|
| 123 |
model_path=model_path,
|
| 124 |
+
n_ctx=N_CTX,
|
| 125 |
+
n_threads=N_THREADS,
|
| 126 |
+
n_batch=N_BATCH,
|
| 127 |
verbose=False,
|
| 128 |
)
|
| 129 |
logger.info(f"Model loaded successfully: {MODEL_ID} ({MODEL_FILENAME})")
|
|
|
|
| 133 |
logger.exception("Model loading failed")
|
| 134 |
raise HTTPException(status_code=503, detail=f"Model unavailable: {model_load_error}")
|
| 135 |
|
| 136 |
+
def _build_chat_prompt(messages: List[Message]) -> list:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
return [{"role": msg.role, "content": msg.content} for msg in messages]
|
| 138 |
|
| 139 |
async def _generate_full(prompt: list, max_new_tokens: int, temperature: float, top_p: float, stop_sequences: Optional[List[str]] = None) -> str:
|
| 140 |
if llm is None:
|
| 141 |
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 142 |
+
result = await asyncio.to_thread(
|
|
|
|
|
|
|
| 143 |
lambda: llm.create_chat_completion(
|
| 144 |
messages=prompt,
|
| 145 |
max_tokens=max_new_tokens,
|
|
|
|
| 147 |
top_p=top_p,
|
| 148 |
stop=stop_sequences,
|
| 149 |
stream=False,
|
| 150 |
+
)
|
| 151 |
)
|
| 152 |
+
return result["choices"][0]["message"]["content"]
|
| 153 |
|
| 154 |
async def _generate_stream(prompt: list, max_new_tokens: int, temperature: float, top_p: float, stop_sequences: Optional[List[str]] = None):
|
| 155 |
if llm is None:
|
| 156 |
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 157 |
+
def sync_gen():
|
|
|
|
|
|
|
| 158 |
for chunk in llm.create_chat_completion(
|
| 159 |
messages=prompt,
|
| 160 |
max_tokens=max_new_tokens,
|
|
|
|
| 165 |
):
|
| 166 |
if "content" in chunk["choices"][0]["delta"]:
|
| 167 |
yield chunk["choices"][0]["delta"]["content"]
|
| 168 |
+
# Convert sync generator to async
|
| 169 |
+
for token in await asyncio.to_thread(list, sync_gen()):
|
| 170 |
+
yield token
|
| 171 |
+
await asyncio.sleep(0)
|
| 172 |
|
| 173 |
+
# ---------- FastAPI App ----------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
@asynccontextmanager
|
| 175 |
async def lifespan(app: FastAPI):
|
| 176 |
try:
|
|
|
|
| 209 |
async def http_exception_handler(request, exc):
|
| 210 |
return JSONResponse(
|
| 211 |
status_code=exc.status_code,
|
| 212 |
+
content=ErrorResponse(error=exc.detail, detail=str(exc.detail)).model_dump(),
|
| 213 |
)
|
| 214 |
|
| 215 |
@app.exception_handler(ValidationError)
|
| 216 |
async def validation_exception_handler(request, exc):
|
| 217 |
return JSONResponse(
|
| 218 |
status_code=422,
|
| 219 |
+
content=ErrorResponse(error="Validation error", detail=str(exc)).model_dump(),
|
| 220 |
)
|
| 221 |
|
| 222 |
@app.exception_handler(Exception)
|
|
|
|
| 224 |
logger.exception("Unhandled exception")
|
| 225 |
return JSONResponse(
|
| 226 |
status_code=500,
|
| 227 |
+
content=ErrorResponse(error="Internal server error", detail=str(exc)).model_dump(),
|
| 228 |
)
|
| 229 |
|
| 230 |
@app.get("/", summary="Root")
|
|
|
|
| 255 |
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
| 256 |
async def chat_completions(req: ChatCompletionRequest):
|
| 257 |
await _ensure_loaded()
|
| 258 |
+
prompt = _build_chat_prompt(req.messages)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
stop_seq = req.stop if isinstance(req.stop, list) else ([req.stop] if req.stop else None)
|
| 260 |
|
| 261 |
if req.stream:
|
|
|
|
| 271 |
else:
|
| 272 |
text = await _generate_full(prompt, req.max_tokens, req.temperature, req.top_p, stop_seq)
|
| 273 |
assistant_msg = Message(role="assistant", content=text)
|
| 274 |
+
usage = Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
return ChatCompletionResponse(
|
| 276 |
id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
| 277 |
created=int(time.time()),
|