DoAn / core /api /server.py
Nguyen Ba Hung
change retrieval
ba44ab9
from __future__ import annotations
import os
import sys
import re
import json
import logging
import time as _time
from collections import defaultdict
from pathlib import Path
from contextlib import asynccontextmanager
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from fastapi import FastAPI, Request, Depends, HTTPException, Security
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import APIKeyHeader
from dotenv import find_dotenv, load_dotenv
from openai import OpenAI
# Setup path & env
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
load_dotenv(find_dotenv(usecwd=True))
from core.rag.embedding_model import EmbeddingConfig, QwenEmbeddings
from core.rag.vector_store import ChromaConfig, ChromaVectorDB
from core.rag.retrieval import Retriever, RetrievalMode, get_retrieval_config
from core.rag.generator import RAGContextBuilder, SYSTEM_PROMPT
# Config
RETRIEVAL_MODE = RetrievalMode.HYBRID
RETRIEVAL_CFG = get_retrieval_config()
LLM_MODEL = os.getenv("LLM_MODEL", "qwen/qwen3-32b")
LLM_API_BASE = "https://api.groq.com/openai/v1"
# Shared state
_state = {}
def _filter_think_tags(text: str) -> str:
return re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize RAG resources on startup."""
print("⏳ Initializing RAG pipeline...")
emb = QwenEmbeddings(EmbeddingConfig())
db = ChromaVectorDB(embedder=emb, config=ChromaConfig())
retriever = Retriever(vector_db=db)
api_key = (os.getenv("GROQ_API_KEY") or "").strip()
if not api_key:
raise RuntimeError("Missing GROQ_API_KEY")
_state["rag"] = RAGContextBuilder(retriever=retriever)
_state["llm"] = OpenAI(api_key=api_key, base_url=LLM_API_BASE)
print("βœ… Ready!")
yield
_state.clear()
app = FastAPI(title="HUST RAG API", lifespan=lifespan)
# ── Security: CORS ──────────────────────────────────────────────
# Chỉ cho phΓ©p frontend cΓΉng origin hoαΊ·c origins cα»₯ thể
ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "").split(",")
ALLOWED_ORIGINS = [o.strip() for o in ALLOWED_ORIGINS if o.strip()] or ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_methods=["GET", "POST"],
allow_headers=["Content-Type", "X-API-Key"],
)
# ── Security: API Key Authentication ────────────────────────────
_api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
FRONTEND_API_KEY = os.getenv("FRONTEND_API_KEY", "").strip()
async def verify_api_key(api_key: str = Security(_api_key_header)):
"""Verify the API key from request header."""
if not FRONTEND_API_KEY:
# NαΊΏu chΖ°a Δ‘αΊ·t FRONTEND_API_KEY thΓ¬ bỏ qua (dev mode)
return None
if api_key != FRONTEND_API_KEY:
raise HTTPException(status_code=403, detail="Invalid or missing API key")
return api_key
# ── Security: Rate Limiting (in-memory) ─────────────────────────
RATE_LIMIT_WINDOW = 60 # seconds
RATE_LIMIT_MAX = int(os.getenv("RATE_LIMIT_MAX", "30")) # max requests per window
_rate_limit_store: dict[str, list[float]] = defaultdict(list)
async def rate_limit(request: Request):
"""Simple per-IP rate limiter."""
client_ip = request.client.host if request.client else "unknown"
now = _time.time()
# Cleanup old entries
_rate_limit_store[client_ip] = [
t for t in _rate_limit_store[client_ip] if now - t < RATE_LIMIT_WINDOW
]
if len(_rate_limit_store[client_ip]) >= RATE_LIMIT_MAX:
raise HTTPException(
status_code=429,
detail=f"Rate limit exceeded. Max {RATE_LIMIT_MAX} requests per minute."
)
_rate_limit_store[client_ip].append(now)
return client_ip
# Serve static files (CSS, JS, images, etc.)
STATIC_DIR = Path(__file__).parent / "static"
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
@app.get("/")
async def index():
"""Serve the chat UI."""
return FileResponse(str(STATIC_DIR / "index.html"))
@app.post("/api/chat")
async def chat(
request: Request,
_key: str = Depends(verify_api_key),
_ip: str = Depends(rate_limit),
):
"""Chat endpoint with Server-Sent Events streaming."""
body = await request.json()
question = (body.get("message") or "").strip()
if not question:
return JSONResponse({"error": "Empty message"}, status_code=400)
# Retrieve context
import time
start_time = time.time()
logger.info(f"Start retrieval for question: {question}")
prepared = _state["rag"].retrieve_and_prepare(
question,
k=RETRIEVAL_CFG.top_k,
initial_k=RETRIEVAL_CFG.initial_k,
mode=RETRIEVAL_MODE.value,
)
if not prepared["results"]:
return JSONResponse({"answer": "Xin lα»—i, tΓ΄i khΓ΄ng tΓ¬m thαΊ₯y thΓ΄ng tin phΓΉ hợp."})
retrieval_time = time.time() - start_time
logger.info(f"Retrieval took {retrieval_time:.2f}s")
def stream():
llm_start_time = time.time()
first_token = True
completion = _state["llm"].chat.completions.create(
model=LLM_MODEL,
messages=[{"role": "user", "content": prepared["prompt"]}],
temperature=0.0,
max_tokens=4096,
stream=True,
)
for chunk in completion:
delta = getattr(chunk.choices[0].delta, "content", "") or ""
if delta:
if first_token:
ttft = time.time() - llm_start_time
logger.info(f"LLM TTFT (Time To First Token): {ttft:.2f}s")
first_token = False
# SSE format
yield f"data: {json.dumps({'token': delta}, ensure_ascii=False)}\n\n"
total_time = time.time() - start_time
logger.info(f"Total request took: {total_time:.2f}s")
yield "data: [DONE]\n\n"
return StreamingResponse(stream(), media_type="text/event-stream")
@app.get("/api/config")
async def config():
"""Provide frontend config (API key for same-origin frontend)."""
return {"api_key": FRONTEND_API_KEY}
@app.get("/api/health")
async def health():
return {"status": "ok"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"core.api.server:app",
host=os.getenv("API_HOST", "0.0.0.0"),
port=int(os.getenv("API_PORT", "8000")),
reload=False,
)