Mirrowel commited on
Commit
65ec48f
·
1 Parent(s): d28e7c5

feat(auth): implement robust OAuth credential initialization and management

Browse files

- Introduce `initialize_token` methods for Gemini and Qwen providers to facilitate interactive OAuth setup upon first use or invalid tokens.
- Enable skipping of interactive OAuth validation at startup via the `SKIP_OAUTH_INIT_CHECK` environment variable for non-interactive environments.
- Allow customization of the background OAuth token refresh interval using the `OAUTH_REFRESH_INTERVAL` environment variable.
- Enhance Gemini project ID discovery with improved mechanisms, caching, and automatic retry logic for 401 Unauthorized errors.
- Implement Qwen-specific error handling for 'slow_down' responses, mapping them to rate limit exceptions, and add 401 retry logic.
- Update `.env.example` to reflect new configuration options for refresh interval and OAuth setup.

.env.example CHANGED
@@ -21,5 +21,16 @@ GEMINI_CLI_OAUTH_1=
21
  # Required for Gemini CLI: Your Google Cloud Project ID
22
  GEMINI_CLI_PROJECT_ID="gen-lang-client-..."
23
 
 
 
 
 
 
24
  # For Qwen Code (OpenAI Compatible)
25
- QWEN_CODE_OAUTH_1=
 
 
 
 
 
 
 
21
  # Required for Gemini CLI: Your Google Cloud Project ID
22
  GEMINI_CLI_PROJECT_ID="gen-lang-client-..."
23
 
24
+ # For Gemini CLI (uses a custom API)
25
+ GEMINI_CLI_OAUTH_1= # Leave blank to auto-discover from ~/.gemini/oauth_creds.json
26
+ # Optional: Overrides auto-discovery for Gemini CLI project ID
27
+ GEMINI_CLI_PROJECT_ID=
28
+
29
  # For Qwen Code (OpenAI Compatible)
30
+ QWEN_CODE_OAUTH_1= # Leave blank to auto-discover from ~/.qwen/oauth_creds.json
31
+
32
+ # [NEW] Optional: Set background OAuth refresh interval in seconds
33
+ OAUTH_REFRESH_INTERVAL=3600 # Default is 3600 seconds (1 hour)
34
+
35
+ # [NEW] Optional: Skip interactive OAuth validation/setup on startup. Set to "true" for non-interactive environments.
36
+ SKIP_OAUTH_INIT_CHECK=false
src/proxy_app/main.py CHANGED
@@ -163,6 +163,27 @@ for key, value in os.environ.items():
163
  @asynccontextmanager
164
  async def lifespan(app: FastAPI):
165
  """Manage the RotatingClient's lifecycle with the app's lifespan."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  # [NEW] Load provider-specific params
167
  litellm_provider_params = {
168
  "gemini_cli": {"project_id": os.getenv("GEMINI_CLI_PROJECT_ID")}
 
163
  @asynccontextmanager
164
  async def lifespan(app: FastAPI):
165
  """Manage the RotatingClient's lifecycle with the app's lifespan."""
166
+ # [MODIFIED] Perform skippable OAuth initialization at startup
167
+ skip_oauth_init = os.getenv("SKIP_OAUTH_INIT_CHECK", "false").lower() == "true"
168
+
169
+ if not skip_oauth_init:
170
+ logging.info("Performing OAuth credential validation at startup...")
171
+ temp_cred_manager = CredentialManager(oauth_credentials)
172
+ discovered_creds = temp_cred_manager.discover_and_prepare()
173
+
174
+ init_tasks = []
175
+ for provider, paths in discovered_creds.items():
176
+ provider_plugin_class = PROVIDER_PLUGINS.get(provider)
177
+ if provider_plugin_class:
178
+ provider_instance = provider_plugin_class()
179
+ if hasattr(provider_instance, 'initialize_token'):
180
+ for path in paths:
181
+ init_tasks.append(provider_instance.initialize_token(path))
182
+
183
+ if init_tasks:
184
+ await asyncio.gather(*init_tasks)
185
+ logging.info("OAuth credential validation complete.")
186
+
187
  # [NEW] Load provider-specific params
188
  litellm_provider_params = {
189
  "gemini_cli": {"project_id": os.getenv("GEMINI_CLI_PROJECT_ID")}
src/rotator_library/background_refresher.py CHANGED
@@ -1,8 +1,9 @@
1
  # src/rotator_library/background_refresher.py
2
 
 
3
  import asyncio
4
  import logging
5
- from typing import TYPE_CHECKING
6
 
7
  if TYPE_CHECKING:
8
  from .client import RotatingClient
@@ -14,9 +15,14 @@ class BackgroundRefresher:
14
  A background task that periodically checks and refreshes OAuth tokens
15
  to ensure they remain valid.
16
  """
