| import os | |
| import logging | |
| from datetime import datetime, timezone | |
| from typing import Optional, List | |
| from pymongo import MongoClient, DESCENDING | |
| from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError | |
| logger = logging.getLogger("rag-api") | |
| _client: Optional[MongoClient] = None | |
| _db = None | |
| def get_database(): | |
| global _client, _db | |
| if _db is not None: | |
| return _db | |
| mongodb_uri = os.getenv("MONGODB_URI") | |
| try: | |
| _client = MongoClient(mongodb_uri, serverSelectionTimeoutMS=10000) | |
| _client.admin.command('ping') | |
| _db = _client.askbookie | |
| logger.info("MongoDB connected successfully") | |
| _create_indexes(_db) | |
| return _db | |
| except (ConnectionFailure, ServerSelectionTimeoutError) as e: | |
| logger.error(f"MongoDB connection failed: {e}") | |
| raise | |
| def _create_indexes(db): | |
| try: | |
| db.failed_auth.create_index("timestamp", expireAfterSeconds=300) | |
| db.failed_auth.create_index([("ip", 1), ("timestamp", 1)]) | |
| db.metrics.create_index([("key_id", 1), ("endpoint", 1), ("timestamp", -1)]) | |
| db.query_history.create_index([("timestamp", -1)]) | |
| logger.info("MongoDB indexes created") | |
| except Exception as e: | |
| logger.warning(f"Index creation warning: {e}") | |
| def record_failed_auth(ip: str): | |
| db = get_database() | |
| db.failed_auth.insert_one({"ip": ip, "timestamp": datetime.now(timezone.utc)}) | |
| def check_auth_lockout(ip: str, limit: int = 5) -> bool: | |
| db = get_database() | |
| import time | |
| window_start = datetime.fromtimestamp(time.time() - 300, tz=timezone.utc) | |
| count = db.failed_auth.count_documents({"ip": ip, "timestamp": {"$gte": window_start}}) | |
| return count >= limit | |
| def record_metric(key_id: str, endpoint: str, success: bool, latency_ms: float): | |
| db = get_database() | |
| db.metrics.insert_one({ | |
| "key_id": key_id, | |
| "endpoint": endpoint, | |
| "success": success, | |
| "latency_ms": latency_ms, | |
| "timestamp": datetime.now(timezone.utc) | |
| }) | |
| def get_metrics_summary() -> dict: | |
| import psutil | |
| db = get_database() | |
| total_calls = db.metrics.count_documents({}) | |
| total_questions = db.metrics.count_documents({"endpoint": "/ask"}) | |
| pipeline = [ | |
| {"$group": { | |
| "_id": "$key_id", | |
| "api_calls": {"$sum": 1}, | |
| "questions_asked": {"$sum": {"$cond": [{"$eq": ["$endpoint", "/ask"]}, 1, 0]}}, | |
| "success_count": {"$sum": {"$cond": ["$success", 1, 0]}}, | |
| "total_latency": {"$sum": "$latency_ms"}, | |
| "ask_fails": {"$sum": {"$cond": [{"$and": [{"$eq": ["$endpoint", "/ask"]}, {"$not": "$success"}]}, 1, 0]}} | |
| }} | |
| ] | |
| per_user = {} | |
| for doc in db.metrics.aggregate(pipeline): | |
| kid = doc["_id"] | |
| total = doc["api_calls"] | |
| per_user[kid] = { | |
| "api_calls": total, | |
| "questions_asked": doc["questions_asked"], | |
| "success_rate": round((doc["success_count"] / total * 100) if total > 0 else 100, 1), | |
| "average_latency_seconds": round((doc["total_latency"] / total / 1000) if total > 0 else 0, 2), | |
| "ask_fails": doc["ask_fails"], | |
| } | |
| process = psutil.Process() | |
| return { | |
| "total_api_calls": total_calls, | |
| "total_questions": total_questions, | |
| "memory_mb": round(process.memory_info().rss / (1024 * 1024), 1), | |
| "per_user": per_user, | |
| } | |
| def store_query_history(key_id: str, subject: str, query: str, answer: str, sources: list, request_id: str, latency_ms: float, model_id: int = None, model_name: str = None): | |
| db = get_database() | |
| db.query_history.insert_one({ | |
| "key_id": key_id, | |
| "subject": subject, | |
| "query": query, | |
| "answer": answer, | |
| "sources": sources, | |
| "request_id": request_id, | |
| "latency_ms": latency_ms, | |
| "model_id": model_id, | |
| "model_name": model_name, | |
| "timestamp": datetime.now(timezone.utc) | |
| }) | |
| def get_query_history(limit: int = 100, offset: int = 0) -> tuple[List[dict], int]: | |
| db = get_database() | |
| total = db.query_history.count_documents({}) | |
| history = list(db.query_history.find({}, {"_id": 0}).sort("timestamp", DESCENDING).skip(offset).limit(limit)) | |
| for i, item in enumerate(history): | |
| item["id"] = offset + i + 1 | |
| if item.get("timestamp"): | |
| item["timestamp"] = item["timestamp"].timestamp() | |
| return history, total | |