File size: 13,493 Bytes
4f48a4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7893c0
4f48a4e
 
 
 
c7893c0
4f48a4e
 
 
 
 
 
c7893c0
 
 
4f48a4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
"""
project/database.py

MongoDB persistence layer with Borg singleton pattern.
Collections:
  evidence       – scraped text for FAISS/NLI pipeline
  users          – registered accounts (hashed + peppered passwords)
  history        – per-user claim history
  revoked_tokens – JWT blocklist (TTL-indexed, self-cleaning)
"""

from datetime import datetime, timedelta, timezone
from bson import ObjectId
from pymongo import MongoClient, ASCENDING, DESCENDING
import certifi
from project import config


# ── Borg Singleton ─────────────────────────────────────────────────────────────
# All DatabaseManager instances share the same __dict__, giving us a singleton
# without a class-level lock β€” safe for Flask's threaded/forked model.
class DatabaseManager:
    _shared_state: dict = {}

    def __init__(self):
        self.__dict__ = self._shared_state

    def _connect(self):
        if getattr(self, '_db', None) is None:
            self._client = MongoClient(
                config.MONGO_URI,
                serverSelectionTimeoutMS=5000,
                connectTimeoutMS=5000,
                tlsCAFile=certifi.where(),
                tlsAllowInvalidCertificates=True
            )
            self._db = self._client[config.MONGO_DB_NAME]
        return self._db

    @property
    def db(self):
        return self._connect()


_manager = DatabaseManager()


def get_db():
    return _manager.db


# ── Initialise indexes ─────────────────────────────────────────────────────────
def init_db():
    db = get_db()
    # evidence: TTL – auto-removes docs older than 30 days
    db.evidence.create_index(
        [("created_at", ASCENDING)],
        expireAfterSeconds=30 * 24 * 3600,
        background=True,
        name="evidence_ttl"
    )
    db.evidence.create_index([("source", ASCENDING)], background=True, name="source_idx")
    # users: unique email
    db.users.create_index([("email", ASCENDING)], unique=True, background=True, name="email_unique")
    db.users.create_index([("username", ASCENDING)], background=True, name="username_idx")
    # history: fast per-user lookup, newest first
    db.history.create_index(
        [("user_id", ASCENDING), ("created_at", DESCENDING)],
        background=True,
        name="user_history_idx"
    )
    # revoked_tokens: TTL auto-removes expired JTIs
    db.revoked_tokens.create_index(
        [("exp", ASCENDING)],
        expireAfterSeconds=0,
        background=True,
        name="token_ttl"
    )
    db.revoked_tokens.create_index([("jti", ASCENDING)], unique=True, background=True, name="jti_unique")
    
    # cached_results: exact claim cache, indexed by normalized claim
    db.cached_results.create_index([("normalized_claim", ASCENDING)], unique=True, background=True, name="claim_cache_idx")
    
    print("[DB] MongoDB indexes ensured.")


# ── Evidence helpers ───────────────────────────────────────────────────────────
def clear_db():
    get_db().evidence.delete_many({})

def save_evidence(text, source, embedding=None):
    try:
        get_db().evidence.insert_one({
            "text":       text,
            "source":     source,
            "embedding":  embedding, # Store the vector list directly
            "created_at": datetime.now(timezone.utc)
        })
    except Exception as e:
        print(f"[DB] save_evidence error: {e}")

def load_all_evidence():
    """Returns list of (id, text, source, embedding) β€” same shape the FAISS pipeline expects."""
    docs = list(get_db().evidence.find({}, {"_id": 1, "text": 1, "source": 1, "embedding": 1}))
    return [(str(d["_id"]), d["text"], d["source"], d.get("embedding")) for d in docs]

def get_total_evidence_count():
    return get_db().evidence.count_documents({})

def prune_old_evidence(days=30):
    cutoff = datetime.now(timezone.utc) - timedelta(days=days)
    result = get_db().evidence.delete_many({"created_at": {"$lt": cutoff}})
    return result.deleted_count


