Mirrowel commited on
Commit
9bcb7cb
·
1 Parent(s): 3f8f2ac

refactor(auth): implement two-pass OAuth credential deduplication

Browse files

Refactors the OAuth token initialization process to proactively handle duplicate credentials based on user email across providers.

The new flow operates in two passes:
- Pass 1 (Pre-scan): Reads metadata from files to identify and skip duplicates based on existing email tags before initialization.
- Pass 2 (Post-init): Initializes unique tokens, retrieves user email via `get_user_info`, and performs a final deduplication check, updating metadata only for unique credentials.

Additionally:
- Introduce automatic refresh token usage for Gemini and Qwen providers when tokens are expired, reducing interactive login necessity.
- Disable automatic scanning of default credential directories in `credential_manager` to prefer locally defined or explicit paths.

src/proxy_app/main.py CHANGED
@@ -163,42 +163,98 @@ async def lifespan(app: FastAPI):
163
  raise ValueError("No provider API keys or OAuth credentials found.")
164
 
165
  if not skip_oauth_init and oauth_credentials:
166
- logging.info("Validating OAuth credentials and checking for duplicates...")
167
- processed_emails = {} # email -> {provider: path}
 
 
 
 
 
168
  for provider, paths in oauth_credentials.items():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  provider_plugin_class = PROVIDER_PLUGINS.get(provider)
170
  if not provider_plugin_class: continue
171
 
172
  provider_instance = provider_plugin_class()
 
173
  for path in paths:
174
  try:
175
  await provider_instance.initialize_token(path)
