File size: 15,636 Bytes
4e581ab
 
 
20f8f08
 
4e581ab
 
20f8f08
396f15b
7e8ec1d
20f8f08
 
396f15b
20f8f08
396f15b
20f8f08
6f26f80
 
396f15b
 
 
 
 
 
4e581ab
7e8ec1d
4e581ab
396f15b
 
 
20f8f08
4e581ab
20f8f08
396f15b
20f8f08
 
4e581ab
20f8f08
396f15b
 
 
 
 
 
20f8f08
 
 
396f15b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20f8f08
 
 
 
 
 
396f15b
20f8f08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396f15b
20f8f08
 
7e8ec1d
20f8f08
 
 
 
7e8ec1d
20f8f08
 
 
 
 
 
 
 
 
 
 
7e8ec1d
 
 
 
 
4e581ab
20f8f08
 
 
0af8e53
396f15b
20f8f08
 
 
 
4e581ab
396f15b
4e581ab
396f15b
 
20f8f08
 
 
 
4e581ab
 
20f8f08
 
 
 
 
 
 
 
 
 
396f15b
20f8f08
 
 
 
 
 
 
396f15b
20f8f08
 
 
 
396f15b
20f8f08
 
 
 
 
396f15b
20f8f08
 
396f15b
 
20f8f08
 
396f15b
20f8f08
 
0a2b63a
 
20f8f08
 
396f15b
20f8f08
396f15b
20f8f08
 
396f15b
 
 
 
4e581ab
20f8f08
396f15b
20f8f08
65663e2
396f15b
7e8ec1d
65663e2
 
 
 
 
 
396f15b
 
 
65663e2
396f15b
70d5168
 
 
4e581ab
20f8f08
 
4e581ab
7e8ec1d
4e581ab
 
396f15b
 
 
 
 
 
4e581ab
 
20f8f08
 
 
 
 
777895b
20f8f08
 
 
 
 
4e581ab
65663e2
396f15b
 
65663e2
 
7e8ec1d
65663e2
 
2528448
7e8ec1d
2528448
 
 
396f15b
 
2528448
 
 
 
 
396f15b
2528448
 
396f15b
 
4e581ab
 
20f8f08
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
import os
import time
import asyncio
import secrets
import hmac
import hashlib
import logging
import re
from datetime import datetime, timezone
from contextlib import asynccontextmanager
from typing import Optional

from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.responses import JSONResponse, FileResponse
from pathlib import Path
from pydantic import BaseModel, Field, field_validator
from dotenv import load_dotenv

from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from g4f.client import Client
from threading import Lock

import database as db

load_dotenv()

logging.basicConfig(level=logging.INFO, format='%(levelname)s:     %(message)s', datefmt='%H:%M:%S')
logger = logging.getLogger("rag-api")

HMAC_TIME_WINDOW = 300
FAILED_AUTH_LIMIT = 5
SUBJECT_PATTERN = re.compile(r'^[a-zA-Z0-9_-]+$')
STARTUP_TIME = time.time()

API_KEYS = {}
admin_secret = os.getenv("ADMIN_API_KEY")
if admin_secret:
    API_KEYS["admin"] = {"secret": admin_secret, "active": True, "role": "admin"}
bot_secret = os.getenv("BOT_API_KEY")
if bot_secret:
    API_KEYS["bot"] = {"secret": bot_secret, "active": True, "role": "user"}
if not API_KEYS:
    logger.warning("No API keys configured!")

PROMPT_TEMPLATE = """
    You are AskBookie, an assistant built on a RAG system using university slide data.

    Your rules:
    1. If the context has the answer, use it.
    2. If the context is related but incomplete, answer from your knowledge but mention the context.
    3. If unrelated, say it's not in context.
    4. Format your answer nicely in Markdown. Use LaTeX for math ($...$ for inline, $$...$$ for block).

    Context:
    {context}

    Question: {question}
"""

g4f_client = Client()

PROMO_PATTERNS = [
    "want best roleplay experience",
    "llmplayground.net",
    "want the best roleplay",
    "best ai roleplay",
]

