llm-proxy-rotate / src /rotator_library /credential_manager.py
bardd's picture
Upload 144 files
260d3dd verified
# SPDX-License-Identifier: LGPL-3.0-only
# Copyright (c) 2026 Mirrowel
import os
import re
import shutil
import logging
from pathlib import Path
from typing import Dict, List, Optional, Set, Union
from .utils.paths import get_oauth_dir
lib_logger = logging.getLogger("rotator_library")
# Standard directories where tools like `gemini login` store credentials.
DEFAULT_OAUTH_DIRS = {
"gemini_cli": Path.home() / ".gemini",
"qwen_code": Path.home() / ".qwen",
"iflow": Path.home() / ".iflow",
"antigravity": Path.home() / ".antigravity",
# Add other providers like 'claude' here if they have a standard CLI path
}
# OAuth providers that support environment variable-based credentials
# Maps provider name to the ENV_PREFIX used by the provider
ENV_OAUTH_PROVIDERS = {
"gemini_cli": "GEMINI_CLI",
"antigravity": "ANTIGRAVITY",
"qwen_code": "QWEN_CODE",
"iflow": "IFLOW",
}
class CredentialManager:
"""
Discovers OAuth credential files from standard locations, copies them locally,
and updates the configuration to use the local paths.
Also discovers environment variable-based OAuth credentials for stateless deployments.
Supports two env var formats:
1. Single credential (legacy): PROVIDER_ACCESS_TOKEN, PROVIDER_REFRESH_TOKEN
2. Multiple credentials (numbered): PROVIDER_1_ACCESS_TOKEN, PROVIDER_2_ACCESS_TOKEN, etc.
When env-based credentials are detected, virtual paths like "env://provider/1" are created.
"""
def __init__(
self,
env_vars: Dict[str, str],
oauth_dir: Optional[Union[Path, str]] = None,
):
"""
Initialize the CredentialManager.
Args:
env_vars: Dictionary of environment variables (typically os.environ).
oauth_dir: Directory for storing OAuth credentials.
If None, uses get_oauth_dir() which respects EXE vs script mode.
"""
self.env_vars = env_vars
self.oauth_base_dir = Path(oauth_dir) if oauth_dir else get_oauth_dir()
self.oauth_base_dir.mkdir(parents=True, exist_ok=True)
def _discover_env_oauth_credentials(self) -> Dict[str, List[str]]:
"""
Discover OAuth credentials defined via environment variables.
Supports two formats:
1. Single credential: ANTIGRAVITY_ACCESS_TOKEN + ANTIGRAVITY_REFRESH_TOKEN
2. Multiple credentials: ANTIGRAVITY_1_ACCESS_TOKEN + ANTIGRAVITY_1_REFRESH_TOKEN, etc.
Returns:
Dict mapping provider name to list of virtual paths (e.g., "env://antigravity/1")
"""
env_credentials: Dict[str, Set[str]] = {}
for provider, env_prefix in ENV_OAUTH_PROVIDERS.items():
found_indices: Set[str] = set()
# Check for numbered credentials (PROVIDER_N_ACCESS_TOKEN pattern)
# Pattern: ANTIGRAVITY_1_ACCESS_TOKEN, ANTIGRAVITY_2_ACCESS_TOKEN, etc.
numbered_pattern = re.compile(rf"^{env_prefix}_(\d+)_ACCESS_TOKEN$")
for key in self.env_vars.keys():
match = numbered_pattern.match(key)
if match:
index = match.group(1)
# Verify refresh token also exists
refresh_key = f"{env_prefix}_{index}_REFRESH_TOKEN"
if refresh_key in self.env_vars and self.env_vars[refresh_key]:
found_indices.add(index)
# Check for legacy single credential (PROVIDER_ACCESS_TOKEN pattern)
# Only use this if no numbered credentials exist
if not found_indices:
access_key = f"{env_prefix}_ACCESS_TOKEN"
refresh_key = f"{env_prefix}_REFRESH_TOKEN"
if (
access_key in self.env_vars
and self.env_vars[access_key]
and refresh_key in self.env_vars
and self.env_vars[refresh_key]
):
# Use "0" as the index for legacy single credential
found_indices.add("0")
if found_indices:
env_credentials[provider] = found_indices
lib_logger.info(
f"Found {len(found_indices)} env-based credential(s) for {provider}"
)
# Convert to virtual paths
result: Dict[str, List[str]] = {}
for provider, indices in env_credentials.items():
# Sort indices numerically for consistent ordering
sorted_indices = sorted(indices, key=lambda x: int(x))
result[provider] = [f"env://{provider}/{idx}" for idx in sorted_indices]
return result
def discover_and_prepare(self) -> Dict[str, List[str]]:
lib_logger.info("Starting automated OAuth credential discovery...")
final_config = {}
# PHASE 1: Discover environment variable-based OAuth credentials
# These take priority for stateless deployments
env_oauth_creds = self._discover_env_oauth_credentials()
for provider, virtual_paths in env_oauth_creds.items():
lib_logger.info(
f"Using {len(virtual_paths)} env-based credential(s) for {provider}"
)
final_config[provider] = virtual_paths
# Extract OAuth file paths from environment variables
env_oauth_paths = {}
for key, value in self.env_vars.items():
if "_OAUTH_" in key:
provider = key.split("_OAUTH_")[0].lower()
if provider not in env_oauth_paths:
env_oauth_paths[provider] = []
if value: # Only consider non-empty values
env_oauth_paths[provider].append(value)
# PHASE 2: Discover file-based OAuth credentials
for provider, default_dir in DEFAULT_OAUTH_DIRS.items():
# Skip if already discovered from environment variables
if provider in final_config:
lib_logger.debug(
f"Skipping file discovery for {provider} - using env-based credentials"
)
continue
# Check for existing local credentials first. If found, use them and skip discovery.
local_provider_creds = sorted(
list(self.oauth_base_dir.glob(f"{provider}_oauth_*.json"))
)
if local_provider_creds:
lib_logger.info(
f"Found {len(local_provider_creds)} existing local credential(s) for {provider}. Skipping discovery."
)
final_config[provider] = [
str(p.resolve()) for p in local_provider_creds
]
continue
# If no local credentials exist, proceed with a one-time discovery and copy.
discovered_paths = set()
# 1. Add paths from environment variables first, as they are overrides
for path_str in env_oauth_paths.get(provider, []):
path = Path(path_str).expanduser()
if path.exists():
discovered_paths.add(path)
# 2. If no overrides are provided via .env, scan the default directory
# [MODIFIED] This logic is now disabled to prefer local-first credential management.
# if not discovered_paths and default_dir.exists():
# for json_file in default_dir.glob('*.json'):
# discovered_paths.add(json_file)
if not discovered_paths:
lib_logger.debug(f"No credential files found for provider: {provider}")
continue
prepared_paths = []
# Sort paths to ensure consistent numbering for the initial copy
for i, source_path in enumerate(sorted(list(discovered_paths))):
account_id = i + 1
local_filename = f"{provider}_oauth_{account_id}.json"
local_path = self.oauth_base_dir / local_filename
try:
# Since we've established no local files exist, we can copy directly.
shutil.copy(source_path, local_path)
lib_logger.info(
f"Copied '{source_path.name}' to local pool at '{local_path}'."
)
prepared_paths.append(str(local_path.resolve()))
except Exception as e:
lib_logger.error(
f"Failed to process OAuth file from '{source_path}': {e}"
)
if prepared_paths:
lib_logger.info(
f"Discovered and prepared {len(prepared_paths)} credential(s) for provider: {provider}"
)
final_config[provider] = prepared_paths
lib_logger.info("OAuth credential discovery complete.")
return final_config