File size: 24,498 Bytes
cd3078d
 
 
2307ae4
cd3078d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4757a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd3078d
 
 
 
 
 
 
 
c4757a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd3078d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e1faa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2307ae4
 
 
 
 
4e1faa9
 
 
 
cd3078d
 
 
 
 
 
 
 
 
 
 
2307ae4
 
 
 
 
 
 
cd3078d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
from __future__ import annotations

import copy
import logging
import json
import re
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Iterable

import asyncpg
from bson import ObjectId

from app.config import settings


COLLECTIONS = (
    "users",
    "documents",
    "verifications",
    "refresh_tokens",
    "password_reset_tokens",
    "revoked_access_tokens",
    "notifications",
    "document_versions",
)

INDEX_STATEMENTS = (
    "CREATE UNIQUE INDEX IF NOT EXISTS users_email_unique ON users ((data->>'email'))",
    "CREATE INDEX IF NOT EXISTS users_role_idx ON users ((data->>'role'))",
    "CREATE INDEX IF NOT EXISTS documents_user_created_idx ON documents ((data->>'user_id'), (data->>'created_at') DESC)",
    "CREATE INDEX IF NOT EXISTS documents_is_deleted_created_idx ON documents ((data->>'is_deleted'), (data->>'created_at') DESC)",
    "CREATE INDEX IF NOT EXISTS verifications_document_idx ON verifications ((data->>'document_id'))",
    "CREATE INDEX IF NOT EXISTS verifications_expert_idx ON verifications ((data->>'expert_id'))",
    "CREATE INDEX IF NOT EXISTS verifications_role_idx ON verifications ((data->>'reviewer_role'))",
    "CREATE UNIQUE INDEX IF NOT EXISTS verifications_document_expert_unique ON verifications ((data->>'document_id'), (data->>'expert_id'))",
    "CREATE UNIQUE INDEX IF NOT EXISTS refresh_tokens_jti_unique ON refresh_tokens ((data->>'jti'))",
    "CREATE INDEX IF NOT EXISTS refresh_tokens_user_idx ON refresh_tokens ((data->>'user_id'))",
    "CREATE INDEX IF NOT EXISTS refresh_tokens_expires_idx ON refresh_tokens ((data->>'expires_at'))",
    "CREATE UNIQUE INDEX IF NOT EXISTS password_reset_tokens_hash_unique ON password_reset_tokens ((data->>'token_hash'))",
    "CREATE INDEX IF NOT EXISTS password_reset_tokens_expires_idx ON password_reset_tokens ((data->>'expires_at'))",
    "CREATE UNIQUE INDEX IF NOT EXISTS revoked_access_tokens_jti_unique ON revoked_access_tokens ((data->>'jti'))",
    "CREATE INDEX IF NOT EXISTS revoked_access_tokens_expires_idx ON revoked_access_tokens ((data->>'expires_at'))",
    "CREATE INDEX IF NOT EXISTS notifications_user_created_idx ON notifications ((data->>'user_id'), (data->>'created_at') DESC)",
    "CREATE INDEX IF NOT EXISTS notifications_user_is_read_idx ON notifications ((data->>'user_id'), (data->>'is_read'))",
    "CREATE INDEX IF NOT EXISTS document_versions_document_created_idx ON document_versions ((data->>'document_id'), (data->>'created_at') DESC)",
)

DATETIME_FIELDS = {
    "created_at",
    "updated_at",
    "expires_at",
    "used_at",
    "deleted_at",
    "submitter_paid_at",
    "linguist_edited_at",
    "translator_edited_at",
    "linguist_approved_at",
    "translator_approved_at",
}

client: asyncpg.Pool | None = None
db = None


class _ExtendedJSONEncoder(json.JSONEncoder):
    def default(self, obj: Any) -> Any:
        if isinstance(obj, datetime):
            return obj.isoformat()
        if isinstance(obj, ObjectId):
            return str(obj)
        return super().default(obj)


def _json_dumps(value: Any) -> str:
    return json.dumps(value, cls=_ExtendedJSONEncoder, separators=(",", ":"))


def _json_loads(value: Any) -> Any:
    if isinstance(value, str):
        raw = json.loads(value)
    elif isinstance(value, bytes):
        raw = json.loads(value.decode("utf-8"))
    else:
        raw = value
    return _restore_special_types(raw)


def _restore_special_types(value: Any, key: str | None = None) -> Any:
    if isinstance(value, str) and key in DATETIME_FIELDS:
        try:
            return datetime.fromisoformat(value)
        except Exception:
            return value
    if isinstance(value, list):
        return [_restore_special_types(v) for v in value]
    if isinstance(value, dict):
        return {k: _restore_special_types(v, key=k) for k, v in value.items()}
    return value


