File size: 6,871 Bytes
4f9286e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba44ab9
4f9286e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e77184
4f9286e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
from __future__ import annotations
import os
import sys
import re
import json
import logging
import time as _time
from collections import defaultdict
from pathlib import Path
from contextlib import asynccontextmanager

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

from fastapi import FastAPI, Request, Depends, HTTPException, Security
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import APIKeyHeader
from dotenv import find_dotenv, load_dotenv
from openai import OpenAI

# Setup path & env
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

load_dotenv(find_dotenv(usecwd=True))

from core.rag.embedding_model import EmbeddingConfig, QwenEmbeddings
from core.rag.vector_store import ChromaConfig, ChromaVectorDB
from core.rag.retrieval import Retriever, RetrievalMode, get_retrieval_config
from core.rag.generator import RAGContextBuilder, SYSTEM_PROMPT

# Config
RETRIEVAL_MODE = RetrievalMode.HYBRID
RETRIEVAL_CFG = get_retrieval_config()
LLM_MODEL = os.getenv("LLM_MODEL", "qwen/qwen3-32b")
LLM_API_BASE = "https://api.groq.com/openai/v1"

# Shared state
_state = {}


def _filter_think_tags(text: str) -> str:
    return re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()


@asynccontextmanager
async def lifespan(app: FastAPI):
    """Initialize RAG resources on startup."""
    print("⏳ Initializing RAG pipeline...")
    emb = QwenEmbeddings(EmbeddingConfig())
    db = ChromaVectorDB(embedder=emb, config=ChromaConfig())
    retriever = Retriever(vector_db=db)

    api_key = (os.getenv("GROQ_API_KEY") or "").strip()
    if not api_key:
        raise RuntimeError("Missing GROQ_API_KEY")

    _state["rag"] = RAGContextBuilder(retriever=retriever)
    _state["llm"] = OpenAI(api_key=api_key, base_url=LLM_API_BASE)
    print("βœ… Ready!")
    yield
    _state.clear()


app = FastAPI(title="HUST RAG API", lifespan=lifespan)

# ── Security: CORS ──────────────────────────────────────────────
# Chỉ cho phΓ©p frontend cΓΉng origin hoαΊ·c origins cα»₯ thể
ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "").split(",")
ALLOWED_ORIGINS = [o.strip() for o in ALLOWED_ORIGINS if o.strip()] or ["*"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=ALLOWED_ORIGINS,
    allow_methods=["GET", "POST"],
    allow_headers=["Content-Type", "X-API-Key"],
)

# ── Security: API Key Authentication ────────────────────────────
_api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
FRONTEND_API_KEY = os.getenv("FRONTEND_API_KEY", "").strip()


async def verify_api_key(api_key: str = Security(_api_key_header)):
    """Verify the API key from request header."""
    if not FRONTEND_API_KEY:
        # NαΊΏu chΖ°a Δ‘αΊ·t FRONTEND_API_KEY thΓ¬ bỏ qua (dev mode)
        return None
    if api_key != FRONTEND_API_KEY:
        raise HTTPException(status_code=403, detail="Invalid or missing API key")
    return api_key


# ── Security: Rate Limiting (in-memory) ─────────────────────────
RATE_LIMIT_WINDOW = 60   # seconds
RATE_LIMIT_MAX = int(os.getenv("RATE_LIMIT_MAX", "30"))  # max requests per window
_rate_limit_store: dict[str, list[float]] = defaultdict(list)


async def rate_limit(request: Request):
    """Simple per-IP rate limiter."""
    client_ip = request.client.host if request.client else "unknown"
    now = _time.time()
    # Cleanup old entries
    _rate_limit_store[client_ip] = [
        t for t in _rate_limit_store[client_ip] if now - t < RATE_LIMIT_WINDOW
    ]
    if len(_rate_limit_store[client_ip]) >= RATE_LIMIT_MAX:
        raise HTTPException(
            status_code=429,
            detail=f"Rate limit exceeded. Max {RATE_LIMIT_MAX} requests per minute."
        )
    _rate_limit_store[client_ip].append(now)
    return client_ip

# Serve static files (CSS, JS, images, etc.)
STATIC_DIR = Path(__file__).parent / "static"
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse

app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")


@app.get("/")
async def index():
    """Serve the chat UI."""
    return FileResponse(str(STATIC_DIR / "index.html"))


@app.post("/api/chat")
async def chat(
    request: Request,
    _key: str = Depends(verify_api_key),
    _ip: str = Depends(rate_limit),
):
    """Chat endpoint with Server-Sent Events streaming."""
    body = await request.json()
    question = (body.get("message") or "").strip()

    if not question:
        return JSONResponse({"error": "Empty message"}, status_code=400)

    # Retrieve context
    import time
    start_time = time.time()
    logger.info(f"Start retrieval for question: {question}")
    prepared = _state["rag"].retrieve_and_prepare(
        question,
        k=RETRIEVAL_CFG.top_k,
        initial_k=RETRIEVAL_CFG.initial_k,
        mode=RETRIEVAL_MODE.value,
    )

    if not prepared["results"]:
        return JSONResponse({"answer": "Xin lα»—i, tΓ΄i khΓ΄ng tΓ¬m thαΊ₯y thΓ΄ng tin phΓΉ hợp."})

    retrieval_time = time.time() - start_time
    logger.info(f"Retrieval took {retrieval_time:.2f}s")

    def stream():
        llm_start_time = time.time()
        first_token = True
        completion = _state["llm"].chat.completions.create(
            model=LLM_MODEL,
            messages=[{"role": "user", "content": prepared["prompt"]}],
            temperature=0.0,
            max_tokens=4096,
            stream=True,
        )
        for chunk in completion:
            delta = getattr(chunk.choices[0].delta, "content", "") or ""
            if delta:
                if first_token:
                    ttft = time.time() - llm_start_time
                    logger.info(f"LLM TTFT (Time To First Token): {ttft:.2f}s")
                    first_token = False
                # SSE format
                yield f"data: {json.dumps({'token': delta}, ensure_ascii=False)}\n\n"
        total_time = time.time() - start_time
        logger.info(f"Total request took: {total_time:.2f}s")
        yield "data: [DONE]\n\n"

    return StreamingResponse(stream(), media_type="text/event-stream")


@app.get("/api/config")
async def config():
    """Provide frontend config (API key for same-origin frontend)."""
    return {"api_key": FRONTEND_API_KEY}


@app.get("/api/health")
async def health():
    return {"status": "ok"}


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(
        "core.api.server:app",
        host=os.getenv("API_HOST", "0.0.0.0"),
        port=int(os.getenv("API_PORT", "8000")),
        reload=False,
    )