claude-code-proxy / core /session_tracker.py
Yash030's picture
Extend session visibility in admin dashboard with configurable retention
6339a53
"""Session-aware request tracking for fair resource sharing across Claude Code instances."""
from __future__ import annotations
import asyncio
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import ClassVar, cast
from loguru import logger
@dataclass(slots=True)
class SessionState:
"""State for a single session across all providers."""
requests_in_window: int = 0
last_request_time: float = 0.0
total_requests: int = 0
@dataclass(frozen=True, slots=True)
class ProviderLoad:
"""Load information for a single provider."""
provider_id: str
active_requests: int
session_count: int
requests_per_minute: float
is_healthy: bool # Not rate limited
@dataclass(frozen=True, slots=True)
class SessionLoad:
"""Load information for a session across all providers."""
session_id: str
total_requests: int
providers: dict[str, int] # provider_id -> request count
class SessionTracker:
"""
Track request rates per session and per provider for fair resource sharing.
This enables multiple Claude Code instances to share the proxy efficiently
without one session starving others.
"""
_instance: ClassVar[SessionTracker | None] = None
def __init__(
self,
*,
max_sessions: int = 50,
window_seconds: float = 60.0,
per_session_rate_limit: int = 30,
retention_seconds: float | None = None,
):
if hasattr(self, "_initialized"):
return
self._sessions: dict[str, SessionState] = {}
self._session_requests: dict[str, dict[str, int]] = defaultdict(
lambda: defaultdict(int)
)
self._provider_active: dict[str, int] = defaultdict(int)
self._max_sessions = max_sessions
self._window_seconds = window_seconds
self._per_session_rate_limit = per_session_rate_limit
self._retention_seconds = (
retention_seconds if retention_seconds is not None else window_seconds * 2
)
self._lock = asyncio.Lock()
self._initialized = True
logger.info(
"SessionTracker initialized (max_sessions={}, window={}s, per_session_limit={}/min, retention={}s)",
max_sessions,
window_seconds,
per_session_rate_limit,
self._retention_seconds,
)
@classmethod
def get_instance(cls, **kwargs) -> SessionTracker:
"""Get or create the singleton instance."""
if cls._instance is None:
cls._instance = cls(**kwargs)
return cls._instance
@classmethod
def reset_instance(cls) -> None:
"""Reset singleton (for testing)."""
cls._instance = None
def _cleanup_old_sessions(self) -> None:
"""Remove sessions with no recent activity (must be called with lock held)."""
now = time.monotonic()
cutoff = now - self._retention_seconds
to_remove = [
sid
for sid, state in self._sessions.items()
if state.last_request_time < cutoff
]
for sid in to_remove:
del self._sessions[sid]
if sid in self._session_requests:
del self._session_requests[sid]
if to_remove:
logger.debug(
"SessionTracker: cleaned up {} stale sessions ({} remaining)",
len(to_remove),
len(self._sessions),
)
async def start_cleanup_loop(self, interval: float = 60.0) -> None:
"""Background task: periodically clean up stale sessions."""
while True:
await asyncio.sleep(interval)
async with self._lock:
self._cleanup_old_sessions()
def _evict_lru_session(self) -> None:
"""Evict least recently used session when at capacity."""
if not self._sessions:
return
lru_sid = min(self._sessions.items(), key=lambda x: x[1].last_request_time)[0]
del self._sessions[lru_sid]
if lru_sid in self._session_requests:
del self._session_requests[lru_sid]
logger.warning("SessionTracker: Evicted LRU session '{}'", lru_sid)
async def track_request(self, session_id: str, provider_id: str) -> None:
"""Record a request for a session to a provider (async-safe)."""
self.track_request_sync(session_id, provider_id)
def track_request_sync(self, session_id: str, provider_id: str) -> None:
"""Record a request for a session to a provider (sync version for hot path)."""
# Hot path - no cleanup on every call, just update state
# Cleanup runs periodically in background, not on every request
if session_id not in self._sessions:
if len(self._sessions) >= self._max_sessions:
self._evict_lru_session()
self._sessions[session_id] = SessionState()
state = self._sessions[session_id]
state.requests_in_window += 1
state.last_request_time = time.monotonic()
state.total_requests += 1
self._session_requests[session_id][provider_id] += 1
self._provider_active[provider_id] += 1
async def track_request_async(self, session_id: str, provider_id: str) -> None:
"""Async version with lock for when called from async contexts that need guarantees."""
async with self._lock:
self.track_request_sync(session_id, provider_id)
async def release_request(self, session_id: str, provider_id: str) -> None:
"""Release a request slot when streaming completes."""
async with self._lock:
self._provider_active[provider_id] = max(
0, self._provider_active[provider_id] - 1
)
def get_provider_load(
self, provider_id: str, blocked: bool = False
) -> ProviderLoad:
"""Get current load information for a provider."""
session_count = sum(
1
for sid in self._sessions
if self._session_requests[sid].get(provider_id, 0) > 0
)
total_requests = sum(
self._session_requests[sid].get(provider_id, 0) for sid in self._sessions
)
return ProviderLoad(
provider_id=provider_id,
active_requests=self._provider_active.get(provider_id, 0),
session_count=session_count,
requests_per_minute=total_requests,
is_healthy=not blocked,
)
def get_all_provider_loads(
self, blocked_providers: set[str] | None = None
) -> dict[str, ProviderLoad]:
"""Get load information for all providers."""
blocked = blocked_providers or set()
all_providers = set(self._provider_active.keys())
# Add providers from sessions even if not currently active
for sid in self._session_requests:
for provider_id in self._session_requests[sid]:
all_providers.add(provider_id)
return {
pid: self.get_provider_load(pid, pid in blocked) for pid in all_providers
}
def get_session_load(self, session_id: str) -> SessionLoad | None:
"""Get load information for a specific session."""
if session_id not in self._sessions:
return None
state = self._sessions[session_id]
provider_counts = dict(self._session_requests[session_id])
return SessionLoad(
session_id=session_id,
total_requests=state.total_requests,
providers=provider_counts,
)
def get_all_session_loads(self) -> dict[str, SessionLoad]:
"""Get load information for all active sessions."""
return {
sid: cast(SessionLoad, self.get_session_load(sid))
for sid in self._sessions
if self.get_session_load(sid) is not None
}
async def check_session_allowed(self, session_id: str) -> tuple[bool, str]:
"""
Check if a session is within its rate limit.
Returns (allowed, reason) tuple.
"""
async with self._lock:
if session_id not in self._sessions:
return True, "new session"
state = self._sessions[session_id]
if state.requests_in_window > self._per_session_rate_limit:
return (
False,
f"rate limit exceeded ({state.requests_in_window}/{self._per_session_rate_limit}/min)",
)
return True, "ok"
def get_healthy_provider_priority(
self,
candidates: list[str],
blocked_providers: set[str] | None = None,
) -> list[str]:
"""
Return candidates sorted by health/load priority.
Healthy providers with lower load come first.
"""
blocked = blocked_providers or set()
return sorted(
candidates,
key=lambda pid: (
pid in blocked, # Blocked providers go last
self._provider_active.get(pid, 0), # Lower load first
),
)
def stats(self) -> dict:
"""Return current statistics."""
return {
"active_sessions": len(self._sessions),
"total_providers": len(self._provider_active),
"provider_active": dict(self._provider_active),
}