def _normalize_scalar(value: Any) -> Any:
    if isinstance(value, ObjectId):
        return str(value)
    return value


def _is_operator_dict(value: Any) -> bool:
    return isinstance(value, dict) and any(k.startswith("$") for k in value.keys())


def _sort_key(value: Any) -> tuple[int, Any]:
    value = _normalize_scalar(value)
    if value is None:
        return (1, "")
    if isinstance(value, datetime):
        return (0, value.timestamp())
    if isinstance(value, bool):
        return (0, int(value))
    if isinstance(value, (int, float)):
        return (0, value)
    return (0, str(value).lower())


def _compare(left: Any, right: Any, operator: str) -> bool:
    left = _normalize_scalar(left)
    right = _normalize_scalar(right)

    if operator == "$in":
        if not isinstance(right, list):
            return False
        return left in {_normalize_scalar(v) for v in right}
    if operator == "$ne":
        return left != right

    if left is None:
        return False

    try:
        if operator == "$gt":
            return left > right
        if operator == "$gte":
            return left >= right
        if operator == "$lt":
            return left < right
        if operator == "$lte":
            return left <= right
    except TypeError:
        return False

    return False


def _matches_query(document: dict[str, Any], query: dict[str, Any]) -> bool:
    for key, expected in (query or {}).items():
        if key == "$or":
            if not isinstance(expected, list) or not any(_matches_query(document, part) for part in expected):
                return False
            continue
        if key == "$and":
            if not isinstance(expected, list) or not all(_matches_query(document, part) for part in expected):
                return False
            continue

        field_value = document.get(key)

        if _is_operator_dict(expected):
            regex = expected.get("$regex")
            if regex is not None:
                flags = 0
                if "i" in str(expected.get("$options", "")):
                    flags |= re.IGNORECASE
                if field_value is None or re.search(str(regex), str(field_value), flags) is None:
                    return False

            for op, op_value in expected.items():
                if op in {"$regex", "$options"}:
                    continue
                if not _compare(field_value, op_value, op):
                    return False
            continue

        if _normalize_scalar(field_value) != _normalize_scalar(expected):
            return False

    return True


def _apply_projection(document: dict[str, Any], projection: dict[str, Any] | None) -> dict[str, Any]:
    if projection is None:
        return copy.deepcopy(document)

    include_fields = [k for k, v in projection.items() if bool(v) and k != "_id"]
    include_id = projection.get("_id", 1) != 0

    if include_fields:
        out: dict[str, Any] = {}
        if include_id and "_id" in document:
            out["_id"] = copy.deepcopy(document["_id"])
        for field in include_fields:
            if field in document:
                out[field] = copy.deepcopy(document[field])
        return out

    out = copy.deepcopy(document)
    if not include_id:
        out.pop("_id", None)
    return out


def _apply_update(document: dict[str, Any], update: dict[str, Any]) -> dict[str, Any]:
    out = copy.deepcopy(document)
    for op, payload in (update or {}).items():
        if op == "$set":
            for key, value in payload.items():
                out[key] = value
            continue
        if op == "$inc":
            for key, value in payload.items():
                out[key] = out.get(key, 0) + value
            continue
        raise ValueError(f"Unsupported update operator: {op}")
    return out


def _seed_from_query(query: dict[str, Any]) -> dict[str, Any]:
    seed: dict[str, Any] = {}
    for key, value in (query or {}).items():
        if key.startswith("$"):
            continue
        if _is_operator_dict(value):
            continue
        seed[key] = _normalize_scalar(value)
    return seed


@dataclass
class InsertOneResult:
    inserted_id: Any


@dataclass
class UpdateResult:
    matched_count: int
    modified_count: int
    upserted_id: Any | None = None


@dataclass
class DeleteResult:
    deleted_count: int


