File size: 8,852 Bytes
d28e7c5
5bc49f2
d28e7c5
 
 
467f294
d28e7c5
467f294
d28e7c5
467f294
d28e7c5
4600c4d
 
 
 
5ee03d6
97f1950
d28e7c5
 
 
5bc49f2
 
 
 
 
 
 
 
 
 
d28e7c5
 
 
 
467f294
5bc49f2
 
467f294
5bc49f2
 
467f294
5bc49f2
d28e7c5
467f294
 
 
 
 
 
 
 
 
 
 
 
 
 
4600c4d
467f294
 
d28e7c5
5bc49f2
 
 
467f294
5bc49f2
 
 
467f294
5bc49f2
 
 
 
467f294
5bc49f2
 
467f294
5bc49f2
 
 
467f294
5bc49f2
 
 
 
 
 
 
 
467f294
5bc49f2
 
 
 
 
467f294
 
 
 
 
 
5bc49f2
 
467f294
5bc49f2
 
467f294
 
 
 
5bc49f2
 
 
 
 
 
467f294
5bc49f2
 
d28e7c5
4600c4d
 
 
5bc49f2
 
 
 
467f294
 
 
5bc49f2
 
 
4600c4d
 
 
 
 
 
467f294
4600c4d
 
5bc49f2
4600c4d
5bc49f2
 
467f294
 
 
5bc49f2
467f294
4600c4d
467f294
 
 
4600c4d
467f294
 
 
 
 
 
4600c4d
d28e7c5
4600c4d
 
 
 
 
 
 
 
467f294
4600c4d
9bcb7cb
 
 
 
467f294
4600c4d
 
 
 
 
 
 
 
d28e7c5
467f294
d28e7c5
4600c4d
 
 
467f294
 
 
4600c4d
 
467f294
 
 
 
4600c4d
467f294
 
 
4600c4d
d28e7c5
4600c4d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import os
import re
import shutil
import logging
from pathlib import Path
from typing import Dict, List, Optional, Set, Union

from .utils.paths import get_oauth_dir

lib_logger = logging.getLogger("rotator_library")

# Standard directories where tools like `gemini login` store credentials.
DEFAULT_OAUTH_DIRS = {
    "gemini_cli": Path.home() / ".gemini",
    "qwen_code": Path.home() / ".qwen",
    "iflow": Path.home() / ".iflow",
    "antigravity": Path.home() / ".antigravity",
    # Add other providers like 'claude' here if they have a standard CLI path
}

# OAuth providers that support environment variable-based credentials
# Maps provider name to the ENV_PREFIX used by the provider
ENV_OAUTH_PROVIDERS = {
    "gemini_cli": "GEMINI_CLI",
    "antigravity": "ANTIGRAVITY",
    "qwen_code": "QWEN_CODE",
    "iflow": "IFLOW",
}