176
- if hasattr(provider_instance, 'get_user_info'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  with open(path, 'r+') as f:
178
  data = json.load(f)
179
  metadata = data.get("_proxy_metadata", {})
180
- last_check = metadata.get("last_check_timestamp", 0)
181
- if time.time() - last_check > 86400:
182
- user_info = await provider_instance.get_user_info(path)
183
- metadata["email"] = user_info.get("email")
184
- metadata["last_check_timestamp"] = time.time()
185
- data["_proxy_metadata"] = metadata
186
- f.seek(0)
187
- json.dump(data, f, indent=2)
188
- f.truncate()
189
-
190
- email = metadata.get("email")
191
- if email:
192
- if email not in processed_emails:
193
- processed_emails[email] = {}
194
- if provider in processed_emails[email]:
195
- original_path = processed_emails[email][provider]
196
- logging.warning(f"Duplicate credential for user '{email}' on provider '{provider}' found at '{Path(path).name}'. Original at '{Path(original_path).name}'.")
197
- else:
198
- processed_emails[email][provider] = path
199
  except Exception as e:
200
  logging.error(f"Failed to process OAuth token for {provider} at '{path}': {e}")
 
201
  logging.info("OAuth credential processing complete.")
 
202
 
203
  # [NEW] Load provider-specific params
204
  litellm_provider_params = {
 
163
  raise ValueError("No provider API keys or OAuth credentials found.")
164
 
165
  if not skip_oauth_init and oauth_credentials:
166
+ logging.info("Starting OAuth credential validation and deduplication...")
167
+ processed_emails = {} # email -> {provider: path}
168
+ credentials_to_initialize = {} # provider -> [paths]
169
+ final_oauth_credentials = {}
170
+
171
+ # --- Pass 1: Pre-initialization Scan & Deduplication ---
172
+ #logging.info("Pass 1: Scanning for existing metadata to find duplicates...")
173
  for provider, paths in oauth_credentials.items():
174
+ if provider not in credentials_to_initialize:
175
+ credentials_to_initialize[provider] = []
176
+ for path in paths:
177
+ try:
178
+ with open(path, 'r') as f:
179
+ data = json.load(f)
180
+ metadata = data.get("_proxy_metadata", {})
181
+ email = metadata.get("email")
182
+
183
+ if email:
184
+ if email not in processed_emails:
185
+ processed_emails[email] = {}
186
+
187
+ if provider in processed_emails[email]:
188
+ original_path = processed_emails[email][provider]
189
+ logging.warning(f"Duplicate for '{email}' on '{provider}' found in pre-scan: '{Path(path).name}'. Original: '{Path(original_path).name}'. Skipping.")
190
+ continue
191
+ else:
192
+ processed_emails[email][provider] = path
193
+
194
+ credentials_to_initialize[provider].append(path)
195
+
196
+ except (FileNotFoundError, json.JSONDecodeError) as e:
197
+ logging.warning(f"Could not pre-read metadata from '{path}': {e}. Will process during initialization.")
198
+ credentials_to_initialize[provider].append(path)
199
+
200
+ # --- Pass 2: Initialization of Filtered Credentials & Final Check ---
201
+ #logging.info("Pass 2: Initializing unique credentials and performing final check...")
202
+ for provider, paths in credentials_to_initialize.items():
203
+ if not paths: continue
204
+
205
  provider_plugin_class = PROVIDER_PLUGINS.get(provider)
206
  if not provider_plugin_class: continue
207
 
208
  provider_instance = provider_plugin_class()
209
+
210
  for path in paths:
211
  try:
212
  await provider_instance.initialize_token(path)
213
+
214
+ if not hasattr(provider_instance, 'get_user_info'):
215
+ if provider not in final_oauth_credentials:
216
+ final_oauth_credentials[provider] = []
217
+ final_oauth_credentials[provider].append(path)
218
+ continue
219
+
220
+ user_info = await provider_instance.get_user_info(path)
221
+ email = user_info.get("email")
222
+
223
+ if not email:
224
+ logging.warning(f"Could not retrieve email for '{path}'. Treating as unique.")
225
+ if provider not in final_oauth_credentials:
226
+ final_oauth_credentials[provider] = []
227
+ final_oauth_credentials[provider].append(path)
228
+ continue
229
+
230
+ if email not in processed_emails:
231
+ processed_emails[email] = {}
232
+
233
+ if provider in processed_emails[email] and processed_emails[email][provider] != path:
234
+ original_path = processed_emails[email][provider]
235
+ logging.warning(f"Duplicate for '{email}' on '{provider}' found post-init: '{Path(path).name}'. Original: '{Path(original_path).name}'. Skipping.")
236
+ continue
237
+ else:
238
+ processed_emails[email][provider] = path
239
+ if provider not in final_oauth_credentials:
240
+ final_oauth_credentials[provider] = []
241
+ final_oauth_credentials[provider].append(path)
242
+
243
  with open(path, 'r+') as f:
244
  data = json.load(f)
245
  metadata = data.get("_proxy_metadata", {})
246
+ metadata["email"] = email
247
+ metadata["last_check_timestamp"] = time.time()
248
+ data["_proxy_metadata"] = metadata
249
+ f.seek(0)
250
+ json.dump(data, f, indent=2)
251
+ f.truncate()
252
+
 
 
 
 
 
 
 
 
 
 
 
 
253
  except Exception as e:
254
  logging.error(f"Failed to process OAuth token for {provider} at '{path}': {e}")
255
+
256
  logging.info("OAuth credential processing complete.")
257
+ oauth_credentials = final_oauth_credentials
258
 
259
  # [NEW] Load provider-specific params
260
  litellm_provider_params = {
src/rotator_library/credential_manager.py CHANGED
@@ -56,9 +56,10 @@ class CredentialManager:
56
  discovered_paths.add(path)
57
 
58
  # 2. If no overrides are provided via .env, scan the default directory
59
- if not discovered_paths and default_dir.exists():
60
- for json_file in default_dir.glob('*.json'):
61
- discovered_paths.add(json_file)
 
62
 
63
  if not discovered_paths:
64
  lib_logger.debug(f"No credential files found for provider: {provider}")
 
56
  discovered_paths.add(path)
57
 
58
  # 2. If no overrides are provided via .env, scan the default directory
59
+ # [MODIFIED] This logic is now disabled to prefer local-first credential management.
60
+ # if not discovered_paths and default_dir.exists():
61
+ # for json_file in default_dir.glob('*.json'):
62
+ # discovered_paths.add(json_file)
63
 
64
  if not discovered_paths:
65
  lib_logger.debug(f"No credential files found for provider: {provider}")
src/rotator_library/providers/gemini_auth_base.py CHANGED
@@ -122,6 +122,12 @@ class GeminiAuthBase:
122
  reason = "token is expired"
123
 
124
  if reason:
 
 
 
 
 
 
125
  lib_logger.warning(f"Gemini OAuth token for '{file_name}' needs setup: {reason}.")
126
  auth_code_future = asyncio.get_event_loop().create_future()
127
  server = None
 
122
  reason = "token is expired"
123
 
124
  if reason:
125
+ if reason == "token is expired" and creds.get("refresh_token"):
126
+ try:
127
+ return await self._refresh_token(path, creds)
128
+ except Exception as e:
129
+ lib_logger.warning(f"Automatic token refresh for '{file_name}' failed: {e}. Proceeding to interactive login.")
130
+
131
  lib_logger.warning(f"Gemini OAuth token for '{file_name}' needs setup: {reason}.")
132
  auth_code_future = asyncio.get_event_loop().create_future()
133
  server = None
src/rotator_library/providers/qwen_auth_base.py CHANGED
@@ -141,6 +141,12 @@ class QwenAuthBase:
141
  reason = "token is expired"
142
 
143
  if reason:
 
 
 
 
 
 
144
  lib_logger.warning(f"Qwen OAuth token for '{file_name}' needs setup: {reason}.")
145
  code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode('utf-8').rstrip('=')
146
  code_challenge = base64.urlsafe_b64encode(
 
141
  reason = "token is expired"
142
 
143
  if reason:
144
+ if reason == "token is expired" and creds.get("refresh_token"):
145
+ try:
146
+ return await self._refresh_token(path)
147
+ except Exception as e:
148
+ lib_logger.warning(f"Automatic token refresh for '{file_name}' failed: {e}. Proceeding to interactive login.")
149
+
150
  lib_logger.warning(f"Qwen OAuth token for '{file_name}' needs setup: {reason}.")
151
  code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode('utf-8').rstrip('=')
152
  code_challenge = base64.urlsafe_b64encode(