# ── User helpers ───────────────────────────────────────────────────────────────
def create_user(username, email, password_hash, is_admin=False):
    """Returns inserted ObjectId string, or None on duplicate/error."""
    try:
        result = get_db().users.insert_one({
            "username":      username,
            "email":         email,
            "password_hash": password_hash,
            "is_admin":      is_admin,
            "created_at":    datetime.now(timezone.utc)
        })
        return str(result.inserted_id)
    except Exception as e:
        # Avoid logging email β€” privacy
        print(f"[DB] create_user error: {type(e).__name__}")
        return None

def find_user_by_email(email):
    return get_db().users.find_one({"email": email})

def find_user_by_id(user_id):
    try:
        return get_db().users.find_one({"_id": ObjectId(user_id)})
    except Exception:
        return None


# ── JWT Blocklist helpers ──────────────────────────────────────────────────────
def add_token_to_blocklist(jti: str, exp: datetime):
    """Persist a revoked JWT jti. TTL index auto-removes it after exp."""
    try:
        get_db().revoked_tokens.insert_one({
            "jti": jti,
            "exp": exp,
            "revoked_at": datetime.now(timezone.utc)
        })
    except Exception as e:
        print(f"[DB] add_token_to_blocklist error: {type(e).__name__}")

def is_token_revoked(jti: str) -> bool:
    return get_db().revoked_tokens.find_one({"jti": jti}) is not None


# ── History helpers ────────────────────────────────────────────────────────────
def save_history(user_id, claim, verdict, confidence, evidence_count):
    try:
        get_db().history.insert_one({
            "user_id":        user_id,
            "claim":          claim,
            "verdict":        verdict,
            "confidence":     confidence,
            "evidence_count": evidence_count,
            "created_at":     datetime.now(timezone.utc)
        })
    except Exception as e:
        print(f"[DB] save_history error: {type(e).__name__}")

def get_user_history(user_id, limit=50):
    return list(
        get_db().history
        .find({"user_id": user_id})
        .sort("created_at", DESCENDING)
        .limit(limit)
    )

def delete_history_item(user_id, item_id):
    try:
        result = get_db().history.delete_one({
            "_id":     ObjectId(item_id),
            "user_id": user_id
        })
        return result.deleted_count == 1
    except Exception as e:
        print(f"[DB] delete_history_item error: {type(e).__name__}")
        return False

def clear_user_history(user_id):
    try:
        result = get_db().history.delete_many({"user_id": user_id})
        return result.deleted_count
    except Exception as e:
        print(f"[DB] clear_user_history error: {type(e).__name__}")
        return 0


# ── EXACT CLAIM CACHE ────────────────────────────────────────────────────────
def _normalize_claim(claim: str) -> str:
    """Lowercase and strip whitespace for exact matching."""
    return claim.strip().lower()

def get_cached_result(claim: str) -> dict:
    """Returns the cached fully structured API dictionary if it exists."""
    norm = _normalize_claim(claim)
    try:
        doc = get_db().cached_results.find_one({"normalized_claim": norm})
        if doc and "result" in doc:
            return doc["result"]
    except Exception as e:
        print(f"[DB] get_cached_result error: {type(e).__name__}")
    return None

def save_cached_result(claim: str, result: dict):
    """Saves a successful API run structure into the cache."""
    if not result.get("success"):
        return
    norm = _normalize_claim(claim)
    try:
        get_db().cached_results.update_one(
            {"normalized_claim": norm},
            {
                "$set": {
                    "result": result,
                    "updated_at": datetime.now(timezone.utc)
                },
                "$setOnInsert": {
                     "created_at": datetime.now(timezone.utc)
                }
            },
            upsert=True
        )
    except Exception as e:
        print(f"[DB] save_cached_result error: {type(e).__name__}")


