Yash030 Claude Opus 4.7 commited on
Commit
cc3287d
·
1 Parent(s): d6a1875

Performance optimizations for proxy speed and shared sessions

Browse files

- Per-provider rate limiting: Zen gets unlimited, NVIDIA NIM gets 40/min
- Higher Zen concurrency: 4x max_concurrency for fast minimax model
- Connection pool tuning: Keepalive connections for faster reuse
- Session tracker: Fair resource sharing across Claude Code instances
- Smart auto-routing: Prioritize Zen (no limits), then by load

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

api/model_router.py CHANGED
@@ -8,6 +8,7 @@ from loguru import logger
8
 
9
  from config.provider_ids import SUPPORTED_PROVIDER_IDS
10
  from config.settings import Settings
 
11
 
12
  from .gateway_model_ids import decode_gateway_model_id
13
  from .models.anthropic import MessagesRequest, TokenCountRequest
@@ -143,6 +144,7 @@ class ModelRouter:
143
  """Resolve a model name to a prioritized list of candidates.
144
 
145
  Used by the 'auto' routing logic to implement provider-side failover.
 
146
  """
147
  if not self._is_auto(claude_model_name):
148
  return [self.resolve(claude_model_name)]
@@ -150,7 +152,7 @@ class ModelRouter:
150
  healthy_candidates: list[ResolvedModel] = []
151
  blocked_candidates: list[ResolvedModel] = []
152
  seen: set[str] = set()
153
-
154
 
155
  def add_candidate(ref: str | None, source: str) -> None:
156
  normalized_ref = self._normalize_candidate_ref(ref or "")
@@ -169,7 +171,13 @@ class ModelRouter:
169
  )
170
 
171
  limiter = GlobalRateLimiter.get_scoped_instance(provider_id)
172
- if limiter.is_blocked():
 
 
 
 
 
 
173
  logger.debug(
174
  "Routing: candidate '{}' (from {}) is BLOCKED",
175
  normalized_ref,
@@ -177,12 +185,14 @@ class ModelRouter:
177
  )
178
  blocked_candidates.append(resolved)
179
  else:
 
180
  logger.debug(
181
  "Routing: added candidate '{}' (from {})",
182
  normalized_ref,
183
  source,
184
  )
185
  healthy_candidates.append(resolved)
 
186
  else:
