| import json |
| import logging |
| import asyncio |
| from fastapi import FastAPI, Request |
| from fastapi.responses import StreamingResponse, RedirectResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| from llama_cpp import Llama, LlamaRAMCache |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger("BestieLocal") |
| app = FastAPI() |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| |
| MODEL_REPO = "unsloth/gemma-4-E4B-it-GGUF" |
| MODEL_FILE = "gemma-4-E4B-it-Q4_K_M.gguf" |
| MODEL_PATH = f"/code/models/{MODEL_FILE}" |
|
|
| MAX_REPLY_TOKENS = 90 |
| N_CTX = 2048 |
|
|
| llm = Llama( |
| model_path=MODEL_PATH, |
| n_ctx=N_CTX, |
| |
| |
| n_threads=2, |
| n_threads_batch=2, |
| n_batch=384, |
| verbose=True, |
| ) |
|
|
| |
| |
| |
| |
|
|
| |
| |
| inference_lock = asyncio.Lock() |
|
|
| SYSTEM_PROMPT = ( |
| "You are the AI Assistant on Bestie, a wellness platform for young Kenyans. " |
| "Talk like a caring " |
| "close friend, not a clinician.\n\n" |
| "Style: 2-4 short sentences, no lists. Match the person's language/energy " |
| "(English, Swahili, Sheng all fine). Validate the feeling before advising. " |
| "Sound like a real person, not a script.\n\n" |
| "Don't rush to fix things. If someone shares something painful (a breakup, " |
| "loss, conflict), don't jump straight to suggestions or coping tips -- ask " |
| "what they actually want right now (to vent, to think out loud, or for ideas) " |
| "unless they've already made that clear. A vague or short reply like 'then' " |
| "or 'idk' usually means they need more space to talk, not an action plan.\n\n" |
| "Never invent a nickname for the person (e.g. don't call anyone 'mzungu') " |
| "or write a placeholder like '[Name]'. Use their name only if they gave it.\n\n" |
| "Never diagnose, give medical advice, or tell someone to start/stop medication. " |
| "You can suggest a doctor/counselor/trusted adult, but don't repeat it every message.\n\n" |
| "If self-harm or suicide comes up: stay calm, don't ask about methods/timing, " |
| "share the Kenya Red Cross lifeline 1199 right away, and don't lecture." |
| ) |
|
|
|
|
| def trim_to_budget(messages, system_prompt, max_new_tokens, n_ctx, safety=64): |
| """Keep the newest messages that actually fit in the context budget, |
| instead of blindly slicing by message count. A fixed-size window (e.g. |
| last-8) invalidates llama.cpp's prefix-match cache the moment a |
| conversation outgrows it, forcing an expensive full re-prefill on every |
| turn after that -- exactly what showed up in the logs as 150-230 *new* |
| prompt tokens needing 14-19 seconds to evaluate. Trimming by an actual |
| token count means the cached prefix stays valid every single turn until |
| the conversation genuinely outgrows n_ctx, which for a short check-in |
| chat may be rarely or never.""" |
| system_tokens = len(llm.tokenize(system_prompt.encode("utf-8"), add_bos=True)) |
| budget = n_ctx - max_new_tokens - system_tokens - safety |
|
|
| kept = [] |
| used = 0 |
| for msg in reversed(messages): |
| content = msg.get("content", "") |
| |
| |
| t = len(llm.tokenize(content.encode("utf-8"), add_bos=False)) + 6 |
| if used + t > budget: |
| break |
| kept.append(msg) |
| used += t |
| kept.reverse() |
| return kept |
|
|
|
|
| @app.get("/") |
| async def root(): |
| return RedirectResponse(url="/health") |
|
|
|
|
| @app.post("/api/chat") |
| async def chat_endpoint(request: Request): |
| try: |
| data = await request.json() |
| messages = data.get("messages", []) |
| except Exception: |
| return {"error": "Invalid JSON"} |
|
|
| history = trim_to_budget( |
| messages, SYSTEM_PROMPT, max_new_tokens=MAX_REPLY_TOKENS, n_ctx=N_CTX |
| ) |
|
|
| |
| |
| chat_messages = [{"role": "system", "content": SYSTEM_PROMPT}] + history |
|
|
| |
| |
| merged = [] |
| for msg in chat_messages: |
| if merged and merged[-1]["role"] == msg["role"]: |
| merged[-1]["content"] += "\n" + msg.get("content", "") |
| else: |
| merged.append(dict(msg)) |
| chat_messages = merged |
|
|
| async def event_generator(): |
| async with inference_lock: |
| loop = asyncio.get_event_loop() |
| queue = asyncio.Queue() |
| DONE = object() |
|
|
| def run(): |
| |
| |
| |
| |
| |
| |
| |
| try: |
| stream = llm.create_chat_completion( |
| messages=chat_messages, |
| max_tokens=MAX_REPLY_TOKENS, |
| temperature=0.5, |
| repeat_penalty=1.1, |
| stream=True, |
| ) |
| finish_reason = None |
| for chunk in stream: |
| delta = chunk["choices"][0]["delta"].get("content") |
| finish_reason = chunk["choices"][0].get("finish_reason") or finish_reason |
| if delta: |
| loop.call_soon_threadsafe(queue.put_nowait, delta) |
| if finish_reason == "length": |
| |
| |
| loop.call_soon_threadsafe(queue.put_nowait, "…") |
| except Exception as e: |
| logger.error(f"Inference error: {e}") |
| loop.call_soon_threadsafe(queue.put_nowait, RuntimeError("System busy")) |
| finally: |
| loop.call_soon_threadsafe(queue.put_nowait, DONE) |
|
|
| loop.run_in_executor(None, run) |
|
|
| while True: |
| item = await queue.get() |
| if item is DONE: |
| break |
| if isinstance(item, RuntimeError): |
| yield f"data: {json.dumps({'error': str(item)})}\n\n" |
| break |
| yield f"data: {json.dumps({'token': item})}\n\n" |
|
|
| return StreamingResponse(event_generator(), media_type="text/event-stream") |
|
|
|
|
| @app.get("/health") |
| def health(): |
| return {"status": "ready", "model": MODEL_FILE} |