File size: 4,439 Bytes
7e8ec1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396f15b
 
7e8ec1d
396f15b
7e8ec1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396f15b
7e8ec1d
 
 
 
396f15b
7e8ec1d
396f15b
7e8ec1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396f15b
7e8ec1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396f15b
7e8ec1d
 
 
 
 
 
 
 
 
396f15b
 
7e8ec1d
 
 
 
 
 
 
396f15b
7e8ec1d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import logging
from datetime import datetime, timezone
from typing import Optional, List

from pymongo import MongoClient, DESCENDING
from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError

logger = logging.getLogger("rag-api")

_client: Optional[MongoClient] = None
_db = None


def get_database():
    global _client, _db
    if _db is not None:
        return _db
    mongodb_uri = os.getenv("MONGODB_URI")

    try:
        _client = MongoClient(mongodb_uri, serverSelectionTimeoutMS=10000)
        _client.admin.command('ping')
        _db = _client.askbookie
        logger.info("MongoDB connected successfully")
        _create_indexes(_db)
        return _db
    except (ConnectionFailure, ServerSelectionTimeoutError) as e:
        logger.error(f"MongoDB connection failed: {e}")
        raise


def _create_indexes(db):
    try:
        db.failed_auth.create_index("timestamp", expireAfterSeconds=300)
        db.failed_auth.create_index([("ip", 1), ("timestamp", 1)])
        db.metrics.create_index([("key_id", 1), ("endpoint", 1), ("timestamp", -1)])
        db.query_history.create_index([("timestamp", -1)])
        logger.info("MongoDB indexes created")
    except Exception as e:
        logger.warning(f"Index creation warning: {e}")


def record_failed_auth(ip: str):
    db = get_database()
    db.failed_auth.insert_one({"ip": ip, "timestamp": datetime.now(timezone.utc)})


def check_auth_lockout(ip: str, limit: int = 5) -> bool:
    db = get_database()
    import time
    window_start = datetime.fromtimestamp(time.time() - 300, tz=timezone.utc)
    count = db.failed_auth.count_documents({"ip": ip, "timestamp": {"$gte": window_start}})
    return count >= limit


def record_metric(key_id: str, endpoint: str, success: bool, latency_ms: float):
    db = get_database()
    db.metrics.insert_one({
        "key_id": key_id,
        "endpoint": endpoint,
        "success": success,
        "latency_ms": latency_ms,
        "timestamp": datetime.now(timezone.utc)
    })


def get_metrics_summary() -> dict:
    import psutil
    db = get_database()
    total_calls = db.metrics.count_documents({})
    total_questions = db.metrics.count_documents({"endpoint": "/ask"})
    pipeline = [
        {"$group": {
            "_id": "$key_id",
            "api_calls": {"$sum": 1},
            "questions_asked": {"$sum": {"$cond": [{"$eq": ["$endpoint", "/ask"]}, 1, 0]}},
            "success_count": {"$sum": {"$cond": ["$success", 1, 0]}},
            "total_latency": {"$sum": "$latency_ms"},
            "ask_fails": {"$sum": {"$cond": [{"$and": [{"$eq": ["$endpoint", "/ask"]}, {"$not": "$success"}]}, 1, 0]}}
        }}
    ]
    per_user = {}
    for doc in db.metrics.aggregate(pipeline):
        kid = doc["_id"]
        total = doc["api_calls"]
        per_user[kid] = {
            "api_calls": total,
            "questions_asked": doc["questions_asked"],
            "success_rate": round((doc["success_count"] / total * 100) if total > 0 else 100, 1),
            "average_latency_seconds": round((doc["total_latency"] / total / 1000) if total > 0 else 0, 2),
            "ask_fails": doc["ask_fails"],
        }
    process = psutil.Process()
    return {
        "total_api_calls": total_calls,
        "total_questions": total_questions,
        "memory_mb": round(process.memory_info().rss / (1024 * 1024), 1),
        "per_user": per_user,
    }


def store_query_history(key_id: str, subject: str, query: str, answer: str, sources: list, request_id: str, latency_ms: float, model_id: int = None, model_name: str = None):
    db = get_database()
    db.query_history.insert_one({
        "key_id": key_id,
        "subject": subject,
        "query": query,
        "answer": answer,
        "sources": sources,
        "request_id": request_id,
        "latency_ms": latency_ms,
        "model_id": model_id,
        "model_name": model_name,
        "timestamp": datetime.now(timezone.utc)
    })


def get_query_history(limit: int = 100, offset: int = 0) -> tuple[List[dict], int]:
    db = get_database()
    total = db.query_history.count_documents({})
    history = list(db.query_history.find({}, {"_id": 0}).sort("timestamp", DESCENDING).skip(offset).limit(limit))
    for i, item in enumerate(history):
        item["id"] = offset + i + 1
        if item.get("timestamp"):
            item["timestamp"] = item["timestamp"].timestamp()
    return history, total