Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| from llama_cpp import Llama | |
| from huggingface_hub import hf_hub_download | |
| from typing import List, Optional | |
| import asyncio | |
| import os | |
| import json | |
| import uvicorn | |
| import gc | |
| # ============================================================================= | |
| # FASTAPI | |
| # ============================================================================= | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ============================================================================= | |
| # MODEL CONFIG | |
| # ============================================================================= | |
| MODEL_REPO = "unsloth/Qwen3-4B-GGUF" | |
| MODEL_FILE = "Qwen3-4B-Q4_K_M.gguf" | |
| MAX_HISTORY = 6 | |
| MAX_CTX = 8192 | |
| MAX_TOKENS = 4096 | |
| # Giữ nguyên tham số theo yêu cầu | |
| THREADS = 2 | |
| N_BATCH = 512 | |
| N_UBATCH = 512 | |
| DEFAULT_SYSTEM = ( | |
| "Bạn là trợ lý AI, trả lời bằng tiếng Việt ngắn gọn." | |
| ) | |
| STOP_TOKENS = [ | |
| "<|im_end|>", | |
| "<|endoftext|>", | |
| ] | |
| # ============================================================================= | |
| # GLOBALS | |
| # ============================================================================= | |
| llm: Optional[Llama] = None | |
| # CPU inference -> serialize request để tránh lag/token collapse | |
| inference_lock = asyncio.Semaphore(1) | |
| # ============================================================================= | |
| # REQUEST MODELS | |
| # ============================================================================= | |
| class Message(BaseModel): | |
| role: str | |
| content: str | |
| class ChatRequest(BaseModel): | |
| prompt: str | |
| history: List[Message] = [] | |
| system_prompt: Optional[str] = None | |
| max_tokens: int = MAX_TOKENS | |
| temperature: float = 0.7 | |
| top_p: float = 0.9 | |
| # ============================================================================= | |
| # HELPERS | |
| # ============================================================================= | |
| def cleanup_text(text: str) -> str: | |
| return text.strip().replace("\x00", "") | |
| def build_messages(req: ChatRequest) -> list: | |
| system_prompt = cleanup_text( | |
| req.system_prompt or DEFAULT_SYSTEM | |
| ) | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": system_prompt, | |
| } | |
| ] | |
| recent = req.history[-(MAX_HISTORY * 2):] | |
| last_role = "system" | |
| for msg in recent: | |
| role = msg.role.strip().lower() | |
| content = cleanup_text(msg.content) | |
| if ( | |
| role not in ("user", "assistant") | |
| or not content | |
| ): | |
| continue | |
| # tránh duplicate role liên tục | |
| if role == last_role: | |
| continue | |
| messages.append( | |
| { | |
| "role": role, | |
| "content": content, | |
| } | |
| ) | |
| last_role = role | |
| prompt = cleanup_text(req.prompt) | |
| if not prompt: | |
| raise HTTPException(400, "Prompt trống") | |
| if len(prompt) > 8000: | |
| raise HTTPException(400, "Prompt quá dài") | |
| if messages[-1]["role"] == "user": | |
| messages.pop() | |
| messages.append( | |
| { | |
| "role": "user", | |
| "content": prompt, | |
| } | |
| ) | |
| return messages | |
| def sse(data): | |
| return f"data: {json.dumps(data, ensure_ascii=False)}\n\n" | |
| # ============================================================================= | |
| # STARTUP | |
| # ============================================================================= | |
| async def startup_event(): | |
| global llm | |
| # Xóa file corrupt | |
| if ( | |
| os.path.exists(MODEL_FILE) | |
| and os.path.getsize(MODEL_FILE) < 1_000_000 | |
| ): | |
| os.remove(MODEL_FILE) | |
| # Download nếu chưa có | |
| if not os.path.exists(MODEL_FILE): | |
| print(f"Downloading {MODEL_FILE}...") | |
| hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename=MODEL_FILE, | |
| local_dir=".", | |
| ) | |
| print("Download complete!") | |
| print("Loading model...") | |
| llm = Llama( | |
| model_path=MODEL_FILE, | |
| # Context | |
| n_ctx=MAX_CTX, | |
| # Giữ nguyên batch | |
| n_batch=N_BATCH, | |
| n_ubatch=N_UBATCH, | |
| # CPU | |
| n_threads=THREADS, | |
| n_threads_batch=THREADS, | |
| n_gpu_layers=0, | |
| # RAM | |
| use_mmap=False, | |
| use_mlock=True, | |
| # KV cache | |
| cache_type_k="q4_0", | |
| cache_type_v="q4_0", | |
| # Prefix detection | |
| last_n_tokens_size=64, | |
| # Performance | |
| flash_attn=True, | |
| # Cleaner logs | |
| verbose=False, | |
| ) | |
| print("Warmup model...") | |
| try: | |
| _ = llm.create_chat_completion( | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": DEFAULT_SYSTEM, | |
| }, | |
| { | |
| "role": "user", | |
| "content": "hi", | |
| }, | |
| ], | |
| max_tokens=1, | |
| stream=False, | |
| ) | |
| except Exception as e: | |
| print(f"Warmup failed: {e}") | |
| gc.collect() | |
| print("Model ready!") | |
| # ============================================================================= | |
| # CHAT | |
| # ============================================================================= | |
| async def chat(req: ChatRequest): | |
| global llm | |
| if llm is None: | |
| raise HTTPException( | |
| 503, | |
| "Model chưa sẵn sàng", | |
| ) | |
| messages = build_messages(req) | |
| # Clamp để user không spam 999999 | |
| max_tokens = min( | |
| max(1, req.max_tokens), | |
| MAX_TOKENS, | |
| ) | |
| temperature = min( | |
| max(0.0, req.temperature), | |
| 2.0, | |
| ) | |
| top_p = min( | |
| max(0.1, req.top_p), | |
| 1.0, | |
| ) | |
| async def event_stream(): | |
| full = "" | |
| async with inference_lock: | |
| try: | |
| stream = llm.create_chat_completion( | |
| messages=messages, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| stop=STOP_TOKENS, | |
| stream=True, | |
| ) | |
| for chunk in stream: | |
| try: | |
| delta = ( | |
| chunk["choices"][0] | |
| .get("delta", {}) | |
| .get("content", "") | |
| ) | |
| if not delta: | |
| continue | |
| full += delta | |
| yield sse( | |
| { | |
| "delta": delta, | |
| } | |
| ) | |
| except Exception: | |
| continue | |
| except Exception as e: | |
| yield sse( | |
| { | |
| "error": str(e), | |
| } | |
| ) | |
| finally: | |
| print( | |
| f"[DONE] " | |
| f"{len(full)} chars" | |
| ) | |
| yield "data: [DONE]\n\n" | |
| gc.collect() | |
| return StreamingResponse( | |
| event_stream(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| }, | |
| ) | |
| # ============================================================================= | |
| # HEALTH | |
| # ============================================================================= | |
| async def root(): | |
| return { | |
| "status": "ok" if llm else "loading", | |
| "model": MODEL_FILE, | |
| "ctx": MAX_CTX, | |
| "batch": N_BATCH, | |
| "threads": THREADS, | |
| } | |
| async def health(): | |
| return { | |
| "healthy": llm is not None, | |
| } | |
| # ============================================================================= | |
| # MAIN | |
| # ============================================================================= | |
| if __name__ == "__main__": | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| # production-ish | |
| access_log=False, | |
| server_header=False, | |
| ) |