class AsyncCursor:
    def __init__(self, collection: "AsyncCollection", query: dict[str, Any] | None, projection: dict[str, Any] | None):
        self._collection = collection
        self._query = query or {}
        self._projection = projection
        self._sort_fields: list[tuple[str, int]] = []
        self._skip = 0
        self._limit: int | None = None
        self._loaded: list[dict[str, Any]] | None = None
        self._index = 0

    def sort(self, key_or_list: Any, direction: int | None = None) -> "AsyncCursor":
        if isinstance(key_or_list, list):
            for key, dir_value in key_or_list:
                self._sort_fields.append((str(key), int(dir_value)))
            return self

        self._sort_fields.append((str(key_or_list), int(direction or 1)))
        return self

    def skip(self, count: int) -> "AsyncCursor":
        self._skip = max(0, int(count))
        return self

    def limit(self, count: int) -> "AsyncCursor":
        self._limit = max(0, int(count))
        return self

    async def _ensure_loaded(self) -> None:
        if self._loaded is not None:
            return
        self._loaded = await self._collection._find_docs(
            query=self._query,
            projection=self._projection,
            sort_fields=self._sort_fields,
            skip=self._skip,
            limit=self._limit,
        )

    async def to_list(self, length: int | None = None) -> list[dict[str, Any]]:
        await self._ensure_loaded()
        items = self._loaded or []
        if length is None:
            return copy.deepcopy(items)
        return copy.deepcopy(items[: max(0, int(length))])

    def __aiter__(self) -> "AsyncCursor":
        return self

    async def __anext__(self) -> dict[str, Any]:
        await self._ensure_loaded()
        assert self._loaded is not None
        if self._index >= len(self._loaded):
            raise StopAsyncIteration
        item = self._loaded[self._index]
        self._index += 1
        return copy.deepcopy(item)