class CredentialManager:
    """
    Discovers OAuth credential files from standard locations, copies them locally,
    and updates the configuration to use the local paths.

    Also discovers environment variable-based OAuth credentials for stateless deployments.
    Supports two env var formats:

    1. Single credential (legacy): PROVIDER_ACCESS_TOKEN, PROVIDER_REFRESH_TOKEN
    2. Multiple credentials (numbered): PROVIDER_1_ACCESS_TOKEN, PROVIDER_2_ACCESS_TOKEN, etc.

    When env-based credentials are detected, virtual paths like "env://provider/1" are created.
    """

    def __init__(
        self,
        env_vars: Dict[str, str],
        oauth_dir: Optional[Union[Path, str]] = None,
    ):
        """
        Initialize the CredentialManager.

        Args:
            env_vars: Dictionary of environment variables (typically os.environ).
            oauth_dir: Directory for storing OAuth credentials.
                       If None, uses get_oauth_dir() which respects EXE vs script mode.
        """
        self.env_vars = env_vars
        self.oauth_base_dir = Path(oauth_dir) if oauth_dir else get_oauth_dir()
        self.oauth_base_dir.mkdir(parents=True, exist_ok=True)

    def _discover_env_oauth_credentials(self) -> Dict[str, List[str]]:
        """
        Discover OAuth credentials defined via environment variables.

        Supports two formats:
        1. Single credential: ANTIGRAVITY_ACCESS_TOKEN + ANTIGRAVITY_REFRESH_TOKEN
        2. Multiple credentials: ANTIGRAVITY_1_ACCESS_TOKEN + ANTIGRAVITY_1_REFRESH_TOKEN, etc.

        Returns:
            Dict mapping provider name to list of virtual paths (e.g., "env://antigravity/1")
        """
        env_credentials: Dict[str, Set[str]] = {}

        for provider, env_prefix in ENV_OAUTH_PROVIDERS.items():
            found_indices: Set[str] = set()

            # Check for numbered credentials (PROVIDER_N_ACCESS_TOKEN pattern)
            # Pattern: ANTIGRAVITY_1_ACCESS_TOKEN, ANTIGRAVITY_2_ACCESS_TOKEN, etc.
            numbered_pattern = re.compile(rf"^{env_prefix}_(\d+)_ACCESS_TOKEN$")

            for key in self.env_vars.keys():
                match = numbered_pattern.match(key)
                if match:
                    index = match.group(1)
                    # Verify refresh token also exists
                    refresh_key = f"{env_prefix}_{index}_REFRESH_TOKEN"
                    if refresh_key in self.env_vars and self.env_vars[refresh_key]:
                        found_indices.add(index)

            # Check for legacy single credential (PROVIDER_ACCESS_TOKEN pattern)
            # Only use this if no numbered credentials exist
            if not found_indices:
                access_key = f"{env_prefix}_ACCESS_TOKEN"
                refresh_key = f"{env_prefix}_REFRESH_TOKEN"
                if (
                    access_key in self.env_vars
                    and self.env_vars[access_key]
                    and refresh_key in self.env_vars
                    and self.env_vars[refresh_key]
                ):
                    # Use "0" as the index for legacy single credential
                    found_indices.add("0")

            if found_indices:
                env_credentials[provider] = found_indices
                lib_logger.info(
                    f"Found {len(found_indices)} env-based credential(s) for {provider}"
                )

        # Convert to virtual paths
        result: Dict[str, List[str]] = {}
        for provider, indices in env_credentials.items():
            # Sort indices numerically for consistent ordering
            sorted_indices = sorted(indices, key=lambda x: int(x))
            result[provider] = [f"env://{provider}/{idx}" for idx in sorted_indices]

        return result

    def discover_and_prepare(self) -> Dict[str, List[str]]:
        lib_logger.info("Starting automated OAuth credential discovery...")
        final_config = {}

        # PHASE 1: Discover environment variable-based OAuth credentials
        # These take priority for stateless deployments
        env_oauth_creds = self._discover_env_oauth_credentials()
        for provider, virtual_paths in env_oauth_creds.items():
            lib_logger.info(
                f"Using {len(virtual_paths)} env-based credential(s) for {provider}"
            )
            final_config[provider] = virtual_paths

        # Extract OAuth file paths from environment variables
        env_oauth_paths = {}
        for key, value in self.env_vars.items():
            if "_OAUTH_" in key:
                provider = key.split("_OAUTH_")[0].lower()
                if provider not in env_oauth_paths:
                    env_oauth_paths[provider] = []
                if value:  # Only consider non-empty values
                    env_oauth_paths[provider].append(value)

        # PHASE 2: Discover file-based OAuth credentials
        for provider, default_dir in DEFAULT_OAUTH_DIRS.items():
            # Skip if already discovered from environment variables
            if provider in final_config:
                lib_logger.debug(
                    f"Skipping file discovery for {provider} - using env-based credentials"
                )
                continue

            # Check for existing local credentials first. If found, use them and skip discovery.
            local_provider_creds = sorted(
                list(self.oauth_base_dir.glob(f"{provider}_oauth_*.json"))
            )
            if local_provider_creds:
                lib_logger.info(
                    f"Found {len(local_provider_creds)} existing local credential(s) for {provider}. Skipping discovery."
                )
                final_config[provider] = [
                    str(p.resolve()) for p in local_provider_creds
                ]
                continue

            # If no local credentials exist, proceed with a one-time discovery and copy.
            discovered_paths = set()

            # 1. Add paths from environment variables first, as they are overrides
            for path_str in env_oauth_paths.get(provider, []):
                path = Path(path_str).expanduser()
                if path.exists():
                    discovered_paths.add(path)

            # 2. If no overrides are provided via .env, scan the default directory
            # [MODIFIED] This logic is now disabled to prefer local-first credential management.
            # if not discovered_paths and default_dir.exists():
            #     for json_file in default_dir.glob('*.json'):
            #         discovered_paths.add(json_file)

            if not discovered_paths:
                lib_logger.debug(f"No credential files found for provider: {provider}")
                continue

            prepared_paths = []
            # Sort paths to ensure consistent numbering for the initial copy
            for i, source_path in enumerate(sorted(list(discovered_paths))):
                account_id = i + 1
                local_filename = f"{provider}_oauth_{account_id}.json"
                local_path = self.oauth_base_dir / local_filename

                try:
                    # Since we've established no local files exist, we can copy directly.
                    shutil.copy(source_path, local_path)
                    lib_logger.info(
                        f"Copied '{source_path.name}' to local pool at '{local_path}'."
                    )
                    prepared_paths.append(str(local_path.resolve()))
                except Exception as e:
                    lib_logger.error(
                        f"Failed to process OAuth file from '{source_path}': {e}"
                    )

            if prepared_paths:
                lib_logger.info(
                    f"Discovered and prepared {len(prepared_paths)} credential(s) for provider: {provider}"
                )
                final_config[provider] = prepared_paths

        lib_logger.info("OAuth credential discovery complete.")
        return final_config