from __future__ import annotations import os import sys import re import json import logging import time as _time from collections import defaultdict from pathlib import Path from contextlib import asynccontextmanager logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) from fastapi import FastAPI, Request, Depends, HTTPException, Security from fastapi.responses import StreamingResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.security import APIKeyHeader from dotenv import find_dotenv, load_dotenv from openai import OpenAI # Setup path & env REPO_ROOT = Path(__file__).resolve().parents[2] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) load_dotenv(find_dotenv(usecwd=True)) from core.rag.embedding_model import EmbeddingConfig, QwenEmbeddings from core.rag.vector_store import ChromaConfig, ChromaVectorDB from core.rag.retrieval import Retriever, RetrievalMode, get_retrieval_config from core.rag.generator import RAGContextBuilder, SYSTEM_PROMPT # Config RETRIEVAL_MODE = RetrievalMode.HYBRID RETRIEVAL_CFG = get_retrieval_config() LLM_MODEL = os.getenv("LLM_MODEL", "qwen/qwen3-32b") LLM_API_BASE = "https://api.groq.com/openai/v1" # Shared state _state = {} def _filter_think_tags(text: str) -> str: return re.sub(r'.*?', '', text, flags=re.DOTALL).strip() @asynccontextmanager async def lifespan(app: FastAPI): """Initialize RAG resources on startup.""" print("⏳ Initializing RAG pipeline...") emb = QwenEmbeddings(EmbeddingConfig()) db = ChromaVectorDB(embedder=emb, config=ChromaConfig()) retriever = Retriever(vector_db=db) api_key = (os.getenv("GROQ_API_KEY") or "").strip() if not api_key: raise RuntimeError("Missing GROQ_API_KEY") _state["rag"] = RAGContextBuilder(retriever=retriever) _state["llm"] = OpenAI(api_key=api_key, base_url=LLM_API_BASE) print("✅ Ready!") yield _state.clear() app = FastAPI(title="HUST RAG API", lifespan=lifespan) # ── Security: CORS ────────────────────────────────────────────── # Chỉ cho phép frontend cùng origin hoặc origins cụ thể ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "").split(",") ALLOWED_ORIGINS = [o.strip() for o in ALLOWED_ORIGINS if o.strip()] or ["*"] app.add_middleware( CORSMiddleware, allow_origins=ALLOWED_ORIGINS, allow_methods=["GET", "POST"], allow_headers=["Content-Type", "X-API-Key"], ) # ── Security: API Key Authentication ──────────────────────────── _api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) FRONTEND_API_KEY = os.getenv("FRONTEND_API_KEY", "").strip() async def verify_api_key(api_key: str = Security(_api_key_header)): """Verify the API key from request header.""" if not FRONTEND_API_KEY: # Nếu chưa đặt FRONTEND_API_KEY thì bỏ qua (dev mode) return None if api_key != FRONTEND_API_KEY: raise HTTPException(status_code=403, detail="Invalid or missing API key") return api_key # ── Security: Rate Limiting (in-memory) ───────────────────────── RATE_LIMIT_WINDOW = 60 # seconds RATE_LIMIT_MAX = int(os.getenv("RATE_LIMIT_MAX", "30")) # max requests per window _rate_limit_store: dict[str, list[float]] = defaultdict(list) async def rate_limit(request: Request): """Simple per-IP rate limiter.""" client_ip = request.client.host if request.client else "unknown" now = _time.time() # Cleanup old entries _rate_limit_store[client_ip] = [ t for t in _rate_limit_store[client_ip] if now - t < RATE_LIMIT_WINDOW ] if len(_rate_limit_store[client_ip]) >= RATE_LIMIT_MAX: raise HTTPException( status_code=429, detail=f"Rate limit exceeded. Max {RATE_LIMIT_MAX} requests per minute." ) _rate_limit_store[client_ip].append(now) return client_ip # Serve static files (CSS, JS, images, etc.) STATIC_DIR = Path(__file__).parent / "static" from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") @app.get("/") async def index(): """Serve the chat UI.""" return FileResponse(str(STATIC_DIR / "index.html")) @app.post("/api/chat") async def chat( request: Request, _key: str = Depends(verify_api_key), _ip: str = Depends(rate_limit), ): """Chat endpoint with Server-Sent Events streaming.""" body = await request.json() question = (body.get("message") or "").strip() if not question: return JSONResponse({"error": "Empty message"}, status_code=400) # Retrieve context import time start_time = time.time() logger.info(f"Start retrieval for question: {question}") prepared = _state["rag"].retrieve_and_prepare( question, k=RETRIEVAL_CFG.top_k, initial_k=RETRIEVAL_CFG.initial_k, mode=RETRIEVAL_MODE.value, ) if not prepared["results"]: return JSONResponse({"answer": "Xin lỗi, tôi không tìm thấy thông tin phù hợp."}) retrieval_time = time.time() - start_time logger.info(f"Retrieval took {retrieval_time:.2f}s") def stream(): llm_start_time = time.time() first_token = True completion = _state["llm"].chat.completions.create( model=LLM_MODEL, messages=[{"role": "user", "content": prepared["prompt"]}], temperature=0.0, max_tokens=4096, stream=True, ) for chunk in completion: delta = getattr(chunk.choices[0].delta, "content", "") or "" if delta: if first_token: ttft = time.time() - llm_start_time logger.info(f"LLM TTFT (Time To First Token): {ttft:.2f}s") first_token = False # SSE format yield f"data: {json.dumps({'token': delta}, ensure_ascii=False)}\n\n" total_time = time.time() - start_time logger.info(f"Total request took: {total_time:.2f}s") yield "data: [DONE]\n\n" return StreamingResponse(stream(), media_type="text/event-stream") @app.get("/api/config") async def config(): """Provide frontend config (API key for same-origin frontend).""" return {"api_key": FRONTEND_API_KEY} @app.get("/api/health") async def health(): return {"status": "ok"} if __name__ == "__main__": import uvicorn uvicorn.run( "core.api.server:app", host=os.getenv("API_HOST", "0.0.0.0"), port=int(os.getenv("API_PORT", "8000")), reload=False, )