# src/rotator_library/background_refresher.py import os import asyncio import logging from typing import TYPE_CHECKING, Optional, Dict, Any, List if TYPE_CHECKING: from .client import RotatingClient lib_logger = logging.getLogger("rotator_library") class BackgroundRefresher: """ A background task manager that handles: 1. Periodic OAuth token refresh for all providers 2. Provider-specific background jobs (e.g., quota refresh) with independent timers Each provider can define its own background job via get_background_job_config() and run_background_job(). These run on their own schedules, independent of the OAuth refresh interval. """ def __init__(self, client: "RotatingClient"): self._client = client self._task: Optional[asyncio.Task] = None self._provider_job_tasks: Dict[str, asyncio.Task] = {} # provider -> task self._initialized = False try: interval_str = os.getenv("OAUTH_REFRESH_INTERVAL", "600") self._interval = int(interval_str) except ValueError: lib_logger.warning( f"Invalid OAUTH_REFRESH_INTERVAL '{interval_str}'. Falling back to 600s." ) self._interval = 600 def start(self): """Starts the background refresh task.""" if self._task is None: self._task = asyncio.create_task(self._run()) lib_logger.info( f"Background token refresher started. Check interval: {self._interval} seconds." ) async def stop(self): """Stops all background tasks (main loop + provider jobs).""" # Cancel provider job tasks first for provider, task in self._provider_job_tasks.items(): if task and not task.done(): task.cancel() try: await task except asyncio.CancelledError: pass lib_logger.debug(f"Stopped background job for '{provider}'") self._provider_job_tasks.clear() # Cancel main task if self._task: self._task.cancel() try: await self._task except asyncio.CancelledError: pass lib_logger.info("Background token refresher stopped.") async def _initialize_credentials(self): """ Initialize all providers by loading credentials and persisted tier data. Called once before the main refresh loop starts. """ if self._initialized: return api_summary = {} # provider -> count oauth_summary = {} # provider -> {"count": N, "tiers": {tier: count}} all_credentials = self._client.all_credentials oauth_providers = self._client.oauth_providers for provider, credentials in all_credentials.items(): if not credentials: continue provider_plugin = self._client._get_provider_instance(provider) # Call initialize_credentials if provider supports it if provider_plugin and hasattr(provider_plugin, "initialize_credentials"): try: await provider_plugin.initialize_credentials(credentials) except Exception as e: lib_logger.error( f"Error initializing credentials for provider '{provider}': {e}" ) # Build summary based on provider type if provider in oauth_providers: tier_breakdown = {} if provider_plugin and hasattr( provider_plugin, "get_credential_tier_name" ): for cred in credentials: tier = provider_plugin.get_credential_tier_name(cred) if tier: tier_breakdown[tier] = tier_breakdown.get(tier, 0) + 1 oauth_summary[provider] = { "count": len(credentials), "tiers": tier_breakdown, } else: api_summary[provider] = len(credentials) # Log 3-line summary total_providers = len(api_summary) + len(oauth_summary) total_credentials = sum(api_summary.values()) + sum( d["count"] for d in oauth_summary.values() ) if total_providers > 0: lib_logger.info( f"Providers initialized: {total_providers} providers, {total_credentials} credentials" ) # API providers line if api_summary: api_parts = [f"{p}:{c}" for p, c in sorted(api_summary.items())] lib_logger.info(f" API: {', '.join(api_parts)}") # OAuth providers line with tier breakdown if oauth_summary: oauth_parts = [] for provider, data in sorted(oauth_summary.items()): if data["tiers"]: tier_str = ", ".join( f"{t}:{c}" for t, c in sorted(data["tiers"].items()) ) oauth_parts.append(f"{provider}:{data['count']} ({tier_str})") else: oauth_parts.append(f"{provider}:{data['count']}") lib_logger.info(f" OAuth: {', '.join(oauth_parts)}") self._initialized = True def _start_provider_background_jobs(self): """ Start independent background job tasks for providers that define them. Each provider with a get_background_job_config() that returns a config gets its own asyncio task running on its own schedule. """ all_credentials = self._client.all_credentials for provider, credentials in all_credentials.items(): if not credentials: continue provider_plugin = self._client._get_provider_instance(provider) if not provider_plugin: continue # Check if provider has a background job if not hasattr(provider_plugin, "get_background_job_config"): continue config = provider_plugin.get_background_job_config() if not config: continue # Start the provider's background job task task = asyncio.create_task( self._run_provider_background_job( provider, provider_plugin, credentials, config ) ) self._provider_job_tasks[provider] = task job_name = config.get("name", "background_job") interval = config.get("interval", 300) lib_logger.info(f"Started {provider} {job_name} (interval: {interval}s)") async def _run_provider_background_job( self, provider_name: str, provider: Any, credentials: List[str], config: Dict[str, Any], ) -> None: """ Independent loop for a single provider's background job. Args: provider_name: Name of the provider (for logging) provider: Provider plugin instance credentials: List of credential paths for this provider config: Background job configuration from get_background_job_config() """ interval = config.get("interval", 300) job_name = config.get("name", "background_job") run_on_start = config.get("run_on_start", True) # Run immediately on start if configured if run_on_start: try: await provider.run_background_job( self._client.usage_manager, credentials ) lib_logger.debug(f"{provider_name} {job_name}: initial run complete") except Exception as e: lib_logger.error( f"Error in {provider_name} {job_name} (initial run): {e}" ) # Main loop while True: try: await asyncio.sleep(interval) await provider.run_background_job( self._client.usage_manager, credentials ) lib_logger.debug(f"{provider_name} {job_name}: periodic run complete") except asyncio.CancelledError: lib_logger.debug(f"{provider_name} {job_name}: cancelled") break except Exception as e: lib_logger.error(f"Error in {provider_name} {job_name}: {e}") async def _run(self): """The main loop for OAuth token refresh.""" # Initialize credentials (load persisted tiers) before starting await self._initialize_credentials() # Start provider-specific background jobs with their own timers self._start_provider_background_jobs() # Main OAuth refresh loop while True: try: oauth_configs = self._client.get_oauth_credentials() for provider, paths in oauth_configs.items(): provider_plugin = self._client._get_provider_instance(provider) if provider_plugin and hasattr( provider_plugin, "proactively_refresh" ): for path in paths: try: await provider_plugin.proactively_refresh(path) except Exception as e: lib_logger.error( f"Error during proactive refresh for '{path}': {e}" ) await asyncio.sleep(self._interval) except asyncio.CancelledError: break except Exception as e: lib_logger.error(f"Unexpected error in background refresher loop: {e}")