Spaces:
Running
Running
| """Session-aware request tracking for fair resource sharing across Claude Code instances.""" | |
| from __future__ import annotations | |
| import asyncio | |
| import time | |
| from collections import defaultdict | |
| from dataclasses import dataclass | |
| from typing import ClassVar, cast | |
| from loguru import logger | |
| class SessionState: | |
| """State for a single session across all providers.""" | |
| requests_in_window: int = 0 | |
| last_request_time: float = 0.0 | |
| total_requests: int = 0 | |
| class ProviderLoad: | |
| """Load information for a single provider.""" | |
| provider_id: str | |
| active_requests: int | |
| session_count: int | |
| requests_per_minute: float | |
| is_healthy: bool # Not rate limited | |
| class SessionLoad: | |
| """Load information for a session across all providers.""" | |
| session_id: str | |
| total_requests: int | |
| providers: dict[str, int] # provider_id -> request count | |
| class SessionTracker: | |
| """ | |
| Track request rates per session and per provider for fair resource sharing. | |
| This enables multiple Claude Code instances to share the proxy efficiently | |
| without one session starving others. | |
| """ | |
| _instance: ClassVar[SessionTracker | None] = None | |
| def __init__( | |
| self, | |
| *, | |
| max_sessions: int = 50, | |
| window_seconds: float = 60.0, | |
| per_session_rate_limit: int = 30, | |
| retention_seconds: float | None = None, | |
| ): | |
| if hasattr(self, "_initialized"): | |
| return | |
| self._sessions: dict[str, SessionState] = {} | |
| self._session_requests: dict[str, dict[str, int]] = defaultdict( | |
| lambda: defaultdict(int) | |
| ) | |
| self._provider_active: dict[str, int] = defaultdict(int) | |
| self._max_sessions = max_sessions | |
| self._window_seconds = window_seconds | |
| self._per_session_rate_limit = per_session_rate_limit | |
| self._retention_seconds = ( | |
| retention_seconds if retention_seconds is not None else window_seconds * 2 | |
| ) | |
| self._lock = asyncio.Lock() | |
| self._initialized = True | |
| logger.info( | |
| "SessionTracker initialized (max_sessions={}, window={}s, per_session_limit={}/min, retention={}s)", | |
| max_sessions, | |
| window_seconds, | |
| per_session_rate_limit, | |
| self._retention_seconds, | |
| ) | |
| def get_instance(cls, **kwargs) -> SessionTracker: | |
| """Get or create the singleton instance.""" | |
| if cls._instance is None: | |
| cls._instance = cls(**kwargs) | |
| return cls._instance | |
| def reset_instance(cls) -> None: | |
| """Reset singleton (for testing).""" | |
| cls._instance = None | |
| def _cleanup_old_sessions(self) -> None: | |
| """Remove sessions with no recent activity (must be called with lock held).""" | |
| now = time.monotonic() | |
| cutoff = now - self._retention_seconds | |
| to_remove = [ | |
| sid | |
| for sid, state in self._sessions.items() | |
| if state.last_request_time < cutoff | |
| ] | |
| for sid in to_remove: | |
| del self._sessions[sid] | |
| if sid in self._session_requests: | |
| del self._session_requests[sid] | |
| if to_remove: | |
| logger.debug( | |
| "SessionTracker: cleaned up {} stale sessions ({} remaining)", | |
| len(to_remove), | |
| len(self._sessions), | |
| ) | |
| async def start_cleanup_loop(self, interval: float = 60.0) -> None: | |
| """Background task: periodically clean up stale sessions.""" | |
| while True: | |
| await asyncio.sleep(interval) | |
| async with self._lock: | |
| self._cleanup_old_sessions() | |
| def _evict_lru_session(self) -> None: | |
| """Evict least recently used session when at capacity.""" | |
| if not self._sessions: | |
| return | |
| lru_sid = min(self._sessions.items(), key=lambda x: x[1].last_request_time)[0] | |
| del self._sessions[lru_sid] | |
| if lru_sid in self._session_requests: | |
| del self._session_requests[lru_sid] | |
| logger.warning("SessionTracker: Evicted LRU session '{}'", lru_sid) | |
| async def track_request(self, session_id: str, provider_id: str) -> None: | |
| """Record a request for a session to a provider (async-safe).""" | |
| self.track_request_sync(session_id, provider_id) | |
| def track_request_sync(self, session_id: str, provider_id: str) -> None: | |
| """Record a request for a session to a provider (sync version for hot path).""" | |
| # Hot path - no cleanup on every call, just update state | |
| # Cleanup runs periodically in background, not on every request | |
| if session_id not in self._sessions: | |
| if len(self._sessions) >= self._max_sessions: | |
| self._evict_lru_session() | |
| self._sessions[session_id] = SessionState() | |
| state = self._sessions[session_id] | |
| state.requests_in_window += 1 | |
| state.last_request_time = time.monotonic() | |
| state.total_requests += 1 | |
| self._session_requests[session_id][provider_id] += 1 | |
| self._provider_active[provider_id] += 1 | |
| async def track_request_async(self, session_id: str, provider_id: str) -> None: | |
| """Async version with lock for when called from async contexts that need guarantees.""" | |
| async with self._lock: | |
| self.track_request_sync(session_id, provider_id) | |
| async def release_request(self, session_id: str, provider_id: str) -> None: | |
| """Release a request slot when streaming completes.""" | |
| async with self._lock: | |
| self._provider_active[provider_id] = max( | |
| 0, self._provider_active[provider_id] - 1 | |
| ) | |
| def get_provider_load( | |
| self, provider_id: str, blocked: bool = False | |
| ) -> ProviderLoad: | |
| """Get current load information for a provider.""" | |
| session_count = sum( | |
| 1 | |
| for sid in self._sessions | |
| if self._session_requests[sid].get(provider_id, 0) > 0 | |
| ) | |
| total_requests = sum( | |
| self._session_requests[sid].get(provider_id, 0) for sid in self._sessions | |
| ) | |
| return ProviderLoad( | |
| provider_id=provider_id, | |
| active_requests=self._provider_active.get(provider_id, 0), | |
| session_count=session_count, | |
| requests_per_minute=total_requests, | |
| is_healthy=not blocked, | |
| ) | |
| def get_all_provider_loads( | |
| self, blocked_providers: set[str] | None = None | |
| ) -> dict[str, ProviderLoad]: | |
| """Get load information for all providers.""" | |
| blocked = blocked_providers or set() | |
| all_providers = set(self._provider_active.keys()) | |
| # Add providers from sessions even if not currently active | |
| for sid in self._session_requests: | |
| for provider_id in self._session_requests[sid]: | |
| all_providers.add(provider_id) | |
| return { | |
| pid: self.get_provider_load(pid, pid in blocked) for pid in all_providers | |
| } | |
| def get_session_load(self, session_id: str) -> SessionLoad | None: | |
| """Get load information for a specific session.""" | |
| if session_id not in self._sessions: | |
| return None | |
| state = self._sessions[session_id] | |
| provider_counts = dict(self._session_requests[session_id]) | |
| return SessionLoad( | |
| session_id=session_id, | |
| total_requests=state.total_requests, | |
| providers=provider_counts, | |
| ) | |
| def get_all_session_loads(self) -> dict[str, SessionLoad]: | |
| """Get load information for all active sessions.""" | |
| return { | |
| sid: cast(SessionLoad, self.get_session_load(sid)) | |
| for sid in self._sessions | |
| if self.get_session_load(sid) is not None | |
| } | |
| async def check_session_allowed(self, session_id: str) -> tuple[bool, str]: | |
| """ | |
| Check if a session is within its rate limit. | |
| Returns (allowed, reason) tuple. | |
| """ | |
| async with self._lock: | |
| if session_id not in self._sessions: | |
| return True, "new session" | |
| state = self._sessions[session_id] | |
| if state.requests_in_window > self._per_session_rate_limit: | |
| return ( | |
| False, | |
| f"rate limit exceeded ({state.requests_in_window}/{self._per_session_rate_limit}/min)", | |
| ) | |
| return True, "ok" | |
| def get_healthy_provider_priority( | |
| self, | |
| candidates: list[str], | |
| blocked_providers: set[str] | None = None, | |
| ) -> list[str]: | |
| """ | |
| Return candidates sorted by health/load priority. | |
| Healthy providers with lower load come first. | |
| """ | |
| blocked = blocked_providers or set() | |
| return sorted( | |
| candidates, | |
| key=lambda pid: ( | |
| pid in blocked, # Blocked providers go last | |
| self._provider_active.get(pid, 0), # Lower load first | |
| ), | |
| ) | |
| def stats(self) -> dict: | |
| """Return current statistics.""" | |
| return { | |
| "active_sessions": len(self._sessions), | |
| "total_providers": len(self._provider_active), | |
| "provider_active": dict(self._provider_active), | |
| } | |