MODEL_OPTIONS = {
    1: {"name": "Gemini-3-flash", "description": "Gemini Primary API Key"},
    2: {"name": "Gemini-3-flash(Back-up)", "description": "Gemini Secondary API Key"},
    3: {"name": "Gemini-3-Pro", "description": "Gemini Primary API Key"},
    4: {"name": "GPT-4o-mini", "description": "DuckDuckGo (Free)"},
    5: {"name": "Claude-3-Haiku", "description": "DuckDuckGo (Free)"},
}


class QuotaExhaustedError(Exception):
    pass


def clean_response(text: str) -> str:
    lines = text.split('\n')
    clean_lines = []
    for line in lines:
        line_lower = line.lower().strip()
        if any(pattern in line_lower for pattern in PROMO_PATTERNS):
            continue
        if line_lower.startswith("http") and "llmplayground" in line_lower:
            continue
        clean_lines.append(line)
    while clean_lines and not clean_lines[-1].strip():
        clean_lines.pop()
    return '\n'.join(clean_lines)


class ModelManager:
    _instance = None
    _lock = Lock()

    def __new__(cls):
        if cls._instance is None:
            with cls._lock:
                if cls._instance is None:
                    cls._instance = super().__new__(cls)
                    cls._instance._initialize()
        return cls._instance

    def _initialize(self):
        self._current_model = 1
        self._model_lock = Lock()
        self._gemini_primary_key = os.getenv("GEMINI_API_KEY")
        self._gemini_secondary_key = os.getenv("GEMINI_2_API_KEY")
        self._gemini_pro_key = os.getenv("GEMINI_API_KEY")
        logger.info(f"ModelManager initialized with model {self._current_model}")

    @property
    def current_model(self) -> int:
        with self._model_lock:
            return self._current_model

    @property
    def current_model_info(self) -> dict:
        with self._model_lock:
            return {"model_id": self._current_model, "name": MODEL_OPTIONS[self._current_model]["name"]}

    def switch_model(self, model_id: int) -> dict:
        if model_id not in MODEL_OPTIONS:
            raise ValueError(f"Invalid model ID: {model_id}. Valid options: 1-5")
        with self._model_lock:
            old_model = self._current_model
            self._current_model = model_id
            logger.info(f"Model switched from {old_model} to {model_id} ({MODEL_OPTIONS[model_id]['name']})")
        return self.current_model_info

    def call_llm(self, prompt: str) -> str:
        model_id = self.current_model
        if model_id == 1:
            return self._call_gemini(prompt, self._gemini_primary_key, "gemini-2.5-flash")
        elif model_id == 2:
            return self._call_gemini(prompt, self._gemini_secondary_key, "gemini-2.5-flash")
        elif model_id == 3:
            return self._call_gemini(prompt, self._gemini_pro_key, "gemini-2.5-pro")
        elif model_id == 4:
            return self._call_gpt4o(prompt)
        elif model_id == 5:
            return self._call_claude(prompt)

    def _call_gemini(self, prompt: str, api_key: str, model: str) -> str:
        try:
            llm = ChatGoogleGenerativeAI(model=model, temperature=0, google_api_key=api_key)
            response = llm.invoke(prompt)
            return response.content
        except Exception as e:
            error_str = str(e).lower()
            if "429" in str(e) or "quota" in error_str or "resource exhausted" in error_str or "rate limit" in error_str:
                logger.error(f"Gemini quota exhausted: {e}")
                raise QuotaExhaustedError("LLM quota exhausted. Please try again later or switch to a different model.")
            raise

    def _call_gpt4o(self, prompt: str) -> str:
        from g4f.Provider import DDG
        response = g4f_client.chat.completions.create(
            model="gpt-4o-mini",
            provider=DDG,
            messages=[{"role": "user", "content": prompt}],
        )
        return clean_response(response.choices[0].message.content)

    def _call_claude(self, prompt: str) -> str:
        from g4f.Provider import DDG
        response = g4f_client.chat.completions.create(
            model="claude-3-haiku",
            provider=DDG,
            messages=[{"role": "user", "content": prompt}],
        )
        return clean_response(response.choices[0].message.content)


model_manager = ModelManager()


