from fastapi import FastAPI from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from llama_cpp import Llama from huggingface_hub import hf_hub_download from supabase import create_client import os, uvicorn, threading from contextlib import asynccontextmanager # ========================= # CONFIG # ========================= HF_TOKEN = os.getenv("HF_TOKEN") SUPABASE_URL = os.getenv("SUPABASE_URL") SUPABASE_KEY = os.getenv("SUPABASE_KEY") supabase = create_client(SUPABASE_URL, SUPABASE_KEY) model = None # ========================= # REQUEST # ========================= class ChatRequest(BaseModel): message: str request_id: str temperature: float = 0.7 # ========================= # CLEAN OUTPUT # ========================= def clean_output(text): stop_words = [ "<|eot_id|>", "<|end_of_text|>", "<|eof|>", "Human:", "Assistant:", "User:" ] for w in stop_words: if w in text: text = text.split(w)[0] return text.strip() # ========================= # PROMPT # ========================= def build_prompt(user_msg): return f"""<|begin_of_text|> <|start_header_id|>system<|end_header_id|> Your name is Llama and you are a cheerful friendly AI buddy made for voice conversation. Rules: - Always refer to yourself as Llama - Speak naturally like a real voice conversation with a friend - Use casual spoken language like hey sure yep got it - Answer in 1 to 2 sentences only - Keep answer under 30 words - Do not use symbols - Do not use abbreviations - Use digits instead of words - No new lines - Output plain text only <|eot_id|> <|start_header_id|>user<|end_header_id|> {user_msg} <|eot_id|> <|start_header_id|>assistant<|end_header_id|> """ # ========================= # MODEL LOAD # ========================= def load_model(): return Llama( model_path=hf_hub_download( repo_id="Valtry/llama3.2-3b-q4-gguf", filename="llama3.2-3b-q4.gguf", token=HF_TOKEN, cache_dir="/data" ), n_ctx=2048, n_threads=4, n_batch=512, use_mmap=True, use_mlock=True, f16_kv=True, verbose=False ) @asynccontextmanager async def lifespan(app: FastAPI): global model model = load_model() yield # ========================= # APP # ========================= app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ========================= # SAVE # ========================= def save_message(role, content, request_id): supabase.table("messages").insert({ "role": role, "content": content, "request_id": request_id }).execute() # ========================= # SUPABASE UPDATE HELPER # ========================= def update_message(msg_id, content, status=None): data = {"content": content} if status: data["status"] = status try: supabase.table("messages").update(data).eq("id", msg_id).execute() except Exception as e: print(f"Supabase update failed: {e}") # ========================= # CHAT # ========================= @app.post("/v1/chat") async def chat(req: ChatRequest): def generate(): prompt = build_prompt(req.message) full_text = "" stream = model( prompt, max_tokens=2048, temperature=req.temperature, top_p=0.9, repeat_penalty=1.15, stop=["<|eot_id|>", "<|end_of_text|>", "<|eof|>"], stream=True ) # 🔥 STREAM DIRECTLY TO ESP for chunk in stream: token = chunk["choices"][0]["text"] full_text += token yield token.replace("\n", " ").replace("\r", "") # ⚡ direct streaming # 🔥 SAVE AFTER COMPLETION final = clean_output(full_text) save_message("user", req.message, req.request_id) save_message("assistant", final, req.request_id) return StreamingResponse(generate(), media_type="text/plain") # ========================= # GET RESPONSE # ========================= @app.get("/v1/get_response/{request_id}") def get_response(request_id: str): try: res = supabase.table("messages") \ .select("content, status") \ .eq("role", "assistant") \ .eq("request_id", request_id) \ .order("created_at", desc=True) \ .limit(1) \ .execute() data = res.data if data: return { "response": data[0]["content"], "status": data[0]["status"] } else: return {"response": None, "status": "waiting"} except Exception as e: return {"error": str(e)} # ========================= # ROOT # ========================= @app.get("/") def root(): return {"status": "LLaMA API running"} # ========================= # RUN # ========================= if __name__ == "__main__": uvicorn.run("app:app", host="0.0.0.0", port=7860)