Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import copy | |
| import logging | |
| import json | |
| import re | |
| from dataclasses import dataclass | |
| from datetime import datetime | |
| from typing import Any, Iterable | |
| import asyncpg | |
| from bson import ObjectId | |
| from app.config import settings | |
| COLLECTIONS = ( | |
| "users", | |
| "documents", | |
| "verifications", | |
| "refresh_tokens", | |
| "password_reset_tokens", | |
| "revoked_access_tokens", | |
| "notifications", | |
| "document_versions", | |
| ) | |
| INDEX_STATEMENTS = ( | |
| "CREATE UNIQUE INDEX IF NOT EXISTS users_email_unique ON users ((data->>'email'))", | |
| "CREATE INDEX IF NOT EXISTS users_role_idx ON users ((data->>'role'))", | |
| "CREATE INDEX IF NOT EXISTS documents_user_created_idx ON documents ((data->>'user_id'), (data->>'created_at') DESC)", | |
| "CREATE INDEX IF NOT EXISTS documents_is_deleted_created_idx ON documents ((data->>'is_deleted'), (data->>'created_at') DESC)", | |
| "CREATE INDEX IF NOT EXISTS verifications_document_idx ON verifications ((data->>'document_id'))", | |
| "CREATE INDEX IF NOT EXISTS verifications_expert_idx ON verifications ((data->>'expert_id'))", | |
| "CREATE INDEX IF NOT EXISTS verifications_role_idx ON verifications ((data->>'reviewer_role'))", | |
| "CREATE UNIQUE INDEX IF NOT EXISTS verifications_document_expert_unique ON verifications ((data->>'document_id'), (data->>'expert_id'))", | |
| "CREATE UNIQUE INDEX IF NOT EXISTS refresh_tokens_jti_unique ON refresh_tokens ((data->>'jti'))", | |
| "CREATE INDEX IF NOT EXISTS refresh_tokens_user_idx ON refresh_tokens ((data->>'user_id'))", | |
| "CREATE INDEX IF NOT EXISTS refresh_tokens_expires_idx ON refresh_tokens ((data->>'expires_at'))", | |
| "CREATE UNIQUE INDEX IF NOT EXISTS password_reset_tokens_hash_unique ON password_reset_tokens ((data->>'token_hash'))", | |
| "CREATE INDEX IF NOT EXISTS password_reset_tokens_expires_idx ON password_reset_tokens ((data->>'expires_at'))", | |
| "CREATE UNIQUE INDEX IF NOT EXISTS revoked_access_tokens_jti_unique ON revoked_access_tokens ((data->>'jti'))", | |
| "CREATE INDEX IF NOT EXISTS revoked_access_tokens_expires_idx ON revoked_access_tokens ((data->>'expires_at'))", | |
| "CREATE INDEX IF NOT EXISTS notifications_user_created_idx ON notifications ((data->>'user_id'), (data->>'created_at') DESC)", | |
| "CREATE INDEX IF NOT EXISTS notifications_user_is_read_idx ON notifications ((data->>'user_id'), (data->>'is_read'))", | |
| "CREATE INDEX IF NOT EXISTS document_versions_document_created_idx ON document_versions ((data->>'document_id'), (data->>'created_at') DESC)", | |
| ) | |
| DATETIME_FIELDS = { | |
| "created_at", | |
| "updated_at", | |
| "expires_at", | |
| "used_at", | |
| "deleted_at", | |
| "submitter_paid_at", | |
| "linguist_edited_at", | |
| "translator_edited_at", | |
| "linguist_approved_at", | |
| "translator_approved_at", | |
| } | |
| client: asyncpg.Pool | None = None | |
| db = None | |
| class _ExtendedJSONEncoder(json.JSONEncoder): | |
| def default(self, obj: Any) -> Any: | |
| if isinstance(obj, datetime): | |
| return obj.isoformat() | |
| if isinstance(obj, ObjectId): | |
| return str(obj) | |
| return super().default(obj) | |
| def _json_dumps(value: Any) -> str: | |
| return json.dumps(value, cls=_ExtendedJSONEncoder, separators=(",", ":")) | |
| def _json_loads(value: Any) -> Any: | |
| if isinstance(value, str): | |
| raw = json.loads(value) | |
| elif isinstance(value, bytes): | |
| raw = json.loads(value.decode("utf-8")) | |
| else: | |
| raw = value | |
| return _restore_special_types(raw) | |
| def _restore_special_types(value: Any, key: str | None = None) -> Any: | |
| if isinstance(value, str) and key in DATETIME_FIELDS: | |
| try: | |
| return datetime.fromisoformat(value) | |
| except Exception: | |
| return value | |
| if isinstance(value, list): | |
| return [_restore_special_types(v) for v in value] | |
| if isinstance(value, dict): | |
| return {k: _restore_special_types(v, key=k) for k, v in value.items()} | |
| return value | |
| def _normalize_scalar(value: Any) -> Any: | |
| if isinstance(value, ObjectId): | |
| return str(value) | |
| return value | |
| def _is_operator_dict(value: Any) -> bool: | |
| return isinstance(value, dict) and any(k.startswith("$") for k in value.keys()) | |
| def _sort_key(value: Any) -> tuple[int, Any]: | |
| value = _normalize_scalar(value) | |
| if value is None: | |
| return (1, "") | |
| if isinstance(value, datetime): | |
| return (0, value.timestamp()) | |
| if isinstance(value, bool): | |
| return (0, int(value)) | |
| if isinstance(value, (int, float)): | |
| return (0, value) | |
| return (0, str(value).lower()) | |
| def _compare(left: Any, right: Any, operator: str) -> bool: | |
| left = _normalize_scalar(left) | |
| right = _normalize_scalar(right) | |
| if operator == "$in": | |
| if not isinstance(right, list): | |
| return False | |
| return left in {_normalize_scalar(v) for v in right} | |
| if operator == "$ne": | |
| return left != right | |
| if left is None: | |
| return False | |
| try: | |
| if operator == "$gt": | |
| return left > right | |
| if operator == "$gte": | |
| return left >= right | |
| if operator == "$lt": | |
| return left < right | |
| if operator == "$lte": | |
| return left <= right | |
| except TypeError: | |
| return False | |
| return False | |
| def _matches_query(document: dict[str, Any], query: dict[str, Any]) -> bool: | |
| for key, expected in (query or {}).items(): | |
| if key == "$or": | |
| if not isinstance(expected, list) or not any(_matches_query(document, part) for part in expected): | |
| return False | |
| continue | |
| if key == "$and": | |
| if not isinstance(expected, list) or not all(_matches_query(document, part) for part in expected): | |
| return False | |
| continue | |
| field_value = document.get(key) | |
| if _is_operator_dict(expected): | |
| regex = expected.get("$regex") | |
| if regex is not None: | |
| flags = 0 | |
| if "i" in str(expected.get("$options", "")): | |
| flags |= re.IGNORECASE | |
| if field_value is None or re.search(str(regex), str(field_value), flags) is None: | |
| return False | |
| for op, op_value in expected.items(): | |
| if op in {"$regex", "$options"}: | |
| continue | |
| if not _compare(field_value, op_value, op): | |
| return False | |
| continue | |
| if _normalize_scalar(field_value) != _normalize_scalar(expected): | |
| return False | |
| return True | |
| def _apply_projection(document: dict[str, Any], projection: dict[str, Any] | None) -> dict[str, Any]: | |
| if projection is None: | |
| return copy.deepcopy(document) | |
| include_fields = [k for k, v in projection.items() if bool(v) and k != "_id"] | |
| include_id = projection.get("_id", 1) != 0 | |
| if include_fields: | |
| out: dict[str, Any] = {} | |
| if include_id and "_id" in document: | |
| out["_id"] = copy.deepcopy(document["_id"]) | |
| for field in include_fields: | |
| if field in document: | |
| out[field] = copy.deepcopy(document[field]) | |
| return out | |
| out = copy.deepcopy(document) | |
| if not include_id: | |
| out.pop("_id", None) | |
| return out | |
| def _apply_update(document: dict[str, Any], update: dict[str, Any]) -> dict[str, Any]: | |
| out = copy.deepcopy(document) | |
| for op, payload in (update or {}).items(): | |
| if op == "$set": | |
| for key, value in payload.items(): | |
| out[key] = value | |
| continue | |
| if op == "$inc": | |
| for key, value in payload.items(): | |
| out[key] = out.get(key, 0) + value | |
| continue | |
| raise ValueError(f"Unsupported update operator: {op}") | |
| return out | |
| def _seed_from_query(query: dict[str, Any]) -> dict[str, Any]: | |
| seed: dict[str, Any] = {} | |
| for key, value in (query or {}).items(): | |
| if key.startswith("$"): | |
| continue | |
| if _is_operator_dict(value): | |
| continue | |
| seed[key] = _normalize_scalar(value) | |
| return seed | |
| class InsertOneResult: | |
| inserted_id: Any | |
| class UpdateResult: | |
| matched_count: int | |
| modified_count: int | |
| upserted_id: Any | None = None | |
| class DeleteResult: | |
| deleted_count: int | |
| class AsyncCursor: | |
| def __init__(self, collection: "AsyncCollection", query: dict[str, Any] | None, projection: dict[str, Any] | None): | |
| self._collection = collection | |
| self._query = query or {} | |
| self._projection = projection | |
| self._sort_fields: list[tuple[str, int]] = [] | |
| self._skip = 0 | |
| self._limit: int | None = None | |
| self._loaded: list[dict[str, Any]] | None = None | |
| self._index = 0 | |
| def sort(self, key_or_list: Any, direction: int | None = None) -> "AsyncCursor": | |
| if isinstance(key_or_list, list): | |
| for key, dir_value in key_or_list: | |
| self._sort_fields.append((str(key), int(dir_value))) | |
| return self | |
| self._sort_fields.append((str(key_or_list), int(direction or 1))) | |
| return self | |
| def skip(self, count: int) -> "AsyncCursor": | |
| self._skip = max(0, int(count)) | |
| return self | |
| def limit(self, count: int) -> "AsyncCursor": | |
| self._limit = max(0, int(count)) | |
| return self | |
| async def _ensure_loaded(self) -> None: | |
| if self._loaded is not None: | |
| return | |
| self._loaded = await self._collection._find_docs( | |
| query=self._query, | |
| projection=self._projection, | |
| sort_fields=self._sort_fields, | |
| skip=self._skip, | |
| limit=self._limit, | |
| ) | |
| async def to_list(self, length: int | None = None) -> list[dict[str, Any]]: | |
| await self._ensure_loaded() | |
| items = self._loaded or [] | |
| if length is None: | |
| return copy.deepcopy(items) | |
| return copy.deepcopy(items[: max(0, int(length))]) | |
| def __aiter__(self) -> "AsyncCursor": | |
| return self | |
| async def __anext__(self) -> dict[str, Any]: | |
| await self._ensure_loaded() | |
| assert self._loaded is not None | |
| if self._index >= len(self._loaded): | |
| raise StopAsyncIteration | |
| item = self._loaded[self._index] | |
| self._index += 1 | |
| return copy.deepcopy(item) | |
| class AsyncCollection: | |
| def __init__(self, database: "PostgresDocumentDatabase", name: str): | |
| self._database = database | |
| self._name = name | |
| async def create_index(self, keys: Any, unique: bool = False, expireAfterSeconds: int | None = None) -> None: | |
| # Indexes are created in bootstrap DDL. | |
| return None | |
| async def _fetch_all_documents(self) -> list[dict[str, Any]]: | |
| rows = await self._database.pool.fetch(f'SELECT _id, data FROM "{self._name}"') | |
| docs: list[dict[str, Any]] = [] | |
| for row in rows: | |
| data = _json_loads(row["data"]) or {} | |
| data["_id"] = row["_id"] | |
| docs.append(data) | |
| return docs | |
| async def _store_document(self, document: dict[str, Any]) -> None: | |
| doc_id = str(_normalize_scalar(document["_id"])) | |
| payload = copy.deepcopy(document) | |
| payload.pop("_id", None) | |
| await self._database.pool.execute( | |
| f'INSERT INTO "{self._name}" (_id, data) VALUES ($1, $2::jsonb) ' | |
| f'ON CONFLICT (_id) DO UPDATE SET data = EXCLUDED.data', | |
| doc_id, | |
| _json_dumps(payload), | |
| ) | |
| def _build_sql_conditions( | |
| query: dict[str, Any] | None, | |
| ) -> tuple[list[str], list[Any], bool]: | |
| """Try to convert MongoDB-style query to SQL WHERE clauses. | |
| Returns (conditions, params, needs_python_filter). | |
| If needs_python_filter is True, the SQL result must still be | |
| filtered in Python with _matches_query for correctness. | |
| """ | |
| if not query: | |
| return [], [], False | |
| conditions: list[str] = [] | |
| params: list[Any] = [] | |
| needs_python = False | |
| idx = 1 # $1, $2, ... param counter | |
| for key, expected in query.items(): | |
| if key in ("$or", "$and"): | |
| needs_python = True | |
| continue | |
| if key == "_id": | |
| if isinstance(expected, (str, ObjectId)): | |
| conditions.append(f"_id = ${idx}") | |
| params.append(str(_normalize_scalar(expected))) | |
| idx += 1 | |
| else: | |
| needs_python = True | |
| continue | |
| if not _is_operator_dict(expected): | |
| # Simple equality: data->>'field' = $N | |
| # Booleans need special handling: Python str(True)='True', JSONB text='true' | |
| if expected is True: | |
| conditions.append(f"data->>'{key}' = 'true'") | |
| elif expected is False: | |
| conditions.append(f"data->>'{key}' = 'false'") | |
| elif expected is None: | |
| conditions.append(f"data->>'{key}' IS NULL") | |
| else: | |
| conditions.append(f"data->>'{key}' = ${idx}") | |
| params.append(str(_normalize_scalar(expected))) | |
| idx += 1 | |
| continue | |
| # Operator dict | |
| for op, op_value in expected.items(): | |
| if op == "$ne": | |
| if op_value is True: | |
| conditions.append( | |
| f"(data->>'{key}' IS NULL OR data->>'{key}' != 'true')" | |
| ) | |
| elif op_value is False: | |
| conditions.append( | |
| f"(data->>'{key}' IS NULL OR data->>'{key}' != 'false')" | |
| ) | |
| else: | |
| conditions.append( | |
| f"(data->>'{key}' IS NULL OR data->>'{key}' != ${idx})" | |
| ) | |
| params.append(str(_normalize_scalar(op_value))) | |
| idx += 1 | |
| elif op == "$in": | |
| if isinstance(op_value, list) and op_value: | |
| placeholders = ", ".join( | |
| f"${idx + i}" for i in range(len(op_value)) | |
| ) | |
| conditions.append( | |
| f"data->>'{key}' IN ({placeholders})" | |
| ) | |
| for v in op_value: | |
| params.append(str(_normalize_scalar(v))) | |
| idx += 1 | |
| else: | |
| needs_python = True | |
| elif op in ("$gt", "$gte", "$lt", "$lte"): | |
| sql_op = {"$gt": ">", "$gte": ">=", "$lt": "<", "$lte": "<="}[op] | |
| conditions.append(f"data->>'{key}' {sql_op} ${idx}") | |
| params.append(str(_normalize_scalar(op_value))) | |
| idx += 1 | |
| elif op in ("$regex", "$options"): | |
| needs_python = True | |
| else: | |
| needs_python = True | |
| return conditions, params, needs_python | |
| async def _find_docs( | |
| self, | |
| query: dict[str, Any] | None, | |
| projection: dict[str, Any] | None, | |
| sort_fields: Iterable[tuple[str, int]] | None = None, | |
| skip: int = 0, | |
| limit: int | None = None, | |
| ) -> list[dict[str, Any]]: | |
| conditions, params, needs_python = self._build_sql_conditions(query) | |
| sql = f'SELECT _id, data FROM "{self._name}"' | |
| if conditions: | |
| sql += " WHERE " + " AND ".join(conditions) | |
| # Push sort to SQL when possible (single sort field) | |
| sort_list = list(sort_fields or []) | |
| sql_sorted = False | |
| if sort_list and not needs_python: | |
| order_clauses = [] | |
| for field, direction in sort_list: | |
| dir_str = "DESC" if int(direction) == -1 else "ASC" | |
| order_clauses.append(f"data->>'{field}' {dir_str}") | |
| sql += " ORDER BY " + ", ".join(order_clauses) | |
| sql_sorted = True | |
| # Push limit/skip to SQL when no Python filtering needed | |
| if not needs_python and sql_sorted: | |
| if skip: | |
| sql += f" OFFSET {max(0, int(skip))}" | |
| if limit is not None: | |
| sql += f" LIMIT {max(0, int(limit))}" | |
| rows = await self._database.pool.fetch(sql, *params) | |
| documents: list[dict[str, Any]] = [] | |
| for row in rows: | |
| data = _json_loads(row["data"]) or {} | |
| data["_id"] = row["_id"] | |
| documents.append(data) | |
| # If we needed Python filtering, apply it now on the narrowed set | |
| if needs_python: | |
| documents = [doc for doc in documents if _matches_query(doc, query or {})] | |
| # If sorting wasn't done in SQL, do it in Python | |
| if not sql_sorted and sort_list: | |
| for field, direction in reversed(sort_list): | |
| documents.sort( | |
| key=lambda item: _sort_key(item.get(field)), | |
| reverse=int(direction) == -1, | |
| ) | |
| # If skip/limit weren't pushed to SQL, apply in Python | |
| if needs_python or not sql_sorted: | |
| if skip: | |
| documents = documents[max(0, int(skip)):] | |
| if limit is not None: | |
| documents = documents[:max(0, int(limit))] | |
| return [_apply_projection(doc, projection) for doc in documents] | |
| def find(self, query: dict[str, Any] | None = None, projection: dict[str, Any] | None = None) -> AsyncCursor: | |
| return AsyncCursor(self, query, projection) | |
| async def find_one( | |
| self, | |
| query: dict[str, Any] | None = None, | |
| projection: dict[str, Any] | None = None, | |
| sort: list[tuple[str, int]] | None = None, | |
| ) -> dict[str, Any] | None: | |
| docs = await self._find_docs(query=query, projection=projection, sort_fields=sort, limit=1) | |
| return docs[0] if docs else None | |
| async def insert_one(self, document: dict[str, Any]) -> InsertOneResult: | |
| payload = copy.deepcopy(document) | |
| existing_id = payload.get("_id") | |
| if existing_id is None: | |
| doc_id = str(ObjectId()) | |
| else: | |
| doc_id = str(_normalize_scalar(existing_id)) | |
| payload["_id"] = doc_id | |
| await self._store_document(payload) | |
| try: | |
| inserted_id: Any = ObjectId(doc_id) | |
| except Exception: | |
| inserted_id = doc_id | |
| return InsertOneResult(inserted_id=inserted_id) | |
| async def update_one(self, query: dict[str, Any], update: dict[str, Any], upsert: bool = False) -> UpdateResult: | |
| docs = await self._find_docs(query=query, projection=None, limit=1) | |
| if docs: | |
| original = docs[0] | |
| updated = _apply_update(original, update) | |
| updated["_id"] = original["_id"] | |
| modified = int(updated != original) | |
| await self._store_document(updated) | |
| return UpdateResult(matched_count=1, modified_count=modified) | |
| if not upsert: | |
| return UpdateResult(matched_count=0, modified_count=0) | |
| upsert_doc = _seed_from_query(query) | |
| upsert_doc = _apply_update(upsert_doc, update) | |
| if "_id" not in upsert_doc: | |
| upsert_doc["_id"] = str(ObjectId()) | |
| else: | |
| upsert_doc["_id"] = str(_normalize_scalar(upsert_doc["_id"])) | |
| await self._store_document(upsert_doc) | |
| return UpdateResult(matched_count=0, modified_count=1, upserted_id=upsert_doc["_id"]) | |
| async def update_many(self, query: dict[str, Any], update: dict[str, Any], upsert: bool = False) -> UpdateResult: | |
| docs = await self._find_docs(query=query, projection=None) | |
| modified = 0 | |
| for original in docs: | |
| updated = _apply_update(original, update) | |
| updated["_id"] = original["_id"] | |
| if updated != original: | |
| modified += 1 | |
| await self._store_document(updated) | |
| if docs: | |
| return UpdateResult(matched_count=len(docs), modified_count=modified) | |
| if upsert: | |
| upsert_doc = _seed_from_query(query) | |
| upsert_doc = _apply_update(upsert_doc, update) | |
| if "_id" not in upsert_doc: | |
| upsert_doc["_id"] = str(ObjectId()) | |
| else: | |
| upsert_doc["_id"] = str(_normalize_scalar(upsert_doc["_id"])) | |
| await self._store_document(upsert_doc) | |
| return UpdateResult(matched_count=0, modified_count=1, upserted_id=upsert_doc["_id"]) | |
| return UpdateResult(matched_count=0, modified_count=0) | |
| async def count_documents(self, query: dict[str, Any] | None = None) -> int: | |
| docs = await self._find_docs(query=query, projection={"_id": 1}) | |
| return len(docs) | |
| async def delete_one(self, query: dict[str, Any]) -> DeleteResult: | |
| docs = await self._find_docs(query=query, projection={"_id": 1}, limit=1) | |
| if not docs: | |
| return DeleteResult(deleted_count=0) | |
| await self._database.pool.execute( | |
| f'DELETE FROM "{self._name}" WHERE _id = $1', | |
| str(_normalize_scalar(docs[0]["_id"])), | |
| ) | |
| return DeleteResult(deleted_count=1) | |
| class PostgresDocumentDatabase: | |
| def __init__(self, pool: asyncpg.Pool): | |
| self.pool = pool | |
| self.users = AsyncCollection(self, "users") | |
| self.documents = AsyncCollection(self, "documents") | |
| self.verifications = AsyncCollection(self, "verifications") | |
| self.refresh_tokens = AsyncCollection(self, "refresh_tokens") | |
| self.password_reset_tokens = AsyncCollection(self, "password_reset_tokens") | |
| self.revoked_access_tokens = AsyncCollection(self, "revoked_access_tokens") | |
| self.notifications = AsyncCollection(self, "notifications") | |
| self.document_versions = AsyncCollection(self, "document_versions") | |
| async def initialize(self) -> None: | |
| for table in COLLECTIONS: | |
| await self.pool.execute( | |
| f'CREATE TABLE IF NOT EXISTS "{table}" (' | |
| '_id TEXT PRIMARY KEY, ' | |
| 'data JSONB NOT NULL' | |
| ')' | |
| ) | |
| for statement in INDEX_STATEMENTS: | |
| await self.pool.execute(statement) | |
| async def command(self, name: str) -> dict[str, int]: | |
| if name != "ping": | |
| raise ValueError(f"Unsupported command: {name}") | |
| await self.pool.fetchval("SELECT 1") | |
| return {"ok": 1} | |
| async def connect_db(): | |
| global client, db | |
| import asyncio | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| uri = settings.postgres_uri | |
| # Ensure SSL for remote connections (required by Supabase) | |
| if "supabase" in uri and "sslmode" not in uri: | |
| separator = "&" if "?" in uri else "?" | |
| uri = f"{uri}{separator}sslmode=require" | |
| max_retries = 3 | |
| for attempt in range(1, max_retries + 1): | |
| try: | |
| client = await asyncpg.create_pool( | |
| uri, min_size=1, max_size=10, | |
| command_timeout=30, timeout=30, | |
| ) | |
| db = PostgresDocumentDatabase(client) | |
| await db.initialize() | |
| logger.info("Connected to PostgreSQL") | |
| return | |
| except Exception as e: | |
| logger.warning(f"DB connection attempt {attempt}/{max_retries} failed: {e}") | |
| if attempt < max_retries: | |
| await asyncio.sleep(2 * attempt) | |
| else: | |
| logger.error( | |
| "\n" + "=" * 60 + | |
| "\n DATABASE CONNECTION FAILED after %d attempts!" | |
| "\n All API endpoints requiring the database will return 503." | |
| "\n Check your POSTGRES_URI environment variable." | |
| "\n" + "=" * 60, | |
| max_retries, | |
| ) | |
| # Don't crash the app — let it start for health checks | |
| return | |
| async def close_db(): | |
| global client | |
| if client: | |
| await client.close() | |
| client = None | |
| print("PostgreSQL connection closed") | |
| def get_db(): | |
| if db is None: | |
| logger.error("get_db() called but database is not connected!") | |
| from fastapi import HTTPException | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Database is not connected. Check server logs for POSTGRES_URI issues.", | |
| ) | |
| return db | |