"""Session-aware request tracking for fair resource sharing across Claude Code instances.""" from __future__ import annotations import asyncio import time from collections import defaultdict from dataclasses import dataclass from typing import ClassVar, cast from loguru import logger @dataclass(slots=True) class SessionState: """State for a single session across all providers.""" requests_in_window: int = 0 last_request_time: float = 0.0 total_requests: int = 0 @dataclass(frozen=True, slots=True) class ProviderLoad: """Load information for a single provider.""" provider_id: str active_requests: int session_count: int requests_per_minute: float is_healthy: bool # Not rate limited @dataclass(frozen=True, slots=True) class SessionLoad: """Load information for a session across all providers.""" session_id: str total_requests: int providers: dict[str, int] # provider_id -> request count class SessionTracker: """ Track request rates per session and per provider for fair resource sharing. This enables multiple Claude Code instances to share the proxy efficiently without one session starving others. """ _instance: ClassVar[SessionTracker | None] = None def __init__( self, *, max_sessions: int = 50, window_seconds: float = 60.0, per_session_rate_limit: int = 30, retention_seconds: float | None = None, ): if hasattr(self, "_initialized"): return self._sessions: dict[str, SessionState] = {} self._session_requests: dict[str, dict[str, int]] = defaultdict( lambda: defaultdict(int) ) self._provider_active: dict[str, int] = defaultdict(int) self._max_sessions = max_sessions self._window_seconds = window_seconds self._per_session_rate_limit = per_session_rate_limit self._retention_seconds = ( retention_seconds if retention_seconds is not None else window_seconds * 2 ) self._lock = asyncio.Lock() self._initialized = True logger.info( "SessionTracker initialized (max_sessions={}, window={}s, per_session_limit={}/min, retention={}s)", max_sessions, window_seconds, per_session_rate_limit, self._retention_seconds, ) @classmethod def get_instance(cls, **kwargs) -> SessionTracker: """Get or create the singleton instance.""" if cls._instance is None: cls._instance = cls(**kwargs) return cls._instance @classmethod def reset_instance(cls) -> None: """Reset singleton (for testing).""" cls._instance = None def _cleanup_old_sessions(self) -> None: """Remove sessions with no recent activity (must be called with lock held).""" now = time.monotonic() cutoff = now - self._retention_seconds to_remove = [ sid for sid, state in self._sessions.items() if state.last_request_time < cutoff ] for sid in to_remove: del self._sessions[sid] if sid in self._session_requests: del self._session_requests[sid] if to_remove: logger.debug( "SessionTracker: cleaned up {} stale sessions ({} remaining)", len(to_remove), len(self._sessions), ) async def start_cleanup_loop(self, interval: float = 60.0) -> None: """Background task: periodically clean up stale sessions.""" while True: await asyncio.sleep(interval) async with self._lock: self._cleanup_old_sessions() def _evict_lru_session(self) -> None: """Evict least recently used session when at capacity.""" if not self._sessions: return lru_sid = min(self._sessions.items(), key=lambda x: x[1].last_request_time)[0] del self._sessions[lru_sid] if lru_sid in self._session_requests: del self._session_requests[lru_sid] logger.warning("SessionTracker: Evicted LRU session '{}'", lru_sid) async def track_request(self, session_id: str, provider_id: str) -> None: """Record a request for a session to a provider (async-safe).""" self.track_request_sync(session_id, provider_id) def track_request_sync(self, session_id: str, provider_id: str) -> None: """Record a request for a session to a provider (sync version for hot path).""" # Hot path - no cleanup on every call, just update state # Cleanup runs periodically in background, not on every request if session_id not in self._sessions: if len(self._sessions) >= self._max_sessions: self._evict_lru_session() self._sessions[session_id] = SessionState() state = self._sessions[session_id] state.requests_in_window += 1 state.last_request_time = time.monotonic() state.total_requests += 1 self._session_requests[session_id][provider_id] += 1 self._provider_active[provider_id] += 1 async def track_request_async(self, session_id: str, provider_id: str) -> None: """Async version with lock for when called from async contexts that need guarantees.""" async with self._lock: self.track_request_sync(session_id, provider_id) async def release_request(self, session_id: str, provider_id: str) -> None: """Release a request slot when streaming completes.""" async with self._lock: self._provider_active[provider_id] = max( 0, self._provider_active[provider_id] - 1 ) def get_provider_load( self, provider_id: str, blocked: bool = False ) -> ProviderLoad: """Get current load information for a provider.""" session_count = sum( 1 for sid in self._sessions if self._session_requests[sid].get(provider_id, 0) > 0 ) total_requests = sum( self._session_requests[sid].get(provider_id, 0) for sid in self._sessions ) return ProviderLoad( provider_id=provider_id, active_requests=self._provider_active.get(provider_id, 0), session_count=session_count, requests_per_minute=total_requests, is_healthy=not blocked, ) def get_all_provider_loads( self, blocked_providers: set[str] | None = None ) -> dict[str, ProviderLoad]: """Get load information for all providers.""" blocked = blocked_providers or set() all_providers = set(self._provider_active.keys()) # Add providers from sessions even if not currently active for sid in self._session_requests: for provider_id in self._session_requests[sid]: all_providers.add(provider_id) return { pid: self.get_provider_load(pid, pid in blocked) for pid in all_providers } def get_session_load(self, session_id: str) -> SessionLoad | None: """Get load information for a specific session.""" if session_id not in self._sessions: return None state = self._sessions[session_id] provider_counts = dict(self._session_requests[session_id]) return SessionLoad( session_id=session_id, total_requests=state.total_requests, providers=provider_counts, ) def get_all_session_loads(self) -> dict[str, SessionLoad]: """Get load information for all active sessions.""" return { sid: cast(SessionLoad, self.get_session_load(sid)) for sid in self._sessions if self.get_session_load(sid) is not None } async def check_session_allowed(self, session_id: str) -> tuple[bool, str]: """ Check if a session is within its rate limit. Returns (allowed, reason) tuple. """ async with self._lock: if session_id not in self._sessions: return True, "new session" state = self._sessions[session_id] if state.requests_in_window > self._per_session_rate_limit: return ( False, f"rate limit exceeded ({state.requests_in_window}/{self._per_session_rate_limit}/min)", ) return True, "ok" def get_healthy_provider_priority( self, candidates: list[str], blocked_providers: set[str] | None = None, ) -> list[str]: """ Return candidates sorted by health/load priority. Healthy providers with lower load come first. """ blocked = blocked_providers or set() return sorted( candidates, key=lambda pid: ( pid in blocked, # Blocked providers go last self._provider_active.get(pid, 0), # Lower load first ), ) def stats(self) -> dict: """Return current statistics.""" return { "active_sessions": len(self._sessions), "total_providers": len(self._provider_active), "provider_active": dict(self._provider_active), }