from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel from llama_cpp import Llama from huggingface_hub import hf_hub_download from typing import List, Optional import asyncio import os import json import uvicorn import gc # ============================================================================= # FASTAPI # ============================================================================= app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ============================================================================= # MODEL CONFIG # ============================================================================= MODEL_REPO = "unsloth/Qwen3-4B-GGUF" MODEL_FILE = "Qwen3-4B-Q4_K_M.gguf" MAX_HISTORY = 6 MAX_CTX = 8192 MAX_TOKENS = 4096 # Giữ nguyên tham số theo yêu cầu THREADS = 2 N_BATCH = 512 N_UBATCH = 512 DEFAULT_SYSTEM = ( "Bạn là trợ lý AI, trả lời bằng tiếng Việt ngắn gọn." ) STOP_TOKENS = [ "<|im_end|>", "<|endoftext|>", ] # ============================================================================= # GLOBALS # ============================================================================= llm: Optional[Llama] = None # CPU inference -> serialize request để tránh lag/token collapse inference_lock = asyncio.Semaphore(1) # ============================================================================= # REQUEST MODELS # ============================================================================= class Message(BaseModel): role: str content: str class ChatRequest(BaseModel): prompt: str history: List[Message] = [] system_prompt: Optional[str] = None max_tokens: int = MAX_TOKENS temperature: float = 0.7 top_p: float = 0.9 # ============================================================================= # HELPERS # ============================================================================= def cleanup_text(text: str) -> str: return text.strip().replace("\x00", "") def build_messages(req: ChatRequest) -> list: system_prompt = cleanup_text( req.system_prompt or DEFAULT_SYSTEM ) messages = [ { "role": "system", "content": system_prompt, } ] recent = req.history[-(MAX_HISTORY * 2):] last_role = "system" for msg in recent: role = msg.role.strip().lower() content = cleanup_text(msg.content) if ( role not in ("user", "assistant") or not content ): continue # tránh duplicate role liên tục if role == last_role: continue messages.append( { "role": role, "content": content, } ) last_role = role prompt = cleanup_text(req.prompt) if not prompt: raise HTTPException(400, "Prompt trống") if len(prompt) > 8000: raise HTTPException(400, "Prompt quá dài") if messages[-1]["role"] == "user": messages.pop() messages.append( { "role": "user", "content": prompt, } ) return messages def sse(data): return f"data: {json.dumps(data, ensure_ascii=False)}\n\n" # ============================================================================= # STARTUP # ============================================================================= @app.on_event("startup") async def startup_event(): global llm # Xóa file corrupt if ( os.path.exists(MODEL_FILE) and os.path.getsize(MODEL_FILE) < 1_000_000 ): os.remove(MODEL_FILE) # Download nếu chưa có if not os.path.exists(MODEL_FILE): print(f"Downloading {MODEL_FILE}...") hf_hub_download( repo_id=MODEL_REPO, filename=MODEL_FILE, local_dir=".", ) print("Download complete!") print("Loading model...") llm = Llama( model_path=MODEL_FILE, # Context n_ctx=MAX_CTX, # Giữ nguyên batch n_batch=N_BATCH, n_ubatch=N_UBATCH, # CPU n_threads=THREADS, n_threads_batch=THREADS, n_gpu_layers=0, # RAM use_mmap=False, use_mlock=True, # KV cache cache_type_k="q4_0", cache_type_v="q4_0", # Prefix detection last_n_tokens_size=64, # Performance flash_attn=True, # Cleaner logs verbose=False, ) print("Warmup model...") try: _ = llm.create_chat_completion( messages=[ { "role": "system", "content": DEFAULT_SYSTEM, }, { "role": "user", "content": "hi", }, ], max_tokens=1, stream=False, ) except Exception as e: print(f"Warmup failed: {e}") gc.collect() print("Model ready!") # ============================================================================= # CHAT # ============================================================================= @app.post("/chat") async def chat(req: ChatRequest): global llm if llm is None: raise HTTPException( 503, "Model chưa sẵn sàng", ) messages = build_messages(req) # Clamp để user không spam 999999 max_tokens = min( max(1, req.max_tokens), MAX_TOKENS, ) temperature = min( max(0.0, req.temperature), 2.0, ) top_p = min( max(0.1, req.top_p), 1.0, ) async def event_stream(): full = "" async with inference_lock: try: stream = llm.create_chat_completion( messages=messages, max_tokens=max_tokens, temperature=temperature, top_p=top_p, stop=STOP_TOKENS, stream=True, ) for chunk in stream: try: delta = ( chunk["choices"][0] .get("delta", {}) .get("content", "") ) if not delta: continue full += delta yield sse( { "delta": delta, } ) except Exception: continue except Exception as e: yield sse( { "error": str(e), } ) finally: print( f"[DONE] " f"{len(full)} chars" ) yield "data: [DONE]\n\n" gc.collect() return StreamingResponse( event_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) # ============================================================================= # HEALTH # ============================================================================= @app.get("/") async def root(): return { "status": "ok" if llm else "loading", "model": MODEL_FILE, "ctx": MAX_CTX, "batch": N_BATCH, "threads": THREADS, } @app.get("/health") async def health(): return { "healthy": llm is not None, } # ============================================================================= # MAIN # ============================================================================= if __name__ == "__main__": uvicorn.run( app, host="0.0.0.0", port=7860, # production-ish access_log=False, server_header=False, )