class RAGService:
    def __init__(self):
        self.qdrant_url = os.getenv("QDRANT_CLUSTER_URL")
        self.qdrant_key = os.getenv("QDRANT_API_KEY")
        self.embeddings = HuggingFaceEmbeddings(
            model_name="Alibaba-NLP/gte-modernbert-base",
            model_kwargs={"device": "cpu"},
            encode_kwargs={"normalize_embeddings": True}
        )

    def ask(self, query_text: str, subject: str, unit: int) -> dict:
        collection_name = f"askbookie_{subject}_unit-{unit}"
        max_retries = 3
        last_error = None
        for attempt in range(max_retries):
            try:
                client = QdrantClient(url=self.qdrant_url, api_key=self.qdrant_key, timeout=120)
                vectorstore = QdrantVectorStore(client=client, collection_name=collection_name, embedding=self.embeddings)
                results = vectorstore.similarity_search_with_score(query_text, k=5)
                break
            except Exception as e:
                last_error = e
                if attempt < max_retries - 1:
                    time.sleep(2 ** attempt)
                    continue
                raise last_error
        top_results = results[:5]
        context_text = "\n\n---\n\n".join([doc.page_content for doc, _ in top_results])
        full_prompt = PROMPT_TEMPLATE.format(context=context_text, question=query_text)
        answer = model_manager.call_llm(full_prompt)
        sources = [doc.metadata.get('slide_number', 'Unknown') for doc, _ in top_results]
        return {"answer": answer, "sources": sources, "collection": collection_name}


def get_client_ip(request: Request) -> str:
    forwarded = request.headers.get("X-Forwarded-For")
    if forwarded:
        return forwarded.split(",")[-1].strip()
    return request.client.host if request.client else "unknown"


def verify_hmac_signature(request: Request) -> Optional[str]:
    key_id = request.headers.get("X-API-Key-Id")
    sig = request.headers.get("X-API-Signature")
    ts = request.headers.get("X-API-Timestamp")
    if not all([key_id, sig, ts]):
        return None
    meta = API_KEYS.get(key_id)
    dummy_secret = "dummy_secret_for_timing_safety"
    secret_to_use = meta["secret"] if meta else dummy_secret
    is_valid_key = meta is not None and meta.get("active", False)
    try:
        ts_int = int(ts)
    except (ValueError, TypeError):
        ts_int = 0
        is_valid_key = False
    time_valid = abs(time.time() - ts_int) <= HMAC_TIME_WINDOW
    is_valid_key = is_valid_key and time_valid
    message = f"{ts_int}\n{request.method.upper()}\n{request.url.path}"
    computed = hmac.new(secret_to_use.encode(), message.encode(), hashlib.sha256).hexdigest()
    sig_valid = secrets.compare_digest(computed, sig) if sig else False
    if is_valid_key and sig_valid:
        return key_id
    return None


async def verify_api_key(request: Request) -> str:
    ip = get_client_ip(request)
    if db.check_auth_lockout(ip, FAILED_AUTH_LIMIT):
        raise HTTPException(status_code=429, detail="Too many failed attempts")
    key_id = verify_hmac_signature(request)
    if key_id:
        return key_id
    db.record_failed_auth(ip)
    await asyncio.sleep(0.1)
    raise HTTPException(status_code=401, detail="Unauthorized")


def sanitize_subject(subject: str) -> str:
    clean = subject.strip().lower()
    if not SUBJECT_PATTERN.match(clean):
        clean = re.sub(r'[^a-zA-Z0-9_-]', '', clean)
    return clean[:50] if clean else "default"

def get_metrics_summary() -> dict:
    metrics = db.get_metrics_summary()
    metrics["uptime_hours"] = round((time.time() - STARTUP_TIME) / 3600, 2)
    for kid in metrics.get("per_user", {}):
        metrics["per_user"][kid]["role"] = API_KEYS.get(kid, {}).get("role", "user")
    return metrics


@asynccontextmanager
async def lifespan(app: FastAPI):
    logger.info("Starting service")
    db.get_database()
    app.state.rag_service = RAGService()
    logger.info("RAG service initialized")
    yield
    logger.info("Shutting down")


app = FastAPI(
    title="AskBookie RAG API",
    version="3.0.0",
    lifespan=lifespan,
    docs_url="/docs",
    redoc_url="/redoc",
    openapi_url="/openapi.json",
)


