File size: 6,871 Bytes
4f9286e ba44ab9 4f9286e 6e77184 4f9286e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 | from __future__ import annotations
import os
import sys
import re
import json
import logging
import time as _time
from collections import defaultdict
from pathlib import Path
from contextlib import asynccontextmanager
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from fastapi import FastAPI, Request, Depends, HTTPException, Security
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import APIKeyHeader
from dotenv import find_dotenv, load_dotenv
from openai import OpenAI
# Setup path & env
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
load_dotenv(find_dotenv(usecwd=True))
from core.rag.embedding_model import EmbeddingConfig, QwenEmbeddings
from core.rag.vector_store import ChromaConfig, ChromaVectorDB
from core.rag.retrieval import Retriever, RetrievalMode, get_retrieval_config
from core.rag.generator import RAGContextBuilder, SYSTEM_PROMPT
# Config
RETRIEVAL_MODE = RetrievalMode.HYBRID
RETRIEVAL_CFG = get_retrieval_config()
LLM_MODEL = os.getenv("LLM_MODEL", "qwen/qwen3-32b")
LLM_API_BASE = "https://api.groq.com/openai/v1"
# Shared state
_state = {}
def _filter_think_tags(text: str) -> str:
return re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize RAG resources on startup."""
print("β³ Initializing RAG pipeline...")
emb = QwenEmbeddings(EmbeddingConfig())
db = ChromaVectorDB(embedder=emb, config=ChromaConfig())
retriever = Retriever(vector_db=db)
api_key = (os.getenv("GROQ_API_KEY") or "").strip()
if not api_key:
raise RuntimeError("Missing GROQ_API_KEY")
_state["rag"] = RAGContextBuilder(retriever=retriever)
_state["llm"] = OpenAI(api_key=api_key, base_url=LLM_API_BASE)
print("β
Ready!")
yield
_state.clear()
app = FastAPI(title="HUST RAG API", lifespan=lifespan)
# ββ Security: CORS ββββββββββββββββββββββββββββββββββββββββββββββ
# Chα» cho phΓ©p frontend cΓΉng origin hoαΊ·c origins cα»₯ thα»
ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "").split(",")
ALLOWED_ORIGINS = [o.strip() for o in ALLOWED_ORIGINS if o.strip()] or ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_methods=["GET", "POST"],
allow_headers=["Content-Type", "X-API-Key"],
)
# ββ Security: API Key Authentication ββββββββββββββββββββββββββββ
_api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
FRONTEND_API_KEY = os.getenv("FRONTEND_API_KEY", "").strip()
async def verify_api_key(api_key: str = Security(_api_key_header)):
"""Verify the API key from request header."""
if not FRONTEND_API_KEY:
# NαΊΏu chΖ°a ΔαΊ·t FRONTEND_API_KEY thΓ¬ bα» qua (dev mode)
return None
if api_key != FRONTEND_API_KEY:
raise HTTPException(status_code=403, detail="Invalid or missing API key")
return api_key
# ββ Security: Rate Limiting (in-memory) βββββββββββββββββββββββββ
RATE_LIMIT_WINDOW = 60 # seconds
RATE_LIMIT_MAX = int(os.getenv("RATE_LIMIT_MAX", "30")) # max requests per window
_rate_limit_store: dict[str, list[float]] = defaultdict(list)
async def rate_limit(request: Request):
"""Simple per-IP rate limiter."""
client_ip = request.client.host if request.client else "unknown"
now = _time.time()
# Cleanup old entries
_rate_limit_store[client_ip] = [
t for t in _rate_limit_store[client_ip] if now - t < RATE_LIMIT_WINDOW
]
if len(_rate_limit_store[client_ip]) >= RATE_LIMIT_MAX:
raise HTTPException(
status_code=429,
detail=f"Rate limit exceeded. Max {RATE_LIMIT_MAX} requests per minute."
)
_rate_limit_store[client_ip].append(now)
return client_ip
# Serve static files (CSS, JS, images, etc.)
STATIC_DIR = Path(__file__).parent / "static"
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
@app.get("/")
async def index():
"""Serve the chat UI."""
return FileResponse(str(STATIC_DIR / "index.html"))
@app.post("/api/chat")
async def chat(
request: Request,
_key: str = Depends(verify_api_key),
_ip: str = Depends(rate_limit),
):
"""Chat endpoint with Server-Sent Events streaming."""
body = await request.json()
question = (body.get("message") or "").strip()
if not question:
return JSONResponse({"error": "Empty message"}, status_code=400)
# Retrieve context
import time
start_time = time.time()
logger.info(f"Start retrieval for question: {question}")
prepared = _state["rag"].retrieve_and_prepare(
question,
k=RETRIEVAL_CFG.top_k,
initial_k=RETRIEVAL_CFG.initial_k,
mode=RETRIEVAL_MODE.value,
)
if not prepared["results"]:
return JSONResponse({"answer": "Xin lα»i, tΓ΄i khΓ΄ng tΓ¬m thαΊ₯y thΓ΄ng tin phΓΉ hợp."})
retrieval_time = time.time() - start_time
logger.info(f"Retrieval took {retrieval_time:.2f}s")
def stream():
llm_start_time = time.time()
first_token = True
completion = _state["llm"].chat.completions.create(
model=LLM_MODEL,
messages=[{"role": "user", "content": prepared["prompt"]}],
temperature=0.0,
max_tokens=4096,
stream=True,
)
for chunk in completion:
delta = getattr(chunk.choices[0].delta, "content", "") or ""
if delta:
if first_token:
ttft = time.time() - llm_start_time
logger.info(f"LLM TTFT (Time To First Token): {ttft:.2f}s")
first_token = False
# SSE format
yield f"data: {json.dumps({'token': delta}, ensure_ascii=False)}\n\n"
total_time = time.time() - start_time
logger.info(f"Total request took: {total_time:.2f}s")
yield "data: [DONE]\n\n"
return StreamingResponse(stream(), media_type="text/event-stream")
@app.get("/api/config")
async def config():
"""Provide frontend config (API key for same-origin frontend)."""
return {"api_key": FRONTEND_API_KEY}
@app.get("/api/health")
async def health():
return {"status": "ok"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"core.api.server:app",
host=os.getenv("API_HOST", "0.0.0.0"),
port=int(os.getenv("API_PORT", "8000")),
reload=False,
)
|