import json import logging import asyncio from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse, RedirectResponse from fastapi.middleware.cors import CORSMiddleware from llama_cpp import Llama, LlamaRAMCache logging.basicConfig(level=logging.INFO) logger = logging.getLogger("BestieLocal") app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # Model is baked into the image at build time (see Dockerfile) so it # isn't re-downloaded every time the Space's ephemeral disk resets. MODEL_REPO = "unsloth/gemma-4-E4B-it-GGUF" MODEL_FILE = "gemma-4-E4B-it-Q4_K_M.gguf" MODEL_PATH = f"/code/models/{MODEL_FILE}" MAX_REPLY_TOKENS = 90 # shared constant -- used both for generation and for budgeting history below N_CTX = 2048 llm = Llama( model_path=MODEL_PATH, n_ctx=N_CTX, # was 768 -- bigger budget means conversations can grow much # longer before a trim is ever needed at all (see history # budgeting below). RAM cost for this jump is modest. n_threads=2, # matches HF CPU-Basic's 2 vCPUs n_threads_batch=2, n_batch=384, verbose=True, # TEMPORARY: prints real timing/tok-per-sec stats to the logs -- check these ) # NOTE: flash_attn, type_k/type_v KV quantization, and LlamaRAMCache were # removed for this test. They're not guaranteed wins on CPU for short, # single-stream chat -- add them back ONE AT A TIME afterward, only if you # can show each one actually improves the tok/s number below. # Serializes access — a single Llama() instance is not safe for concurrent calls, # and with only 2 cores you don't want two generations fighting for them anyway. inference_lock = asyncio.Lock() SYSTEM_PROMPT = ( "You are the AI Assistant on Bestie, a wellness platform for young Kenyans. " "Talk like a caring " "close friend, not a clinician.\n\n" "Style: 2-4 short sentences, no lists. Match the person's language/energy " "(English, Swahili, Sheng all fine). Validate the feeling before advising. " "Sound like a real person, not a script.\n\n" "Don't rush to fix things. If someone shares something painful (a breakup, " "loss, conflict), don't jump straight to suggestions or coping tips -- ask " "what they actually want right now (to vent, to think out loud, or for ideas) " "unless they've already made that clear. A vague or short reply like 'then' " "or 'idk' usually means they need more space to talk, not an action plan.\n\n" "Never invent a nickname for the person (e.g. don't call anyone 'mzungu') " "or write a placeholder like '[Name]'. Use their name only if they gave it.\n\n" "Never diagnose, give medical advice, or tell someone to start/stop medication. " "You can suggest a doctor/counselor/trusted adult, but don't repeat it every message.\n\n" "If self-harm or suicide comes up: stay calm, don't ask about methods/timing, " "share the Kenya Red Cross lifeline 1199 right away, and don't lecture." ) def trim_to_budget(messages, system_prompt, max_new_tokens, n_ctx, safety=64): """Keep the newest messages that actually fit in the context budget, instead of blindly slicing by message count. A fixed-size window (e.g. last-8) invalidates llama.cpp's prefix-match cache the moment a conversation outgrows it, forcing an expensive full re-prefill on every turn after that -- exactly what showed up in the logs as 150-230 *new* prompt tokens needing 14-19 seconds to evaluate. Trimming by an actual token count means the cached prefix stays valid every single turn until the conversation genuinely outgrows n_ctx, which for a short check-in chat may be rarely or never.""" system_tokens = len(llm.tokenize(system_prompt.encode("utf-8"), add_bos=True)) budget = n_ctx - max_new_tokens - system_tokens - safety kept = [] used = 0 for msg in reversed(messages): content = msg.get("content", "") # +6 is a rough allowance for the chat template's per-turn delimiter # tokens, which aren't visible in the raw message content itself. t = len(llm.tokenize(content.encode("utf-8"), add_bos=False)) + 6 if used + t > budget: break kept.append(msg) used += t kept.reverse() return kept @app.get("/") async def root(): return RedirectResponse(url="/health") @app.post("/api/chat") async def chat_endpoint(request: Request): try: data = await request.json() messages = data.get("messages", []) except Exception: return {"error": "Invalid JSON"} history = trim_to_budget( messages, SYSTEM_PROMPT, max_new_tokens=MAX_REPLY_TOKENS, n_ctx=N_CTX ) # Gemma 4 natively supports a system role (Gemma 3n didn't), so this can # just be a normal system message now -- no more merge-into-user hack. chat_messages = [{"role": "system", "content": SYSTEM_PROMPT}] + history # Defensive: collapse any consecutive same-role turns (e.g. if the # frontend ever sends two user messages in a row) so alternation holds. merged = [] for msg in chat_messages: if merged and merged[-1]["role"] == msg["role"]: merged[-1]["content"] += "\n" + msg.get("content", "") else: merged.append(dict(msg)) chat_messages = merged async def event_generator(): async with inference_lock: loop = asyncio.get_event_loop() queue = asyncio.Queue() DONE = object() def run(): # This whole function -- including the token-by-token loop, # not just creating the generator -- now runs in the # executor thread. Previously only `create_chat_completion` # itself was offloaded; the actual `for chunk in stream` # iteration ran back on the main event loop, blocking the # whole server (new connections, other queued requests) # for the full duration of every single token. try: stream = llm.create_chat_completion( messages=chat_messages, max_tokens=MAX_REPLY_TOKENS, # decode is the real cost, so keep replies capped temperature=0.5, # lower than the 0.8 default -> more focused, consistent replies repeat_penalty=1.1, stream=True, ) finish_reason = None for chunk in stream: delta = chunk["choices"][0]["delta"].get("content") finish_reason = chunk["choices"][0].get("finish_reason") or finish_reason if delta: loop.call_soon_threadsafe(queue.put_nowait, delta) if finish_reason == "length": # Hit max_tokens before a natural stop -- say so # instead of leaving the sentence cut off mid-word. loop.call_soon_threadsafe(queue.put_nowait, "…") except Exception as e: logger.error(f"Inference error: {e}") loop.call_soon_threadsafe(queue.put_nowait, RuntimeError("System busy")) finally: loop.call_soon_threadsafe(queue.put_nowait, DONE) loop.run_in_executor(None, run) while True: item = await queue.get() if item is DONE: break if isinstance(item, RuntimeError): yield f"data: {json.dumps({'error': str(item)})}\n\n" break yield f"data: {json.dumps({'token': item})}\n\n" return StreamingResponse(event_generator(), media_type="text/event-stream") @app.get("/health") def health(): return {"status": "ready", "model": MODEL_FILE}