Toilatop1sever commited on
Commit
86e793d
·
verified ·
1 Parent(s): 58615c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +302 -104
app.py CHANGED
@@ -5,11 +5,18 @@ from pydantic import BaseModel
5
  from llama_cpp import Llama
6
  from huggingface_hub import hf_hub_download
7
  from typing import List, Optional
 
8
  import os
9
  import json
10
  import uvicorn
 
 
 
 
 
11
 
12
  app = FastAPI()
 
13
  app.add_middleware(
14
  CORSMiddleware,
15
  allow_origins=["*"],
@@ -17,166 +24,357 @@ app.add_middleware(
17
  allow_headers=["*"],
18
  )
19
 
 
 
 
 
20
  MODEL_REPO = "unsloth/Qwen3-4B-GGUF"
21
  MODEL_FILE = "Qwen3-4B-Q4_K_M.gguf"
22
 
23
- # ── Triết lý tối ưu ───────────────────────────────────────────────────────
24
- # RAM 18GB dư dả → nhét hết vào RAM, dùng prefix cache để CPU
25
- # không phải recompute system prompt mỗi request
26
- # n_batch = 4096 (sweet spot) — đủ để prefill nhanh mà không gây RAM spike
27
- # ─────────────────────────────────────────────────────────────────────────
28
  MAX_HISTORY = 6
29
- MAX_CTX = 8192
30
- MAX_TOKENS = 2048
31
- THREADS = 2
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- # System prompt cố định — sẽ được cache sẵn vào KV cache lúc startup
34
- # CPU chỉ tính 1 lần duy nhất, mọi request sau dùng lại cache này
35
- DEFAULT_SYSTEM = "Bạn là trợ lý AI, trả lời bằng tiếng Việt ngắn gọn."
36
 
37
  llm: Optional[Llama] = None
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  @app.on_event("startup")
41
  async def startup_event():
42
  global llm
43
 
44
- if os.path.exists(MODEL_FILE) and os.path.getsize(MODEL_FILE) < 1_000_000:
 
 
 
 
45
  os.remove(MODEL_FILE)
46
 
 
47
  if not os.path.exists(MODEL_FILE):
48
  print(f"Downloading {MODEL_FILE}...")
49
- hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, local_dir=".")
50
- print("Download done!")
51
 
52
- print("Loading model — RAM-heavy, CPU-light mode...")
 
 
 
 
 
 
 
 
 
53
  llm = Llama(
54
- model_path = MODEL_FILE,
55
 
56
- # ── Context & batch ───────────────────────────────────────────────
57
- n_ctx = MAX_CTX,
58
- n_batch = 512 , # Nhỏ vừa tay CPU: 2 vCPU không bị nghẹt khi prefill
59
- n_ubatch = 512 , # Giữ nhỏ: ổn định hơn khi decode
60
 
61
- # ── CPU ───────────────────────────────────────────────────────────
62
- n_threads = THREADS,
63
- n_threads_batch = THREADS,
64
- n_gpu_layers = 0,
65
 
66
- # ── RAM: load toàn bộ, khóa lại, không swap ──────────────────────
67
- use_mmap = False,
68
- use_mlock = True,
 
69
 
70
- # ── KV Cache quantize — ăn RAM ít hơn, CPU vẫn nhẹ ───────────────
71
- cache_type_k = "q4_0",
72
- cache_type_v = "q4_0",
73
 
74
- # ── Prefix cache: CPU tính system prompt 1 lần rồi thôi ──────────
75
- last_n_tokens_size = 64, # Cửa sổ detect prefix trùng
 
76
 
77
- flash_attn = True,
78
- verbose = False,
79
- )
80
 
81
- # ── Warm up prefix cache với system prompt ────────────────────────────
82
- # Gọi 1 lần lúc startup để KV cache của system prompt được lưu sẵn
83
- # Mọi request sau có cùng system prompt → CPU bỏ qua phần này hoàn toàn
84
- print("Warming up prefix cache...")
85
- warmup_msgs = [
86
- {"role": "system", "content": DEFAULT_SYSTEM},
87
- {"role": "user", "content": "hi"},
88
- ]
89
- _ = llm.create_chat_completion(
90
- messages = warmup_msgs,
91
- max_tokens = 1,
92
- stream = False,
93
- )
94
- print("Prefix cache warmed up! Model ready.")
95
 
 
 
 
96
 
