| from fastapi import FastAPI, Request |
| from pydantic import BaseModel |
| from typing import Optional |
| import logging |
| import time |
| import asyncio |
| import os |
| import re |
| import threading |
|
|
| from app.model_loader import load_model |
|
|
| app = FastAPI() |
|
|
| |
| thread_local = threading.local() |
| n_ctx_cached = None |
|
|
| class PromptRequest(BaseModel): |
| prompt: str |
| temperature: Optional[float] = 0.0 |
| top_p: Optional[float] = 0.5 |
| top_k: Optional[int] = 1 |
| max_tokens: Optional[int] = None |
|
|
| |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") |
|
|
| def format_prompt_as_chat(user_prompt: str, include_system: bool = True) -> str: |
| prompt_lines = [] |
| prompt_lines.append("<bos>") |
| if include_system: |
| prompt_lines.append("<start_of_turn>system\nBạn là trợ lý AI hữu ích, luôn trả lời ngắn gọn và chính xác.") |
| prompt_lines.append("<end_of_turn>") |
| prompt_lines.append("<start_of_turn>user\n" + user_prompt.strip()) |
| prompt_lines.append("<end_of_turn>") |
| prompt_lines.append("<start_of_turn>model\n") |
| return "\n".join(prompt_lines) |
|
|
| def format_prompt_as_user_prompt(user_prompt: str) -> str: |
| return "<|user|>\n" + user_prompt.strip() + "</s>\n" |
|
|
| def format_prompt_as_pure_prompt(user_prompt: str) -> str: |
| return user_prompt.strip() + "\n" |
|
|
| def get_model(): |
| if not hasattr(thread_local, "llm"): |
| models_dir = "models" |
| model_path = None |
| for fname in os.listdir(models_dir): |
| if fname.endswith(".gguf"): |
| model_path = os.path.join(models_dir, fname) |
| break |
| if not model_path: |
| raise RuntimeError("❌ Không tìm thấy mô hình trong thư mục 'models'") |
| logging.info(f"🚀 Thread khởi tạo mô hình: {model_path}") |
| model_obj = load_model(model_path) |
| thread_local.llm = model_obj["llm"] |
| thread_local.n_ctx = model_obj["n_ctx"] |
| return thread_local.llm |
|
|
| def get_n_ctx(): |
| return getattr(thread_local, "n_ctx", 2048) |
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| global n_ctx_cached |
| model_path = None |
| models_dir = "models" |
| for fname in os.listdir(models_dir): |
| if fname.endswith(".gguf"): |
| model_path = os.path.join(models_dir, fname) |
| break |
| if not model_path: |
| raise RuntimeError("❌ Không tìm thấy mô hình trong thư mục 'models'") |
| model_obj = await asyncio.to_thread(load_model, model_path) |
| n_ctx_cached = model_obj["n_ctx"] |
| logging.info(f"✅ Mô hình sẵn sàng. Context size: {n_ctx_cached}") |
|
|
| semaphore = asyncio.Semaphore(1) |
|
|
| async def generate_response(formatted_prompt: str, prompt: PromptRequest): |
| prompt_length = len(formatted_prompt.split()) |
| max_tokens = prompt.max_tokens |
|
|
| if max_tokens is None or max_tokens <= 0: |
| max_tokens = prompt_length * 2 |
| logging.info(f"⚙️ Sử dụng max_tokens tự động = {max_tokens}") |
|
|
| max_tokens = min(max_tokens, get_n_ctx() - prompt_length) |
| logging.info(f"🧠 prompt_length = {prompt_length}, max_tokens = {max_tokens}, n_ctx = {get_n_ctx()}\n\t{formatted_prompt[:20]} ... {formatted_prompt[-20:]}") |
|
|
| async with semaphore: |
| return await asyncio.to_thread( |
| get_model(), |
| formatted_prompt, |
| max_tokens=max_tokens, |
| temperature=prompt.temperature, |
| top_k=prompt.top_k, |
| top_p=prompt.top_p, |
| stop=["</s>"] |
| ) |
|
|
| @app.post("/chat") |
| async def chat(request: Request, prompt: PromptRequest): |
| start_time = time.time() |
| logging.info(f"📩 Nhận request từ {request.client.host} lúc {time.strftime('%Y-%m-%d %H:%M:%S')}") |
| formatted_prompt = format_prompt_as_chat(prompt.prompt) |
| output = await generate_response(formatted_prompt, prompt) |
| logging.info(f"✅ Xử lý xong sau {time.time() - start_time:.2f} giây.") |
| return {"response": output["choices"][0]["text"].strip()} |
|
|
| @app.post("/userchat") |
| async def userchat(request: Request, prompt: PromptRequest): |
| start_time = time.time() |
| logging.info(f"📩 Nhận request từ {request.client.host} lúc {time.strftime('%Y-%m-%d %H:%M:%S')}") |
| formatted_prompt = format_prompt_as_user_prompt(prompt.prompt) |
| output = await generate_response(formatted_prompt, prompt) |
| logging.info(f"✅ Xử lý xong sau {time.time() - start_time:.2f} giây.") |
| return {"response": output["choices"][0]["text"].strip()} |
|
|
| @app.post("/purechat") |
| async def purechat(request: Request, prompt: PromptRequest): |
| start_time = time.time() |
| logging.info(f"📩 Nhận request từ {request.client.host} lúc {time.strftime('%Y-%m-%d %H:%M:%S')}") |
| formatted_prompt = format_prompt_as_pure_prompt(prompt.prompt) |
| output = await generate_response(formatted_prompt, prompt) |
| logging.info(f"✅ Xử lý xong sau {time.time() - start_time:.2f} giây.") |
| return {"response": output["choices"][0]["text"].strip()} |
|
|
| @app.post("/analyze") |
| async def analyze(request: Request, prompt: PromptRequest): |
| import json |
| start_time = time.time() |
| logging.info(f"📩 Nhận analyze request từ {request.client.host} lúc {time.strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
| analysis_prompt = f""" |
| Phân tích ngữ nghĩa câu sau: \"{prompt.prompt.strip()}\" |
| |
| Trả lời dưới dạng JSON với 3 trường sau: |
| {{ |
| "muc_dich": "...", |
| "phuong_tien": "...", |
| "hanh_vi_vi_pham": "..." |
| }} |
| |
| Ví dụ: |
| "Tôi chạy xe hơi không bật đèn vào ban đêm thì có bị sao không?" |
| → {{ |
| "muc_dich": "Hỏi về hậu quả/hình phạt khi không bật đèn xe hơi ban đêm", |
| "phuong_tien": "Xe hơi", |
| "hanh_vi_vi_pham": "Không bật đèn khi lái xe vào ban đêm" |
| }} |
| |
| Câu bạn cần phân tích: |
| \"{prompt.prompt.strip()}\" |
| """.strip() |
|
|
| formatted_prompt = format_prompt_as_pure_prompt(analysis_prompt) |
| output = await generate_response(formatted_prompt, prompt) |
| raw_text = output["choices"][0]["text"].strip() |
| logging.info(f"📤 Phản hồi gốc từ mô hình:\n{raw_text}") |
|
|
| try: |
| match = re.search(r"\{.*\}", raw_text, re.DOTALL) |
| if not match: |
| raise ValueError("Không tìm thấy đoạn JSON trong output") |
|
|
| parsed = json.loads(match.group(0)) |
| result = { |
| "p": parsed.get("muc_dich", "").strip(), |
| "v": parsed.get("phuong_tien", "").strip(), |
| "a": parsed.get("hanh_vi_vi_pham", "").strip() |
| } |
| except Exception as e: |
| logging.error(f"❌ Lỗi khi phân tích JSON: {e}") |
| result = {"p": "", "v": "", "a": ""} |
|
|
| logging.info(f"✅ Output đã format: {result}") |
| logging.info(f"✅ Xử lý analyze xong sau {time.time() - start_time:.2f} giây.") |
| return result |
|
|
| @app.get("/") |
| async def get(): |
| start_time = time.time() |
| logging.info(f"📩 Nhận get request lúc {time.strftime('%Y-%m-%d %H:%M:%S')}") |
| sample_prompt = "Xin chào, ngày hôm nay của bạn thế nào?" |
| logging.info(f"✅ Xử lý xong sau {time.time() - start_time:.2f} giây.") |
| return {"response": sample_prompt} |
|
|