import os import logging from datetime import datetime, timezone from typing import Optional, List from pymongo import MongoClient, DESCENDING from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError logger = logging.getLogger("rag-api") _client: Optional[MongoClient] = None _db = None def get_database(): global _client, _db if _db is not None: return _db mongodb_uri = os.getenv("MONGODB_URI") try: _client = MongoClient(mongodb_uri, serverSelectionTimeoutMS=10000) _client.admin.command('ping') _db = _client.askbookie logger.info("MongoDB connected successfully") _create_indexes(_db) return _db except (ConnectionFailure, ServerSelectionTimeoutError) as e: logger.error(f"MongoDB connection failed: {e}") raise def _create_indexes(db): try: db.failed_auth.create_index("timestamp", expireAfterSeconds=300) db.failed_auth.create_index([("ip", 1), ("timestamp", 1)]) db.metrics.create_index([("key_id", 1), ("endpoint", 1), ("timestamp", -1)]) db.query_history.create_index([("timestamp", -1)]) logger.info("MongoDB indexes created") except Exception as e: logger.warning(f"Index creation warning: {e}") def record_failed_auth(ip: str): db = get_database() db.failed_auth.insert_one({"ip": ip, "timestamp": datetime.now(timezone.utc)}) def check_auth_lockout(ip: str, limit: int = 5) -> bool: db = get_database() import time window_start = datetime.fromtimestamp(time.time() - 300, tz=timezone.utc) count = db.failed_auth.count_documents({"ip": ip, "timestamp": {"$gte": window_start}}) return count >= limit def record_metric(key_id: str, endpoint: str, success: bool, latency_ms: float): db = get_database() db.metrics.insert_one({ "key_id": key_id, "endpoint": endpoint, "success": success, "latency_ms": latency_ms, "timestamp": datetime.now(timezone.utc) }) def get_metrics_summary() -> dict: import psutil db = get_database() total_calls = db.metrics.count_documents({}) total_questions = db.metrics.count_documents({"endpoint": "/ask"}) pipeline = [ {"$group": { "_id": "$key_id", "api_calls": {"$sum": 1}, "questions_asked": {"$sum": {"$cond": [{"$eq": ["$endpoint", "/ask"]}, 1, 0]}}, "success_count": {"$sum": {"$cond": ["$success", 1, 0]}}, "total_latency": {"$sum": "$latency_ms"}, "ask_fails": {"$sum": {"$cond": [{"$and": [{"$eq": ["$endpoint", "/ask"]}, {"$not": "$success"}]}, 1, 0]}} }} ] per_user = {} for doc in db.metrics.aggregate(pipeline): kid = doc["_id"] total = doc["api_calls"] per_user[kid] = { "api_calls": total, "questions_asked": doc["questions_asked"], "success_rate": round((doc["success_count"] / total * 100) if total > 0 else 100, 1), "average_latency_seconds": round((doc["total_latency"] / total / 1000) if total > 0 else 0, 2), "ask_fails": doc["ask_fails"], } process = psutil.Process() return { "total_api_calls": total_calls, "total_questions": total_questions, "memory_mb": round(process.memory_info().rss / (1024 * 1024), 1), "per_user": per_user, } def store_query_history(key_id: str, subject: str, query: str, answer: str, sources: list, request_id: str, latency_ms: float, model_id: int = None, model_name: str = None): db = get_database() db.query_history.insert_one({ "key_id": key_id, "subject": subject, "query": query, "answer": answer, "sources": sources, "request_id": request_id, "latency_ms": latency_ms, "model_id": model_id, "model_name": model_name, "timestamp": datetime.now(timezone.utc) }) def get_query_history(limit: int = 100, offset: int = 0) -> tuple[List[dict], int]: db = get_database() total = db.query_history.count_documents({}) history = list(db.query_history.find({}, {"_id": 0}).sort("timestamp", DESCENDING).skip(offset).limit(limit)) for i, item in enumerate(history): item["id"] = offset + i + 1 if item.get("timestamp"): item["timestamp"] = item["timestamp"].timestamp() return history, total