97
- class Message(BaseModel):
98
- role: str
99
- content: str
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- class ChatRequest(BaseModel):
103
- prompt: str
104
- history: List[Message] = []
105
- system_prompt: Optional[str] = None # Để None → tận dụng prefix cache
106
- max_tokens: int = MAX_TOKENS
107
- temperature: float = 0.7
108
- top_p: float = 0.9
109
 
 
110
 
111
- def build_messages(req: ChatRequest) -> list:
112
- # Dùng DEFAULT_SYSTEM nếu không truyền system_prompt
113
- # → prefix cache luôn hit, CPU không recompute
114
- system = req.system_prompt or DEFAULT_SYSTEM
115
- msgs = [{"role": "system", "content": system}]
116
 
117
- recent = req.history[-(MAX_HISTORY * 2):]
118
- for msg in recent:
119
- if msg.role in ("user", "assistant") and msg.content.strip():
120
- if msgs[-1]["role"] != msg.role:
121
- msgs.append({"role": msg.role, "content": msg.content.strip()})
122
 
123
- if msgs[-1]["role"] == "user":
124
- msgs.pop()
125
- msgs.append({"role": "user", "content": req.prompt.strip()})
126
- return msgs
127
 
128
 
129
  @app.post("/chat")
130
  async def chat(req: ChatRequest):
 
 
131
  if llm is None:
132
- raise HTTPException(503, "Model chưa sẵn sàng, thử lại sau!")
133
- if not req.prompt.strip():
134
- raise HTTPException(400, "Prompt trống")
135
- if len(req.prompt) > 8000:
136
- raise HTTPException(400, "Prompt quá dài")
137
 
138
  messages = build_messages(req)
139
 
