backend / main.py
Rofati's picture
Update main.py
f0edfa1 verified
Raw
History Blame Contribute Delete
8.09 kB
import json
import logging
import asyncio
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, RedirectResponse
from fastapi.middleware.cors import CORSMiddleware
from llama_cpp import Llama, LlamaRAMCache
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("BestieLocal")
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Model is baked into the image at build time (see Dockerfile) so it
# isn't re-downloaded every time the Space's ephemeral disk resets.
MODEL_REPO = "unsloth/gemma-4-E4B-it-GGUF"
MODEL_FILE = "gemma-4-E4B-it-Q4_K_M.gguf"
MODEL_PATH = f"/code/models/{MODEL_FILE}"
MAX_REPLY_TOKENS = 90 # shared constant -- used both for generation and for budgeting history below
N_CTX = 2048
llm = Llama(
model_path=MODEL_PATH,
n_ctx=N_CTX, # was 768 -- bigger budget means conversations can grow much
# longer before a trim is ever needed at all (see history
# budgeting below). RAM cost for this jump is modest.
n_threads=2, # matches HF CPU-Basic's 2 vCPUs
n_threads_batch=2,
n_batch=384,
verbose=True, # TEMPORARY: prints real timing/tok-per-sec stats to the logs -- check these
)
# NOTE: flash_attn, type_k/type_v KV quantization, and LlamaRAMCache were
# removed for this test. They're not guaranteed wins on CPU for short,
# single-stream chat -- add them back ONE AT A TIME afterward, only if you
# can show each one actually improves the tok/s number below.
# Serializes access — a single Llama() instance is not safe for concurrent calls,
# and with only 2 cores you don't want two generations fighting for them anyway.
inference_lock = asyncio.Lock()
SYSTEM_PROMPT = (
"You are the AI Assistant on Bestie, a wellness platform for young Kenyans. "
"Talk like a caring "
"close friend, not a clinician.\n\n"
"Style: 2-4 short sentences, no lists. Match the person's language/energy "
"(English, Swahili, Sheng all fine). Validate the feeling before advising. "
"Sound like a real person, not a script.\n\n"
"Don't rush to fix things. If someone shares something painful (a breakup, "
"loss, conflict), don't jump straight to suggestions or coping tips -- ask "
"what they actually want right now (to vent, to think out loud, or for ideas) "
"unless they've already made that clear. A vague or short reply like 'then' "
"or 'idk' usually means they need more space to talk, not an action plan.\n\n"
"Never invent a nickname for the person (e.g. don't call anyone 'mzungu') "
"or write a placeholder like '[Name]'. Use their name only if they gave it.\n\n"
"Never diagnose, give medical advice, or tell someone to start/stop medication. "
"You can suggest a doctor/counselor/trusted adult, but don't repeat it every message.\n\n"
"If self-harm or suicide comes up: stay calm, don't ask about methods/timing, "
"share the Kenya Red Cross lifeline 1199 right away, and don't lecture."
)
def trim_to_budget(messages, system_prompt, max_new_tokens, n_ctx, safety=64):
"""Keep the newest messages that actually fit in the context budget,
instead of blindly slicing by message count. A fixed-size window (e.g.
last-8) invalidates llama.cpp's prefix-match cache the moment a
conversation outgrows it, forcing an expensive full re-prefill on every
turn after that -- exactly what showed up in the logs as 150-230 *new*
prompt tokens needing 14-19 seconds to evaluate. Trimming by an actual
token count means the cached prefix stays valid every single turn until
the conversation genuinely outgrows n_ctx, which for a short check-in
chat may be rarely or never."""
system_tokens = len(llm.tokenize(system_prompt.encode("utf-8"), add_bos=True))
budget = n_ctx - max_new_tokens - system_tokens - safety
kept = []
used = 0
for msg in reversed(messages):
content = msg.get("content", "")
# +6 is a rough allowance for the chat template's per-turn delimiter
# tokens, which aren't visible in the raw message content itself.
t = len(llm.tokenize(content.encode("utf-8"), add_bos=False)) + 6
if used + t > budget:
break
kept.append(msg)
used += t
kept.reverse()
return kept
@app.get("/")
async def root():
return RedirectResponse(url="/health")
@app.post("/api/chat")
async def chat_endpoint(request: Request):
try:
data = await request.json()
messages = data.get("messages", [])
except Exception:
return {"error": "Invalid JSON"}
history = trim_to_budget(
messages, SYSTEM_PROMPT, max_new_tokens=MAX_REPLY_TOKENS, n_ctx=N_CTX
)
# Gemma 4 natively supports a system role (Gemma 3n didn't), so this can
# just be a normal system message now -- no more merge-into-user hack.
chat_messages = [{"role": "system", "content": SYSTEM_PROMPT}] + history
# Defensive: collapse any consecutive same-role turns (e.g. if the
# frontend ever sends two user messages in a row) so alternation holds.
merged = []
for msg in chat_messages:
if merged and merged[-1]["role"] == msg["role"]:
merged[-1]["content"] += "\n" + msg.get("content", "")
else:
merged.append(dict(msg))
chat_messages = merged
async def event_generator():
async with inference_lock:
loop = asyncio.get_event_loop()
queue = asyncio.Queue()
DONE = object()
def run():
# This whole function -- including the token-by-token loop,
# not just creating the generator -- now runs in the
# executor thread. Previously only `create_chat_completion`
# itself was offloaded; the actual `for chunk in stream`
# iteration ran back on the main event loop, blocking the
# whole server (new connections, other queued requests)
# for the full duration of every single token.
try:
stream = llm.create_chat_completion(
messages=chat_messages,
max_tokens=MAX_REPLY_TOKENS, # decode is the real cost, so keep replies capped
temperature=0.5, # lower than the 0.8 default -> more focused, consistent replies
repeat_penalty=1.1,
stream=True,
)
finish_reason = None
for chunk in stream:
delta = chunk["choices"][0]["delta"].get("content")
finish_reason = chunk["choices"][0].get("finish_reason") or finish_reason
if delta:
loop.call_soon_threadsafe(queue.put_nowait, delta)
if finish_reason == "length":
# Hit max_tokens before a natural stop -- say so
# instead of leaving the sentence cut off mid-word.
loop.call_soon_threadsafe(queue.put_nowait, "…")
except Exception as e:
logger.error(f"Inference error: {e}")
loop.call_soon_threadsafe(queue.put_nowait, RuntimeError("System busy"))
finally:
loop.call_soon_threadsafe(queue.put_nowait, DONE)
loop.run_in_executor(None, run)
while True:
item = await queue.get()
if item is DONE:
break
if isinstance(item, RuntimeError):
yield f"data: {json.dumps({'error': str(item)})}\n\n"
break
yield f"data: {json.dumps({'token': item})}\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")
@app.get("/health")
def health():
return {"status": "ready", "model": MODEL_FILE}