@app.middleware("http")
async def security_middleware(request: Request, call_next):
    request_id = secrets.token_hex(8)
    request.state.request_id = request_id
    start_time = time.time()
    response = await call_next(request)
    response.headers.update({
        "X-Request-ID": request_id,
        "X-Content-Type-Options": "nosniff",
        "X-Frame-Options": "DENY",
        "Referrer-Policy": "strict-origin-when-cross-origin",
    })
    duration_ms = round((time.time() - start_time) * 1000, 2)
    key_id = getattr(request.state, "key_id", None)
    logger.info(f"{request.method} {request.url.path} {response.status_code} {duration_ms}ms key={key_id} rid={request_id}")
    return response


@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
    return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail}, headers=getattr(exc, "headers", None))


@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
    logger.exception(f"Unhandled error: {exc}")
    return JSONResponse(status_code=500, content={"detail": "Internal error"})


class AskRequest(BaseModel):
    query: str = Field(..., min_length=1, max_length=1000)
    subject: str = Field(..., min_length=1, max_length=100)
    unit: int = Field(..., ge=1, le=4)
    context_limit: int = Field(default=5, ge=1, le=20)

    @field_validator("query", "subject")
    @classmethod
    def sanitize(cls, v: str) -> str:
        if v is None:
            return v
        return v.strip()


@app.post("/ask")
async def ask(request: Request, body: AskRequest, key_id: str = Depends(verify_api_key)):
    start_time = time.time()
    success = False
    request.state.key_id = key_id
    subject = sanitize_subject(body.subject)
    if not subject:
        raise HTTPException(status_code=400, detail="Invalid subject")
    try:
        rag_service: RAGService = request.app.state.rag_service
        result = rag_service.ask(body.query, subject, body.unit)
        success = True
        latency_ms = (time.time() - start_time) * 1000
        current_model = model_manager.current_model_info
        db.store_query_history(
            key_id=key_id,
            subject=subject,
            query=body.query,
            answer=result["answer"],
            sources=result["sources"],
            request_id=request.state.request_id,
            latency_ms=latency_ms,
            model_id=current_model["model_id"],
            model_name=current_model["name"]
        )
        return {"answer": result["answer"], "sources": result["sources"], "collection": result["collection"], "model": current_model, "request_id": request.state.request_id}
    except QuotaExhaustedError as e:
        logger.warning(f"LLM quota exhausted: {e}")
        raise HTTPException(status_code=429, detail="LLM quota exhausted. Try again later or switch model.", headers={"Retry-After": "3600"})
    except Exception as e:
        logger.exception(f"RAG query failed: {e}")
        raise HTTPException(status_code=500, detail="Query failed")
    finally:
        db.record_metric(key_id, "/ask", success, (time.time() - start_time) * 1000)


@app.get("/")
async def dashboard():
    dashboard_path = Path(__file__).parent.parent / "assets" / "index.html"
    if dashboard_path.exists():
        return FileResponse(dashboard_path, media_type="text/html")
    return {"service": "AskBookie RAG API", "version": "3.0.0"}


@app.get("/health")
async def health():
    try:
        metrics = get_metrics_summary()
        metrics["status"] = "healthy"
        metrics["current_model"] = model_manager.current_model_info
        return metrics
    except Exception:
        logger.exception("Metrics error")
        return {"status": "degraded", "uptime_hours": round((time.time() - STARTUP_TIME) / 3600, 2)}


@app.get("/history")
async def get_query_history(request: Request, limit: int = 100, offset: int = 0, key_id: str = Depends(verify_api_key)):
    request.state.key_id = key_id
    if API_KEYS.get(key_id, {}).get("role") != "admin":
        raise HTTPException(status_code=403, detail="Forbidden")
    history, total = db.get_query_history(limit, offset)
    return {"history": history, "total": total, "limit": limit, "offset": offset}

class ModelSwitchRequest(BaseModel):
    model_id: int = Field(..., ge=1, le=5)


@app.post("/admin/models/switch")
async def switch_model(request: Request, body: ModelSwitchRequest, key_id: str = Depends(verify_api_key)):
    request.state.key_id = key_id
    if API_KEYS.get(key_id, {}).get("role") != "admin":
        raise HTTPException(status_code=403, detail="Forbidden")
    try:
        result = model_manager.switch_model(body.model_id)
        logger.info(f"Admin switched model to {body.model_id}")
        return {"status": "success", "message": f"Switched to model {body.model_id}", "model": result}
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)