# ── ADMIN HELPERS ────────────────────────────────────────────────────────────
def get_system_stats():
    """Aggregate high-level system metrics + chart data for God Mode."""
    db = get_db()
    try:
        total_users    = db.users.count_documents({})
        total_checks   = db.history.count_documents({})
        total_evidence = db.evidence.count_documents({})
        total_cached   = db.cached_results.count_documents({})

        day_ago  = datetime.now(timezone.utc) - timedelta(days=1)
        recent_checks = db.history.count_documents({"created_at": {"$gt": day_ago}})

        # ── Verdict breakdown (last 500 checks) ──────────────────────────────
        verdict_pipeline = [
            {"$group": {"_id": "$verdict", "count": {"$sum": 1}}},
        ]
        verdict_raw = list(db.history.aggregate(verdict_pipeline))
        verdict_counts = {r["_id"]: r["count"] for r in verdict_raw}

        # ── Daily checks β€” last 7 days ──────────────────────────────────────
        seven_days_ago = datetime.now(timezone.utc) - timedelta(days=7)
        daily_pipeline = [
            {"$match": {"created_at": {"$gt": seven_days_ago}}},
            {"$group": {
                "_id": {
                    "year":  {"$year": "$created_at"},
                    "month": {"$month": "$created_at"},
                    "day":   {"$dayOfMonth": "$created_at"},
                },
                "count": {"$sum": 1}
            }},
            {"$sort": {"_id.year": 1, "_id.month": 1, "_id.day": 1}}
        ]
        daily_raw = list(db.history.aggregate(daily_pipeline))
        # Build a filled 7-day array (fill missing days with 0)
        daily_map = {}
        for r in daily_raw:
            d = r["_id"]
            key = f"{d['year']}-{str(d['month']).zfill(2)}-{str(d['day']).zfill(2)}"
            daily_map[key] = r["count"]

        daily_labels, daily_data = [], []
        for i in range(6, -1, -1):
            day = datetime.now(timezone.utc) - timedelta(days=i)
            label = day.strftime("%b %d")
            key   = day.strftime("%Y-%m-%d")
            daily_labels.append(label)
            daily_data.append(daily_map.get(key, 0))

        # ── Top 5 users by check count ────────────────────────────────────
        top_users_pipeline = [
            {"$group": {"_id": "$user_id", "count": {"$sum": 1}}},
            {"$sort": {"count": -1}},
            {"$limit": 5},
        ]
        top_users_raw = list(db.history.aggregate(top_users_pipeline))
        top_users = []
        for r in top_users_raw:
            user = find_user_by_id(r["_id"])
            top_users.append({
                "username": user.get("username", "Unknown") if user else "Unknown",
                "count": r["count"]
            })

        cache_hit_rate = round((total_cached / total_checks * 100), 1) if total_checks else 0

        return {
            "total_users":        total_users,
            "total_checks":       total_checks,
            "total_evidence":     total_evidence,
            "total_cached":       total_cached,
            "recent_checks_24h":  recent_checks,
            "cache_hit_rate":     cache_hit_rate,
            "verdict_counts":     verdict_counts,
            "daily_labels":       daily_labels,
            "daily_data":         daily_data,
            "top_users":          top_users,
        }
    except Exception as e:
        print(f"[DB] get_system_stats error: {e}")
        return {}


def list_all_users(limit=100):
    """Returns all users for the admin dashboard."""
    return list(get_db().users.find({}, {"password_hash": 0}).sort("created_at", DESCENDING).limit(limit))


def get_global_history(limit=500):
    """Returns the most recent fact checks across all users, enriched with usernames."""
    history = list(get_db().history.find({}).sort("created_at", DESCENDING).limit(limit))
    user_map = {}
    for h in history:
        uid = h.get("user_id")
        if uid not in user_map:
            user = find_user_by_id(uid)
            user_map[uid] = user.get("username", "Unknown") if user else "Unknown"
        h["username"] = user_map[uid]
    return history