File size: 23,113 Bytes
21dcb11
3c52746
2be7535
 
 
 
 
 
 
 
 
 
bd84d38
d51ae99
d28e7c5
21dcb11
2be7535
 
 
4dfb828
3c52746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21dcb11
 
d28e7c5
 
21dcb11
4dfb828
63418af
4dfb828
bd84d38
 
 
 
 
3c52746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aefb706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21dcb11
d51ae99
21dcb11
 
 
 
 
d51ae99
21dcb11
 
 
 
 
576a3ec
d28e7c5
 
 
 
 
 
 
 
4dfb828
 
 
d28e7c5
 
 
4dfb828
 
 
d28e7c5
4dfb828
 
 
d28e7c5
4dfb828
 
 
 
 
 
 
576a3ec
 
4dfb828
576a3ec
 
4dfb828
576a3ec
 
 
 
4dfb828
d28e7c5
 
 
 
 
 
 
 
 
 
 
 
b5da45c
4dfb828
b5da45c
3c52746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5da45c
 
 
 
3c52746
 
 
 
4dfb828
3c52746
 
 
 
4dfb828
b5da45c
 
4dfb828
b5da45c
3c52746
b5da45c
3c52746
 
 
 
4dfb828
b5da45c
 
 
 
4dfb828
b5da45c
 
4dfb828
b5da45c
 
4dfb828
b5da45c
 
4dfb828
b5da45c
 
 
 
 
4dfb828
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd84d38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ca1651
bd84d38
 
 
 
98f6823
 
 
 
3c52746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98f6823
 
 
 
3c52746
 
 
 
 
 
98f6823
 
0ca1651
 
 
98f6823
 
 
 
 
 
 
 
0ca1651
3c52746
98f6823
 
 
0ca1651
 
 
 
 
 
 
98f6823
3c52746
 
98f6823
 
 
 
 
 
 
 
 
 
 
 
0ca1651
 
 
 
 
3c52746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ca1651
 
 
 
3c52746
 
 
0ca1651
 
 
 
3c52746
 
0ca1651
 
 
 
 
 
 
3c52746
 
 
0ca1651
 
 
 
 
3c52746
 
0ca1651
 
 
 
 
 
 
3c52746
aefb706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2be7535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7c643f
2be7535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import (
    List,
    Dict,
    Any,
    Optional,
    AsyncGenerator,
    Union,
    FrozenSet,
    TYPE_CHECKING,
)
import os
import httpx
import litellm

if TYPE_CHECKING:
    from ..usage_manager import UsageManager


# =============================================================================
# TIER & USAGE CONFIGURATION TYPES
# =============================================================================


@dataclass(frozen=True)
class UsageResetConfigDef:
    """
    Definition for usage reset configuration per tier type.

    Providers define these as class attributes to specify how usage stats
    should reset based on credential tier (paid vs free).

    Attributes:
        window_seconds: Duration of the usage tracking window in seconds.
        mode: Either "credential" (one window per credential) or "per_model"
              (separate window per model or model group).
        description: Human-readable description for logging.
        field_name: The key used in usage data JSON structure.
                    Typically "models" for per_model mode, "daily" for credential mode.
    """

    window_seconds: int
    mode: str  # "credential" or "per_model"
    description: str
    field_name: str = "daily"  # Default for backwards compatibility


# Type aliases for provider configuration
TierPriorityMap = Dict[str, int]  # tier_name -> priority
UsageConfigKey = Union[FrozenSet[int], str]  # frozenset of priorities OR "default"
UsageConfigMap = Dict[UsageConfigKey, UsageResetConfigDef]  # priority_set -> config
QuotaGroupMap = Dict[str, List[str]]  # group_name -> [models]