140
- def generate():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  full = ""
142
- try:
143
- for chunk in llm.create_chat_completion(
144
- messages = messages,
145
- max_tokens = req.max_tokens,
146
- temperature = req.temperature,
147
- top_p = req.top_p,
148
- stream = True,
149
- ):
150
- delta = chunk["choices"][0]["delta"].get("content", "")
151
- if delta:
152
- full += delta
153
- yield f"data: {json.dumps({'delta': delta}, ensure_ascii=False)}\n\n"
154
- except Exception as e:
155
- yield f"data: {json.dumps({'delta': f'[Lỗi: {str(e)}]'})}\n\n"
156
- finally:
157
- print(f">> Done ({len(full)} chars): {full[:80]}")
158
- yield "data: [DONE]\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
  return StreamingResponse(
161
- generate(),
162
- media_type = "text/event-stream",
163
- headers = {"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
 
 
 
 
164
  )
165
 
166
 
 
 
 
 
 
167
  @app.get("/")
168
  async def root():
169
  return {
170
- "status" : "ok" if llm else "loading",
171
- "model" : MODEL_FILE,
172
- "message" : "Model ready (prefix cache active)!" if llm else "Model đang tải...",
 
 
173
  }
174
 
175
 
176
  @app.get("/health")
177
  async def health():
178
- return {"status": "healthy", "model_loaded": llm is not None}
 
 
 
179
 
 
 
 
180
 
181
  if __name__ == "__main__":
182
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
5
  from llama_cpp import Llama
6
  from huggingface_hub import hf_hub_download
7
  from typing import List, Optional
8
+ import asyncio
9
  import os
10
  import json
11
  import uvicorn
12
+ import gc
13
+
14
+ # =============================================================================
15
+ # FASTAPI
16
+ # =============================================================================
17
 
18
  app = FastAPI()
19
+
20
  app.add_middleware(
21
  CORSMiddleware,
22
  allow_origins=["*"],
 
24
  allow_headers=["*"],
25
  )
26
 
27
+ # =============================================================================
28
+ # MODEL CONFIG
29
+ # =============================================================================
30
+
31
  MODEL_REPO = "unsloth/Qwen3-4B-GGUF"
32
  MODEL_FILE = "Qwen3-4B-Q4_K_M.gguf"
33
 
 
 
 
 
 
34
  MAX_HISTORY = 6
35
+ MAX_CTX = 8192
36
+ MAX_TOKENS = 4096
37
+
38
+ # Giữ nguyên tham số theo yêu cầu
39
+ THREADS = 2
40
+ N_BATCH = 512
41
+ N_UBATCH = 512
42
+
43
+ DEFAULT_SYSTEM = (
44
+ "Bạn là trợ lý AI, trả lời bằng tiếng Việt ngắn gọn."
45
+ )
46
+
47
+ STOP_TOKENS = [
48
+ "<|im_end|>",
49
+ "<|endoftext|>",
50
+ ]
51
 
52
+ # =============================================================================
53
+ # GLOBALS
54
+ # =============================================================================
55
 
56
  llm: Optional[Llama] = None
57
 
58
+ # CPU inference -> serialize request để tránh lag/token collapse
59
+ inference_lock = asyncio.Semaphore(1)
60
+
61
+ # =============================================================================
62
+ # REQUEST MODELS
63
+ # =============================================================================
64
+
65
+
66
+ class Message(BaseModel):
67
+ role: str
68
+ content: str
69
+
70
+
71
+ class ChatRequest(BaseModel):
72
+ prompt: str
73
+ history: List[Message] = []
74
+ system_prompt: Optional[str] = None
75
+
76
+ max_tokens: int = MAX_TOKENS
77
+ temperature: float = 0.7
78
+ top_p: float = 0.9
79
+
80
+
81
+ # =============================================================================
82
+ # HELPERS
83
+ # =============================================================================
84
+
85
+
86
+ def cleanup_text(text: str) -> str:
87
+ return text.strip().replace("\x00", "")
88
+
89
+
90
+ def build_messages(req: ChatRequest) -> list:
91
+ system_prompt = cleanup_text(
92
+ req.system_prompt or DEFAULT_SYSTEM
93
+ )
94
+
95
+ messages = [
96
+ {
97
+ "role": "system",
98
+ "content": system_prompt,
99
+ }
100
+ ]
101
+
102
+ recent = req.history[-(MAX_HISTORY * 2):]
103
+
104
+ last_role = "system"
105
+
106
+ for msg in recent:
107
+ role = msg.role.strip().lower()
108
+ content = cleanup_text(msg.content)
109
+
110
+ if (
111
+ role not in ("user", "assistant")
112
+ or not content
113
+ ):
114
+ continue
115
+
116
+ # tránh duplicate role liên tục
117
+ if role == last_role:
118
+ continue
119
+
120
+ messages.append(
121
+ {
122
+ "role": role,
123
+ "content": content,
124
+ }
125
+ )
126
+
127
+ last_role = role
128
+
129
+ prompt = cleanup_text(req.prompt)
130
+
131
+ if not prompt:
132
+ raise HTTPException(400, "Prompt trống")
133
+
134
+ if len(prompt) > 8000:
135
+ raise HTTPException(400, "Prompt quá dài")
136
+
137
+ if messages[-1]["role"] == "user":
138
+ messages.pop()
139
+
140
+ messages.append(
141
+ {
142
+ "role": "user",
143
+ "content": prompt,
144
+ }
145
+ )
146
+
147
+ return messages
148
+
149
+
150
+ def sse(data):
151
+ return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
152
+
153
+
154
+ # =============================================================================
155
+ # STARTUP
156
+ # =============================================================================
157
+
158
 
159
  @app.on_event("startup")
160
  async def startup_event():
161
  global llm
162
 
163
+ # Xóa file corrupt
164
+ if (
165
+ os.path.exists(MODEL_FILE)
166
+ and os.path.getsize(MODEL_FILE) < 1_000_000
167
+ ):
168
  os.remove(MODEL_FILE)
169
 
170
+ # Download nếu chưa có
171
  if not os.path.exists(MODEL_FILE):
172
  print(f"Downloading {MODEL_FILE}...")
 
 
173
 
174
+ hf_hub_download(
175
+ repo_id=MODEL_REPO,
176
+ filename=MODEL_FILE,
177
+ local_dir=".",
178
+ )
179
+
180
+ print("Download complete!")
181
+
182
+ print("Loading model...")
183
+
184
  llm = Llama(
185
+ model_path=MODEL_FILE,
186
 
187
+ # Context
188
+ n_ctx=MAX_CTX,
 
 
189
 
190
+ # Giữ nguyên batch
191
+ n_batch=N_BATCH,
192
+ n_ubatch=N_UBATCH,
 
193
 
194
+ # CPU
195
+ n_threads=THREADS,
196
+ n_threads_batch=THREADS,
197
+ n_gpu_layers=0,
198
 
199
+ # RAM
200
+ use_mmap=False,
201
+ use_mlock=True,
202
 
203
+ # KV cache
204
+ cache_type_k="q4_0",
205
+ cache_type_v="q4_0",
206
 
207
+ # Prefix detection
208
+ last_n_tokens_size=64,
 
209
 
210
+ # Performance
211
+ flash_attn=True,
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
+ # Cleaner logs
214
+ verbose=False,
215
+ )
216
 
217
+ print("Warmup model...")
 
 
218
 
219
+ try:
220
+ _ = llm.create_chat_completion(
221
+ messages=[
222
+ {
223
+ "role": "system",
224
+ "content": DEFAULT_SYSTEM,
225
+ },
226
+ {
227
+ "role": "user",
228
+ "content": "hi",
229
+ },
230
+ ],
231
+ max_tokens=1,
232
+ stream=False,
233
+ )
234
 
235
+ except Exception as e:
236
+ print(f"Warmup failed: {e}")
 
 
 
 
 
237
 
238
+ gc.collect()
239
 
240
+ print("Model ready!")
 
 
 
 
241
 
 
 
 
 
 
242
 
243
+ # =============================================================================
244
+ # CHAT
245
+ # =============================================================================
 
246
 
247
 
248
  @app.post("/chat")
249
  async def chat(req: ChatRequest):
250
+ global llm
251
+
252
  if llm is None:
253
+ raise HTTPException(
254
+ 503,
255
+ "Model chưa sẵn sàng",
256
+ )
 
257
 
258
  messages = build_messages(req)
259
 
260
+ # Clamp để user không spam 999999
261
+ max_tokens = min(
262
+ max(1, req.max_tokens),
263
+ MAX_TOKENS,
264
+ )
265
+
266
+ temperature = min(
267
+ max(0.0, req.temperature),
268
+ 2.0,
269
+ )
270
+
271
+ top_p = min(
272
+ max(0.1, req.top_p),
273
+ 1.0,
274
+ )
275
+
276
+ async def event_stream():
277
  full = ""
278
+
279
+ async with inference_lock:
280
+ try:
281
+ stream = llm.create_chat_completion(
282
+ messages=messages,
283
+
284
+ max_tokens=max_tokens,
285
+
286
+ temperature=temperature,
287
+ top_p=top_p,
288
+
289
+ stop=STOP_TOKENS,
290
+
291
+ stream=True,
292
+ )
293
+
294
+ for chunk in stream:
295
+ try:
296
+ delta = (
297
+ chunk["choices"][0]
298
+ .get("delta", {})
299
+ .get("content", "")
300
+ )
301
+
302
+ if not delta:
303
+ continue
304
+
305
+ full += delta
306
+
307
+ yield sse(
308
+ {
309
+ "delta": delta,
310
+ }
311
+ )
312
+
313
+ except Exception:
314
+ continue
315
+
316
+ except Exception as e:
317
+ yield sse(
318
+ {
319
+ "error": str(e),
320
+ }
321
+ )
322
+
323
+ finally:
324
+ print(
325
+ f"[DONE] "
326
+ f"{len(full)} chars"
327
+ )
328
+
329
+ yield "data: [DONE]\n\n"
330
+
331
+ gc.collect()
332
 
333
  return StreamingResponse(
334
+ event_stream(),
335
+ media_type="text/event-stream",
336
+ headers={
337
+ "Cache-Control": "no-cache",
338
+ "Connection": "keep-alive",
339
+ "X-Accel-Buffering": "no",
340
+ },
341
  )
342
 
343
 
344
+ # =============================================================================
345
+ # HEALTH
346
+ # =============================================================================
347
+
348
+
349
  @app.get("/")
350
  async def root():
351
  return {
352
+ "status": "ok" if llm else "loading",
353
+ "model": MODEL_FILE,
354
+ "ctx": MAX_CTX,
355
+ "batch": N_BATCH,
356
+ "threads": THREADS,
357
  }
358
 
359
 
360
  @app.get("/health")
361
  async def health():
362
+ return {
363
+ "healthy": llm is not None,
364
+ }
365
+
366
 
367
+ # =============================================================================
368
+ # MAIN
369
+ # =============================================================================
370
 
371
  if __name__ == "__main__":
372
+ uvicorn.run(
373
+ app,
374
+ host="0.0.0.0",
375
+ port=7860,
376
+
377
+ # production-ish
378
+ access_log=False,
379
+ server_header=False,
380
+ )