Gemma34B / app /main.py
VietCat's picture
add log for prompt
02e09e0
from fastapi import FastAPI, Request
from pydantic import BaseModel
from typing import Optional
import logging
import time
import asyncio
import os
import re
import threading
from app.model_loader import load_model
app = FastAPI()
# Thread-local để mỗi thread giữ 1 instance riêng
thread_local = threading.local()
n_ctx_cached = None # Cache context length
class PromptRequest(BaseModel):
prompt: str
temperature: Optional[float] = 0.0
top_p: Optional[float] = 0.5
top_k: Optional[int] = 1
max_tokens: Optional[int] = None
# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
def format_prompt_as_chat(user_prompt: str, include_system: bool = True) -> str:
prompt_lines = []
prompt_lines.append("<bos>")
if include_system:
prompt_lines.append("<start_of_turn>system\nBạn là trợ lý AI hữu ích, luôn trả lời ngắn gọn và chính xác.")
prompt_lines.append("<end_of_turn>")
prompt_lines.append("<start_of_turn>user\n" + user_prompt.strip())
prompt_lines.append("<end_of_turn>")
prompt_lines.append("<start_of_turn>model\n")
return "\n".join(prompt_lines)
def format_prompt_as_user_prompt(user_prompt: str) -> str:
return "<|user|>\n" + user_prompt.strip() + "</s>\n"
def format_prompt_as_pure_prompt(user_prompt: str) -> str:
return user_prompt.strip() + "\n"
def get_model():
if not hasattr(thread_local, "llm"):
models_dir = "models"
model_path = None
for fname in os.listdir(models_dir):
if fname.endswith(".gguf"):
model_path = os.path.join(models_dir, fname)
break
if not model_path:
raise RuntimeError("❌ Không tìm thấy mô hình trong thư mục 'models'")
logging.info(f"🚀 Thread khởi tạo mô hình: {model_path}")
model_obj = load_model(model_path)
thread_local.llm = model_obj["llm"]
thread_local.n_ctx = model_obj["n_ctx"]
return thread_local.llm
def get_n_ctx():
return getattr(thread_local, "n_ctx", 2048)
@app.on_event("startup")
async def startup_event():
global n_ctx_cached
model_path = None
models_dir = "models"
for fname in os.listdir(models_dir):
if fname.endswith(".gguf"):
model_path = os.path.join(models_dir, fname)
break
if not model_path:
raise RuntimeError("❌ Không tìm thấy mô hình trong thư mục 'models'")
model_obj = await asyncio.to_thread(load_model, model_path)
n_ctx_cached = model_obj["n_ctx"]
logging.info(f"✅ Mô hình sẵn sàng. Context size: {n_ctx_cached}")
semaphore = asyncio.Semaphore(1)
async def generate_response(formatted_prompt: str, prompt: PromptRequest):
prompt_length = len(formatted_prompt.split())
max_tokens = prompt.max_tokens
if max_tokens is None or max_tokens <= 0:
max_tokens = prompt_length * 2
logging.info(f"⚙️ Sử dụng max_tokens tự động = {max_tokens}")
max_tokens = min(max_tokens, get_n_ctx() - prompt_length)
logging.info(f"🧠 prompt_length = {prompt_length}, max_tokens = {max_tokens}, n_ctx = {get_n_ctx()}\n\t{formatted_prompt[:20]} ... {formatted_prompt[-20:]}")
async with semaphore:
return await asyncio.to_thread(
get_model(),
formatted_prompt,
max_tokens=max_tokens,
temperature=prompt.temperature,
top_k=prompt.top_k,
top_p=prompt.top_p,
stop=["</s>"]
)
@app.post("/chat")
async def chat(request: Request, prompt: PromptRequest):
start_time = time.time()
logging.info(f"📩 Nhận request từ {request.client.host} lúc {time.strftime('%Y-%m-%d %H:%M:%S')}")
formatted_prompt = format_prompt_as_chat(prompt.prompt)
output = await generate_response(formatted_prompt, prompt)
logging.info(f"✅ Xử lý xong sau {time.time() - start_time:.2f} giây.")
return {"response": output["choices"][0]["text"].strip()}
@app.post("/userchat")
async def userchat(request: Request, prompt: PromptRequest):
start_time = time.time()
logging.info(f"📩 Nhận request từ {request.client.host} lúc {time.strftime('%Y-%m-%d %H:%M:%S')}")
formatted_prompt = format_prompt_as_user_prompt(prompt.prompt)
output = await generate_response(formatted_prompt, prompt)
logging.info(f"✅ Xử lý xong sau {time.time() - start_time:.2f} giây.")
return {"response": output["choices"][0]["text"].strip()}
@app.post("/purechat")
async def purechat(request: Request, prompt: PromptRequest):
start_time = time.time()
logging.info(f"📩 Nhận request từ {request.client.host} lúc {time.strftime('%Y-%m-%d %H:%M:%S')}")
formatted_prompt = format_prompt_as_pure_prompt(prompt.prompt)
output = await generate_response(formatted_prompt, prompt)
logging.info(f"✅ Xử lý xong sau {time.time() - start_time:.2f} giây.")
return {"response": output["choices"][0]["text"].strip()}
@app.post("/analyze")
async def analyze(request: Request, prompt: PromptRequest):
import json
start_time = time.time()
logging.info(f"📩 Nhận analyze request từ {request.client.host} lúc {time.strftime('%Y-%m-%d %H:%M:%S')}")
analysis_prompt = f"""
Phân tích ngữ nghĩa câu sau: \"{prompt.prompt.strip()}\"
Trả lời dưới dạng JSON với 3 trường sau:
{{
"muc_dich": "...",
"phuong_tien": "...",
"hanh_vi_vi_pham": "..."
}}
Ví dụ:
"Tôi chạy xe hơi không bật đèn vào ban đêm thì có bị sao không?"
→ {{
"muc_dich": "Hỏi về hậu quả/hình phạt khi không bật đèn xe hơi ban đêm",
"phuong_tien": "Xe hơi",
"hanh_vi_vi_pham": "Không bật đèn khi lái xe vào ban đêm"
}}
Câu bạn cần phân tích:
\"{prompt.prompt.strip()}\"
""".strip()
formatted_prompt = format_prompt_as_pure_prompt(analysis_prompt)
output = await generate_response(formatted_prompt, prompt)
raw_text = output["choices"][0]["text"].strip()
logging.info(f"📤 Phản hồi gốc từ mô hình:\n{raw_text}")
try:
match = re.search(r"\{.*\}", raw_text, re.DOTALL)
if not match:
raise ValueError("Không tìm thấy đoạn JSON trong output")
parsed = json.loads(match.group(0))
result = {
"p": parsed.get("muc_dich", "").strip(),
"v": parsed.get("phuong_tien", "").strip(),
"a": parsed.get("hanh_vi_vi_pham", "").strip()
}
except Exception as e:
logging.error(f"❌ Lỗi khi phân tích JSON: {e}")
result = {"p": "", "v": "", "a": ""}
logging.info(f"✅ Output đã format: {result}")
logging.info(f"✅ Xử lý analyze xong sau {time.time() - start_time:.2f} giây.")
return result
@app.get("/")
async def get():
start_time = time.time()
logging.info(f"📩 Nhận get request lúc {time.strftime('%Y-%m-%d %H:%M:%S')}")
sample_prompt = "Xin chào, ngày hôm nay của bạn thế nào?"
logging.info(f"✅ Xử lý xong sau {time.time() - start_time:.2f} giây.")
return {"response": sample_prompt}