class AsyncCollection:
    def __init__(self, database: "PostgresDocumentDatabase", name: str):
        self._database = database
        self._name = name

    async def create_index(self, keys: Any, unique: bool = False, expireAfterSeconds: int | None = None) -> None:
        # Indexes are created in bootstrap DDL.
        return None

    async def _fetch_all_documents(self) -> list[dict[str, Any]]:
        rows = await self._database.pool.fetch(f'SELECT _id, data FROM "{self._name}"')
        docs: list[dict[str, Any]] = []
        for row in rows:
            data = _json_loads(row["data"]) or {}
            data["_id"] = row["_id"]
            docs.append(data)
        return docs

    async def _store_document(self, document: dict[str, Any]) -> None:
        doc_id = str(_normalize_scalar(document["_id"]))
        payload = copy.deepcopy(document)
        payload.pop("_id", None)
        await self._database.pool.execute(
            f'INSERT INTO "{self._name}" (_id, data) VALUES ($1, $2::jsonb) '
            f'ON CONFLICT (_id) DO UPDATE SET data = EXCLUDED.data',
            doc_id,
            _json_dumps(payload),
        )

    @staticmethod
    def _build_sql_conditions(
        query: dict[str, Any] | None,
    ) -> tuple[list[str], list[Any], bool]:
        """Try to convert MongoDB-style query to SQL WHERE clauses.

        Returns (conditions, params, needs_python_filter).
        If needs_python_filter is True, the SQL result must still be
        filtered in Python with _matches_query for correctness.
        """
        if not query:
            return [], [], False

        conditions: list[str] = []
        params: list[Any] = []
        needs_python = False
        idx = 1  # $1, $2, ... param counter

        for key, expected in query.items():
            if key in ("$or", "$and"):
                needs_python = True
                continue

            if key == "_id":
                if isinstance(expected, (str, ObjectId)):
                    conditions.append(f"_id = ${idx}")
                    params.append(str(_normalize_scalar(expected)))
                    idx += 1
                else:
                    needs_python = True
                continue

            if not _is_operator_dict(expected):
                # Simple equality: data->>'field' = $N
                # Booleans need special handling: Python str(True)='True', JSONB text='true'
                if expected is True:
                    conditions.append(f"data->>'{key}' = 'true'")
                elif expected is False:
                    conditions.append(f"data->>'{key}' = 'false'")
                elif expected is None:
                    conditions.append(f"data->>'{key}' IS NULL")
                else:
                    conditions.append(f"data->>'{key}' = ${idx}")
                    params.append(str(_normalize_scalar(expected)))
                    idx += 1
                continue

            # Operator dict
            for op, op_value in expected.items():
                if op == "$ne":
                    if op_value is True:
                        conditions.append(
                            f"(data->>'{key}' IS NULL OR data->>'{key}' != 'true')"
                        )
                    elif op_value is False:
                        conditions.append(
                            f"(data->>'{key}' IS NULL OR data->>'{key}' != 'false')"
                        )
                    else:
                        conditions.append(
                            f"(data->>'{key}' IS NULL OR data->>'{key}' != ${idx})"
                        )
                        params.append(str(_normalize_scalar(op_value)))
                        idx += 1
                elif op == "$in":
                    if isinstance(op_value, list) and op_value:
                        placeholders = ", ".join(
                            f"${idx + i}" for i in range(len(op_value))
                        )
                        conditions.append(
                            f"data->>'{key}' IN ({placeholders})"
                        )
                        for v in op_value:
                            params.append(str(_normalize_scalar(v)))
                            idx += 1
                    else:
                        needs_python = True
                elif op in ("$gt", "$gte", "$lt", "$lte"):
                    sql_op = {"$gt": ">", "$gte": ">=", "$lt": "<", "$lte": "<="}[op]
                    conditions.append(f"data->>'{key}' {sql_op} ${idx}")
                    params.append(str(_normalize_scalar(op_value)))
                    idx += 1
                elif op in ("$regex", "$options"):
                    needs_python = True
                else:
                    needs_python = True

        return conditions, params, needs_python

    async def _find_docs(
        self,
        query: dict[str, Any] | None,
        projection: dict[str, Any] | None,
        sort_fields: Iterable[tuple[str, int]] | None = None,
        skip: int = 0,
        limit: int | None = None,
    ) -> list[dict[str, Any]]:
        conditions, params, needs_python = self._build_sql_conditions(query)

        sql = f'SELECT _id, data FROM "{self._name}"'
        if conditions:
            sql += " WHERE " + " AND ".join(conditions)

        # Push sort to SQL when possible (single sort field)
        sort_list = list(sort_fields or [])
        sql_sorted = False
        if sort_list and not needs_python:
            order_clauses = []
            for field, direction in sort_list:
                dir_str = "DESC" if int(direction) == -1 else "ASC"
                order_clauses.append(f"data->>'{field}' {dir_str}")
            sql += " ORDER BY " + ", ".join(order_clauses)
            sql_sorted = True

        # Push limit/skip to SQL when no Python filtering needed
        if not needs_python and sql_sorted:
            if skip:
                sql += f" OFFSET {max(0, int(skip))}"
            if limit is not None:
                sql += f" LIMIT {max(0, int(limit))}"

        rows = await self._database.pool.fetch(sql, *params)
        documents: list[dict[str, Any]] = []
        for row in rows:
            data = _json_loads(row["data"]) or {}
            data["_id"] = row["_id"]
            documents.append(data)

        # If we needed Python filtering, apply it now on the narrowed set
        if needs_python:
            documents = [doc for doc in documents if _matches_query(doc, query or {})]

        # If sorting wasn't done in SQL, do it in Python
        if not sql_sorted and sort_list:
            for field, direction in reversed(sort_list):
                documents.sort(
                    key=lambda item: _sort_key(item.get(field)),
                    reverse=int(direction) == -1,
                )

        # If skip/limit weren't pushed to SQL, apply in Python
        if needs_python or not sql_sorted:
            if skip:
                documents = documents[max(0, int(skip)):]
            if limit is not None:
                documents = documents[:max(0, int(limit))]

        return [_apply_projection(doc, projection) for doc in documents]

    def find(self, query: dict[str, Any] | None = None, projection: dict[str, Any] | None = None) -> AsyncCursor:
        return AsyncCursor(self, query, projection)

    async def find_one(
        self,
        query: dict[str, Any] | None = None,
        projection: dict[str, Any] | None = None,
        sort: list[tuple[str, int]] | None = None,
    ) -> dict[str, Any] | None:
        docs = await self._find_docs(query=query, projection=projection, sort_fields=sort, limit=1)
        return docs[0] if docs else None

    async def insert_one(self, document: dict[str, Any]) -> InsertOneResult:
        payload = copy.deepcopy(document)
        existing_id = payload.get("_id")
        if existing_id is None:
            doc_id = str(ObjectId())
        else:
            doc_id = str(_normalize_scalar(existing_id))
        payload["_id"] = doc_id
        await self._store_document(payload)

        try:
            inserted_id: Any = ObjectId(doc_id)
        except Exception:
            inserted_id = doc_id
        return InsertOneResult(inserted_id=inserted_id)

    async def update_one(self, query: dict[str, Any], update: dict[str, Any], upsert: bool = False) -> UpdateResult:
        docs = await self._find_docs(query=query, projection=None, limit=1)
        if docs:
            original = docs[0]
            updated = _apply_update(original, update)
            updated["_id"] = original["_id"]
            modified = int(updated != original)
            await self._store_document(updated)
            return UpdateResult(matched_count=1, modified_count=modified)

        if not upsert:
            return UpdateResult(matched_count=0, modified_count=0)

        upsert_doc = _seed_from_query(query)
        upsert_doc = _apply_update(upsert_doc, update)
        if "_id" not in upsert_doc:
            upsert_doc["_id"] = str(ObjectId())
        else:
            upsert_doc["_id"] = str(_normalize_scalar(upsert_doc["_id"]))

        await self._store_document(upsert_doc)
        return UpdateResult(matched_count=0, modified_count=1, upserted_id=upsert_doc["_id"])

    async def update_many(self, query: dict[str, Any], update: dict[str, Any], upsert: bool = False) -> UpdateResult:
        docs = await self._find_docs(query=query, projection=None)
        modified = 0
        for original in docs:
            updated = _apply_update(original, update)
            updated["_id"] = original["_id"]
            if updated != original:
                modified += 1
            await self._store_document(updated)

        if docs:
            return UpdateResult(matched_count=len(docs), modified_count=modified)

        if upsert:
            upsert_doc = _seed_from_query(query)
            upsert_doc = _apply_update(upsert_doc, update)
            if "_id" not in upsert_doc:
                upsert_doc["_id"] = str(ObjectId())
            else:
                upsert_doc["_id"] = str(_normalize_scalar(upsert_doc["_id"]))
            await self._store_document(upsert_doc)
            return UpdateResult(matched_count=0, modified_count=1, upserted_id=upsert_doc["_id"])

        return UpdateResult(matched_count=0, modified_count=0)

    async def count_documents(self, query: dict[str, Any] | None = None) -> int:
        docs = await self._find_docs(query=query, projection={"_id": 1})
        return len(docs)

    async def delete_one(self, query: dict[str, Any]) -> DeleteResult:
        docs = await self._find_docs(query=query, projection={"_id": 1}, limit=1)
        if not docs:
            return DeleteResult(deleted_count=0)
        await self._database.pool.execute(
            f'DELETE FROM "{self._name}" WHERE _id = $1',
            str(_normalize_scalar(docs[0]["_id"])),
        )
        return DeleteResult(deleted_count=1)


