askbookie / src /database.py
pmmdot's picture
Removed /upload!
396f15b
import os
import logging
from datetime import datetime, timezone
from typing import Optional, List
from pymongo import MongoClient, DESCENDING
from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError
logger = logging.getLogger("rag-api")
_client: Optional[MongoClient] = None
_db = None
def get_database():
global _client, _db
if _db is not None:
return _db
mongodb_uri = os.getenv("MONGODB_URI")
try:
_client = MongoClient(mongodb_uri, serverSelectionTimeoutMS=10000)
_client.admin.command('ping')
_db = _client.askbookie
logger.info("MongoDB connected successfully")
_create_indexes(_db)
return _db
except (ConnectionFailure, ServerSelectionTimeoutError) as e:
logger.error(f"MongoDB connection failed: {e}")
raise
def _create_indexes(db):
try:
db.failed_auth.create_index("timestamp", expireAfterSeconds=300)
db.failed_auth.create_index([("ip", 1), ("timestamp", 1)])
db.metrics.create_index([("key_id", 1), ("endpoint", 1), ("timestamp", -1)])
db.query_history.create_index([("timestamp", -1)])
logger.info("MongoDB indexes created")
except Exception as e:
logger.warning(f"Index creation warning: {e}")
def record_failed_auth(ip: str):
db = get_database()
db.failed_auth.insert_one({"ip": ip, "timestamp": datetime.now(timezone.utc)})
def check_auth_lockout(ip: str, limit: int = 5) -> bool:
db = get_database()
import time
window_start = datetime.fromtimestamp(time.time() - 300, tz=timezone.utc)
count = db.failed_auth.count_documents({"ip": ip, "timestamp": {"$gte": window_start}})
return count >= limit
def record_metric(key_id: str, endpoint: str, success: bool, latency_ms: float):
db = get_database()
db.metrics.insert_one({
"key_id": key_id,
"endpoint": endpoint,
"success": success,
"latency_ms": latency_ms,
"timestamp": datetime.now(timezone.utc)
})
def get_metrics_summary() -> dict:
import psutil
db = get_database()
total_calls = db.metrics.count_documents({})
total_questions = db.metrics.count_documents({"endpoint": "/ask"})
pipeline = [
{"$group": {
"_id": "$key_id",
"api_calls": {"$sum": 1},
"questions_asked": {"$sum": {"$cond": [{"$eq": ["$endpoint", "/ask"]}, 1, 0]}},
"success_count": {"$sum": {"$cond": ["$success", 1, 0]}},
"total_latency": {"$sum": "$latency_ms"},
"ask_fails": {"$sum": {"$cond": [{"$and": [{"$eq": ["$endpoint", "/ask"]}, {"$not": "$success"}]}, 1, 0]}}
}}
]
per_user = {}
for doc in db.metrics.aggregate(pipeline):
kid = doc["_id"]
total = doc["api_calls"]
per_user[kid] = {
"api_calls": total,
"questions_asked": doc["questions_asked"],
"success_rate": round((doc["success_count"] / total * 100) if total > 0 else 100, 1),
"average_latency_seconds": round((doc["total_latency"] / total / 1000) if total > 0 else 0, 2),
"ask_fails": doc["ask_fails"],
}
process = psutil.Process()
return {
"total_api_calls": total_calls,
"total_questions": total_questions,
"memory_mb": round(process.memory_info().rss / (1024 * 1024), 1),
"per_user": per_user,
}
def store_query_history(key_id: str, subject: str, query: str, answer: str, sources: list, request_id: str, latency_ms: float, model_id: int = None, model_name: str = None):
db = get_database()
db.query_history.insert_one({
"key_id": key_id,
"subject": subject,
"query": query,
"answer": answer,
"sources": sources,
"request_id": request_id,
"latency_ms": latency_ms,
"model_id": model_id,
"model_name": model_name,
"timestamp": datetime.now(timezone.utc)
})
def get_query_history(limit: int = 100, offset: int = 0) -> tuple[List[dict], int]:
db = get_database()
total = db.query_history.count_documents({})
history = list(db.query_history.find({}, {"_id": 0}).sort("timestamp", DESCENDING).skip(offset).limit(limit))
for i, item in enumerate(history):
item["id"] = offset + i + 1
if item.get("timestamp"):
item["timestamp"] = item["timestamp"].timestamp()
return history, total