import os import re import shutil import logging from pathlib import Path from typing import Dict, List, Optional, Set, Union from .utils.paths import get_oauth_dir lib_logger = logging.getLogger("rotator_library") # Standard directories where tools like `gemini login` store credentials. DEFAULT_OAUTH_DIRS = { "gemini_cli": Path.home() / ".gemini", "qwen_code": Path.home() / ".qwen", "iflow": Path.home() / ".iflow", "antigravity": Path.home() / ".antigravity", # Add other providers like 'claude' here if they have a standard CLI path } # OAuth providers that support environment variable-based credentials # Maps provider name to the ENV_PREFIX used by the provider ENV_OAUTH_PROVIDERS = { "gemini_cli": "GEMINI_CLI", "antigravity": "ANTIGRAVITY", "qwen_code": "QWEN_CODE", "iflow": "IFLOW", } class CredentialManager: """ Discovers OAuth credential files from standard locations, copies them locally, and updates the configuration to use the local paths. Also discovers environment variable-based OAuth credentials for stateless deployments. Supports two env var formats: 1. Single credential (legacy): PROVIDER_ACCESS_TOKEN, PROVIDER_REFRESH_TOKEN 2. Multiple credentials (numbered): PROVIDER_1_ACCESS_TOKEN, PROVIDER_2_ACCESS_TOKEN, etc. When env-based credentials are detected, virtual paths like "env://provider/1" are created. """ def __init__( self, env_vars: Dict[str, str], oauth_dir: Optional[Union[Path, str]] = None, ): """ Initialize the CredentialManager. Args: env_vars: Dictionary of environment variables (typically os.environ). oauth_dir: Directory for storing OAuth credentials. If None, uses get_oauth_dir() which respects EXE vs script mode. """ self.env_vars = env_vars self.oauth_base_dir = Path(oauth_dir) if oauth_dir else get_oauth_dir() self.oauth_base_dir.mkdir(parents=True, exist_ok=True) def _discover_env_oauth_credentials(self) -> Dict[str, List[str]]: """ Discover OAuth credentials defined via environment variables. Supports two formats: 1. Single credential: ANTIGRAVITY_ACCESS_TOKEN + ANTIGRAVITY_REFRESH_TOKEN 2. Multiple credentials: ANTIGRAVITY_1_ACCESS_TOKEN + ANTIGRAVITY_1_REFRESH_TOKEN, etc. Returns: Dict mapping provider name to list of virtual paths (e.g., "env://antigravity/1") """ env_credentials: Dict[str, Set[str]] = {} for provider, env_prefix in ENV_OAUTH_PROVIDERS.items(): found_indices: Set[str] = set() # Check for numbered credentials (PROVIDER_N_ACCESS_TOKEN pattern) # Pattern: ANTIGRAVITY_1_ACCESS_TOKEN, ANTIGRAVITY_2_ACCESS_TOKEN, etc. numbered_pattern = re.compile(rf"^{env_prefix}_(\d+)_ACCESS_TOKEN$") for key in self.env_vars.keys(): match = numbered_pattern.match(key) if match: index = match.group(1) # Verify refresh token also exists refresh_key = f"{env_prefix}_{index}_REFRESH_TOKEN" if refresh_key in self.env_vars and self.env_vars[refresh_key]: found_indices.add(index) # Check for legacy single credential (PROVIDER_ACCESS_TOKEN pattern) # Only use this if no numbered credentials exist if not found_indices: access_key = f"{env_prefix}_ACCESS_TOKEN" refresh_key = f"{env_prefix}_REFRESH_TOKEN" if ( access_key in self.env_vars and self.env_vars[access_key] and refresh_key in self.env_vars and self.env_vars[refresh_key] ): # Use "0" as the index for legacy single credential found_indices.add("0") if found_indices: env_credentials[provider] = found_indices lib_logger.info( f"Found {len(found_indices)} env-based credential(s) for {provider}" ) # Convert to virtual paths result: Dict[str, List[str]] = {} for provider, indices in env_credentials.items(): # Sort indices numerically for consistent ordering sorted_indices = sorted(indices, key=lambda x: int(x)) result[provider] = [f"env://{provider}/{idx}" for idx in sorted_indices] return result def discover_and_prepare(self) -> Dict[str, List[str]]: lib_logger.info("Starting automated OAuth credential discovery...") final_config = {} # PHASE 1: Discover environment variable-based OAuth credentials # These take priority for stateless deployments env_oauth_creds = self._discover_env_oauth_credentials() for provider, virtual_paths in env_oauth_creds.items(): lib_logger.info( f"Using {len(virtual_paths)} env-based credential(s) for {provider}" ) final_config[provider] = virtual_paths # Extract OAuth file paths from environment variables env_oauth_paths = {} for key, value in self.env_vars.items(): if "_OAUTH_" in key: provider = key.split("_OAUTH_")[0].lower() if provider not in env_oauth_paths: env_oauth_paths[provider] = [] if value: # Only consider non-empty values env_oauth_paths[provider].append(value) # PHASE 2: Discover file-based OAuth credentials for provider, default_dir in DEFAULT_OAUTH_DIRS.items(): # Skip if already discovered from environment variables if provider in final_config: lib_logger.debug( f"Skipping file discovery for {provider} - using env-based credentials" ) continue # Check for existing local credentials first. If found, use them and skip discovery. local_provider_creds = sorted( list(self.oauth_base_dir.glob(f"{provider}_oauth_*.json")) ) if local_provider_creds: lib_logger.info( f"Found {len(local_provider_creds)} existing local credential(s) for {provider}. Skipping discovery." ) final_config[provider] = [ str(p.resolve()) for p in local_provider_creds ] continue # If no local credentials exist, proceed with a one-time discovery and copy. discovered_paths = set() # 1. Add paths from environment variables first, as they are overrides for path_str in env_oauth_paths.get(provider, []): path = Path(path_str).expanduser() if path.exists(): discovered_paths.add(path) # 2. If no overrides are provided via .env, scan the default directory # [MODIFIED] This logic is now disabled to prefer local-first credential management. # if not discovered_paths and default_dir.exists(): # for json_file in default_dir.glob('*.json'): # discovered_paths.add(json_file) if not discovered_paths: lib_logger.debug(f"No credential files found for provider: {provider}") continue prepared_paths = [] # Sort paths to ensure consistent numbering for the initial copy for i, source_path in enumerate(sorted(list(discovered_paths))): account_id = i + 1 local_filename = f"{provider}_oauth_{account_id}.json" local_path = self.oauth_base_dir / local_filename try: # Since we've established no local files exist, we can copy directly. shutil.copy(source_path, local_path) lib_logger.info( f"Copied '{source_path.name}' to local pool at '{local_path}'." ) prepared_paths.append(str(local_path.resolve())) except Exception as e: lib_logger.error( f"Failed to process OAuth file from '{source_path}': {e}" ) if prepared_paths: lib_logger.info( f"Discovered and prepared {len(prepared_paths)} credential(s) for provider: {provider}" ) final_config[provider] = prepared_paths lib_logger.info("OAuth credential discovery complete.") return final_config