llm-api-proxy / src /rotator_library /credential_manager.py
Mirrowel
refactor(core): 🔨 centralize path management for PyInstaller compatibility
467f294
import os
import re
import shutil
import logging
from pathlib import Path
from typing import Dict, List, Optional, Set, Union
from .utils.paths import get_oauth_dir
lib_logger = logging.getLogger("rotator_library")
# Standard directories where tools like `gemini login` store credentials.
DEFAULT_OAUTH_DIRS = {
"gemini_cli": Path.home() / ".gemini",
"qwen_code": Path.home() / ".qwen",
"iflow": Path.home() / ".iflow",
"antigravity": Path.home() / ".antigravity",
# Add other providers like 'claude' here if they have a standard CLI path
}
# OAuth providers that support environment variable-based credentials
# Maps provider name to the ENV_PREFIX used by the provider
ENV_OAUTH_PROVIDERS = {
"gemini_cli": "GEMINI_CLI",
"antigravity": "ANTIGRAVITY",
"qwen_code": "QWEN_CODE",
"iflow": "IFLOW",
}
class CredentialManager:
"""
Discovers OAuth credential files from standard locations, copies them locally,
and updates the configuration to use the local paths.
Also discovers environment variable-based OAuth credentials for stateless deployments.
Supports two env var formats:
1. Single credential (legacy): PROVIDER_ACCESS_TOKEN, PROVIDER_REFRESH_TOKEN
2. Multiple credentials (numbered): PROVIDER_1_ACCESS_TOKEN, PROVIDER_2_ACCESS_TOKEN, etc.
When env-based credentials are detected, virtual paths like "env://provider/1" are created.
"""
def __init__(
self,
env_vars: Dict[str, str],
oauth_dir: Optional[Union[Path, str]] = None,
):
"""
Initialize the CredentialManager.
Args:
env_vars: Dictionary of environment variables (typically os.environ).
oauth_dir: Directory for storing OAuth credentials.
If None, uses get_oauth_dir() which respects EXE vs script mode.
"""
self.env_vars = env_vars
self.oauth_base_dir = Path(oauth_dir) if oauth_dir else get_oauth_dir()
self.oauth_base_dir.mkdir(parents=True, exist_ok=True)
def _discover_env_oauth_credentials(self) -> Dict[str, List[str]]:
"""
Discover OAuth credentials defined via environment variables.
Supports two formats:
1. Single credential: ANTIGRAVITY_ACCESS_TOKEN + ANTIGRAVITY_REFRESH_TOKEN
2. Multiple credentials: ANTIGRAVITY_1_ACCESS_TOKEN + ANTIGRAVITY_1_REFRESH_TOKEN, etc.
Returns:
Dict mapping provider name to list of virtual paths (e.g., "env://antigravity/1")
"""
env_credentials: Dict[str, Set[str]] = {}
for provider, env_prefix in ENV_OAUTH_PROVIDERS.items():
found_indices: Set[str] = set()
# Check for numbered credentials (PROVIDER_N_ACCESS_TOKEN pattern)
# Pattern: ANTIGRAVITY_1_ACCESS_TOKEN, ANTIGRAVITY_2_ACCESS_TOKEN, etc.
numbered_pattern = re.compile(rf"^{env_prefix}_(\d+)_ACCESS_TOKEN$")
for key in self.env_vars.keys():
match = numbered_pattern.match(key)
if match:
index = match.group(1)
# Verify refresh token also exists
refresh_key = f"{env_prefix}_{index}_REFRESH_TOKEN"
if refresh_key in self.env_vars and self.env_vars[refresh_key]:
found_indices.add(index)
# Check for legacy single credential (PROVIDER_ACCESS_TOKEN pattern)
# Only use this if no numbered credentials exist
if not found_indices:
access_key = f"{env_prefix}_ACCESS_TOKEN"
refresh_key = f"{env_prefix}_REFRESH_TOKEN"
if (
access_key in self.env_vars
and self.env_vars[access_key]
and refresh_key in self.env_vars
and self.env_vars[refresh_key]
):
# Use "0" as the index for legacy single credential
found_indices.add("0")
if found_indices:
env_credentials[provider] = found_indices
lib_logger.info(
f"Found {len(found_indices)} env-based credential(s) for {provider}"
)
# Convert to virtual paths
result: Dict[str, List[str]] = {}
for provider, indices in env_credentials.items():
# Sort indices numerically for consistent ordering
sorted_indices = sorted(indices, key=lambda x: int(x))
result[provider] = [f"env://{provider}/{idx}" for idx in sorted_indices]
return result
def discover_and_prepare(self) -> Dict[str, List[str]]:
lib_logger.info("Starting automated OAuth credential discovery...")
final_config = {}
# PHASE 1: Discover environment variable-based OAuth credentials
# These take priority for stateless deployments
env_oauth_creds = self._discover_env_oauth_credentials()
for provider, virtual_paths in env_oauth_creds.items():
lib_logger.info(
f"Using {len(virtual_paths)} env-based credential(s) for {provider}"
)
final_config[provider] = virtual_paths
# Extract OAuth file paths from environment variables
env_oauth_paths = {}
for key, value in self.env_vars.items():
if "_OAUTH_" in key:
provider = key.split("_OAUTH_")[0].lower()
if provider not in env_oauth_paths:
env_oauth_paths[provider] = []
if value: # Only consider non-empty values
env_oauth_paths[provider].append(value)
# PHASE 2: Discover file-based OAuth credentials
for provider, default_dir in DEFAULT_OAUTH_DIRS.items():
# Skip if already discovered from environment variables
if provider in final_config:
lib_logger.debug(
f"Skipping file discovery for {provider} - using env-based credentials"
)
continue
# Check for existing local credentials first. If found, use them and skip discovery.
local_provider_creds = sorted(
list(self.oauth_base_dir.glob(f"{provider}_oauth_*.json"))
)
if local_provider_creds:
lib_logger.info(
f"Found {len(local_provider_creds)} existing local credential(s) for {provider}. Skipping discovery."
)
final_config[provider] = [
str(p.resolve()) for p in local_provider_creds
]
continue
# If no local credentials exist, proceed with a one-time discovery and copy.
discovered_paths = set()
# 1. Add paths from environment variables first, as they are overrides
for path_str in env_oauth_paths.get(provider, []):
path = Path(path_str).expanduser()
if path.exists():
discovered_paths.add(path)
# 2. If no overrides are provided via .env, scan the default directory
# [MODIFIED] This logic is now disabled to prefer local-first credential management.
# if not discovered_paths and default_dir.exists():
# for json_file in default_dir.glob('*.json'):
# discovered_paths.add(json_file)
if not discovered_paths:
lib_logger.debug(f"No credential files found for provider: {provider}")
continue
prepared_paths = []
# Sort paths to ensure consistent numbering for the initial copy
for i, source_path in enumerate(sorted(list(discovered_paths))):
account_id = i + 1
local_filename = f"{provider}_oauth_{account_id}.json"
local_path = self.oauth_base_dir / local_filename
try:
# Since we've established no local files exist, we can copy directly.
shutil.copy(source_path, local_path)
lib_logger.info(
f"Copied '{source_path.name}' to local pool at '{local_path}'."
)
prepared_paths.append(str(local_path.resolve()))
except Exception as e:
lib_logger.error(
f"Failed to process OAuth file from '{source_path}': {e}"
)
if prepared_paths:
lib_logger.info(
f"Discovered and prepared {len(prepared_paths)} credential(s) for provider: {provider}"
)
final_config[provider] = prepared_paths
lib_logger.info("OAuth credential discovery complete.")
return final_config