| from __future__ import annotations |
| import os |
| import sys |
| import re |
| import json |
| import logging |
| import time as _time |
| from collections import defaultdict |
| from pathlib import Path |
| from contextlib import asynccontextmanager |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| from fastapi import FastAPI, Request, Depends, HTTPException, Security |
| from fastapi.responses import StreamingResponse, JSONResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.security import APIKeyHeader |
| from dotenv import find_dotenv, load_dotenv |
| from openai import OpenAI |
|
|
| |
| REPO_ROOT = Path(__file__).resolve().parents[2] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
| load_dotenv(find_dotenv(usecwd=True)) |
|
|
| from core.rag.embedding_model import EmbeddingConfig, QwenEmbeddings |
| from core.rag.vector_store import ChromaConfig, ChromaVectorDB |
| from core.rag.retrieval import Retriever, RetrievalMode, get_retrieval_config |
| from core.rag.generator import RAGContextBuilder, SYSTEM_PROMPT |
|
|
| |
| RETRIEVAL_MODE = RetrievalMode.HYBRID |
| RETRIEVAL_CFG = get_retrieval_config() |
| LLM_MODEL = os.getenv("LLM_MODEL", "qwen/qwen3-32b") |
| LLM_API_BASE = "https://api.groq.com/openai/v1" |
|
|
| |
| _state = {} |
|
|
|
|
| def _filter_think_tags(text: str) -> str: |
| return re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip() |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| """Initialize RAG resources on startup.""" |
| print("β³ Initializing RAG pipeline...") |
| emb = QwenEmbeddings(EmbeddingConfig()) |
| db = ChromaVectorDB(embedder=emb, config=ChromaConfig()) |
| retriever = Retriever(vector_db=db) |
|
|
| api_key = (os.getenv("GROQ_API_KEY") or "").strip() |
| if not api_key: |
| raise RuntimeError("Missing GROQ_API_KEY") |
|
|
| _state["rag"] = RAGContextBuilder(retriever=retriever) |
| _state["llm"] = OpenAI(api_key=api_key, base_url=LLM_API_BASE) |
| print("β
Ready!") |
| yield |
| _state.clear() |
|
|
|
|
| app = FastAPI(title="HUST RAG API", lifespan=lifespan) |
|
|
| |
| |
| ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "").split(",") |
| ALLOWED_ORIGINS = [o.strip() for o in ALLOWED_ORIGINS if o.strip()] or ["*"] |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=ALLOWED_ORIGINS, |
| allow_methods=["GET", "POST"], |
| allow_headers=["Content-Type", "X-API-Key"], |
| ) |
|
|
| |
| _api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) |
| FRONTEND_API_KEY = os.getenv("FRONTEND_API_KEY", "").strip() |
|
|
|
|
| async def verify_api_key(api_key: str = Security(_api_key_header)): |
| """Verify the API key from request header.""" |
| if not FRONTEND_API_KEY: |
| |
| return None |
| if api_key != FRONTEND_API_KEY: |
| raise HTTPException(status_code=403, detail="Invalid or missing API key") |
| return api_key |
|
|
|
|
| |
| RATE_LIMIT_WINDOW = 60 |
| RATE_LIMIT_MAX = int(os.getenv("RATE_LIMIT_MAX", "30")) |
| _rate_limit_store: dict[str, list[float]] = defaultdict(list) |
|
|
|
|
| async def rate_limit(request: Request): |
| """Simple per-IP rate limiter.""" |
| client_ip = request.client.host if request.client else "unknown" |
| now = _time.time() |
| |
| _rate_limit_store[client_ip] = [ |
| t for t in _rate_limit_store[client_ip] if now - t < RATE_LIMIT_WINDOW |
| ] |
| if len(_rate_limit_store[client_ip]) >= RATE_LIMIT_MAX: |
| raise HTTPException( |
| status_code=429, |
| detail=f"Rate limit exceeded. Max {RATE_LIMIT_MAX} requests per minute." |
| ) |
| _rate_limit_store[client_ip].append(now) |
| return client_ip |
|
|
| |
| STATIC_DIR = Path(__file__).parent / "static" |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.responses import FileResponse |
|
|
| app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") |
|
|
|
|
| @app.get("/") |
| async def index(): |
| """Serve the chat UI.""" |
| return FileResponse(str(STATIC_DIR / "index.html")) |
|
|
|
|
| @app.post("/api/chat") |
| async def chat( |
| request: Request, |
| _key: str = Depends(verify_api_key), |
| _ip: str = Depends(rate_limit), |
| ): |
| """Chat endpoint with Server-Sent Events streaming.""" |
| body = await request.json() |
| question = (body.get("message") or "").strip() |
|
|
| if not question: |
| return JSONResponse({"error": "Empty message"}, status_code=400) |
|
|
| |
| import time |
| start_time = time.time() |
| logger.info(f"Start retrieval for question: {question}") |
| prepared = _state["rag"].retrieve_and_prepare( |
| question, |
| k=RETRIEVAL_CFG.top_k, |
| initial_k=RETRIEVAL_CFG.initial_k, |
| mode=RETRIEVAL_MODE.value, |
| ) |
|
|
| if not prepared["results"]: |
| return JSONResponse({"answer": "Xin lα»i, tΓ΄i khΓ΄ng tΓ¬m thαΊ₯y thΓ΄ng tin phΓΉ hợp."}) |
|
|
| retrieval_time = time.time() - start_time |
| logger.info(f"Retrieval took {retrieval_time:.2f}s") |
|
|
| def stream(): |
| llm_start_time = time.time() |
| first_token = True |
| completion = _state["llm"].chat.completions.create( |
| model=LLM_MODEL, |
| messages=[{"role": "user", "content": prepared["prompt"]}], |
| temperature=0.0, |
| max_tokens=4096, |
| stream=True, |
| ) |
| for chunk in completion: |
| delta = getattr(chunk.choices[0].delta, "content", "") or "" |
| if delta: |
| if first_token: |
| ttft = time.time() - llm_start_time |
| logger.info(f"LLM TTFT (Time To First Token): {ttft:.2f}s") |
| first_token = False |
| |
| yield f"data: {json.dumps({'token': delta}, ensure_ascii=False)}\n\n" |
| total_time = time.time() - start_time |
| logger.info(f"Total request took: {total_time:.2f}s") |
| yield "data: [DONE]\n\n" |
|
|
| return StreamingResponse(stream(), media_type="text/event-stream") |
|
|
|
|
| @app.get("/api/config") |
| async def config(): |
| """Provide frontend config (API key for same-origin frontend).""" |
| return {"api_key": FRONTEND_API_KEY} |
|
|
|
|
| @app.get("/api/health") |
| async def health(): |
| return {"status": "ok"} |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run( |
| "core.api.server:app", |
| host=os.getenv("API_HOST", "0.0.0.0"), |
| port=int(os.getenv("API_PORT", "8000")), |
| reload=False, |
| ) |
|
|