LLM / app.py
dpv007's picture
Update app.py
9655db7 verified
Raw
History Blame Contribute Delete
7.5 kB
import asyncio
import concurrent.futures
from huggingface_hub import hf_hub_download
from llama_cpp import Llama, LlamaRAMCache
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import uvicorn
# ──────────────────────────────────────────────
# 1. Model Configuration
# ──────────────────────────────────────────────
MODEL_REPO = "Qwen/Qwen2.5-3B-Instruct-GGUF"
MODEL_FILE = "qwen2.5-3b-instruct-q4_k_m.gguf"
SYSTEM_PROMPT = (
"<|im_start|>system\n"
"You are a highly capable technical assistant."
"<|im_end|>\n"
)
# ──────────────────────────────────────────────
# 2. Download & Load Model
# ──────────────────────────────────────────────
print("Downloading model...")
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
print("Tokenizing system prompt for n_keep...")
_tmp = Llama(model_path=model_path, n_ctx=64, verbose=False)
_system_token_count = len(_tmp.tokenize(SYSTEM_PROMPT.encode()))
del _tmp
print(f"System prompt is {_system_token_count} tokens β†’ pinning in KV cache.")
print("Loading model into memory...")
llm = Llama(
model_path=model_path,
n_threads=8,
n_ctx=16384,
n_keep=_system_token_count,
n_batch=1024, # large batch = fastest prefill (prompt processing)
n_ubatch=512,
use_mmap=True,
use_mlock=True,
verbose=False,
)
cache = LlamaRAMCache(capacity_bytes=8_589_934_592)
llm.set_cache(cache)
print("RAM cache initialized (8 GB).")
# llama.cpp is not thread-safe β€” single worker serializes calls correctly
# (multiple workers would corrupt state)
_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
# ──────────────────────────────────────────────
# 3. FastAPI App
# ──────────────────────────────────────────────
app = FastAPI(title="Qwen-2.5 1.5B Local CPU API")
class QueryRequest(BaseModel):
prompt: str
max_tokens: int = 512
temperature: float = 0.7
def format_prompt(user_text: str) -> str:
return (
f"{SYSTEM_PROMPT}"
f"<|im_start|>user\n{user_text}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
# ──────────────────────────────────────────────
# 4. Startup Warmup β€” prime CPU caches & KV allocator
# ──────────────────────────────────────────────
@app.on_event("startup")
async def warmup():
print("Warming up model...")
loop = asyncio.get_event_loop()
# Warm with a realistic-length prompt so prefill path is fully primed
await loop.run_in_executor(
_executor,
lambda: llm(
format_prompt("Explain what recursion is in one sentence."),
max_tokens=4,
stream=False,
),
)
print("Warmup complete. Server is ready.")
# ──────────────────────────────────────────────
# 5. Health check
# ──────────────────────────────────────────────
@app.get("/")
def root():
return {"status": "ok", "model": MODEL_FILE}
# ──────────────────────────────────────────────
# ENDPOINT 1: Full response
# ──────────────────────────────────────────────
@app.post("/generate")
async def generate_full(request: QueryRequest):
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
_executor,
lambda: llm(
format_prompt(request.prompt),
max_tokens=request.max_tokens,
stream=False,
temperature=request.temperature,
stop=["<|im_end|>", "<|im_start|>"],
),
)
return {"response": response["choices"][0]["text"]}
# ──────────────────────────────────────────────
# ENDPOINT 2: Streaming (SSE)
# ──────────────────────────────────────────────
@app.post("/stream")
async def generate_stream(request: QueryRequest):
formatted_prompt = format_prompt(request.prompt)
loop = asyncio.get_event_loop()
# Use a small queue (maxsize=1) so the inference thread blocks the moment
# the consumer (event loop) is not keeping up β€” prevents runaway buffering
queue: asyncio.Queue = asyncio.Queue(maxsize=1)
def run_inference():
try:
stream = llm(
formatted_prompt,
max_tokens=request.max_tokens,
stream=True,
temperature=request.temperature,
stop=["<|im_end|>", "<|im_start|>"],
)
for chunk in stream:
text = chunk["choices"][0]["text"]
if text:
# run_coroutine_threadsafe + result() makes the inference
# thread WAIT until the event loop has consumed this token
# before producing the next one β€” zero buffering delay
fut = asyncio.run_coroutine_threadsafe(queue.put(text), loop)
fut.result() # block until token is in the queue
finally:
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
async def stream_generator():
# Break Nginx / HF Space proxy buffer
yield f": {' ' * 2048}\n\n"
# Start inference in background β€” don't await so we proceed immediately
loop.run_in_executor(_executor, run_inference)
while True:
token = await queue.get()
if token is None:
break
yield f"data: {token}\n\n"
# Force Uvicorn to flush this chunk to the network RIGHT NOW
# without this, the event loop may batch multiple tokens together
await asyncio.sleep(0)
yield "data: [DONE]\n\n"
return StreamingResponse(
stream_generator(),
media_type="text/event-stream",
headers={
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "text/event-stream; charset=utf-8",
},
)
# ──────────────────────────────────────────────
# 6. Entrypoint
# ──────────────────────────────────────────────
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)