17
- def __init__(self, client: 'RotatingClient', interval_seconds: int = 300):
 
 
 
 
 
 
18
  self._client = client
19
- self._interval = interval_seconds
20
  self._task: Optional[asyncio.Task] = None
21
 
22
  def start(self):
@@ -24,6 +30,7 @@ class BackgroundRefresher:
24
  if self._task is None:
25
  self._task = asyncio.create_task(self._run())
26
  lib_logger.info(f"Background token refresher started. Check interval: {self._interval} seconds.")
 
27
 
28
  async def stop(self):
29
  """Stops the background refresh task."""
 
1
  # src/rotator_library/background_refresher.py
2
 
3
+ import os
4
  import asyncio
5
  import logging
6
+ from typing import TYPE_CHECKING, Optional
7
 
8
  if TYPE_CHECKING:
9
  from .client import RotatingClient
 
15
  A background task that periodically checks and refreshes OAuth tokens
16
  to ensure they remain valid.
17
  """
18
+ def __init__(self, client: 'RotatingClient'):
19
+ try:
20
+ interval_str = os.getenv("OAUTH_REFRESH_INTERVAL", "3600")
21
+ self._interval = int(interval_str)
22
+ except ValueError:
23
+ lib_logger.warning(f"Invalid OAUTH_REFRESH_INTERVAL '{interval_str}'. Falling back to 3600s.")
24
+ self._interval = 3600
25
  self._client = client
 
26
  self._task: Optional[asyncio.Task] = None
27
 
28
  def start(self):
 
30
  if self._task is None:
31
  self._task = asyncio.create_task(self._run())
32
  lib_logger.info(f"Background token refresher started. Check interval: {self._interval} seconds.")
33
+ # [NEW] Log if custom interval is set
34
 
35
  async def stop(self):
36
  """Stops the background refresh task."""
src/rotator_library/providers/gemini_auth_base.py CHANGED
@@ -1,5 +1,7 @@
1
  # src/rotator_library/providers/gemini_auth_base.py
2
 
 
 
3
  import json
4
  import time
5
  import asyncio
@@ -85,12 +87,6 @@ class GeminiAuthBase:
85
  lib_logger.info(f"Successfully refreshed Gemini OAuth token for '{Path(path).name}'.")
86
  return creds
87
 
88
- async def get_auth_header(self, credential_path: str) -> Dict[str, str]:
89
- creds = await self._load_credentials(credential_path)
90
- if self._is_token_expired(creds):
91
- creds = await self._refresh_token(credential_path, creds)
92
- return {"Authorization": f"Bearer {creds['access_token']}"}
93
-
94
  async def proactively_refresh(self, credential_path: str):
95
  creds = await self._load_credentials(credential_path)
96
  if self._is_token_expired(creds):
@@ -99,4 +95,59 @@ class GeminiAuthBase:
99
  def _get_lock(self, path: str) -> asyncio.Lock:
100
  if path not in self._refresh_locks:
101
  self._refresh_locks[path] = asyncio.Lock()
102
- return self._refresh_locks[path]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # src/rotator_library/providers/gemini_auth_base.py
2
 
3
+ import subprocess
4
+ from typing import Optional
5
  import json
6
  import time
7
  import asyncio
 
87
  lib_logger.info(f"Successfully refreshed Gemini OAuth token for '{Path(path).name}'.")
88
  return creds
89
 
 
 
 
 
 
 
90
  async def proactively_refresh(self, credential_path: str):
91
  creds = await self._load_credentials(credential_path)
92
  if self._is_token_expired(creds):
 
95
  def _get_lock(self, path: str) -> asyncio.Lock:
96
  if path not in self._refresh_locks:
97
  self._refresh_locks[path] = asyncio.Lock()
98
+ return self._refresh_locks[path]
99
+
100
+ # [NEW] Add init flow for invalid/expired tokens
101
+ async def initialize_token(self, path: str) -> Dict[str, Any]:
102
+ """Initiates OAuth flow if tokens are missing or invalid."""
103
+ try:
104
+ creds = await self._load_credentials(path)
105
+ if not creds.get("refresh_token") or self._is_token_expired(creds):
106
+ lib_logger.warning(f"Invalid or missing Gemini OAuth tokens at '{path}'. Initiating setup...")
107
+ # Use subprocess to run gemini-cli setup or simulate web flow
108
+ # Based on CLIProxyAPI-main/gemini/gemini_auth.go: Use web flow with local server
109
+ # For simplicity, prompt user to run manual setup or integrate browser flow
110
+ print("Gemini CLI OAuth setup required. Please visit the authorization URL and paste the code.")
111
+ # Simulate getTokenFromWeb logic
112
+ from urllib.parse import urlencode
113
+ auth_url = "https://accounts.google.com/oauth2/v2/auth?" + urlencode({
114
+ "client_id": CLIENT_ID,
115
+ "redirect_uri": "http://localhost:8085/oauth2callback",
116
+ "scope": " ".join(["https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile"]),
117
+ "access_type": "offline",
118
+ "response_type": "code",
119
+ "prompt": "consent"
120
+ })
121
+ print(f"\n--- Gemini OAuth Setup Required for {Path(path).name} ---")
122
+ print(f"Please open this URL in your browser:\n\n{auth_url}\n")
123
+ auth_code = input("After authorizing, paste the 'code' from the redirected URL here: ")
124
+
125
+ async with httpx.AsyncClient() as client:
126
+ response = await client.post(TOKEN_URI, data={
127
+ "code": auth_code.strip(),
128
+ "client_id": CLIENT_ID,
129
+ "client_secret": CLIENT_SECRET,
130
+ "redirect_uri": "http://localhost:8085/oauth2callback",
131
+ "grant_type": "authorization_code"
132
+ })
133
+ response.raise_for_status()
134
+ token_data = response.json()
135
+ creds = {
136
+ "access_token": token_data["access_token"],
137
+ "refresh_token": token_data["refresh_token"],
138
+ "expiry_date": (time.time() + token_data["expires_in"]) * 1000,
139
+ "client_id": CLIENT_ID,
140
+ "client_secret": CLIENT_SECRET
141
+ }
142
+ await self._save_credentials(path, creds)
143
+ lib_logger.info(f"Gemini OAuth initialized successfully for '{path}'.")
144
+ return creds
145
+ return creds
146
+ except Exception as e:
147
+ raise ValueError(f"Failed to initialize Gemini OAuth: {e}")
148
+
149
+ async def get_auth_header(self, credential_path: str) -> Dict[str, str]:
150
+ creds = await self.initialize_token(credential_path) # [NEW] Call init if needed
151
+ if self._is_token_expired(creds):
152
+ creds = await self._refresh_token(credential_path, creds)
153
+ return {"Authorization": f"Bearer {creds['access_token']}"}
src/rotator_library/providers/gemini_cli_provider.py CHANGED
@@ -18,35 +18,52 @@ CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com/v1internal"
18
  class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
19
  def __init__(self):
20
  super().__init__()
21
- self.project_id: Optional[str] = None
22
 
23
- async def _discover_project_id(self, litellm_params: Dict[str, Any]) -> str:
24
- """Discovers the Google Cloud Project ID."""
25
- if self.project_id:
26
- return self.project_id
27
 
28
- # 1. Prioritize explicitly configured project_id
29
  if litellm_params.get("project_id"):
30
- self.project_id = litellm_params["project_id"]
31
- lib_logger.info(f"Using configured Gemini CLI project ID: {self.project_id}")
32
- return self.project_id
 
33
 
34
- # 2. Fallback: Look for .env file in the standard .gemini directory
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  try:
36
- gemini_env_path = Path.home() / ".gemini" / ".env"
37
- if gemini_env_path.exists():
38
- with open(gemini_env_path, 'r') as f:
39
- for line in f:
40
- if line.startswith("GOOGLE_CLOUD_PROJECT="):
41
- self.project_id = line.strip().split("=")[1]
42
- lib_logger.info(f"Discovered Gemini CLI project ID from ~/.gemini/.env: {self.project_id}")
43
- return self.project_id
44
- except Exception as e:
45
- lib_logger.warning(f"Could not read project ID from ~/.gemini/.env: {e}")
 
 
46
 
47
  raise ValueError(
48
- "Gemini CLI project ID not found. Please set `GEMINI_CLI_PROJECT_ID` in your main .env file "
49
- "or ensure it is present in `~/.gemini/.env`."
50
  )
51
  def has_custom_logic(self) -> bool:
52
  return True
@@ -109,53 +126,64 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
109
  async def acompletion(self, client: httpx.AsyncClient, **kwargs) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]:
110
  model = kwargs["model"]
111
  credential_path = kwargs.pop("credential_identifier")
112
- auth_header = await self.get_auth_header(credential_path)
113
-
114
- project_id = await self._discover_project_id(kwargs.get("litellm_params", {}))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
- # Handle :thinking suffix from Kilo example
117
- model_name = model.split('/')[-1]
118
- enable_thinking = model_name.endswith(':thinking')
119
- if enable_thinking:
120
- model_name = model_name.replace(':thinking', '')
 
 
 
 
 
 
 
 
 
 
121
 
122
- gen_config = {
123
- "temperature": kwargs.get("temperature", 0.7),
124
- "maxOutputTokens": kwargs.get("max_tokens", 8192),
125
- }
126
- if enable_thinking:
127
- gen_config["thinkingConfig"] = {"thinkingBudget": -1}
128
-
129
- request_payload = {
130
- "model": model_name,
131
- "project": project_id,
132
- "request": {
133
- "contents": self._transform_messages(kwargs.get("messages", [])),
134
- "generationConfig": gen_config,
135
- },
136
- }
137
 
138
- url = f"{CODE_ASSIST_ENDPOINT}:streamGenerateContent"
139
-
140
- async def stream_handler():
141
- async with client.stream("POST", url, headers=auth_header, json=request_payload, params={"alt": "sse"}, timeout=600) as response:
142
- response.raise_for_status()
143
- async for line in response.aiter_lines():
144
- if line.startswith('data: '):
145
- data_str = line[6:]
146
- if data_str == "[DONE]": break
147
- try:
148
- chunk = json.loads(data_str)
149
- openai_chunk = self._convert_chunk_to_openai(chunk, model)
150
- yield litellm.ModelResponse(**openai_chunk)
151
- except json.JSONDecodeError:
152
- lib_logger.warning(f"Could not decode JSON from Gemini CLI: {line}")
153
-
154
  if kwargs.get("stream", False):
155
- return stream_handler()
156
  else:
157
  # Accumulate stream for non-streaming response
158
- chunks = [chunk async for chunk in stream_handler()]
159
  return litellm.utils.stream_to_completion_response(chunks)
160
 
161
  # [NEW] Hardcoded model list based on Kilo example
 
18
  class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
19
  def __init__(self):
20
  super().__init__()
21
+ self.project_id_cache: Dict[str, str] = {} # Cache project ID per credential path
22
 
23
+ async def _discover_project_id(self, credential_path: str, access_token: str, litellm_params: Dict[str, Any]) -> str:
24
+ """Discovers the Google Cloud Project ID, with caching."""
25
+ if credential_path in self.project_id_cache:
26
+ return self.project_id_cache[credential_path]
27
 
 
28
  if litellm_params.get("project_id"):
29
+ project_id = litellm_params["project_id"]
30
+ lib_logger.info(f"Using configured Gemini CLI project ID: {project_id}")
31
+ self.project_id_cache[credential_path] = project_id
32
+ return project_id
33
 
34
+ headers = {'Authorization': f'Bearer {access_token}', 'Content-Type': 'application/json'}
35
+
36
+ # 1. Try Gemini-specific discovery endpoint
37
+ try:
38
+ async with httpx.AsyncClient() as client:
39
+ response = await client.post(f"{CODE_ASSIST_ENDPOINT}:loadCodeAssist", headers=headers, json={"metadata": {"pluginType": "GEMINI"}})
40
+ response.raise_for_status()
41
+ data = response.json()
42
+ if data.get('cloudaicompanionProject'):
43
+ project_id = data['cloudaicompanionProject']
44
+ lib_logger.info(f"Discovered Gemini project ID via loadCodeAssist: {project_id}")
45
+ self.project_id_cache[credential_path] = project_id
46
+ return project_id
47
+ except httpx.RequestError as e:
48
+ lib_logger.warning(f"Gemini loadCodeAssist failed, falling back to project listing: {e}")
49
+
50
+ # 2. Fallback to listing all available GCP projects
51
  try:
52
+ async with httpx.AsyncClient() as client:
53
+ response = await client.get("https://cloudresourcemanager.googleapis.com/v1/projects", headers=headers)
54
+ response.raise_for_status()
55
+ projects = response.json().get('projects', [])
56
+ active_projects = [p for p in projects if p.get('lifecycleState') == 'ACTIVE']
57
+ if active_projects:
58
+ project_id = active_projects[0]['projectId']
59
+ lib_logger.info(f"Discovered Gemini project ID from active projects list: {project_id}")
60
+ self.project_id_cache[credential_path] = project_id
61
+ return project_id
62
+ except httpx.RequestError as e:
63
+ lib_logger.error(f"Failed to list GCP projects: {e}")
64
 
65
  raise ValueError(
66
+ "Could not auto-discover Gemini project ID. Please set GEMINI_CLI_PROJECT_ID in your .env file."
 
67
  )
68
  def has_custom_logic(self) -> bool:
69
  return True
 
126
  async def acompletion(self, client: httpx.AsyncClient, **kwargs) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]:
127
  model = kwargs["model"]
128
  credential_path = kwargs.pop("credential_identifier")
129
+
130
+ async def do_call():
131
+ auth_header = await self.get_auth_header(credential_path)
132
+ project_id = await self._discover_project_id(credential_path, auth_header['Authorization'].split(' ')[1], kwargs.get("litellm_params", {}))
133
+
134
+ # Handle :thinking suffix from Kilo example
135
+ model_name = model.split('/')[-1]
136
+ enable_thinking = model_name.endswith(':thinking')
137
+ if enable_thinking:
138
+ model_name = model_name.replace(':thinking', '')
139
+
140
+ gen_config = {
141
+ "temperature": kwargs.get("temperature", 0.7),
142
+ "maxOutputTokens": kwargs.get("max_tokens", 8192),
143
+ }
144
+ if enable_thinking:
145
+ gen_config["thinkingConfig"] = {"thinkingBudget": -1}
146
+
147
+ request_payload = {
148
+ "model": model_name,
149
+ "project": project_id,
150
+ "request": {
151
+ "contents": self._transform_messages(kwargs.get("messages", [])),
152
+ "generationConfig": gen_config,
153
+ },
154
+ }
155
 
156
+ url = f"{CODE_ASSIST_ENDPOINT}:streamGenerateContent"
157
+ async def stream_handler():
158
+ async with client.stream("POST", url, headers=auth_header, json=request_payload, params={"alt": "sse"}, timeout=600) as response:
159
+ response.raise_for_status()
160
+ async for line in response.aiter_lines():
161
+ if line.startswith('data: '):
162
+ data_str = line[6:]
163
+ if data_str == "[DONE]": break
164
+ try:
165
+ chunk = json.loads(data_str)
166
+ openai_chunk = self._convert_chunk_to_openai(chunk, model)
167
+ yield litellm.ModelResponse(**openai_chunk)
168
+ except json.JSONDecodeError:
169
+ lib_logger.warning(f"Could not decode JSON from Gemini CLI: {line}")
170
+ return stream_handler()
171
 
172
+ try:
173
+ response_gen = await do_call()
174
+ except httpx.HTTPStatusError as e:
175
+ if e.response.status_code == 401:
176
+ lib_logger.warning("Gemini CLI returned 401. Forcing token refresh and retrying once.")
177
+ await self._refresh_token(credential_path, force=True)
178
+ response_gen = await do_call()
179
+ else:
180
+ raise e
 
 
 
 
 
 
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  if kwargs.get("stream", False):
183
+ return response_gen
184
  else:
185
  # Accumulate stream for non-streaming response
186
+ chunks = [chunk async for chunk in response_gen]
187
  return litellm.utils.stream_to_completion_response(chunks)
188
 
189
  # [NEW] Hardcoded model list based on Kilo example
src/rotator_library/providers/qwen_auth_base.py CHANGED
@@ -1,5 +1,8 @@
1
  # src/rotator_library/providers/qwen_auth_base.py
2
 
 
 
 
3
  import json
4
  import time
5
  import asyncio
@@ -77,12 +80,6 @@ class QwenAuthBase:
77
  lib_logger.info(f"Successfully refreshed Qwen OAuth token for '{Path(path).name}'.")
78
  return creds_from_file
79
 
80
- async def get_auth_header(self, credential_path: str) -> Dict[str, str]:
81
- creds = await self._load_credentials(credential_path)
82
- if self._is_token_expired(creds):
83
- creds = await self._refresh_token(credential_path)
84
- return {"Authorization": f"Bearer {creds['access_token']}"}
85
-
86
  def get_api_details(self, credential_path: str) -> Tuple[str, str]:
87
  creds = self._credentials_cache[credential_path]
88
  base_url = creds.get("resource_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
@@ -98,4 +95,73 @@ class QwenAuthBase:
98
  def _get_lock(self, path: str) -> asyncio.Lock:
99
  if path not in self._refresh_locks:
100
  self._refresh_locks[path] = asyncio.Lock()
101
- return self._refresh_locks[path]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # src/rotator_library/providers/qwen_auth_base.py
2
 
3
+ import secrets
4
+ import hashlib
5
+ import base64
6
  import json
7
  import time
8
  import asyncio
 
80
  lib_logger.info(f"Successfully refreshed Qwen OAuth token for '{Path(path).name}'.")
81
  return creds_from_file
82
 
 
 
 
 
 
 
83
  def get_api_details(self, credential_path: str) -> Tuple[str, str]:
84
  creds = self._credentials_cache[credential_path]
85
  base_url = creds.get("resource_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
 
95
  def _get_lock(self, path: str) -> asyncio.Lock:
96
  if path not in self._refresh_locks:
97
  self._refresh_locks[path] = asyncio.Lock()
98
+ return self._refresh_locks[path]
99
+
100
+ # [NEW] Add init flow for invalid/expired tokens
101
+ async def initialize_token(self, path: str) -> Dict[str, Any]:
102
+ """Initiates device flow if tokens are missing or invalid."""
103
+ try:
104
+ creds = await self._load_credentials(path)
105
+ if not creds.get("refresh_token") or self._is_token_expired(creds):
106
+ lib_logger.warning(f"Invalid or missing Qwen OAuth tokens at '{path}'. Initiating device flow...")
107
+ # Based on CLIProxyAPI-main/qwen/qwen_auth.go: Use device code with PKCE
108
+ code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode('utf-8').rstrip('=')
109
+ code_challenge = base64.urlsafe_b64encode(
110
+ hashlib.sha256(code_verifier.encode('utf-8')).digest()
111
+ ).decode('utf-8').rstrip('=')
112
+
113
+ async with httpx.AsyncClient() as client:
114
+ dev_response = await client.post(
115
+ "https://chat.qwen.ai/api/v1/oauth2/device/code",
116
+ data={
117
+ "client_id": CLIENT_ID,
118
+ "scope": "openid profile email model.completion",
119
+ "code_challenge": code_challenge,
120
+ "code_challenge_method": "S256"
121
+ }
122
+ )
123
+ dev_response.raise_for_status()
124
+ dev_data = dev_response.json()
125
+
126
+ print(f"\n--- Qwen OAuth Setup Required for {Path(path).name} ---")
127
+ print(f"Please visit: {dev_data['verification_uri_complete']}")
128
+ print(f"And enter code: {dev_data['user_code']}\n")
129
+
130
+ token_data = None
131
+ start_time = time.time()
132
+ while time.time() - start_time < dev_data['expires_in']:
133
+ poll_response = await client.post(
134
+ TOKEN_ENDPOINT,
135
+ data={
136
+ "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
137
+ "device_code": dev_data['device_code'],
138
+ "client_id": CLIENT_ID,
139
+ "code_verifier": code_verifier
140
+ }
141
+ )
142
+ if poll_response.status_code == 200:
143
+ token_data = poll_response.json()
144
+ break
145
+ await asyncio.sleep(dev_data['interval'])
146
+
147
+ if not token_data:
148
+ raise TimeoutError("Qwen device flow timed out.")
149
+
150
+ creds.update({
151
+ "access_token": token_data["access_token"],
152
+ "refresh_token": token_data.get("refresh_token"),
153
+ "expiry_date": (time.time() + token_data["expires_in"]) * 1000,
154
+ "resource_url": token_data.get("resource_url")
155
+ })
156
+ await self._save_credentials(path, creds)
157
+ lib_logger.info(f"Qwen OAuth initialized successfully for '{path}'.")
158
+ return creds
159
+ return creds
160
+ except Exception as e:
161
+ raise ValueError(f"Failed to initialize Qwen OAuth: {e}")
162
+
163
+ async def get_auth_header(self, credential_path: str) -> Dict[str, str]:
164
+ creds = await self.initialize_token(credential_path) # [NEW] Call init if needed
165
+ if self._is_token_expired(creds):
166
+ creds = await self._refresh_token(credential_path)
167
+ return {"Authorization": f"Bearer {creds['access_token']}"}
src/rotator_library/providers/qwen_code_provider.py CHANGED
@@ -1,5 +1,6 @@
1
  # src/rotator_library/providers/qwen_code_provider.py
2
 
 
3
  import httpx
4
  import logging
5
  from typing import Union, AsyncGenerator
@@ -31,6 +32,12 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
31
  if content and ("<think>" in content or "</think>" in content):
32
  parts = content.replace("<think>", "||THINK||").replace("</think>", "||/THINK||").split("||")
33
  for part in parts:
 
 
 
 
 
 
34
  if not part: continue
35
  new_chunk = chunk.copy()
36
  if part.startswith("THINK||"):
@@ -52,8 +59,15 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
52
  async def do_call():
53
  api_base, access_token = self.get_api_details(credential_path)
54
  response = await litellm.acompletion(
 
55
  **kwargs, api_key=access_token, api_base=api_base
56
  )
 
 
 
 
 
 
57
  return response
58
 
59
  try:
@@ -63,6 +77,11 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
63
  lib_logger.warning("Qwen Code returned 401. Forcing token refresh and retrying once.")
64
  await self._refresh_token(credential_path, force=True)
65
  response = await do_call()
 
 
 
 
 
66
  else:
67
  raise e
68
 
 
1
  # src/rotator_library/providers/qwen_code_provider.py
2
 
3
+ import litellm.exceptions as litellm_exc
4
  import httpx
5
  import logging
6
  from typing import Union, AsyncGenerator
 
32
  if content and ("<think>" in content or "</think>" in content):
33
  parts = content.replace("<think>", "||THINK||").replace("</think>", "||/THINK||").split("||")
34
  for part in parts:
35
+ # [NEW] Check for provider-specific errors in content
36
+ if "slow_down" in part.lower():
37
+ lib_logger.warning("Qwen 'slow_down' detected in response content. Treating as rate limit.")
38
+ raise litellm_exc.RateLimitError(
39
+ message="Qwen slow_down error detected.", llm_provider="qwen_code"
40
+ )
41
  if not part: continue
42
  new_chunk = chunk.copy()
43
  if part.startswith("THINK||"):
 
59
  async def do_call():
60
  api_base, access_token = self.get_api_details(credential_path)
61
  response = await litellm.acompletion(
62
+ # [NEW] Add timeout and retry params if needed, but since rotation handles retries, this is optional
63
  **kwargs, api_key=access_token, api_base=api_base
64
  )
65
+ # [NEW] Post-call check for specific finish reasons or errors
66
+ if not kwargs.get("stream") and response.choices[0].finish_reason == "slow_down":
67
+ lib_logger.warning("Qwen 'slow_down' finish reason detected. Treating as rate limit.")
68
+ raise litellm_exc.RateLimitError(
69
+ message="Qwen slow_down finish reason.", llm_provider="qwen_code"
70
+ )
71
  return response
72
 
73
  try:
 
77
  lib_logger.warning("Qwen Code returned 401. Forcing token refresh and retrying once.")
78
  await self._refresh_token(credential_path, force=True)
79
  response = await do_call()
80
+ # [NEW] Catch provider-specific exceptions
81
+ elif "slow_down" in str(e).lower():
82
+ raise litellm_exc.RateLimitError(
83
+ message="Qwen slow_down error in exception.", llm_provider="qwen_code"
84
+ )
85
  else:
86
  raise e
87