class PostgresDocumentDatabase:
    def __init__(self, pool: asyncpg.Pool):
        self.pool = pool

        self.users = AsyncCollection(self, "users")
        self.documents = AsyncCollection(self, "documents")
        self.verifications = AsyncCollection(self, "verifications")
        self.refresh_tokens = AsyncCollection(self, "refresh_tokens")
        self.password_reset_tokens = AsyncCollection(self, "password_reset_tokens")
        self.revoked_access_tokens = AsyncCollection(self, "revoked_access_tokens")
        self.notifications = AsyncCollection(self, "notifications")
        self.document_versions = AsyncCollection(self, "document_versions")

    async def initialize(self) -> None:
        for table in COLLECTIONS:
            await self.pool.execute(
                f'CREATE TABLE IF NOT EXISTS "{table}" ('
                '_id TEXT PRIMARY KEY, '
                'data JSONB NOT NULL'
                ')'
            )
        for statement in INDEX_STATEMENTS:
            await self.pool.execute(statement)

    async def command(self, name: str) -> dict[str, int]:
        if name != "ping":
            raise ValueError(f"Unsupported command: {name}")
        await self.pool.fetchval("SELECT 1")
        return {"ok": 1}


async def connect_db():
    global client, db
    import asyncio
    import logging

    logger = logging.getLogger(__name__)
    uri = settings.postgres_uri

    # Ensure SSL for remote connections (required by Supabase)
    if "supabase" in uri and "sslmode" not in uri:
        separator = "&" if "?" in uri else "?"
        uri = f"{uri}{separator}sslmode=require"

    max_retries = 3
    for attempt in range(1, max_retries + 1):
        try:
            client = await asyncpg.create_pool(
                uri, min_size=1, max_size=10,
                command_timeout=30, timeout=30,
            )
            db = PostgresDocumentDatabase(client)
            await db.initialize()
            logger.info("Connected to PostgreSQL")
            return
        except Exception as e:
            logger.warning(f"DB connection attempt {attempt}/{max_retries} failed: {e}")
            if attempt < max_retries:
                await asyncio.sleep(2 * attempt)
            else:
                logger.error(
                    "\n" + "=" * 60 +
                    "\n  DATABASE CONNECTION FAILED after %d attempts!"
                    "\n  All API endpoints requiring the database will return 503."
                    "\n  Check your POSTGRES_URI environment variable."
                    "\n" + "=" * 60,
                    max_retries,
                )
                # Don't crash the app — let it start for health checks
                return


async def close_db():
    global client
    if client:
        await client.close()
        client = None
        print("PostgreSQL connection closed")


def get_db():
    if db is None:
        logger.error("get_db() called but database is not connected!")
        from fastapi import HTTPException
        raise HTTPException(
            status_code=503,
            detail="Database is not connected. Check server logs for POSTGRES_URI issues.",
        )
    return db