AI_Coder / app.py
Toilatop1sever's picture
Update app.py
86e793d verified
Raw
History Blame Contribute Delete
8.4 kB
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
from typing import List, Optional
import asyncio
import os
import json
import uvicorn
import gc
# =============================================================================
# FASTAPI
# =============================================================================
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# =============================================================================
# MODEL CONFIG
# =============================================================================
MODEL_REPO = "unsloth/Qwen3-4B-GGUF"
MODEL_FILE = "Qwen3-4B-Q4_K_M.gguf"
MAX_HISTORY = 6
MAX_CTX = 8192
MAX_TOKENS = 4096
# Giữ nguyên tham số theo yêu cầu
THREADS = 2
N_BATCH = 512
N_UBATCH = 512
DEFAULT_SYSTEM = (
"Bạn là trợ lý AI, trả lời bằng tiếng Việt ngắn gọn."
)
STOP_TOKENS = [
"<|im_end|>",
"<|endoftext|>",
]
# =============================================================================
# GLOBALS
# =============================================================================
llm: Optional[Llama] = None
# CPU inference -> serialize request để tránh lag/token collapse
inference_lock = asyncio.Semaphore(1)
# =============================================================================
# REQUEST MODELS
# =============================================================================
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
prompt: str
history: List[Message] = []
system_prompt: Optional[str] = None
max_tokens: int = MAX_TOKENS
temperature: float = 0.7
top_p: float = 0.9
# =============================================================================
# HELPERS
# =============================================================================
def cleanup_text(text: str) -> str:
return text.strip().replace("\x00", "")
def build_messages(req: ChatRequest) -> list:
system_prompt = cleanup_text(
req.system_prompt or DEFAULT_SYSTEM
)
messages = [
{
"role": "system",
"content": system_prompt,
}
]
recent = req.history[-(MAX_HISTORY * 2):]
last_role = "system"
for msg in recent:
role = msg.role.strip().lower()
content = cleanup_text(msg.content)
if (
role not in ("user", "assistant")
or not content
):
continue
# tránh duplicate role liên tục
if role == last_role:
continue
messages.append(
{
"role": role,
"content": content,
}
)
last_role = role
prompt = cleanup_text(req.prompt)
if not prompt:
raise HTTPException(400, "Prompt trống")
if len(prompt) > 8000:
raise HTTPException(400, "Prompt quá dài")
if messages[-1]["role"] == "user":
messages.pop()
messages.append(
{
"role": "user",
"content": prompt,
}
)
return messages
def sse(data):
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
# =============================================================================
# STARTUP
# =============================================================================
@app.on_event("startup")
async def startup_event():
global llm
# Xóa file corrupt
if (
os.path.exists(MODEL_FILE)
and os.path.getsize(MODEL_FILE) < 1_000_000
):
os.remove(MODEL_FILE)
# Download nếu chưa có
if not os.path.exists(MODEL_FILE):
print(f"Downloading {MODEL_FILE}...")
hf_hub_download(
repo_id=MODEL_REPO,
filename=MODEL_FILE,
local_dir=".",
)
print("Download complete!")
print("Loading model...")
llm = Llama(
model_path=MODEL_FILE,
# Context
n_ctx=MAX_CTX,
# Giữ nguyên batch
n_batch=N_BATCH,
n_ubatch=N_UBATCH,
# CPU
n_threads=THREADS,
n_threads_batch=THREADS,
n_gpu_layers=0,
# RAM
use_mmap=False,
use_mlock=True,
# KV cache
cache_type_k="q4_0",
cache_type_v="q4_0",
# Prefix detection
last_n_tokens_size=64,
# Performance
flash_attn=True,
# Cleaner logs
verbose=False,
)
print("Warmup model...")
try:
_ = llm.create_chat_completion(
messages=[
{
"role": "system",
"content": DEFAULT_SYSTEM,
},
{
"role": "user",
"content": "hi",
},
],
max_tokens=1,
stream=False,
)
except Exception as e:
print(f"Warmup failed: {e}")
gc.collect()
print("Model ready!")
# =============================================================================
# CHAT
# =============================================================================
@app.post("/chat")
async def chat(req: ChatRequest):
global llm
if llm is None:
raise HTTPException(
503,
"Model chưa sẵn sàng",
)
messages = build_messages(req)
# Clamp để user không spam 999999
max_tokens = min(
max(1, req.max_tokens),
MAX_TOKENS,
)
temperature = min(
max(0.0, req.temperature),
2.0,
)
top_p = min(
max(0.1, req.top_p),
1.0,
)
async def event_stream():
full = ""
async with inference_lock:
try:
stream = llm.create_chat_completion(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop=STOP_TOKENS,
stream=True,
)
for chunk in stream:
try:
delta = (
chunk["choices"][0]
.get("delta", {})
.get("content", "")
)
if not delta:
continue
full += delta
yield sse(
{
"delta": delta,
}
)
except Exception:
continue
except Exception as e:
yield sse(
{
"error": str(e),
}
)
finally:
print(
f"[DONE] "
f"{len(full)} chars"
)
yield "data: [DONE]\n\n"
gc.collect()
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
# =============================================================================
# HEALTH
# =============================================================================
@app.get("/")
async def root():
return {
"status": "ok" if llm else "loading",
"model": MODEL_FILE,
"ctx": MAX_CTX,
"batch": N_BATCH,
"threads": THREADS,
}
@app.get("/health")
async def health():
return {
"healthy": llm is not None,
}
# =============================================================================
# MAIN
# =============================================================================
if __name__ == "__main__":
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
# production-ish
access_log=False,
server_header=False,
)