proofly / project /database.py
Pragthedon's picture
Fix: Backend OOM crashes via Vector Cache and worker reduction
c7893c0
"""
project/database.py
MongoDB persistence layer with Borg singleton pattern.
Collections:
evidence – scraped text for FAISS/NLI pipeline
users – registered accounts (hashed + peppered passwords)
history – per-user claim history
revoked_tokens – JWT blocklist (TTL-indexed, self-cleaning)
"""
from datetime import datetime, timedelta, timezone
from bson import ObjectId
from pymongo import MongoClient, ASCENDING, DESCENDING
import certifi
from project import config
# ── Borg Singleton ─────────────────────────────────────────────────────────────
# All DatabaseManager instances share the same __dict__, giving us a singleton
# without a class-level lock β€” safe for Flask's threaded/forked model.
class DatabaseManager:
_shared_state: dict = {}
def __init__(self):
self.__dict__ = self._shared_state
def _connect(self):
if getattr(self, '_db', None) is None:
self._client = MongoClient(
config.MONGO_URI,
serverSelectionTimeoutMS=5000,
connectTimeoutMS=5000,
tlsCAFile=certifi.where(),
tlsAllowInvalidCertificates=True
)
self._db = self._client[config.MONGO_DB_NAME]
return self._db
@property
def db(self):
return self._connect()
_manager = DatabaseManager()
def get_db():
return _manager.db
# ── Initialise indexes ─────────────────────────────────────────────────────────
def init_db():
db = get_db()
# evidence: TTL – auto-removes docs older than 30 days
db.evidence.create_index(
[("created_at", ASCENDING)],
expireAfterSeconds=30 * 24 * 3600,
background=True,
name="evidence_ttl"
)
db.evidence.create_index([("source", ASCENDING)], background=True, name="source_idx")
# users: unique email
db.users.create_index([("email", ASCENDING)], unique=True, background=True, name="email_unique")
db.users.create_index([("username", ASCENDING)], background=True, name="username_idx")
# history: fast per-user lookup, newest first
db.history.create_index(
[("user_id", ASCENDING), ("created_at", DESCENDING)],
background=True,
name="user_history_idx"
)
# revoked_tokens: TTL auto-removes expired JTIs
db.revoked_tokens.create_index(
[("exp", ASCENDING)],
expireAfterSeconds=0,
background=True,
name="token_ttl"
)
db.revoked_tokens.create_index([("jti", ASCENDING)], unique=True, background=True, name="jti_unique")
# cached_results: exact claim cache, indexed by normalized claim
db.cached_results.create_index([("normalized_claim", ASCENDING)], unique=True, background=True, name="claim_cache_idx")
print("[DB] MongoDB indexes ensured.")
# ── Evidence helpers ───────────────────────────────────────────────────────────
def clear_db():
get_db().evidence.delete_many({})
def save_evidence(text, source, embedding=None):
try:
get_db().evidence.insert_one({
"text": text,
"source": source,
"embedding": embedding, # Store the vector list directly
"created_at": datetime.now(timezone.utc)
})
except Exception as e:
print(f"[DB] save_evidence error: {e}")
def load_all_evidence():
"""Returns list of (id, text, source, embedding) β€” same shape the FAISS pipeline expects."""
docs = list(get_db().evidence.find({}, {"_id": 1, "text": 1, "source": 1, "embedding": 1}))
return [(str(d["_id"]), d["text"], d["source"], d.get("embedding")) for d in docs]
def get_total_evidence_count():
return get_db().evidence.count_documents({})
def prune_old_evidence(days=30):
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
result = get_db().evidence.delete_many({"created_at": {"$lt": cutoff}})
return result.deleted_count
# ── User helpers ───────────────────────────────────────────────────────────────
def create_user(username, email, password_hash, is_admin=False):
"""Returns inserted ObjectId string, or None on duplicate/error."""
try:
result = get_db().users.insert_one({
"username": username,
"email": email,
"password_hash": password_hash,
"is_admin": is_admin,
"created_at": datetime.now(timezone.utc)
})
return str(result.inserted_id)
except Exception as e:
# Avoid logging email β€” privacy
print(f"[DB] create_user error: {type(e).__name__}")
return None
def find_user_by_email(email):
return get_db().users.find_one({"email": email})
def find_user_by_id(user_id):
try:
return get_db().users.find_one({"_id": ObjectId(user_id)})
except Exception:
return None
# ── JWT Blocklist helpers ──────────────────────────────────────────────────────
def add_token_to_blocklist(jti: str, exp: datetime):
"""Persist a revoked JWT jti. TTL index auto-removes it after exp."""
try:
get_db().revoked_tokens.insert_one({
"jti": jti,
"exp": exp,
"revoked_at": datetime.now(timezone.utc)
})
except Exception as e:
print(f"[DB] add_token_to_blocklist error: {type(e).__name__}")
def is_token_revoked(jti: str) -> bool:
return get_db().revoked_tokens.find_one({"jti": jti}) is not None
# ── History helpers ────────────────────────────────────────────────────────────
def save_history(user_id, claim, verdict, confidence, evidence_count):
try:
get_db().history.insert_one({
"user_id": user_id,
"claim": claim,
"verdict": verdict,
"confidence": confidence,
"evidence_count": evidence_count,
"created_at": datetime.now(timezone.utc)
})
except Exception as e:
print(f"[DB] save_history error: {type(e).__name__}")
def get_user_history(user_id, limit=50):
return list(
get_db().history
.find({"user_id": user_id})
.sort("created_at", DESCENDING)
.limit(limit)
)
def delete_history_item(user_id, item_id):
try:
result = get_db().history.delete_one({
"_id": ObjectId(item_id),
"user_id": user_id
})
return result.deleted_count == 1
except Exception as e:
print(f"[DB] delete_history_item error: {type(e).__name__}")
return False
def clear_user_history(user_id):
try:
result = get_db().history.delete_many({"user_id": user_id})
return result.deleted_count
except Exception as e:
print(f"[DB] clear_user_history error: {type(e).__name__}")
return 0
# ── EXACT CLAIM CACHE ────────────────────────────────────────────────────────
def _normalize_claim(claim: str) -> str:
"""Lowercase and strip whitespace for exact matching."""
return claim.strip().lower()
def get_cached_result(claim: str) -> dict:
"""Returns the cached fully structured API dictionary if it exists."""
norm = _normalize_claim(claim)
try:
doc = get_db().cached_results.find_one({"normalized_claim": norm})
if doc and "result" in doc:
return doc["result"]
except Exception as e:
print(f"[DB] get_cached_result error: {type(e).__name__}")
return None
def save_cached_result(claim: str, result: dict):
"""Saves a successful API run structure into the cache."""
if not result.get("success"):
return
norm = _normalize_claim(claim)
try:
get_db().cached_results.update_one(
{"normalized_claim": norm},
{
"$set": {
"result": result,
"updated_at": datetime.now(timezone.utc)
},
"$setOnInsert": {
"created_at": datetime.now(timezone.utc)
}
},
upsert=True
)
except Exception as e:
print(f"[DB] save_cached_result error: {type(e).__name__}")
# ── ADMIN HELPERS ────────────────────────────────────────────────────────────
def get_system_stats():
"""Aggregate high-level system metrics + chart data for God Mode."""
db = get_db()
try:
total_users = db.users.count_documents({})
total_checks = db.history.count_documents({})
total_evidence = db.evidence.count_documents({})
total_cached = db.cached_results.count_documents({})
day_ago = datetime.now(timezone.utc) - timedelta(days=1)
recent_checks = db.history.count_documents({"created_at": {"$gt": day_ago}})
# ── Verdict breakdown (last 500 checks) ──────────────────────────────
verdict_pipeline = [
{"$group": {"_id": "$verdict", "count": {"$sum": 1}}},
]
verdict_raw = list(db.history.aggregate(verdict_pipeline))
verdict_counts = {r["_id"]: r["count"] for r in verdict_raw}
# ── Daily checks β€” last 7 days ──────────────────────────────────────
seven_days_ago = datetime.now(timezone.utc) - timedelta(days=7)
daily_pipeline = [
{"$match": {"created_at": {"$gt": seven_days_ago}}},
{"$group": {
"_id": {
"year": {"$year": "$created_at"},
"month": {"$month": "$created_at"},
"day": {"$dayOfMonth": "$created_at"},
},
"count": {"$sum": 1}
}},
{"$sort": {"_id.year": 1, "_id.month": 1, "_id.day": 1}}
]
daily_raw = list(db.history.aggregate(daily_pipeline))
# Build a filled 7-day array (fill missing days with 0)
daily_map = {}
for r in daily_raw:
d = r["_id"]
key = f"{d['year']}-{str(d['month']).zfill(2)}-{str(d['day']).zfill(2)}"
daily_map[key] = r["count"]
daily_labels, daily_data = [], []
for i in range(6, -1, -1):
day = datetime.now(timezone.utc) - timedelta(days=i)
label = day.strftime("%b %d")
key = day.strftime("%Y-%m-%d")
daily_labels.append(label)
daily_data.append(daily_map.get(key, 0))
# ── Top 5 users by check count ────────────────────────────────────
top_users_pipeline = [
{"$group": {"_id": "$user_id", "count": {"$sum": 1}}},
{"$sort": {"count": -1}},
{"$limit": 5},
]
top_users_raw = list(db.history.aggregate(top_users_pipeline))
top_users = []
for r in top_users_raw:
user = find_user_by_id(r["_id"])
top_users.append({
"username": user.get("username", "Unknown") if user else "Unknown",
"count": r["count"]
})
cache_hit_rate = round((total_cached / total_checks * 100), 1) if total_checks else 0
return {
"total_users": total_users,
"total_checks": total_checks,
"total_evidence": total_evidence,
"total_cached": total_cached,
"recent_checks_24h": recent_checks,
"cache_hit_rate": cache_hit_rate,
"verdict_counts": verdict_counts,
"daily_labels": daily_labels,
"daily_data": daily_data,
"top_users": top_users,
}
except Exception as e:
print(f"[DB] get_system_stats error: {e}")
return {}
def list_all_users(limit=100):
"""Returns all users for the admin dashboard."""
return list(get_db().users.find({}, {"password_hash": 0}).sort("created_at", DESCENDING).limit(limit))
def get_global_history(limit=500):
"""Returns the most recent fact checks across all users, enriched with usernames."""
history = list(get_db().history.find({}).sort("created_at", DESCENDING).limit(limit))
user_map = {}
for h in history:
uid = h.get("user_id")
if uid not in user_map:
user = find_user_by_id(uid)
user_map[uid] = user.get("username", "Unknown") if user else "Unknown"
h["username"] = user_map[uid]
return history