| """Daily quota for premium model session creations. |
| |
| Tracks per-user premium model session starts against a daily cap derived from |
| the user's HF plan. MongoDB is the source of truth when configured; the |
| in-process dict remains the fallback for local/dev/test runs. |
| |
| The public names still say ``claude`` because this quota bucket originally |
| only covered Claude and the persisted session field uses that name. |
| |
| Unit: session *creations*, not messages. A user who sends with a premium model |
| in a new session consumes one quota point; switching an already-counted session |
| back to a premium model doesn't (`AgentSession.claude_counted` guards that). |
| |
| Cap tiers: |
| free user → CLAUDE_FREE_DAILY (1) |
| pro user → CLAUDE_PRO_DAILY (20) |
| """ |
|
|
| import asyncio |
| import os |
| from datetime import UTC, datetime |
|
|
| from agent.core.session_persistence import ( |
| NoopSessionStore, |
| get_session_store, |
| _reset_store_for_tests, |
| ) |
|
|
| CLAUDE_FREE_DAILY: int = int(os.environ.get("CLAUDE_FREE_DAILY", "1")) |
| CLAUDE_PRO_DAILY: int = int(os.environ.get("CLAUDE_PRO_DAILY", "20")) |
|
|
| |
| _claude_counts: dict[str, tuple[str, int]] = {} |
| _lock = asyncio.Lock() |
|
|
|
|
| def _today() -> str: |
| return datetime.now(UTC).date().isoformat() |
|
|
|
|
| def daily_cap_for(plan: str | None) -> int: |
| """Return the daily Claude-session cap for the given plan.""" |
| return CLAUDE_PRO_DAILY if plan == "pro" else CLAUDE_FREE_DAILY |
|
|
|
|
| async def get_claude_used_today(user_id: str) -> int: |
| """Return today's Claude session count for the user (0 if none / stale day).""" |
| store = get_session_store() |
| if getattr(store, "enabled", False): |
| db_count = await store.get_quota(user_id, _today()) |
| return db_count or 0 |
|
|
| async with _lock: |
| entry = _claude_counts.get(user_id) |
| if entry is None: |
| return 0 |
| day, count = entry |
| if day != _today(): |
| |
| _claude_counts.pop(user_id, None) |
| return 0 |
| return count |
|
|
|
|
| async def increment_claude(user_id: str) -> int: |
| """Bump today's Claude session count for the user. Returns the new value.""" |
| store = get_session_store() |
| if getattr(store, "enabled", False): |
| db_count = await store.try_increment_quota(user_id, _today(), cap=10**9) |
| return db_count or 0 |
|
|
| async with _lock: |
| today = _today() |
| day, count = _claude_counts.get(user_id, (today, 0)) |
| if day != today: |
| count = 0 |
| count += 1 |
| _claude_counts[user_id] = (today, count) |
| return count |
|
|
|
|
| async def try_increment_claude(user_id: str, cap: int) -> int | None: |
| """Atomically bump today's count if below *cap*. |
| |
| Returns the new count, or None when the user is already at the cap. |
| """ |
| store = get_session_store() |
| if getattr(store, "enabled", False): |
| return await store.try_increment_quota(user_id, _today(), cap) |
|
|
| async with _lock: |
| today = _today() |
| day, count = _claude_counts.get(user_id, (today, 0)) |
| if day != today: |
| count = 0 |
| if count >= cap: |
| return None |
| count += 1 |
| _claude_counts[user_id] = (today, count) |
| return count |
|
|
|
|
| async def refund_claude(user_id: str) -> None: |
| """Decrement today's count — used when session creation fails after a successful gate.""" |
| store = get_session_store() |
| if getattr(store, "enabled", False): |
| await store.refund_quota(user_id, _today()) |
| return |
|
|
| async with _lock: |
| entry = _claude_counts.get(user_id) |
| if entry is None: |
| return |
| day, count = entry |
| if day != _today(): |
| _claude_counts.pop(user_id, None) |
| return |
| new_count = max(0, count - 1) |
| if new_count == 0: |
| _claude_counts.pop(user_id, None) |
| else: |
| _claude_counts[user_id] = (day, new_count) |
|
|
|
|
| def _reset_for_tests() -> None: |
| """Test-only: clear the in-memory store.""" |
| _claude_counts.clear() |
| _reset_store_for_tests(NoopSessionStore()) |
|
|