|
|
import os |
|
|
import time |
|
|
import asyncio |
|
|
import secrets |
|
|
import hmac |
|
|
import hashlib |
|
|
import logging |
|
|
import re |
|
|
from datetime import datetime, timezone |
|
|
from contextlib import asynccontextmanager |
|
|
from typing import Optional |
|
|
|
|
|
from fastapi import FastAPI, Request, HTTPException, Depends |
|
|
from fastapi.responses import JSONResponse, FileResponse |
|
|
from pathlib import Path |
|
|
from pydantic import BaseModel, Field, field_validator |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
from langchain_qdrant import QdrantVectorStore |
|
|
from qdrant_client import QdrantClient |
|
|
from g4f.client import Client |
|
|
from threading import Lock |
|
|
|
|
|
import database as db |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s', datefmt='%H:%M:%S') |
|
|
logger = logging.getLogger("rag-api") |
|
|
|
|
|
HMAC_TIME_WINDOW = 300 |
|
|
FAILED_AUTH_LIMIT = 5 |
|
|
SUBJECT_PATTERN = re.compile(r'^[a-zA-Z0-9_-]+$') |
|
|
STARTUP_TIME = time.time() |
|
|
|
|
|
API_KEYS = {} |
|
|
admin_secret = os.getenv("ADMIN_API_KEY") |
|
|
if admin_secret: |
|
|
API_KEYS["admin"] = {"secret": admin_secret, "active": True, "role": "admin"} |
|
|
bot_secret = os.getenv("BOT_API_KEY") |
|
|
if bot_secret: |
|
|
API_KEYS["bot"] = {"secret": bot_secret, "active": True, "role": "user"} |
|
|
if not API_KEYS: |
|
|
logger.warning("No API keys configured!") |
|
|
|
|
|
PROMPT_TEMPLATE = """ |
|
|
You are AskBookie, an assistant built on a RAG system using university slide data. |
|
|
|
|
|
Your rules: |
|
|
1. If the context has the answer, use it. |
|
|
2. If the context is related but incomplete, answer from your knowledge but mention the context. |
|
|
3. If unrelated, say it's not in context. |
|
|
4. Format your answer nicely in Markdown. Use LaTeX for math ($...$ for inline, $$...$$ for block). |
|
|
|
|
|
Context: |
|
|
{context} |
|
|
|
|
|
Question: {question} |
|
|
""" |
|
|
|
|
|
g4f_client = Client() |
|
|
|
|
|
PROMO_PATTERNS = [ |
|
|
"want best roleplay experience", |
|
|
"llmplayground.net", |
|
|
"want the best roleplay", |
|
|
"best ai roleplay", |
|
|
] |
|
|
|
|
|
MODEL_OPTIONS = { |
|
|
1: {"name": "Gemini-3-flash", "description": "Gemini Primary API Key"}, |
|
|
2: {"name": "Gemini-3-flash(Back-up)", "description": "Gemini Secondary API Key"}, |
|
|
3: {"name": "Gemini-3-Pro", "description": "Gemini Primary API Key"}, |
|
|
4: {"name": "GPT-4o-mini", "description": "DuckDuckGo (Free)"}, |
|
|
5: {"name": "Claude-3-Haiku", "description": "DuckDuckGo (Free)"}, |
|
|
} |
|
|
|
|
|
|
|
|
class QuotaExhaustedError(Exception): |
|
|
pass |
|
|
|
|
|
|
|
|
def clean_response(text: str) -> str: |
|
|
lines = text.split('\n') |
|
|
clean_lines = [] |
|
|
for line in lines: |
|
|
line_lower = line.lower().strip() |
|
|
if any(pattern in line_lower for pattern in PROMO_PATTERNS): |
|
|
continue |
|
|
if line_lower.startswith("http") and "llmplayground" in line_lower: |
|
|
continue |
|
|
clean_lines.append(line) |
|
|
while clean_lines and not clean_lines[-1].strip(): |
|
|
clean_lines.pop() |
|
|
return '\n'.join(clean_lines) |
|
|
|
|
|
|
|
|
class ModelManager: |
|
|
_instance = None |
|
|
_lock = Lock() |
|
|
|
|
|
def __new__(cls): |
|
|
if cls._instance is None: |
|
|
with cls._lock: |
|
|
if cls._instance is None: |
|
|
cls._instance = super().__new__(cls) |
|
|
cls._instance._initialize() |
|
|
return cls._instance |
|
|
|
|
|
def _initialize(self): |
|
|
self._current_model = 1 |
|
|
self._model_lock = Lock() |
|
|
self._gemini_primary_key = os.getenv("GEMINI_API_KEY") |
|
|
self._gemini_secondary_key = os.getenv("GEMINI_2_API_KEY") |
|
|
self._gemini_pro_key = os.getenv("GEMINI_API_KEY") |
|
|
logger.info(f"ModelManager initialized with model {self._current_model}") |
|
|
|
|
|
@property |
|
|
def current_model(self) -> int: |
|
|
with self._model_lock: |
|
|
return self._current_model |
|
|
|
|
|
@property |
|
|
def current_model_info(self) -> dict: |
|
|
with self._model_lock: |
|
|
return {"model_id": self._current_model, "name": MODEL_OPTIONS[self._current_model]["name"]} |
|
|
|
|
|
def switch_model(self, model_id: int) -> dict: |
|
|
if model_id not in MODEL_OPTIONS: |
|
|
raise ValueError(f"Invalid model ID: {model_id}. Valid options: 1-5") |
|
|
with self._model_lock: |
|
|
old_model = self._current_model |
|
|
self._current_model = model_id |
|
|
logger.info(f"Model switched from {old_model} to {model_id} ({MODEL_OPTIONS[model_id]['name']})") |
|
|
return self.current_model_info |
|
|
|
|
|
def call_llm(self, prompt: str) -> str: |
|
|
model_id = self.current_model |
|
|
if model_id == 1: |
|
|
return self._call_gemini(prompt, self._gemini_primary_key, "gemini-2.5-flash") |
|
|
elif model_id == 2: |
|
|
return self._call_gemini(prompt, self._gemini_secondary_key, "gemini-2.5-flash") |
|
|
elif model_id == 3: |
|
|
return self._call_gemini(prompt, self._gemini_pro_key, "gemini-2.5-pro") |
|
|
elif model_id == 4: |
|
|
return self._call_gpt4o(prompt) |
|
|
elif model_id == 5: |
|
|
return self._call_claude(prompt) |
|
|
|
|
|
def _call_gemini(self, prompt: str, api_key: str, model: str) -> str: |
|
|
try: |
|
|
llm = ChatGoogleGenerativeAI(model=model, temperature=0, google_api_key=api_key) |
|
|
response = llm.invoke(prompt) |
|
|
return response.content |
|
|
except Exception as e: |
|
|
error_str = str(e).lower() |
|
|
if "429" in str(e) or "quota" in error_str or "resource exhausted" in error_str or "rate limit" in error_str: |
|
|
logger.error(f"Gemini quota exhausted: {e}") |
|
|
raise QuotaExhaustedError("LLM quota exhausted. Please try again later or switch to a different model.") |
|
|
raise |
|
|
|
|
|
def _call_gpt4o(self, prompt: str) -> str: |
|
|
from g4f.Provider import DDG |
|
|
response = g4f_client.chat.completions.create( |
|
|
model="gpt-4o-mini", |
|
|
provider=DDG, |
|
|
messages=[{"role": "user", "content": prompt}], |
|
|
) |
|
|
return clean_response(response.choices[0].message.content) |
|
|
|
|
|
def _call_claude(self, prompt: str) -> str: |
|
|
from g4f.Provider import DDG |
|
|
response = g4f_client.chat.completions.create( |
|
|
model="claude-3-haiku", |
|
|
provider=DDG, |
|
|
messages=[{"role": "user", "content": prompt}], |
|
|
) |
|
|
return clean_response(response.choices[0].message.content) |
|
|
|
|
|
|
|
|
model_manager = ModelManager() |
|
|
|
|
|
|
|
|
class RAGService: |
|
|
def __init__(self): |
|
|
self.qdrant_url = os.getenv("QDRANT_CLUSTER_URL") |
|
|
self.qdrant_key = os.getenv("QDRANT_API_KEY") |
|
|
self.embeddings = HuggingFaceEmbeddings( |
|
|
model_name="Alibaba-NLP/gte-modernbert-base", |
|
|
model_kwargs={"device": "cpu"}, |
|
|
encode_kwargs={"normalize_embeddings": True} |
|
|
) |
|
|
|
|
|
def ask(self, query_text: str, subject: str, unit: int) -> dict: |
|
|
collection_name = f"askbookie_{subject}_unit-{unit}" |
|
|
max_retries = 3 |
|
|
last_error = None |
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
client = QdrantClient(url=self.qdrant_url, api_key=self.qdrant_key, timeout=120) |
|
|
vectorstore = QdrantVectorStore(client=client, collection_name=collection_name, embedding=self.embeddings) |
|
|
results = vectorstore.similarity_search_with_score(query_text, k=5) |
|
|
break |
|
|
except Exception as e: |
|
|
last_error = e |
|
|
if attempt < max_retries - 1: |
|
|
time.sleep(2 ** attempt) |
|
|
continue |
|
|
raise last_error |
|
|
top_results = results[:5] |
|
|
context_text = "\n\n---\n\n".join([doc.page_content for doc, _ in top_results]) |
|
|
full_prompt = PROMPT_TEMPLATE.format(context=context_text, question=query_text) |
|
|
answer = model_manager.call_llm(full_prompt) |
|
|
sources = [doc.metadata.get('slide_number', 'Unknown') for doc, _ in top_results] |
|
|
return {"answer": answer, "sources": sources, "collection": collection_name} |
|
|
|
|
|
|
|
|
def get_client_ip(request: Request) -> str: |
|
|
forwarded = request.headers.get("X-Forwarded-For") |
|
|
if forwarded: |
|
|
return forwarded.split(",")[-1].strip() |
|
|
return request.client.host if request.client else "unknown" |
|
|
|
|
|
|
|
|
def verify_hmac_signature(request: Request) -> Optional[str]: |
|
|
key_id = request.headers.get("X-API-Key-Id") |
|
|
sig = request.headers.get("X-API-Signature") |
|
|
ts = request.headers.get("X-API-Timestamp") |
|
|
if not all([key_id, sig, ts]): |
|
|
return None |
|
|
meta = API_KEYS.get(key_id) |
|
|
dummy_secret = "dummy_secret_for_timing_safety" |
|
|
secret_to_use = meta["secret"] if meta else dummy_secret |
|
|
is_valid_key = meta is not None and meta.get("active", False) |
|
|
try: |
|
|
ts_int = int(ts) |
|
|
except (ValueError, TypeError): |
|
|
ts_int = 0 |
|
|
is_valid_key = False |
|
|
time_valid = abs(time.time() - ts_int) <= HMAC_TIME_WINDOW |
|
|
is_valid_key = is_valid_key and time_valid |
|
|
message = f"{ts_int}\n{request.method.upper()}\n{request.url.path}" |
|
|
computed = hmac.new(secret_to_use.encode(), message.encode(), hashlib.sha256).hexdigest() |
|
|
sig_valid = secrets.compare_digest(computed, sig) if sig else False |
|
|
if is_valid_key and sig_valid: |
|
|
return key_id |
|
|
return None |
|
|
|
|
|
|
|
|
async def verify_api_key(request: Request) -> str: |
|
|
ip = get_client_ip(request) |
|
|
if db.check_auth_lockout(ip, FAILED_AUTH_LIMIT): |
|
|
raise HTTPException(status_code=429, detail="Too many failed attempts") |
|
|
key_id = verify_hmac_signature(request) |
|
|
if key_id: |
|
|
return key_id |
|
|
db.record_failed_auth(ip) |
|
|
await asyncio.sleep(0.1) |
|
|
raise HTTPException(status_code=401, detail="Unauthorized") |
|
|
|
|
|
|
|
|
def sanitize_subject(subject: str) -> str: |
|
|
clean = subject.strip().lower() |
|
|
if not SUBJECT_PATTERN.match(clean): |
|
|
clean = re.sub(r'[^a-zA-Z0-9_-]', '', clean) |
|
|
return clean[:50] if clean else "default" |
|
|
|
|
|
def get_metrics_summary() -> dict: |
|
|
metrics = db.get_metrics_summary() |
|
|
metrics["uptime_hours"] = round((time.time() - STARTUP_TIME) / 3600, 2) |
|
|
for kid in metrics.get("per_user", {}): |
|
|
metrics["per_user"][kid]["role"] = API_KEYS.get(kid, {}).get("role", "user") |
|
|
return metrics |
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
logger.info("Starting service") |
|
|
db.get_database() |
|
|
app.state.rag_service = RAGService() |
|
|
logger.info("RAG service initialized") |
|
|
yield |
|
|
logger.info("Shutting down") |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="AskBookie RAG API", |
|
|
version="3.0.0", |
|
|
lifespan=lifespan, |
|
|
docs_url="/docs", |
|
|
redoc_url="/redoc", |
|
|
openapi_url="/openapi.json", |
|
|
) |
|
|
|
|
|
|
|
|
@app.middleware("http") |
|
|
async def security_middleware(request: Request, call_next): |
|
|
request_id = secrets.token_hex(8) |
|
|
request.state.request_id = request_id |
|
|
start_time = time.time() |
|
|
response = await call_next(request) |
|
|
response.headers.update({ |
|
|
"X-Request-ID": request_id, |
|
|
"X-Content-Type-Options": "nosniff", |
|
|
"X-Frame-Options": "DENY", |
|
|
"Referrer-Policy": "strict-origin-when-cross-origin", |
|
|
}) |
|
|
duration_ms = round((time.time() - start_time) * 1000, 2) |
|
|
key_id = getattr(request.state, "key_id", None) |
|
|
logger.info(f"{request.method} {request.url.path} {response.status_code} {duration_ms}ms key={key_id} rid={request_id}") |
|
|
return response |
|
|
|
|
|
|
|
|
@app.exception_handler(HTTPException) |
|
|
async def http_exception_handler(request: Request, exc: HTTPException): |
|
|
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail}, headers=getattr(exc, "headers", None)) |
|
|
|
|
|
|
|
|
@app.exception_handler(Exception) |
|
|
async def general_exception_handler(request: Request, exc: Exception): |
|
|
logger.exception(f"Unhandled error: {exc}") |
|
|
return JSONResponse(status_code=500, content={"detail": "Internal error"}) |
|
|
|
|
|
|
|
|
class AskRequest(BaseModel): |
|
|
query: str = Field(..., min_length=1, max_length=1000) |
|
|
subject: str = Field(..., min_length=1, max_length=100) |
|
|
unit: int = Field(..., ge=1, le=4) |
|
|
context_limit: int = Field(default=5, ge=1, le=20) |
|
|
|
|
|
@field_validator("query", "subject") |
|
|
@classmethod |
|
|
def sanitize(cls, v: str) -> str: |
|
|
if v is None: |
|
|
return v |
|
|
return v.strip() |
|
|
|
|
|
|
|
|
@app.post("/ask") |
|
|
async def ask(request: Request, body: AskRequest, key_id: str = Depends(verify_api_key)): |
|
|
start_time = time.time() |
|
|
success = False |
|
|
request.state.key_id = key_id |
|
|
subject = sanitize_subject(body.subject) |
|
|
if not subject: |
|
|
raise HTTPException(status_code=400, detail="Invalid subject") |
|
|
try: |
|
|
rag_service: RAGService = request.app.state.rag_service |
|
|
result = rag_service.ask(body.query, subject, body.unit) |
|
|
success = True |
|
|
latency_ms = (time.time() - start_time) * 1000 |
|
|
current_model = model_manager.current_model_info |
|
|
db.store_query_history( |
|
|
key_id=key_id, |
|
|
subject=subject, |
|
|
query=body.query, |
|
|
answer=result["answer"], |
|
|
sources=result["sources"], |
|
|
request_id=request.state.request_id, |
|
|
latency_ms=latency_ms, |
|
|
model_id=current_model["model_id"], |
|
|
model_name=current_model["name"] |
|
|
) |
|
|
return {"answer": result["answer"], "sources": result["sources"], "collection": result["collection"], "model": current_model, "request_id": request.state.request_id} |
|
|
except QuotaExhaustedError as e: |
|
|
logger.warning(f"LLM quota exhausted: {e}") |
|
|
raise HTTPException(status_code=429, detail="LLM quota exhausted. Try again later or switch model.", headers={"Retry-After": "3600"}) |
|
|
except Exception as e: |
|
|
logger.exception(f"RAG query failed: {e}") |
|
|
raise HTTPException(status_code=500, detail="Query failed") |
|
|
finally: |
|
|
db.record_metric(key_id, "/ask", success, (time.time() - start_time) * 1000) |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def dashboard(): |
|
|
dashboard_path = Path(__file__).parent.parent / "assets" / "index.html" |
|
|
if dashboard_path.exists(): |
|
|
return FileResponse(dashboard_path, media_type="text/html") |
|
|
return {"service": "AskBookie RAG API", "version": "3.0.0"} |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health(): |
|
|
try: |
|
|
metrics = get_metrics_summary() |
|
|
metrics["status"] = "healthy" |
|
|
metrics["current_model"] = model_manager.current_model_info |
|
|
return metrics |
|
|
except Exception: |
|
|
logger.exception("Metrics error") |
|
|
return {"status": "degraded", "uptime_hours": round((time.time() - STARTUP_TIME) / 3600, 2)} |
|
|
|
|
|
|
|
|
@app.get("/history") |
|
|
async def get_query_history(request: Request, limit: int = 100, offset: int = 0, key_id: str = Depends(verify_api_key)): |
|
|
request.state.key_id = key_id |
|
|
if API_KEYS.get(key_id, {}).get("role") != "admin": |
|
|
raise HTTPException(status_code=403, detail="Forbidden") |
|
|
history, total = db.get_query_history(limit, offset) |
|
|
return {"history": history, "total": total, "limit": limit, "offset": offset} |
|
|
|
|
|
class ModelSwitchRequest(BaseModel): |
|
|
model_id: int = Field(..., ge=1, le=5) |
|
|
|
|
|
|
|
|
@app.post("/admin/models/switch") |
|
|
async def switch_model(request: Request, body: ModelSwitchRequest, key_id: str = Depends(verify_api_key)): |
|
|
request.state.key_id = key_id |
|
|
if API_KEYS.get(key_id, {}).get("role") != "admin": |
|
|
raise HTTPException(status_code=403, detail="Forbidden") |
|
|
try: |
|
|
result = model_manager.switch_model(body.model_id) |
|
|
logger.info(f"Admin switched model to {body.model_id}") |
|
|
return {"status": "success", "message": f"Switched to model {body.model_id}", "model": result} |
|
|
except ValueError as e: |
|
|
raise HTTPException(status_code=400, detail=str(e)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |