| """Optional durable session persistence for the hosted backend. |
| |
| The public CLI must keep working without MongoDB. This module therefore |
| exposes one small async store interface and returns a no-op implementation |
| unless ``MONGODB_URI`` is configured and reachable. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| import os |
| from datetime import UTC, datetime |
| from typing import Any |
|
|
| from bson import BSON |
| from pymongo import AsyncMongoClient, DeleteMany, ReturnDocument, UpdateOne |
| from pymongo.errors import DuplicateKeyError, InvalidDocument, PyMongoError |
|
|
| logger = logging.getLogger(__name__) |
|
|
| SCHEMA_VERSION = 1 |
| MAX_BSON_BYTES = 15 * 1024 * 1024 |
|
|
|
|
| def _now() -> datetime: |
| return datetime.now(UTC) |
|
|
|
|
| def _doc_id(session_id: str, idx: int) -> str: |
| return f"{session_id}:{idx}" |
|
|
|
|
| def _safe_message_doc(message: dict[str, Any]) -> dict[str, Any]: |
| """Return a Mongo-safe message document payload. |
| |
| Mongo's hard document limit is 16 MB. We stay below that and store an |
| explicit marker rather than failing the whole snapshot for one huge tool log. |
| """ |
| try: |
| if len(BSON.encode({"message": message})) <= MAX_BSON_BYTES: |
| return message |
| except (InvalidDocument, OverflowError): |
| pass |
| return { |
| "role": "tool", |
| "content": ( |
| "[SYSTEM: A single persisted message exceeded MongoDB's document " |
| "size/encoding limit and was replaced by this marker.]" |
| ), |
| "ml_intern_persistence_error": "message_too_large_or_invalid", |
| } |
|
|
|
|
| class NoopSessionStore: |
| """Async no-op store used when Mongo is not configured.""" |
|
|
| enabled = False |
|
|
| async def init(self) -> None: |
| return None |
|
|
| async def close(self) -> None: |
| return None |
|
|
| async def upsert_session(self, **_: Any) -> None: |
| return None |
|
|
| async def save_snapshot(self, **_: Any) -> None: |
| return None |
|
|
| async def load_session(self, *_: Any, **__: Any) -> dict[str, Any] | None: |
| return None |
|
|
| async def list_sessions(self, *_: Any, **__: Any) -> list[dict[str, Any]]: |
| return [] |
|
|
| async def soft_delete_session(self, *_: Any, **__: Any) -> None: |
| return None |
|
|
| async def update_session_fields(self, *_: Any, **__: Any) -> None: |
| return None |
|
|
| async def append_event(self, *_: Any, **__: Any) -> int | None: |
| return None |
|
|
| async def load_events_after(self, *_: Any, **__: Any) -> list[dict[str, Any]]: |
| return [] |
|
|
| async def append_trace_message(self, *_: Any, **__: Any) -> int | None: |
| return None |
|
|
| async def get_quota(self, *_: Any, **__: Any) -> int | None: |
| return None |
|
|
| async def try_increment_quota(self, *_: Any, **__: Any) -> int | None: |
| return None |
|
|
| async def refund_quota(self, *_: Any, **__: Any) -> None: |
| return None |
|
|
| async def mark_pro_seen(self, *_: Any, **__: Any) -> dict[str, Any] | None: |
| return None |
|
|
|
|
| class MongoSessionStore(NoopSessionStore): |
| """MongoDB-backed session store.""" |
|
|
| enabled = True |
|
|
| def __init__(self, uri: str, db_name: str) -> None: |
| self.uri = uri |
| self.db_name = db_name |
| self.enabled = False |
| self.client: AsyncMongoClient | None = None |
| self.db = None |
|
|
| async def init(self) -> None: |
| try: |
| self.client = AsyncMongoClient(self.uri, serverSelectionTimeoutMS=3000) |
| self.db = self.client[self.db_name] |
| await self.client.admin.command("ping") |
| await self._create_indexes() |
| self.enabled = True |
| logger.info("Mongo session persistence enabled (db=%s)", self.db_name) |
| except Exception as e: |
| logger.warning("Mongo session persistence disabled: %s", e) |
| self.enabled = False |
| if self.client is not None: |
| await self.client.close() |
| self.client = None |
| self.db = None |
|
|
| async def close(self) -> None: |
| if self.client is not None: |
| await self.client.close() |
| self.client = None |
| self.db = None |
|
|
| async def _create_indexes(self) -> None: |
| if self.db is None: |
| return |
| await self.db.sessions.create_index( |
| [("user_id", 1), ("visibility", 1), ("updated_at", -1)] |
| ) |
| await self.db.sessions.create_index( |
| [("visibility", 1), ("status", 1), ("last_active_at", -1)] |
| ) |
| await self.db.session_messages.create_index( |
| [("session_id", 1), ("idx", 1)], unique=True |
| ) |
| await self.db.session_events.create_index( |
| [("session_id", 1), ("seq", 1)], unique=True |
| ) |
| await self.db.session_trace_messages.create_index( |
| [("session_id", 1), ("seq", 1)], unique=True |
| ) |
| await self.db.session_trace_messages.create_index([("created_at", -1)]) |
| await self.db.pro_users.create_index([("first_seen_pro_at", -1)]) |
|
|
| def _ready(self) -> bool: |
| return bool(self.enabled and self.db is not None) |
|
|
| async def upsert_session( |
| self, |
| *, |
| session_id: str, |
| user_id: str, |
| model: str, |
| title: str | None = None, |
| surface: str = "frontend", |
| created_at: datetime | None = None, |
| runtime_state: str = "idle", |
| status: str = "active", |
| message_count: int = 0, |
| turn_count: int = 0, |
| pending_approval: list[dict[str, Any]] | None = None, |
| claude_counted: bool = False, |
| notification_destinations: list[str] | None = None, |
| auto_approval_enabled: bool = False, |
| auto_approval_cost_cap_usd: float | None = None, |
| auto_approval_estimated_spend_usd: float = 0.0, |
| ) -> None: |
| if not self._ready(): |
| return |
| now = _now() |
| await self.db.sessions.update_one( |
| {"_id": session_id}, |
| { |
| "$setOnInsert": { |
| "_id": session_id, |
| "session_id": session_id, |
| "user_id": user_id, |
| "surface": surface, |
| "created_at": created_at or now, |
| "schema_version": SCHEMA_VERSION, |
| "visibility": "live", |
| }, |
| "$set": { |
| "title": title, |
| "model": model, |
| "status": status, |
| "runtime_state": runtime_state, |
| "updated_at": now, |
| "last_active_at": now, |
| "message_count": message_count, |
| "turn_count": turn_count, |
| "pending_approval": pending_approval or [], |
| "claude_counted": claude_counted, |
| "notification_destinations": notification_destinations or [], |
| "auto_approval_enabled": auto_approval_enabled, |
| "auto_approval_cost_cap_usd": auto_approval_cost_cap_usd, |
| "auto_approval_estimated_spend_usd": auto_approval_estimated_spend_usd, |
| }, |
| }, |
| upsert=True, |
| ) |
|
|
| async def save_snapshot( |
| self, |
| *, |
| session_id: str, |
| user_id: str, |
| model: str, |
| messages: list[dict[str, Any]], |
| title: str | None = None, |
| runtime_state: str = "idle", |
| status: str = "active", |
| turn_count: int = 0, |
| pending_approval: list[dict[str, Any]] | None = None, |
| claude_counted: bool = False, |
| created_at: datetime | None = None, |
| notification_destinations: list[str] | None = None, |
| auto_approval_enabled: bool = False, |
| auto_approval_cost_cap_usd: float | None = None, |
| auto_approval_estimated_spend_usd: float = 0.0, |
| ) -> None: |
| if not self._ready(): |
| return |
| now = _now() |
| await self.upsert_session( |
| session_id=session_id, |
| user_id=user_id, |
| model=model, |
| title=title, |
| created_at=created_at, |
| runtime_state=runtime_state, |
| status=status, |
| message_count=len(messages), |
| turn_count=turn_count, |
| pending_approval=pending_approval, |
| claude_counted=claude_counted, |
| notification_destinations=notification_destinations, |
| auto_approval_enabled=auto_approval_enabled, |
| auto_approval_cost_cap_usd=auto_approval_cost_cap_usd, |
| auto_approval_estimated_spend_usd=auto_approval_estimated_spend_usd, |
| ) |
| ops: list[Any] = [] |
| for idx, raw in enumerate(messages): |
| ops.append( |
| UpdateOne( |
| {"_id": _doc_id(session_id, idx)}, |
| { |
| "$set": { |
| "session_id": session_id, |
| "idx": idx, |
| "message": _safe_message_doc(raw), |
| "updated_at": now, |
| }, |
| "$setOnInsert": {"created_at": now}, |
| }, |
| upsert=True, |
| ) |
| ) |
| ops.append( |
| DeleteMany({"session_id": session_id, "idx": {"$gte": len(messages)}}) |
| ) |
| try: |
| if ops: |
| await self.db.session_messages.bulk_write(ops, ordered=False) |
| except PyMongoError as e: |
| logger.warning("Failed to persist session %s snapshot: %s", session_id, e) |
|
|
| async def load_session( |
| self, session_id: str, *, include_deleted: bool = False |
| ) -> dict[str, Any] | None: |
| if not self._ready(): |
| return None |
| meta = await self.db.sessions.find_one({"_id": session_id}) |
| if not meta: |
| return None |
| if meta.get("visibility") == "deleted" and not include_deleted: |
| return None |
| cursor = self.db.session_messages.find({"session_id": session_id}).sort( |
| "idx", 1 |
| ) |
| messages = [row.get("message") async for row in cursor] |
| return {"metadata": meta, "messages": messages} |
|
|
| async def list_sessions( |
| self, user_id: str, *, include_deleted: bool = False |
| ) -> list[dict[str, Any]]: |
| if not self._ready(): |
| return [] |
| query: dict[str, Any] = {"user_id": user_id} |
| if user_id == "dev": |
| query = {} |
| if not include_deleted: |
| query["visibility"] = {"$ne": "deleted"} |
| cursor = self.db.sessions.find(query).sort("updated_at", -1) |
| return [row async for row in cursor] |
|
|
| async def soft_delete_session(self, session_id: str) -> None: |
| if not self._ready(): |
| return |
| await self.db.sessions.update_one( |
| {"_id": session_id}, |
| { |
| "$set": { |
| "visibility": "deleted", |
| "runtime_state": "idle", |
| "updated_at": _now(), |
| } |
| }, |
| ) |
|
|
| async def update_session_fields(self, session_id: str, **fields: Any) -> None: |
| if not self._ready() or not fields: |
| return |
| fields["updated_at"] = _now() |
| await self.db.sessions.update_one({"_id": session_id}, {"$set": fields}) |
|
|
| async def _next_seq(self, counter_id: str) -> int: |
| doc = await self.db.counters.find_one_and_update( |
| {"_id": counter_id}, |
| {"$inc": {"seq": 1}}, |
| upsert=True, |
| return_document=ReturnDocument.AFTER, |
| ) |
| return int(doc["seq"]) |
|
|
| async def append_event( |
| self, session_id: str, event_type: str, data: dict[str, Any] | None |
| ) -> int | None: |
| if not self._ready(): |
| return None |
| try: |
| seq = await self._next_seq(f"event:{session_id}") |
| await self.db.session_events.insert_one( |
| { |
| "_id": _doc_id(session_id, seq), |
| "session_id": session_id, |
| "seq": seq, |
| "event_type": event_type, |
| "data": data or {}, |
| "created_at": _now(), |
| } |
| ) |
| return seq |
| except PyMongoError as e: |
| logger.debug("Failed to append event for %s: %s", session_id, e) |
| return None |
|
|
| async def load_events_after( |
| self, session_id: str, after_seq: int = 0 |
| ) -> list[dict[str, Any]]: |
| if not self._ready(): |
| return [] |
| cursor = self.db.session_events.find( |
| {"session_id": session_id, "seq": {"$gt": int(after_seq or 0)}} |
| ).sort("seq", 1) |
| return [row async for row in cursor] |
|
|
| async def append_trace_message( |
| self, session_id: str, message: dict[str, Any], source: str = "message" |
| ) -> int | None: |
| if not self._ready(): |
| return None |
| try: |
| seq = await self._next_seq(f"trace:{session_id}") |
| await self.db.session_trace_messages.insert_one( |
| { |
| "_id": _doc_id(session_id, seq), |
| "session_id": session_id, |
| "seq": seq, |
| "role": message.get("role"), |
| "message": _safe_message_doc(message), |
| "source": source, |
| "created_at": _now(), |
| } |
| ) |
| return seq |
| except PyMongoError as e: |
| logger.debug("Failed to append trace message for %s: %s", session_id, e) |
| return None |
|
|
| async def get_quota(self, user_id: str, day: str) -> int | None: |
| if not self._ready(): |
| return None |
| doc = await self.db.claude_quotas.find_one({"_id": f"{user_id}:{day}"}) |
| return int(doc.get("count", 0)) if doc else 0 |
|
|
| async def try_increment_quota(self, user_id: str, day: str, cap: int) -> int | None: |
| if not self._ready(): |
| return None |
| key = f"{user_id}:{day}" |
| now = _now() |
| try: |
| await self.db.claude_quotas.insert_one( |
| { |
| "_id": key, |
| "user_id": user_id, |
| "day": day, |
| "count": 1, |
| "updated_at": now, |
| } |
| ) |
| return 1 |
| except DuplicateKeyError: |
| pass |
| doc = await self.db.claude_quotas.find_one_and_update( |
| {"_id": key, "count": {"$lt": cap}}, |
| {"$inc": {"count": 1}, "$set": {"updated_at": now}}, |
| return_document=ReturnDocument.AFTER, |
| ) |
| return int(doc["count"]) if doc else None |
|
|
| async def refund_quota(self, user_id: str, day: str) -> None: |
| if not self._ready(): |
| return |
| await self.db.claude_quotas.update_one( |
| {"_id": f"{user_id}:{day}", "count": {"$gt": 0}}, |
| {"$inc": {"count": -1}, "$set": {"updated_at": _now()}}, |
| ) |
|
|
| async def mark_pro_seen( |
| self, user_id: str, *, is_pro: bool |
| ) -> dict[str, Any] | None: |
| """Track per-user Pro state and detect free→Pro conversions. |
| |
| Returns ``{"converted": True, "first_seen_at": ..."}`` exactly once |
| per user — the first time we see them as Pro after having recorded |
| them as non-Pro at least once. Otherwise returns ``None``. |
| |
| Storing ``ever_non_pro`` lets us distinguish "user joined as Pro" |
| (no conversion) from "user upgraded" (conversion). The atomic |
| ``find_one_and_update`` on a guarded filter makes the conversion |
| emit at-most-once even under concurrent requests. |
| """ |
| if not self._ready() or not user_id: |
| return None |
| now = _now() |
| set_fields: dict[str, Any] = {"last_seen_at": now, "is_pro": bool(is_pro)} |
| if not is_pro: |
| set_fields["ever_non_pro"] = True |
| try: |
| await self.db.pro_users.update_one( |
| {"_id": user_id}, |
| { |
| "$setOnInsert": {"_id": user_id, "first_seen_at": now}, |
| "$set": set_fields, |
| }, |
| upsert=True, |
| ) |
| except PyMongoError as e: |
| logger.debug("mark_pro_seen upsert failed for %s: %s", user_id, e) |
| return None |
|
|
| if not is_pro: |
| return None |
|
|
| try: |
| doc = await self.db.pro_users.find_one_and_update( |
| { |
| "_id": user_id, |
| "ever_non_pro": True, |
| "first_seen_pro_at": {"$exists": False}, |
| }, |
| {"$set": {"first_seen_pro_at": now}}, |
| return_document=ReturnDocument.AFTER, |
| ) |
| except PyMongoError as e: |
| logger.debug("mark_pro_seen conversion check failed for %s: %s", user_id, e) |
| return None |
|
|
| if not doc: |
| return None |
| return { |
| "converted": True, |
| "first_seen_at": (doc.get("first_seen_at") or now).isoformat(), |
| } |
|
|
|
|
| _store: NoopSessionStore | MongoSessionStore | None = None |
|
|
|
|
| def get_session_store() -> NoopSessionStore | MongoSessionStore: |
| global _store |
| if _store is None: |
| uri = os.environ.get("MONGODB_URI") |
| db_name = os.environ.get("MONGODB_DB", "ml-intern") |
| _store = MongoSessionStore(uri, db_name) if uri else NoopSessionStore() |
| return _store |
|
|
|
|
| def _reset_store_for_tests( |
| store: NoopSessionStore | MongoSessionStore | None = None, |
| ) -> None: |
| global _store |
| _store = store |
|
|