Spaces:
Paused
Paused
| # src/rotator_library/background_refresher.py | |
| import os | |
| import asyncio | |
| import logging | |
| from typing import TYPE_CHECKING, Optional, Dict, Any, List | |
| if TYPE_CHECKING: | |
| from .client import RotatingClient | |
| lib_logger = logging.getLogger("rotator_library") | |
| class BackgroundRefresher: | |
| """ | |
| A background task manager that handles: | |
| 1. Periodic OAuth token refresh for all providers | |
| 2. Provider-specific background jobs (e.g., quota refresh) with independent timers | |
| Each provider can define its own background job via get_background_job_config() | |
| and run_background_job(). These run on their own schedules, independent of the | |
| OAuth refresh interval. | |
| """ | |
| def __init__(self, client: "RotatingClient"): | |
| self._client = client | |
| self._task: Optional[asyncio.Task] = None | |
| self._provider_job_tasks: Dict[str, asyncio.Task] = {} # provider -> task | |
| self._initialized = False | |
| try: | |
| interval_str = os.getenv("OAUTH_REFRESH_INTERVAL", "600") | |
| self._interval = int(interval_str) | |
| except ValueError: | |
| lib_logger.warning( | |
| f"Invalid OAUTH_REFRESH_INTERVAL '{interval_str}'. Falling back to 600s." | |
| ) | |
| self._interval = 600 | |
| def start(self): | |
| """Starts the background refresh task.""" | |
| if self._task is None: | |
| self._task = asyncio.create_task(self._run()) | |
| lib_logger.info( | |
| f"Background token refresher started. Check interval: {self._interval} seconds." | |
| ) | |
| async def stop(self): | |
| """Stops all background tasks (main loop + provider jobs).""" | |
| # Cancel provider job tasks first | |
| for provider, task in self._provider_job_tasks.items(): | |
| if task and not task.done(): | |
| task.cancel() | |
| try: | |
| await task | |
| except asyncio.CancelledError: | |
| pass | |
| lib_logger.debug(f"Stopped background job for '{provider}'") | |
| self._provider_job_tasks.clear() | |
| # Cancel main task | |
| if self._task: | |
| self._task.cancel() | |
| try: | |
| await self._task | |
| except asyncio.CancelledError: | |
| pass | |
| lib_logger.info("Background token refresher stopped.") | |
| async def _initialize_credentials(self): | |
| """ | |
| Initialize all providers by loading credentials and persisted tier data. | |
| Called once before the main refresh loop starts. | |
| """ | |
| if self._initialized: | |
| return | |
| api_summary = {} # provider -> count | |
| oauth_summary = {} # provider -> {"count": N, "tiers": {tier: count}} | |
| all_credentials = self._client.all_credentials | |
| oauth_providers = self._client.oauth_providers | |
| for provider, credentials in all_credentials.items(): | |
| if not credentials: | |
| continue | |
| provider_plugin = self._client._get_provider_instance(provider) | |
| # Call initialize_credentials if provider supports it | |
| if provider_plugin and hasattr(provider_plugin, "initialize_credentials"): | |
| try: | |
| await provider_plugin.initialize_credentials(credentials) | |
| except Exception as e: | |
| lib_logger.error( | |
| f"Error initializing credentials for provider '{provider}': {e}" | |
| ) | |
| # Build summary based on provider type | |
| if provider in oauth_providers: | |
| tier_breakdown = {} | |
| if provider_plugin and hasattr( | |
| provider_plugin, "get_credential_tier_name" | |
| ): | |
| for cred in credentials: | |
| tier = provider_plugin.get_credential_tier_name(cred) | |
| if tier: | |
| tier_breakdown[tier] = tier_breakdown.get(tier, 0) + 1 | |
| oauth_summary[provider] = { | |
| "count": len(credentials), | |
| "tiers": tier_breakdown, | |
| } | |
| else: | |
| api_summary[provider] = len(credentials) | |
| # Log 3-line summary | |
| total_providers = len(api_summary) + len(oauth_summary) | |
| total_credentials = sum(api_summary.values()) + sum( | |
| d["count"] for d in oauth_summary.values() | |
| ) | |
| if total_providers > 0: | |
| lib_logger.info( | |
| f"Providers initialized: {total_providers} providers, {total_credentials} credentials" | |
| ) | |
| # API providers line | |
| if api_summary: | |
| api_parts = [f"{p}:{c}" for p, c in sorted(api_summary.items())] | |
| lib_logger.info(f" API: {', '.join(api_parts)}") | |
| # OAuth providers line with tier breakdown | |
| if oauth_summary: | |
| oauth_parts = [] | |
| for provider, data in sorted(oauth_summary.items()): | |
| if data["tiers"]: | |
| tier_str = ", ".join( | |
| f"{t}:{c}" for t, c in sorted(data["tiers"].items()) | |
| ) | |
| oauth_parts.append(f"{provider}:{data['count']} ({tier_str})") | |
| else: | |
| oauth_parts.append(f"{provider}:{data['count']}") | |
| lib_logger.info(f" OAuth: {', '.join(oauth_parts)}") | |
| self._initialized = True | |
| def _start_provider_background_jobs(self): | |
| """ | |
| Start independent background job tasks for providers that define them. | |
| Each provider with a get_background_job_config() that returns a config | |
| gets its own asyncio task running on its own schedule. | |
| """ | |
| all_credentials = self._client.all_credentials | |
| for provider, credentials in all_credentials.items(): | |
| if not credentials: | |
| continue | |
| provider_plugin = self._client._get_provider_instance(provider) | |
| if not provider_plugin: | |
| continue | |
| # Check if provider has a background job | |
| if not hasattr(provider_plugin, "get_background_job_config"): | |
| continue | |
| config = provider_plugin.get_background_job_config() | |
| if not config: | |
| continue | |
| # Start the provider's background job task | |
| task = asyncio.create_task( | |
| self._run_provider_background_job( | |
| provider, provider_plugin, credentials, config | |
| ) | |
| ) | |
| self._provider_job_tasks[provider] = task | |
| job_name = config.get("name", "background_job") | |
| interval = config.get("interval", 300) | |
| lib_logger.info(f"Started {provider} {job_name} (interval: {interval}s)") | |
| async def _run_provider_background_job( | |
| self, | |
| provider_name: str, | |
| provider: Any, | |
| credentials: List[str], | |
| config: Dict[str, Any], | |
| ) -> None: | |
| """ | |
| Independent loop for a single provider's background job. | |
| Args: | |
| provider_name: Name of the provider (for logging) | |
| provider: Provider plugin instance | |
| credentials: List of credential paths for this provider | |
| config: Background job configuration from get_background_job_config() | |
| """ | |
| interval = config.get("interval", 300) | |
| job_name = config.get("name", "background_job") | |
| run_on_start = config.get("run_on_start", True) | |
| # Run immediately on start if configured | |
| if run_on_start: | |
| try: | |
| await provider.run_background_job( | |
| self._client.usage_manager, credentials | |
| ) | |
| lib_logger.debug(f"{provider_name} {job_name}: initial run complete") | |
| except Exception as e: | |
| lib_logger.error( | |
| f"Error in {provider_name} {job_name} (initial run): {e}" | |
| ) | |
| # Main loop | |
| while True: | |
| try: | |
| await asyncio.sleep(interval) | |
| await provider.run_background_job( | |
| self._client.usage_manager, credentials | |
| ) | |
| lib_logger.debug(f"{provider_name} {job_name}: periodic run complete") | |
| except asyncio.CancelledError: | |
| lib_logger.debug(f"{provider_name} {job_name}: cancelled") | |
| break | |
| except Exception as e: | |
| lib_logger.error(f"Error in {provider_name} {job_name}: {e}") | |
| async def _run(self): | |
| """The main loop for OAuth token refresh.""" | |
| # Initialize credentials (load persisted tiers) before starting | |
| await self._initialize_credentials() | |
| # Start provider-specific background jobs with their own timers | |
| self._start_provider_background_jobs() | |
| # Main OAuth refresh loop | |
| while True: | |
| try: | |
| oauth_configs = self._client.get_oauth_credentials() | |
| for provider, paths in oauth_configs.items(): | |
| provider_plugin = self._client._get_provider_instance(provider) | |
| if provider_plugin and hasattr( | |
| provider_plugin, "proactively_refresh" | |
| ): | |
| for path in paths: | |
| try: | |
| await provider_plugin.proactively_refresh(path) | |
| except Exception as e: | |
| lib_logger.error( | |
| f"Error during proactive refresh for '{path}': {e}" | |
| ) | |
| await asyncio.sleep(self._interval) | |
| except asyncio.CancelledError: | |
| break | |
| except Exception as e: | |
| lib_logger.error(f"Unexpected error in background refresher loop: {e}") | |