class ProviderInterface(ABC):
    """
    An interface for API provider-specific functionality, including model
    discovery and custom API call handling for non-standard providers.
    """

    skip_cost_calculation: bool = False

    # Default rotation mode for this provider ("balanced" or "sequential")
    # - "balanced": Rotate credentials to distribute load evenly
    # - "sequential": Use one credential until exhausted, then switch to next
    default_rotation_mode: str = "balanced"

    # =========================================================================
    # TIER CONFIGURATION - Override in subclass
    # =========================================================================

    # Provider name for env var lookups (e.g., "antigravity", "gemini_cli")
    # Used for: QUOTA_GROUPS_{provider_env_name}_{GROUP}
    provider_env_name: str = ""

    # Tier name -> priority mapping (Single Source of Truth)
    # Lower numbers = higher priority (1 is highest)
    # Multiple tiers can map to the same priority
    # Unknown tiers fall back to default_tier_priority
    tier_priorities: TierPriorityMap = {}

    # Default priority for tiers not in tier_priorities mapping
    default_tier_priority: int = 10

    # =========================================================================
    # USAGE RESET CONFIGURATION - Override in subclass
    # =========================================================================

    # Usage reset configurations keyed by priority sets
    # Keys: frozenset of priority values (e.g., frozenset({1, 2})) OR "default"
    # The "default" key is used for any priority not matched by a frozenset
    usage_reset_configs: UsageConfigMap = {}

    # =========================================================================
    # MODEL QUOTA GROUPS - Override in subclass
    # =========================================================================

    # Models that share quota/cooldown timing
    # Can be overridden via env: QUOTA_GROUPS_{PROVIDER}_{GROUP}="model1,model2"
    model_quota_groups: QuotaGroupMap = {}

    # Model usage weights for grouped usage calculation
    # When calculating combined usage for quota groups, each model's usage
    # is multiplied by its weight. This accounts for models that consume
    # more quota per request (e.g., Opus uses more than Sonnet).
    # Models not in the map default to weight 1.
    # Example: {"claude-opus-4-5": 2} means Opus usage counts 2x
    model_usage_weights: Dict[str, int] = {}

    # =========================================================================
    # PRIORITY CONCURRENCY MULTIPLIERS - Override in subclass
    # =========================================================================

    # Priority-based concurrency multipliers (universal, applies to all modes)
    # Maps priority level -> multiplier
    # Higher priority credentials (lower number) can have higher multipliers
    # to allow more concurrent requests
    # Example: {1: 5, 2: 3} means Priority 1 gets 5x, Priority 2 gets 3x
    default_priority_multipliers: Dict[int, int] = {}

    # Fallback multiplier for sequential mode when priority not in default_priority_multipliers
    # This is used for lower-priority tiers in sequential mode to maintain some stickiness
    # Default: 1 (no multiplier effect)
    default_sequential_fallback_multiplier: int = 1

    @abstractmethod
    async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]:
        """
        Fetches the list of available model names from the provider's API.

        Args:
            api_key: The API key required for authentication.
            client: An httpx.AsyncClient instance for making requests.

        Returns:
            A list of model name strings.
        """
        pass

    # [NEW] Add methods for providers that need to bypass litellm
    def has_custom_logic(self) -> bool:
        """
        Returns True if the provider implements its own acompletion/aembedding logic,
        bypassing the standard litellm call.
        """
        return False

    async def acompletion(
        self, client: httpx.AsyncClient, **kwargs
    ) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]:
        """
        Handles the entire completion call for non-standard providers.
        """
        raise NotImplementedError(
            f"{self.__class__.__name__} does not implement custom acompletion."
        )

    async def aembedding(
        self, client: httpx.AsyncClient, **kwargs
    ) -> litellm.EmbeddingResponse:
        """Handles the entire embedding call for non-standard providers."""
        raise NotImplementedError(
            f"{self.__class__.__name__} does not implement custom aembedding."
        )

    def convert_safety_settings(
        self, settings: Dict[str, str]
    ) -> Optional[List[Dict[str, Any]]]:
        """
        Converts a generic safety settings dictionary to the provider-specific format.

        Args:
            settings: A dictionary with generic harm categories and thresholds.

        Returns:
            A list of provider-specific safety setting objects or None.
        """
        return None

    # [NEW] Add new methods for OAuth providers
    async def get_auth_header(self, credential_identifier: str) -> Dict[str, str]:
        """
        For OAuth providers, this method returns the Authorization header.
        For API key providers, this can be a no-op or raise NotImplementedError.
        """
        raise NotImplementedError("This provider does not support OAuth.")

    async def proactively_refresh(self, credential_path: str):
        """
        Proactively refreshes a token if it's nearing expiry.
        """
        pass

    # [NEW] Credential Prioritization System

    # =========================================================================
    # TIER RESOLUTION LOGIC (Centralized)
    # =========================================================================

    def _resolve_tier_priority(self, tier_name: Optional[str]) -> int:
        """
        Resolve priority for a tier name using provider's tier_priorities mapping.

        Args:
            tier_name: The tier name string (e.g., "free-tier", "standard-tier")

        Returns:
            Priority level from tier_priorities, or default_tier_priority if
            tier_name is None or not found in the mapping.
        """
        if tier_name is None:
            return self.default_tier_priority
        return self.tier_priorities.get(tier_name, self.default_tier_priority)

    def get_credential_priority(self, credential: str) -> Optional[int]:
        """
        Returns the priority level for a credential.
        Lower numbers = higher priority (1 is highest).
        Returns None if tier not yet discovered.

        Uses the provider's tier_priorities mapping to resolve priority from
        tier name. Unknown tiers fall back to default_tier_priority.

        Subclasses should:
        1. Define tier_priorities dict with all known tier names
        2. Override get_credential_tier_name() for tier lookup
        Do NOT override this method.

        Args:
            credential: The credential identifier (API key or path)

        Returns:
            Priority level (1-10) or None if tier not yet discovered
        """
        tier = self.get_credential_tier_name(credential)
        if tier is None:
            return None  # Tier not yet discovered
        return self._resolve_tier_priority(tier)

    def get_model_tier_requirement(self, model: str) -> Optional[int]:
        """
        Returns the minimum priority tier required for a model.
        If a model requires priority 1, only credentials with priority <= 1 can use it.

        This allows providers to restrict certain models to specific credential tiers.
        For example, Gemini 3 models require paid-tier credentials.

        Args:
            model: The model name (with or without provider prefix)

        Returns:
            Minimum required priority level or None if no restrictions

        Example:
            For Gemini CLI:
            - gemini-3-*: requires priority 1 (paid tier only)
            - gemini-2.5-*: no restriction (None)
        """
        return None

    async def initialize_credentials(self, credential_paths: List[str]) -> None:
        """
        Called at startup to initialize provider with all available credentials.

        Providers can override this to load cached tier data, discover priorities,
        or perform any other initialization needed before the first API request.

        This is called once during startup by the BackgroundRefresher before
        the main refresh loop begins.

        Args:
            credential_paths: List of credential file paths for this provider
        """
        pass

    def get_credential_tier_name(self, credential: str) -> Optional[str]:
        """
        Returns the human-readable tier name for a credential.

        This is used for logging purposes to show which plan tier a credential belongs to.

        Args:
            credential: The credential identifier (API key or path)

        Returns:
            Tier name string (e.g., "free-tier", "paid-tier") or None if unknown
        """
        return None

    # =========================================================================
    # Sequential Rotation Support
    # =========================================================================

    @classmethod
    def get_rotation_mode(cls, provider_name: str) -> str:
        """
        Get the rotation mode for this provider.

        Checks ROTATION_MODE_{PROVIDER} environment variable first,
        then falls back to the class's default_rotation_mode.

        Args:
            provider_name: The provider name (e.g., "antigravity", "gemini_cli")

        Returns:
            "balanced" or "sequential"
        """
        env_key = f"ROTATION_MODE_{provider_name.upper()}"
        return os.getenv(env_key, cls.default_rotation_mode)

    @staticmethod
    def parse_quota_error(
        error: Exception, error_body: Optional[str] = None
    ) -> Optional[Dict[str, Any]]:
        """
        Parse a quota/rate-limit error and extract structured information.

        Providers should override this method to handle their specific error formats.
        This allows the error_handler to use provider-specific parsing when available,
        falling back to generic parsing otherwise.

        Args:
            error: The caught exception
            error_body: Optional raw response body string

        Returns:
            None if not a parseable quota error, otherwise:
            {
                "retry_after": int,  # seconds until quota resets
                "reason": str,       # e.g., "QUOTA_EXHAUSTED", "RATE_LIMITED"
                "reset_timestamp": str | None,  # ISO timestamp if available
                "quota_reset_timestamp": float | None,  # Unix timestamp for quota reset
            }
        """
        return None  # Default: no provider-specific parsing

    # =========================================================================
    # Per-Provider Usage Tracking Configuration
    # =========================================================================

    # =========================================================================
    # USAGE RESET CONFIG LOGIC (Centralized)
    # =========================================================================

    def _find_usage_config_for_priority(
        self, priority: int
    ) -> Optional[UsageResetConfigDef]:
        """
        Find usage config that applies to a priority value.

        Checks frozenset keys first (priority must be in the set),
        then falls back to "default" key if no match found.

        Args:
            priority: The credential priority level

        Returns:
            UsageResetConfigDef if found, None otherwise
        """
        # First, check frozenset keys for explicit priority match
        for key, config in self.usage_reset_configs.items():
            if isinstance(key, frozenset) and priority in key:
                return config

        # Fall back to "default" key
        return self.usage_reset_configs.get("default")

    def _build_usage_reset_config(
        self, tier_name: Optional[str]
    ) -> Optional[Dict[str, Any]]:
        """
        Build usage reset configuration dict for a tier.

        Resolves tier to priority, then finds matching usage config.
        Returns None if provider doesn't define usage_reset_configs.

        Args:
            tier_name: The tier name string

        Returns:
            Usage config dict with window_seconds, mode, priority, description,
            field_name, or None if no config applies
        """
        if not self.usage_reset_configs:
            return None

        priority = self._resolve_tier_priority(tier_name)
        config = self._find_usage_config_for_priority(priority)

        if config is None:
            return None

        return {
            "window_seconds": config.window_seconds,
            "mode": config.mode,
            "priority": priority,
            "description": config.description,
            "field_name": config.field_name,
        }

    def get_usage_reset_config(self, credential: str) -> Optional[Dict[str, Any]]:
        """
        Get provider-specific usage tracking configuration for a credential.

        Uses the provider's usage_reset_configs class attribute to build
        the configuration dict. Priority is auto-derived from tier.

        Subclasses should define usage_reset_configs as a class attribute
        instead of overriding this method. Only override get_credential_tier_name()
        to provide the tier lookup mechanism.

        The UsageManager will use this configuration to:
        1. Track usage per-model or per-credential based on mode
        2. Reset usage based on a rolling window OR quota exhausted timestamp
        3. Archive stats to "global" when the window/quota expires

        Args:
            credential: The credential identifier (API key or path)

        Returns:
            None to use default daily reset, otherwise a dict with:
            {
                "window_seconds": int,     # Duration in seconds (e.g., 18000 for 5h)
                "mode": str,               # "credential" or "per_model"
                "priority": int,           # Priority level (auto-derived from tier)
                "description": str,        # Human-readable description (for logging)
            }

        Modes:
            - "credential": One window per credential. Window starts from first
              request of ANY model. All models reset together when window expires.
            - "per_model": Separate window per model (or model group). Window starts
              from first request of THAT model. Models reset independently unless
              grouped. If a quota_exhausted error provides exact reset time, that
              becomes the authoritative reset time for the model.
        """
        tier = self.get_credential_tier_name(credential)
        return self._build_usage_reset_config(tier)

    def get_default_usage_field_name(self) -> str:
        """
        Get the default usage tracking field name for this provider.

        Providers can override this to use a custom field name for usage tracking
        when no credential-specific config is available.

        Returns:
            Field name string (default: "daily")
        """
        return "daily"

    # =========================================================================
    # Model Quota Grouping
    # =========================================================================

    # =========================================================================
    # QUOTA GROUPS LOGIC (Centralized)
    # =========================================================================

    def _get_effective_quota_groups(self) -> QuotaGroupMap:
        """
        Get quota groups with .env overrides applied.

        Env format: QUOTA_GROUPS_{PROVIDER}_{GROUP}="model1,model2"
        Set empty string to disable a default group.
        """
        if not self.provider_env_name or not self.model_quota_groups:
            return self.model_quota_groups

        result: QuotaGroupMap = {}

        for group_name, default_models in self.model_quota_groups.items():
            env_key = (
                f"QUOTA_GROUPS_{self.provider_env_name.upper()}_{group_name.upper()}"
            )
            env_value = os.getenv(env_key)

            if env_value is not None:
                # Env override present
                if env_value.strip():
                    # Parse comma-separated models
                    result[group_name] = [
                        m.strip() for m in env_value.split(",") if m.strip()
                    ]
                # Empty string = group disabled, don't add to result
            else:
                # Use default
                result[group_name] = list(default_models)

        return result

    def _find_model_quota_group(self, model: str) -> Optional[str]:
        """Find which quota group a model belongs to."""
        groups = self._get_effective_quota_groups()
        for group_name, models in groups.items():
            if model in models:
                return group_name
        return None

    def _get_quota_group_models(self, group: str) -> List[str]:
        """Get all models in a quota group."""
        groups = self._get_effective_quota_groups()
        return groups.get(group, [])

    def get_model_quota_group(self, model: str) -> Optional[str]:
        """
        Returns the quota group name for a model, or None if not grouped.

        Uses the provider's model_quota_groups class attribute with .env overrides
        via QUOTA_GROUPS_{PROVIDER}_{GROUP}="model1,model2".

        Models in the same quota group share cooldown timing - when one model
        hits a quota exhausted error, all models in the group get the same
        reset timestamp. They also reset (archive stats) together.

        Subclasses should define model_quota_groups as a class attribute
        instead of overriding this method.

        Args:
            model: Model name (with or without provider prefix)

        Returns:
            Group name string (e.g., "claude") or None if model is not grouped
        """
        # Strip provider prefix if present
        clean_model = model.split("/")[-1] if "/" in model else model
        return self._find_model_quota_group(clean_model)

    def get_models_in_quota_group(self, group: str) -> List[str]:
        """
        Returns all model names that belong to a quota group.

        Uses the provider's model_quota_groups class attribute with .env overrides.

        Args:
            group: Group name (e.g., "claude")

        Returns:
            List of model names (WITHOUT provider prefix) in the group.
            Empty list if group doesn't exist.
        """
        return self._get_quota_group_models(group)

    def get_model_usage_weight(self, model: str) -> int:
        """
        Returns the usage weight for a model when calculating grouped usage.

        Models with higher weights contribute more to the combined group usage.
        This accounts for models that consume more quota per request.

        Args:
            model: Model name (with or without provider prefix)

        Returns:
            Weight multiplier (default 1 if not configured)
        """
        # Strip provider prefix if present
        clean_model = model.split("/")[-1] if "/" in model else model
        return self.model_usage_weights.get(clean_model, 1)

    # =========================================================================
    # BACKGROUND JOB INTERFACE - Override in subclass for periodic tasks
    # =========================================================================

    def get_background_job_config(self) -> Optional[Dict[str, Any]]:
        """
        Return configuration for provider-specific background job, or None if none.

        Providers that need periodic background tasks (e.g., quota refresh,
        cache cleanup) should override this method.

        The BackgroundRefresher will call run_background_job() at the specified
        interval for each provider that returns a config.

        Returns:
            None if no background job, otherwise:
            {
                "interval": 300,  # seconds between runs
                "name": "my_job",  # for logging (e.g., "quota_refresh")
                "run_on_start": True,  # whether to run immediately at startup
            }
        """
        return None

    async def run_background_job(
        self,
        usage_manager: "UsageManager",
        credentials: List[str],
    ) -> None:
        """
        Execute the provider's periodic background job.

        Called by BackgroundRefresher at the interval specified in
        get_background_job_config(). Override this method to implement
        provider-specific periodic tasks.

        Args:
            usage_manager: UsageManager instance for storing/reading usage data
            credentials: List of credential paths for this provider
        """
        pass