187
  logger.debug(
188
  "Routing: candidate '{}' (from {}) is NOT CONFIGURED",
@@ -210,6 +220,15 @@ class ModelRouter:
210
  add_candidate(self._settings.model_sonnet, "MODEL_SONNET")
211
  add_candidate(self._settings.model_haiku, "MODEL_HAIKU")
212
 
 
 
 
 
 
 
 
 
 
213
  all_candidates = healthy_candidates + blocked_candidates
214
  logger.info(
215
  "Routing: resolved '{}' to {} candidates: {}",
 
8
 
9
  from config.provider_ids import SUPPORTED_PROVIDER_IDS
10
  from config.settings import Settings
11
+ from core.session_tracker import SessionTracker
12
 
13
  from .gateway_model_ids import decode_gateway_model_id
14
  from .models.anthropic import MessagesRequest, TokenCountRequest
 
144
  """Resolve a model name to a prioritized list of candidates.
145
 
146
  Used by the 'auto' routing logic to implement provider-side failover.
147
+ Considers session load for fair resource sharing across multiple clients.
148
  """
149
  if not self._is_auto(claude_model_name):
150
  return [self.resolve(claude_model_name)]
 
152
  healthy_candidates: list[ResolvedModel] = []
153
  blocked_candidates: list[ResolvedModel] = []
154
  seen: set[str] = set()
155
+ session_tracker = SessionTracker.get_instance()
156
 
157
  def add_candidate(ref: str | None, source: str) -> None:
158
  normalized_ref = self._normalize_candidate_ref(ref or "")
 
171
  )
172
 
173
  limiter = GlobalRateLimiter.get_scoped_instance(provider_id)
174
+ is_blocked = limiter.is_blocked()
175
+
176
+ # For Zen provider, never consider it blocked (no rate limits)
177
+ if provider_id == "zen":
178
+ is_blocked = False
179
+
180
+ if is_blocked:
181
  logger.debug(
182
  "Routing: candidate '{}' (from {}) is BLOCKED",
183
  normalized_ref,
 
185
  )
186
  blocked_candidates.append(resolved)
187
  else:
188
+ # Smart ordering: Zen (no rate limits) gets priority, then by load
189
  logger.debug(
190
  "Routing: added candidate '{}' (from {})",
191
  normalized_ref,
192
  source,
193
  )
194
  healthy_candidates.append(resolved)
195
+
196
  else:
197
  logger.debug(
198
  "Routing: candidate '{}' (from {}) is NOT CONFIGURED",
 
220
  add_candidate(self._settings.model_sonnet, "MODEL_SONNET")
221
  add_candidate(self._settings.model_haiku, "MODEL_HAIKU")
222
 
223
+ # Smart ordering: Zen goes first (no rate limits), then sort by load
224
+ def provider_priority(c: ResolvedModel) -> tuple:
225
+ # Priority: zen > others, then by active request count
226
+ is_zen = 0 if c.provider_id == "zen" else 1
227
+ active = session_tracker._provider_active.get(c.provider_id, 0)
228
+ return (is_zen, active)
229
+
230
+ healthy_candidates.sort(key=provider_priority)
231
+
232
  all_candidates = healthy_candidates + blocked_candidates
233
  logger.info(
234
  "Routing: resolved '{}' to {} candidates: {}",
api/services.py CHANGED
@@ -14,6 +14,7 @@ from loguru import logger
14
  from config.settings import Settings
15
  from core.anthropic import get_token_count, get_user_facing_error_message
16
  from core.anthropic.sse import ANTHROPIC_SSE_RESPONSE_HEADERS, format_sse_event
 
17
  from providers.base import BaseProvider
18
  from providers.exceptions import (
19
  InvalidRequestError,
@@ -101,6 +102,15 @@ class ClaudeProxyService:
101
  self._provider_getter = provider_getter
102
  self._model_router = model_router or ModelRouter(settings)
103
  self._token_counter = token_counter
 
 
 
 
 
 
 
 
 
104
 
105
  def create_message(self, request_data: MessagesRequest) -> object:
106
  """Create a message response or streaming response with optional failover."""
 
14
  from config.settings import Settings
15
  from core.anthropic import get_token_count, get_user_facing_error_message
16
  from core.anthropic.sse import ANTHROPIC_SSE_RESPONSE_HEADERS, format_sse_event
17
+ from core.session_tracker import SessionTracker
18
  from providers.base import BaseProvider
19
  from providers.exceptions import (
20
  InvalidRequestError,
 
102
  self._provider_getter = provider_getter
103
  self._model_router = model_router or ModelRouter(settings)
104
  self._token_counter = token_counter
105
+ self._session_tracker = SessionTracker.get_instance()
106
+
107
+ def _get_session_id(self, request_data: MessagesRequest) -> str:
108
+ """Extract or generate a session ID from the request."""
109
+ # Try to extract session ID from messages metadata or generate one
110
+ # This allows multiple Claude Code instances to share the proxy fairly
111
+ if hasattr(request_data, 'custom_id'):
112
+ return str(request_data.custom_id)
113
+ return f"session_{uuid.uuid4().hex[:12]}"
114
 
115
  def create_message(self, request_data: MessagesRequest) -> object:
116
  """Create a message response or streaming response with optional failover."""
core/session_tracker.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Session-aware request tracking for fair resource sharing across Claude Code instances."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import time
7
+ from collections import defaultdict
8
+ from dataclasses import dataclass, field
9
+ from typing import ClassVar
10
+
11
+ from loguru import logger
12
+
13
+
14
+ @dataclass(slots=True)
15
+ class SessionState:
16
+ """State for a single session across all providers."""
17
+ requests_in_window: int = 0
18
+ last_request_time: float = 0.0
19
+ total_requests: int = 0
20
+
21
+
22
+ @dataclass(frozen=True, slots=True)
23
+ class ProviderLoad:
24
+ """Load information for a single provider."""
25
+ provider_id: str
26
+ active_requests: int
27
+ session_count: int
28
+ requests_per_minute: float
29
+ is_healthy: bool # Not rate limited
30
+
31
+
32
+ @dataclass(frozen=True, slots=True)
33
+ class SessionLoad:
34
+ """Load information for a session across all providers."""
35
+ session_id: str
36
+ total_requests: int
37
+ providers: dict[str, int] # provider_id -> request count
38
+
39
+
40
+ class SessionTracker:
41
+ """
42
+ Track request rates per session and per provider for fair resource sharing.
43
+
44
+ This enables multiple Claude Code instances to share the proxy efficiently
45
+ without one session starving others.
46
+ """
47
+
48
+ _instance: ClassVar[SessionTracker | None] = None
49
+
50
+ def __init__(
51
+ self,
52
+ *,
53
+ max_sessions: int = 50,
54
+ window_seconds: float = 60.0,
55
+ per_session_rate_limit: int = 30,
56
+ ):
57
+ if hasattr(self, "_initialized"):
58
+ return
59
+
60
+ self._sessions: dict[str, SessionState] = {}
61
+ self._session_requests: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
62
+ self._provider_active: dict[str, int] = defaultdict(int)
63
+ self._max_sessions = max_sessions
64
+ self._window_seconds = window_seconds
65
+ self._per_session_rate_limit = per_session_rate_limit
66
+ self._lock = asyncio.Lock()
67
+ self._initialized = True
68
+
69
+ logger.info(
70
+ "SessionTracker initialized (max_sessions={}, window={}s, per_session_limit={}/min)",
71
+ max_sessions,
72
+ window_seconds,
73
+ per_session_rate_limit,
74
+ )
75
+
76
+ @classmethod
77
+ def get_instance(cls, **kwargs) -> SessionTracker:
78
+ """Get or create the singleton instance."""
79
+ if cls._instance is None:
80
+ cls._instance = cls(**kwargs)
81
+ return cls._instance
82
+
83
+ @classmethod
84
+ def reset_instance(cls) -> None:
85
+ """Reset singleton (for testing)."""
86
+ cls._instance = None
87
+
88
+ def _cleanup_old_sessions(self) -> None:
89
+ """Remove sessions with no recent activity."""
90
+ now = time.monotonic()
91
+ cutoff = now - (self._window_seconds * 2)
92
+ to_remove = [
93
+ sid for sid, state in self._sessions.items()
94
+ if state.last_request_time < cutoff
95
+ ]
96
+ for sid in to_remove:
97
+ del self._sessions[sid]
98
+ if sid in self._session_requests:
99
+ del self._session_requests[sid]
100
+
101
+ def _evict_lru_session(self) -> None:
102
+ """Evict least recently used session when at capacity."""
103
+ if not self._sessions:
104
+ return
105
+ lru_sid = min(
106
+ self._sessions.items(),
107
+ key=lambda x: x[1].last_request_time
108
+ )[0]
109
+ del self._sessions[lru_sid]
110
+ if lru_sid in self._session_requests:
111
+ del self._session_requests[lru_sid]
112
+ logger.warning("SessionTracker: Evicted LRU session '{}'", lru_sid)
113
+
114
+ async def track_request(self, session_id: str, provider_id: str) -> None:
115
+ """Record a request for a session to a provider."""
116
+ async with self._lock:
117
+ self._cleanup_old_sessions()
118
+
119
+ if session_id not in self._sessions:
120
+ if len(self._sessions) >= self._max_sessions:
121
+ self._evict_lru_session()
122
+ self._sessions[session_id] = SessionState()
123
+
124
+ state = self._sessions[session_id]
125
+ state.requests_in_window += 1
126
+ state.last_request_time = time.monotonic()
127
+ state.total_requests += 1
128
+
129
+ self._session_requests[session_id][provider_id] += 1
130
+ self._provider_active[provider_id] += 1
131
+
132
+ async def release_request(self, session_id: str, provider_id: str) -> None:
133
+ """Release a request slot when streaming completes."""
134
+ async with self._lock:
135
+ self._provider_active[provider_id] = max(0, self._provider_active[provider_id] - 1)
136
+
137
+ def get_provider_load(self, provider_id: str, blocked: bool = False) -> ProviderLoad:
138
+ """Get current load information for a provider."""
139
+ session_count = sum(
140
+ 1 for sid in self._sessions
141
+ if self._session_requests[sid].get(provider_id, 0) > 0
142
+ )
143
+ total_requests = sum(
144
+ self._session_requests[sid].get(provider_id, 0)
145
+ for sid in self._sessions
146
+ )
147
+
148
+ return ProviderLoad(
149
+ provider_id=provider_id,
150
+ active_requests=self._provider_active.get(provider_id, 0),
151
+ session_count=session_count,
152
+ requests_per_minute=total_requests,
153
+ is_healthy=not blocked,
154
+ )
155
+
156
+ def get_all_provider_loads(self, blocked_providers: set[str] | None = None) -> dict[str, ProviderLoad]:
157
+ """Get load information for all providers."""
158
+ blocked = blocked_providers or set()
159
+ all_providers = set(self._provider_active.keys())
160
+
161
+ # Add providers from sessions even if not currently active
162
+ for sid in self._session_requests:
163
+ for provider_id in self._session_requests[sid]:
164
+ all_providers.add(provider_id)
165
+
166
+ return {
167
+ pid: self.get_provider_load(pid, pid in blocked)
168
+ for pid in all_providers
169
+ }
170
+
171
+ def get_session_load(self, session_id: str) -> SessionLoad | None:
172
+ """Get load information for a specific session."""
173
+ if session_id not in self._sessions:
174
+ return None
175
+
176
+ state = self._sessions[session_id]
177
+ provider_counts = dict(self._session_requests[session_id])
178
+
179
+ return SessionLoad(
180
+ session_id=session_id,
181
+ total_requests=state.total_requests,
182
+ providers=provider_counts,
183
+ )
184
+
185
+ def get_all_session_loads(self) -> dict[str, SessionLoad]:
186
+ """Get load information for all active sessions."""
187
+ return {
188
+ sid: self.get_session_load(sid)
189
+ for sid in self._sessions
190
+ if self.get_session_load(sid) is not None
191
+ }
192
+
193
+ async def check_session_allowed(self, session_id: str) -> tuple[bool, str]:
194
+ """
195
+ Check if a session is within its rate limit.
196
+
197
+ Returns (allowed, reason) tuple.
198
+ """
199
+ now = time.monotonic()
200
+
201
+ async with self._lock:
202
+ if session_id not in self._sessions:
203
+ return True, "new session"
204
+
205
+ state = self._sessions[session_id]
206
+ window_start = now - self._window_seconds
207
+
208
+ # Count requests in current window
209
+ recent_requests = [
210
+ sid for sid, s in self._sessions.items()
211
+ if s.last_request_time >= window_start
212
+ ]
213
+
214
+ total_in_window = sum(
215
+ self._sessions[sid].requests_in_window
216
+ for sid in recent_requests
217
+ ) // len(recent_requests) if recent_requests else 0
218
+
219
+ if state.requests_in_window > self._per_session_rate_limit:
220
+ return False, f"rate limit exceeded ({state.requests_in_window}/{self._per_session_rate_limit}/min)"
221
+
222
+ return True, "ok"
223
+
224
+ def get_healthy_provider_priority(
225
+ self,
226
+ candidates: list[str],
227
+ blocked_providers: set[str] | None = None,
228
+ ) -> list[str]:
229
+ """
230
+ Return candidates sorted by health/load priority.
231
+
232
+ Healthy providers with lower load come first.
233
+ """
234
+ blocked = blocked_providers or set()
235
+ return sorted(
236
+ candidates,
237
+ key=lambda pid: (
238
+ pid in blocked, # Blocked providers go last
239
+ self._provider_active.get(pid, 0), # Lower load first
240
+ )
241
+ )
242
+
243
+ def stats(self) -> dict:
244
+ """Return current statistics."""
245
+ return {
246
+ "active_sessions": len(self._sessions),
247
+ "total_providers": len(self._provider_active),
248
+ "provider_active": dict(self._provider_active),
249
+ }
providers/openai_compat.py CHANGED
@@ -77,14 +77,23 @@ class OpenAIChatTransport(BaseProvider):
77
  self._base_url = base_url.rstrip("/")
78
  self._http_client = None
79
  self._client_cache: dict[str, AsyncOpenAI] = {}
 
 
 
 
 
 
 
 
80
  self._global_rate_limiter = GlobalRateLimiter.get_scoped_instance(
81
  provider_name.lower(),
82
- rate_limit=config.rate_limit,
83
  rate_window=config.rate_window,
84
- max_concurrency=config.max_concurrency,
85
  )
86
  # Always create an explicit httpx.AsyncClient with trust_env=False to avoid
87
  # slow system proxy detection on Windows during initialization.
 
88
  http_client_args = {
89
  "timeout": httpx.Timeout(
90
  config.http_read_timeout,
@@ -94,6 +103,11 @@ class OpenAIChatTransport(BaseProvider):
94
  ),
95
  "trust_env": False,
96
  "http2": True,
 
 
 
 
 
97
  }
98
  if config.proxy:
99
  http_client_args["proxy"] = config.proxy
 
77
  self._base_url = base_url.rstrip("/")
78
  self._http_client = None
79
  self._client_cache: dict[str, AsyncOpenAI] = {}
80
+ # Zen has no rate limits - use very high limits to avoid throttling
81
+ # NVIDIA NIM has 40 req/min - respect that limit
82
+ if provider_name.lower() == "zen":
83
+ effective_rate_limit = 9999 # Effectively unlimited
84
+ effective_max_concurrency = config.max_concurrency * 4 # Higher concurrency for Zen
85
+ else:
86
+ effective_rate_limit = config.rate_limit or 40
87
+ effective_max_concurrency = config.max_concurrency
88
  self._global_rate_limiter = GlobalRateLimiter.get_scoped_instance(
89
  provider_name.lower(),
90
+ rate_limit=effective_rate_limit,
91
  rate_window=config.rate_window,
92
+ max_concurrency=effective_max_concurrency,
93
  )
94
  # Always create an explicit httpx.AsyncClient with trust_env=False to avoid
95
  # slow system proxy detection on Windows during initialization.
96
+ # Connection pool tuned for high throughput with keepalive optimization.
97
  http_client_args = {
98
  "timeout": httpx.Timeout(
99
  config.http_read_timeout,
 
103
  ),
104
  "trust_env": False,
105
  "http2": True,
106
+ "limits": httpx.Limits(
107
+ max_keepalive_connections=20,
108
+ max_connections=100,
109
+ max_keepalive_expiry=30.0,
110
+ ),
111
  }
112
  if config.proxy:
113
  http_client_args["proxy"] = config.proxy