Spaces:
Paused
Paused
File size: 23,113 Bytes
21dcb11 3c52746 2be7535 bd84d38 d51ae99 d28e7c5 21dcb11 2be7535 4dfb828 3c52746 21dcb11 d28e7c5 21dcb11 4dfb828 63418af 4dfb828 bd84d38 3c52746 aefb706 21dcb11 d51ae99 21dcb11 d51ae99 21dcb11 576a3ec d28e7c5 4dfb828 d28e7c5 4dfb828 d28e7c5 4dfb828 d28e7c5 4dfb828 576a3ec 4dfb828 576a3ec 4dfb828 576a3ec 4dfb828 d28e7c5 b5da45c 4dfb828 b5da45c 3c52746 b5da45c 3c52746 4dfb828 3c52746 4dfb828 b5da45c 4dfb828 b5da45c 3c52746 b5da45c 3c52746 4dfb828 b5da45c 4dfb828 b5da45c 4dfb828 b5da45c 4dfb828 b5da45c 4dfb828 b5da45c 4dfb828 bd84d38 0ca1651 bd84d38 98f6823 3c52746 98f6823 3c52746 98f6823 0ca1651 98f6823 0ca1651 3c52746 98f6823 0ca1651 98f6823 3c52746 98f6823 0ca1651 3c52746 0ca1651 3c52746 0ca1651 3c52746 0ca1651 3c52746 0ca1651 3c52746 0ca1651 3c52746 aefb706 2be7535 d7c643f 2be7535 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 | from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import (
List,
Dict,
Any,
Optional,
AsyncGenerator,
Union,
FrozenSet,
TYPE_CHECKING,
)
import os
import httpx
import litellm
if TYPE_CHECKING:
from ..usage_manager import UsageManager
# =============================================================================
# TIER & USAGE CONFIGURATION TYPES
# =============================================================================
@dataclass(frozen=True)
class UsageResetConfigDef:
"""
Definition for usage reset configuration per tier type.
Providers define these as class attributes to specify how usage stats
should reset based on credential tier (paid vs free).
Attributes:
window_seconds: Duration of the usage tracking window in seconds.
mode: Either "credential" (one window per credential) or "per_model"
(separate window per model or model group).
description: Human-readable description for logging.
field_name: The key used in usage data JSON structure.
Typically "models" for per_model mode, "daily" for credential mode.
"""
window_seconds: int
mode: str # "credential" or "per_model"
description: str
field_name: str = "daily" # Default for backwards compatibility
# Type aliases for provider configuration
TierPriorityMap = Dict[str, int] # tier_name -> priority
UsageConfigKey = Union[FrozenSet[int], str] # frozenset of priorities OR "default"
UsageConfigMap = Dict[UsageConfigKey, UsageResetConfigDef] # priority_set -> config
QuotaGroupMap = Dict[str, List[str]] # group_name -> [models]
class ProviderInterface(ABC):
"""
An interface for API provider-specific functionality, including model
discovery and custom API call handling for non-standard providers.
"""
skip_cost_calculation: bool = False
# Default rotation mode for this provider ("balanced" or "sequential")
# - "balanced": Rotate credentials to distribute load evenly
# - "sequential": Use one credential until exhausted, then switch to next
default_rotation_mode: str = "balanced"
# =========================================================================
# TIER CONFIGURATION - Override in subclass
# =========================================================================
# Provider name for env var lookups (e.g., "antigravity", "gemini_cli")
# Used for: QUOTA_GROUPS_{provider_env_name}_{GROUP}
provider_env_name: str = ""
# Tier name -> priority mapping (Single Source of Truth)
# Lower numbers = higher priority (1 is highest)
# Multiple tiers can map to the same priority
# Unknown tiers fall back to default_tier_priority
tier_priorities: TierPriorityMap = {}
# Default priority for tiers not in tier_priorities mapping
default_tier_priority: int = 10
# =========================================================================
# USAGE RESET CONFIGURATION - Override in subclass
# =========================================================================
# Usage reset configurations keyed by priority sets
# Keys: frozenset of priority values (e.g., frozenset({1, 2})) OR "default"
# The "default" key is used for any priority not matched by a frozenset
usage_reset_configs: UsageConfigMap = {}
# =========================================================================
# MODEL QUOTA GROUPS - Override in subclass
# =========================================================================
# Models that share quota/cooldown timing
# Can be overridden via env: QUOTA_GROUPS_{PROVIDER}_{GROUP}="model1,model2"
model_quota_groups: QuotaGroupMap = {}
# Model usage weights for grouped usage calculation
# When calculating combined usage for quota groups, each model's usage
# is multiplied by its weight. This accounts for models that consume
# more quota per request (e.g., Opus uses more than Sonnet).
# Models not in the map default to weight 1.
# Example: {"claude-opus-4-5": 2} means Opus usage counts 2x
model_usage_weights: Dict[str, int] = {}
# =========================================================================
# PRIORITY CONCURRENCY MULTIPLIERS - Override in subclass
# =========================================================================
# Priority-based concurrency multipliers (universal, applies to all modes)
# Maps priority level -> multiplier
# Higher priority credentials (lower number) can have higher multipliers
# to allow more concurrent requests
# Example: {1: 5, 2: 3} means Priority 1 gets 5x, Priority 2 gets 3x
default_priority_multipliers: Dict[int, int] = {}
# Fallback multiplier for sequential mode when priority not in default_priority_multipliers
# This is used for lower-priority tiers in sequential mode to maintain some stickiness
# Default: 1 (no multiplier effect)
default_sequential_fallback_multiplier: int = 1
@abstractmethod
async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]:
"""
Fetches the list of available model names from the provider's API.
Args:
api_key: The API key required for authentication.
client: An httpx.AsyncClient instance for making requests.
Returns:
A list of model name strings.
"""
pass
# [NEW] Add methods for providers that need to bypass litellm
def has_custom_logic(self) -> bool:
"""
Returns True if the provider implements its own acompletion/aembedding logic,
bypassing the standard litellm call.
"""
return False
async def acompletion(
self, client: httpx.AsyncClient, **kwargs
) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]:
"""
Handles the entire completion call for non-standard providers.
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement custom acompletion."
)
async def aembedding(
self, client: httpx.AsyncClient, **kwargs
) -> litellm.EmbeddingResponse:
"""Handles the entire embedding call for non-standard providers."""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement custom aembedding."
)
def convert_safety_settings(
self, settings: Dict[str, str]
) -> Optional[List[Dict[str, Any]]]:
"""
Converts a generic safety settings dictionary to the provider-specific format.
Args:
settings: A dictionary with generic harm categories and thresholds.
Returns:
A list of provider-specific safety setting objects or None.
"""
return None
# [NEW] Add new methods for OAuth providers
async def get_auth_header(self, credential_identifier: str) -> Dict[str, str]:
"""
For OAuth providers, this method returns the Authorization header.
For API key providers, this can be a no-op or raise NotImplementedError.
"""
raise NotImplementedError("This provider does not support OAuth.")
async def proactively_refresh(self, credential_path: str):
"""
Proactively refreshes a token if it's nearing expiry.
"""
pass
# [NEW] Credential Prioritization System
# =========================================================================
# TIER RESOLUTION LOGIC (Centralized)
# =========================================================================
def _resolve_tier_priority(self, tier_name: Optional[str]) -> int:
"""
Resolve priority for a tier name using provider's tier_priorities mapping.
Args:
tier_name: The tier name string (e.g., "free-tier", "standard-tier")
Returns:
Priority level from tier_priorities, or default_tier_priority if
tier_name is None or not found in the mapping.
"""
if tier_name is None:
return self.default_tier_priority
return self.tier_priorities.get(tier_name, self.default_tier_priority)
def get_credential_priority(self, credential: str) -> Optional[int]:
"""
Returns the priority level for a credential.
Lower numbers = higher priority (1 is highest).
Returns None if tier not yet discovered.
Uses the provider's tier_priorities mapping to resolve priority from
tier name. Unknown tiers fall back to default_tier_priority.
Subclasses should:
1. Define tier_priorities dict with all known tier names
2. Override get_credential_tier_name() for tier lookup
Do NOT override this method.
Args:
credential: The credential identifier (API key or path)
Returns:
Priority level (1-10) or None if tier not yet discovered
"""
tier = self.get_credential_tier_name(credential)
if tier is None:
return None # Tier not yet discovered
return self._resolve_tier_priority(tier)
def get_model_tier_requirement(self, model: str) -> Optional[int]:
"""
Returns the minimum priority tier required for a model.
If a model requires priority 1, only credentials with priority <= 1 can use it.
This allows providers to restrict certain models to specific credential tiers.
For example, Gemini 3 models require paid-tier credentials.
Args:
model: The model name (with or without provider prefix)
Returns:
Minimum required priority level or None if no restrictions
Example:
For Gemini CLI:
- gemini-3-*: requires priority 1 (paid tier only)
- gemini-2.5-*: no restriction (None)
"""
return None
async def initialize_credentials(self, credential_paths: List[str]) -> None:
"""
Called at startup to initialize provider with all available credentials.
Providers can override this to load cached tier data, discover priorities,
or perform any other initialization needed before the first API request.
This is called once during startup by the BackgroundRefresher before
the main refresh loop begins.
Args:
credential_paths: List of credential file paths for this provider
"""
pass
def get_credential_tier_name(self, credential: str) -> Optional[str]:
"""
Returns the human-readable tier name for a credential.
This is used for logging purposes to show which plan tier a credential belongs to.
Args:
credential: The credential identifier (API key or path)
Returns:
Tier name string (e.g., "free-tier", "paid-tier") or None if unknown
"""
return None
# =========================================================================
# Sequential Rotation Support
# =========================================================================
@classmethod
def get_rotation_mode(cls, provider_name: str) -> str:
"""
Get the rotation mode for this provider.
Checks ROTATION_MODE_{PROVIDER} environment variable first,
then falls back to the class's default_rotation_mode.
Args:
provider_name: The provider name (e.g., "antigravity", "gemini_cli")
Returns:
"balanced" or "sequential"
"""
env_key = f"ROTATION_MODE_{provider_name.upper()}"
return os.getenv(env_key, cls.default_rotation_mode)
@staticmethod
def parse_quota_error(
error: Exception, error_body: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""
Parse a quota/rate-limit error and extract structured information.
Providers should override this method to handle their specific error formats.
This allows the error_handler to use provider-specific parsing when available,
falling back to generic parsing otherwise.
Args:
error: The caught exception
error_body: Optional raw response body string
Returns:
None if not a parseable quota error, otherwise:
{
"retry_after": int, # seconds until quota resets
"reason": str, # e.g., "QUOTA_EXHAUSTED", "RATE_LIMITED"
"reset_timestamp": str | None, # ISO timestamp if available
"quota_reset_timestamp": float | None, # Unix timestamp for quota reset
}
"""
return None # Default: no provider-specific parsing
# =========================================================================
# Per-Provider Usage Tracking Configuration
# =========================================================================
# =========================================================================
# USAGE RESET CONFIG LOGIC (Centralized)
# =========================================================================
def _find_usage_config_for_priority(
self, priority: int
) -> Optional[UsageResetConfigDef]:
"""
Find usage config that applies to a priority value.
Checks frozenset keys first (priority must be in the set),
then falls back to "default" key if no match found.
Args:
priority: The credential priority level
Returns:
UsageResetConfigDef if found, None otherwise
"""
# First, check frozenset keys for explicit priority match
for key, config in self.usage_reset_configs.items():
if isinstance(key, frozenset) and priority in key:
return config
# Fall back to "default" key
return self.usage_reset_configs.get("default")
def _build_usage_reset_config(
self, tier_name: Optional[str]
) -> Optional[Dict[str, Any]]:
"""
Build usage reset configuration dict for a tier.
Resolves tier to priority, then finds matching usage config.
Returns None if provider doesn't define usage_reset_configs.
Args:
tier_name: The tier name string
Returns:
Usage config dict with window_seconds, mode, priority, description,
field_name, or None if no config applies
"""
if not self.usage_reset_configs:
return None
priority = self._resolve_tier_priority(tier_name)
config = self._find_usage_config_for_priority(priority)
if config is None:
return None
return {
"window_seconds": config.window_seconds,
"mode": config.mode,
"priority": priority,
"description": config.description,
"field_name": config.field_name,
}
def get_usage_reset_config(self, credential: str) -> Optional[Dict[str, Any]]:
"""
Get provider-specific usage tracking configuration for a credential.
Uses the provider's usage_reset_configs class attribute to build
the configuration dict. Priority is auto-derived from tier.
Subclasses should define usage_reset_configs as a class attribute
instead of overriding this method. Only override get_credential_tier_name()
to provide the tier lookup mechanism.
The UsageManager will use this configuration to:
1. Track usage per-model or per-credential based on mode
2. Reset usage based on a rolling window OR quota exhausted timestamp
3. Archive stats to "global" when the window/quota expires
Args:
credential: The credential identifier (API key or path)
Returns:
None to use default daily reset, otherwise a dict with:
{
"window_seconds": int, # Duration in seconds (e.g., 18000 for 5h)
"mode": str, # "credential" or "per_model"
"priority": int, # Priority level (auto-derived from tier)
"description": str, # Human-readable description (for logging)
}
Modes:
- "credential": One window per credential. Window starts from first
request of ANY model. All models reset together when window expires.
- "per_model": Separate window per model (or model group). Window starts
from first request of THAT model. Models reset independently unless
grouped. If a quota_exhausted error provides exact reset time, that
becomes the authoritative reset time for the model.
"""
tier = self.get_credential_tier_name(credential)
return self._build_usage_reset_config(tier)
def get_default_usage_field_name(self) -> str:
"""
Get the default usage tracking field name for this provider.
Providers can override this to use a custom field name for usage tracking
when no credential-specific config is available.
Returns:
Field name string (default: "daily")
"""
return "daily"
# =========================================================================
# Model Quota Grouping
# =========================================================================
# =========================================================================
# QUOTA GROUPS LOGIC (Centralized)
# =========================================================================
def _get_effective_quota_groups(self) -> QuotaGroupMap:
"""
Get quota groups with .env overrides applied.
Env format: QUOTA_GROUPS_{PROVIDER}_{GROUP}="model1,model2"
Set empty string to disable a default group.
"""
if not self.provider_env_name or not self.model_quota_groups:
return self.model_quota_groups
result: QuotaGroupMap = {}
for group_name, default_models in self.model_quota_groups.items():
env_key = (
f"QUOTA_GROUPS_{self.provider_env_name.upper()}_{group_name.upper()}"
)
env_value = os.getenv(env_key)
if env_value is not None:
# Env override present
if env_value.strip():
# Parse comma-separated models
result[group_name] = [
m.strip() for m in env_value.split(",") if m.strip()
]
# Empty string = group disabled, don't add to result
else:
# Use default
result[group_name] = list(default_models)
return result
def _find_model_quota_group(self, model: str) -> Optional[str]:
"""Find which quota group a model belongs to."""
groups = self._get_effective_quota_groups()
for group_name, models in groups.items():
if model in models:
return group_name
return None
def _get_quota_group_models(self, group: str) -> List[str]:
"""Get all models in a quota group."""
groups = self._get_effective_quota_groups()
return groups.get(group, [])
def get_model_quota_group(self, model: str) -> Optional[str]:
"""
Returns the quota group name for a model, or None if not grouped.
Uses the provider's model_quota_groups class attribute with .env overrides
via QUOTA_GROUPS_{PROVIDER}_{GROUP}="model1,model2".
Models in the same quota group share cooldown timing - when one model
hits a quota exhausted error, all models in the group get the same
reset timestamp. They also reset (archive stats) together.
Subclasses should define model_quota_groups as a class attribute
instead of overriding this method.
Args:
model: Model name (with or without provider prefix)
Returns:
Group name string (e.g., "claude") or None if model is not grouped
"""
# Strip provider prefix if present
clean_model = model.split("/")[-1] if "/" in model else model
return self._find_model_quota_group(clean_model)
def get_models_in_quota_group(self, group: str) -> List[str]:
"""
Returns all model names that belong to a quota group.
Uses the provider's model_quota_groups class attribute with .env overrides.
Args:
group: Group name (e.g., "claude")
Returns:
List of model names (WITHOUT provider prefix) in the group.
Empty list if group doesn't exist.
"""
return self._get_quota_group_models(group)
def get_model_usage_weight(self, model: str) -> int:
"""
Returns the usage weight for a model when calculating grouped usage.
Models with higher weights contribute more to the combined group usage.
This accounts for models that consume more quota per request.
Args:
model: Model name (with or without provider prefix)
Returns:
Weight multiplier (default 1 if not configured)
"""
# Strip provider prefix if present
clean_model = model.split("/")[-1] if "/" in model else model
return self.model_usage_weights.get(clean_model, 1)
# =========================================================================
# BACKGROUND JOB INTERFACE - Override in subclass for periodic tasks
# =========================================================================
def get_background_job_config(self) -> Optional[Dict[str, Any]]:
"""
Return configuration for provider-specific background job, or None if none.
Providers that need periodic background tasks (e.g., quota refresh,
cache cleanup) should override this method.
The BackgroundRefresher will call run_background_job() at the specified
interval for each provider that returns a config.
Returns:
None if no background job, otherwise:
{
"interval": 300, # seconds between runs
"name": "my_job", # for logging (e.g., "quota_refresh")
"run_on_start": True, # whether to run immediately at startup
}
"""
return None
async def run_background_job(
self,
usage_manager: "UsageManager",
credentials: List[str],
) -> None:
"""
Execute the provider's periodic background job.
Called by BackgroundRefresher at the interval specified in
get_background_job_config(). Override this method to implement
provider-specific periodic tasks.
Args:
usage_manager: UsageManager instance for storing/reading usage data
credentials: List of credential paths for this provider
"""
pass
|