Mirrowel commited on
Commit
5ee03d6
·
1 Parent(s): 74bc3cb

feat(auth): introduce iFlow provider integration via OAuth

Browse files

Implement full integration for iFlow, supporting automatic credential discovery and management using the OAuth 2.0 authorization code flow.

- Added `IFlowAuthBase` to handle the interactive OAuth flow and token refresh using a local `aiohttp` callback server.
- Enabled automatic discovery of iFlow credentials from the standard `~/.iflow/` directory.
- Ensures the dedicated API key (derived from user info after OAuth) is always used for API calls and refreshed proactively.
- Updated `requirements.txt` to include `aiohttp` for the local server functionality.
- Refactored Qwen providers to clean tool schemas and standardize streaming assembly logic for better robustness.

README.md CHANGED
@@ -28,7 +28,7 @@ This project provides a powerful solution for developers building complex applic
28
  - **Resilient Performance**: A global timeout on all requests prevents your application from hanging on unresponsive provider APIs.
29
  - **Efficient Concurrency**: Maximizes throughput by allowing a single API key to handle multiple concurrent requests to different models.
30
  - **Intelligent Key Management**: Optimizes request distribution across your pool of keys by selecting the best available one for each call.
31
- - **Automated OAuth Discovery**: Automatically discovers, validates, and manages OAuth credentials from standard provider directories (e.g., `~/.gemini/`, `~/.qwen/`). No manual `.env` configuration is required for supported providers.
32
  - **Duplicate Credential Detection**: Intelligently detects if multiple local credential files belong to the same user account and logs a warning, preventing redundancy in your key pool.
33
  - **Escalating Per-Model Cooldowns**: If a key fails for a specific model, it's placed on a temporary, escalating cooldown for that model, allowing it to be used with others.
34
  - **Automatic Daily Resets**: Cooldowns and usage statistics are automatically reset daily, making the system self-maintaining.
@@ -115,7 +115,7 @@ The proxy supports two types of credentials:
115
 
116
  For many providers, **no configuration is necessary**. The proxy automatically discovers and manages credentials from their default locations:
117
  - **API Keys**: Scans your environment variables for keys matching the format `PROVIDER_API_KEY_1` (e.g., `GEMINI_API_KEY_1`).
118
- - **OAuth Credentials**: Scans default system directories (e.g., `~/.gemini/`, `~/.qwen/`) for all `*.json` credential files.
119
 
120
  You only need to create a `.env` file to set your `PROXY_API_KEY` and to override or add credentials if the automatic discovery doesn't suit your needs.
121
 
 
28
  - **Resilient Performance**: A global timeout on all requests prevents your application from hanging on unresponsive provider APIs.
29
  - **Efficient Concurrency**: Maximizes throughput by allowing a single API key to handle multiple concurrent requests to different models.
30
  - **Intelligent Key Management**: Optimizes request distribution across your pool of keys by selecting the best available one for each call.
31
+ - **Automated OAuth Discovery**: Automatically discovers, validates, and manages OAuth credentials from standard provider directories (e.g., `~/.gemini/`, `~/.qwen/`, `~/.iflow/`). No manual `.env` configuration is required for supported providers.
32
  - **Duplicate Credential Detection**: Intelligently detects if multiple local credential files belong to the same user account and logs a warning, preventing redundancy in your key pool.
33
  - **Escalating Per-Model Cooldowns**: If a key fails for a specific model, it's placed on a temporary, escalating cooldown for that model, allowing it to be used with others.
34
  - **Automatic Daily Resets**: Cooldowns and usage statistics are automatically reset daily, making the system self-maintaining.
 
115
 
116
  For many providers, **no configuration is necessary**. The proxy automatically discovers and manages credentials from their default locations:
117
  - **API Keys**: Scans your environment variables for keys matching the format `PROVIDER_API_KEY_1` (e.g., `GEMINI_API_KEY_1`).
118
+ - **OAuth Credentials**: Scans default system directories (e.g., `~/.gemini/`, `~/.qwen/`, `~/.iflow/`) for all `*.json` credential files.
119
 
120
  You only need to create a `.env` file to set your `PROXY_API_KEY` and to override or add credentials if the automatic discovery doesn't suit your needs.
121
 
requirements.txt CHANGED
@@ -14,6 +14,7 @@ litellm
14
  filelock
15
  httpx
16
  aiofiles
 
17
 
18
  colorlog
19
 
 
14
  filelock
15
  httpx
16
  aiofiles
17
+ aiohttp
18
 
19
  colorlog
20
 
src/rotator_library/credential_manager.py CHANGED
@@ -13,6 +13,7 @@ OAUTH_BASE_DIR.mkdir(exist_ok=True)
13
  DEFAULT_OAUTH_DIRS = {
14
  "gemini_cli": Path.home() / ".gemini",
15
  "qwen_code": Path.home() / ".qwen",
 
16
  # Add other providers like 'claude' here if they have a standard CLI path
17
  }
18
 
 
13
  DEFAULT_OAUTH_DIRS = {
14
  "gemini_cli": Path.home() / ".gemini",
15
  "qwen_code": Path.home() / ".qwen",
16
+ "iflow": Path.home() / ".iflow",
17
  # Add other providers like 'claude' here if they have a standard CLI path
18
  }
19
 
src/rotator_library/credential_tool.py CHANGED
@@ -80,7 +80,7 @@ async def setup_api_key():
80
  }
81
 
82
  # Discover custom providers and add them to the list
83
- oauth_providers = {'gemini_cli', 'qwen_code'}
84
  discovered_providers = {
85
  p.replace('_', ' ').title(): p.upper() + "_API_KEY"
86
  for p in PROVIDER_PLUGINS.keys()
@@ -222,7 +222,8 @@ async def main():
222
  available_providers = get_available_providers()
223
  oauth_friendly_names = {
224
  "gemini_cli": "Gemini CLI (OAuth)",
225
- "qwen_code": "Qwen Code (OAuth)"
 
226
  }
227
 
228
  provider_text = Text()
 
80
  }
81
 
82
  # Discover custom providers and add them to the list
83
+ oauth_providers = {'gemini_cli', 'qwen_code', 'iflow'}
84
  discovered_providers = {
85
  p.replace('_', ' ').title(): p.upper() + "_API_KEY"
86
  for p in PROVIDER_PLUGINS.keys()
 
222
  available_providers = get_available_providers()
223
  oauth_friendly_names = {
224
  "gemini_cli": "Gemini CLI (OAuth)",
225
+ "qwen_code": "Qwen Code (OAuth)",
226
+ "iflow": "iFlow (OAuth)"
227
  }
228
 
229
  provider_text = Text()
src/rotator_library/provider_factory.py CHANGED
@@ -2,10 +2,12 @@
2
 
3
  from .providers.gemini_auth_base import GeminiAuthBase
4
  from .providers.qwen_auth_base import QwenAuthBase
 
5
 
6
  PROVIDER_MAP = {
7
  "gemini_cli": GeminiAuthBase,
8
  "qwen_code": QwenAuthBase,
 
9
  }
10
 
11
  def get_provider_auth_class(provider_name: str):
 
2
 
3
  from .providers.gemini_auth_base import GeminiAuthBase
4
  from .providers.qwen_auth_base import QwenAuthBase
5
+ from .providers.iflow_auth_base import IFlowAuthBase
6
 
7
  PROVIDER_MAP = {
8
  "gemini_cli": GeminiAuthBase,
9
  "qwen_code": QwenAuthBase,
10
+ "iflow": IFlowAuthBase,
11
  }
12
 
13
  def get_provider_auth_class(provider_name: str):
src/rotator_library/providers/iflow_auth_base.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/rotator_library/providers/iflow_auth_base.py
2
+
3
+ import secrets
4
+ import base64
5
+ import json
6
+ import time
7
+ import asyncio
8
+ import logging
9
+ import webbrowser
10
+ import socket
11
+ from pathlib import Path
12
+ from typing import Dict, Any, Tuple, Union, Optional
13
+ from urllib.parse import urlencode, parse_qs, urlparse
14
+
15
+ import httpx
16
+ from aiohttp import web
17
+ from rich.console import Console
18
+ from rich.panel import Panel
19
+ from rich.prompt import Prompt
20
+ from rich.text import Text
21
+
22
+ lib_logger = logging.getLogger('rotator_library')
23
+
24
+ # OAuth endpoints and credentials from Go example
25
+ IFLOW_OAUTH_AUTHORIZE_ENDPOINT = "https://iflow.cn/oauth"
26
+ IFLOW_OAUTH_TOKEN_ENDPOINT = "https://iflow.cn/oauth/token"
27
+ IFLOW_USER_INFO_ENDPOINT = "https://iflow.cn/api/oauth/getUserInfo"
28
+ IFLOW_SUCCESS_REDIRECT_URL = "https://iflow.cn/oauth/success"
29
+ IFLOW_ERROR_REDIRECT_URL = "https://iflow.cn/oauth/error"
30
+
31
+ # Client credentials provided by iFlow
32
+ IFLOW_CLIENT_ID = "10009311001"
33
+ IFLOW_CLIENT_SECRET = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW"
34
+
35
+ # Local callback server port
36
+ CALLBACK_PORT = 11451
37
+
38
+ # Refresh tokens 24 hours before expiry (from Go example)
39
+ REFRESH_EXPIRY_BUFFER_SECONDS = 24 * 60 * 60
40
+
41
+ console = Console()
42
+
43
+
44
+ class OAuthCallbackServer:
45
+ """
46
+ Minimal HTTP server for handling iFlow OAuth callbacks.
47
+ Based on the Go example's oauth_server.go implementation.
48
+ """
49
+
50
+ def __init__(self, port: int = CALLBACK_PORT):
51
+ self.port = port
52
+ self.app = web.Application()
53
+ self.runner: Optional[web.AppRunner] = None
54
+ self.site: Optional[web.TCPSite] = None
55
+ self.result_future: Optional[asyncio.Future] = None
56
+ self.expected_state: Optional[str] = None
57
+
58
+ def _is_port_available(self) -> bool:
59
+ """Checks if the callback port is available."""
60
+ try:
61
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
62
+ sock.bind(('', self.port))
63
+ sock.close()
64
+ return True
65
+ except OSError:
66
+ return False
67
+
68
+ async def start(self, expected_state: str):
69
+ """Starts the OAuth callback server."""
70
+ if not self._is_port_available():
71
+ raise RuntimeError(f"Port {self.port} is already in use")
72
+
73
+ self.expected_state = expected_state
74
+ self.result_future = asyncio.Future()
75
+
76
+ # Setup route
77
+ self.app.router.add_get('/oauth2callback', self._handle_callback)
78
+
79
+ # Start server
80
+ self.runner = web.AppRunner(self.app)
81
+ await self.runner.setup()
82
+ self.site = web.TCPSite(self.runner, 'localhost', self.port)
83
+ await self.site.start()
84
+
85
+ lib_logger.debug(f"iFlow OAuth callback server started on port {self.port}")
86
+
87
+ async def stop(self):
88
+ """Stops the OAuth callback server."""
89
+ if self.site:
90
+ await self.site.stop()
91
+ if self.runner:
92
+ await self.runner.cleanup()
93
+ lib_logger.debug("iFlow OAuth callback server stopped")
94
+
95
+ async def _handle_callback(self, request: web.Request) -> web.Response:
96
+ """Handles the OAuth callback request."""
97
+ query = request.query
98
+
99
+ # Check for error parameter
100
+ if 'error' in query:
101
+ error = query.get('error', 'unknown_error')
102
+ lib_logger.error(f"iFlow OAuth callback received error: {error}")
103
+ if not self.result_future.done():
104
+ self.result_future.set_exception(ValueError(f"OAuth error: {error}"))
105
+ return web.Response(status=302, headers={'Location': IFLOW_ERROR_REDIRECT_URL})
106
+
107
+ # Check for authorization code
108
+ code = query.get('code')
109
+ if not code:
110
+ lib_logger.error("iFlow OAuth callback missing authorization code")
111
+ if not self.result_future.done():
112
+ self.result_future.set_exception(ValueError("Missing authorization code"))
113
+ return web.Response(status=302, headers={'Location': IFLOW_ERROR_REDIRECT_URL})
114
+
115
+ # Validate state parameter
116
+ state = query.get('state', '')
117
+ if state != self.expected_state:
118
+ lib_logger.error(f"iFlow OAuth state mismatch. Expected: {self.expected_state}, Got: {state}")
119
+ if not self.result_future.done():
120
+ self.result_future.set_exception(ValueError("State parameter mismatch"))
121
+ return web.Response(status=302, headers={'Location': IFLOW_ERROR_REDIRECT_URL})
122
+
123
+ # Success - set result and redirect to success page
124
+ if not self.result_future.done():
125
+ self.result_future.set_result(code)
126
+
127
+ return web.Response(status=302, headers={'Location': IFLOW_SUCCESS_REDIRECT_URL})
128
+
129
+ async def wait_for_callback(self, timeout: float = 300.0) -> str:
130
+ """Waits for the OAuth callback and returns the authorization code."""
131
+ try:
132
+ code = await asyncio.wait_for(self.result_future, timeout=timeout)
133
+ return code
134
+ except asyncio.TimeoutError:
135
+ raise TimeoutError("Timeout waiting for OAuth callback")
136
+
137
+
138
+ class IFlowAuthBase:
139
+ """
140
+ iFlow OAuth authentication base class.
141
+ Implements authorization code flow with local callback server.
142
+ Based on the Go example implementation.
143
+ """
144
+
145
+ def __init__(self):
146
+ self._credentials_cache: Dict[str, Dict[str, Any]] = {}
147
+ self._refresh_locks: Dict[str, asyncio.Lock] = {}
148
+
149
+ async def _read_creds_from_file(self, path: str) -> Dict[str, Any]:
150
+ """Reads credentials from file and populates the cache. No locking."""
151
+ try:
152
+ lib_logger.debug(f"Reading iFlow credentials from file: {path}")
153
+ with open(path, 'r') as f:
154
+ creds = json.load(f)
155
+ self._credentials_cache[path] = creds
156
+ return creds
157
+ except FileNotFoundError:
158
+ raise IOError(f"iFlow OAuth credential file not found at '{path}'")
159
+ except Exception as e:
160
+ raise IOError(f"Failed to load iFlow OAuth credentials from '{path}': {e}")
161
+
162
+ async def _load_credentials(self, path: str) -> Dict[str, Any]:
163
+ """Loads credentials from cache or file."""
164
+ if path in self._credentials_cache:
165
+ return self._credentials_cache[path]
166
+
167
+ async with self._get_lock(path):
168
+ # Re-check cache after acquiring lock
169
+ if path in self._credentials_cache:
170
+ return self._credentials_cache[path]
171
+ return await self._read_creds_from_file(path)
172
+
173
+ async def _save_credentials(self, path: str, creds: Dict[str, Any]):
174
+ """Saves credentials to cache and file."""
175
+ self._credentials_cache[path] = creds
176
+ try:
177
+ with open(path, 'w') as f:
178
+ json.dump(creds, f, indent=2)
179
+ lib_logger.debug(f"Saved updated iFlow OAuth credentials to '{path}'.")
180
+ except Exception as e:
181
+ lib_logger.error(f"Failed to save updated iFlow OAuth credentials to '{path}': {e}")
182
+
183
+ def _is_token_expired(self, creds: Dict[str, Any]) -> bool:
184
+ """Checks if the token is expired (with buffer for proactive refresh)."""
185
+ # Try to parse expiry_date as ISO 8601 string (from Go example)
186
+ expiry_str = creds.get("expiry_date")
187
+ if not expiry_str:
188
+ return True
189
+
190
+ try:
191
+ # Parse ISO 8601 format (e.g., "2025-01-17T12:00:00Z")
192
+ from datetime import datetime
193
+ expiry_dt = datetime.fromisoformat(expiry_str.replace('Z', '+00:00'))
194
+ expiry_timestamp = expiry_dt.timestamp()
195
+ except (ValueError, AttributeError):
196
+ # Fallback: treat as numeric timestamp
197
+ try:
198
+ expiry_timestamp = float(expiry_str)
199
+ except (ValueError, TypeError):
200
+ lib_logger.warning(f"Could not parse expiry_date: {expiry_str}")
201
+ return True
202
+
203
+ return expiry_timestamp < time.time() + REFRESH_EXPIRY_BUFFER_SECONDS
204
+
205
+ async def _fetch_user_info(self, access_token: str) -> Dict[str, Any]:
206
+ """
207
+ Fetches user info (including API key) from iFlow API.
208
+ This is critical: iFlow uses a separate API key for actual API calls.
209
+ """
210
+ if not access_token or not access_token.strip():
211
+ raise ValueError("Access token is empty")
212
+
213
+ url = f"{IFLOW_USER_INFO_ENDPOINT}?accessToken={access_token}"
214
+ headers = {"Accept": "application/json"}
215
+
216
+ async with httpx.AsyncClient(timeout=30.0) as client:
217
+ response = await client.get(url, headers=headers)
218
+ response.raise_for_status()
219
+ result = response.json()
220
+
221
+ if not result.get("success"):
222
+ raise ValueError("iFlow user info request not successful")
223
+
224
+ data = result.get("data", {})
225
+ api_key = data.get("apiKey", "").strip()
226
+ if not api_key:
227
+ raise ValueError("Missing API key in user info response")
228
+
229
+ email = data.get("email", "").strip()
230
+ if not email:
231
+ email = data.get("phone", "").strip()
232
+ if not email:
233
+ raise ValueError("Missing email/phone in user info response")
234
+
235
+ return {"api_key": api_key, "email": email}
236
+
237
+ async def _exchange_code_for_tokens(self, code: str, redirect_uri: str) -> Dict[str, Any]:
238
+ """
239
+ Exchanges authorization code for access and refresh tokens.
240
+ Uses Basic Auth with client credentials (from Go example).
241
+ """
242
+ # Create Basic Auth header
243
+ auth_string = f"{IFLOW_CLIENT_ID}:{IFLOW_CLIENT_SECRET}"
244
+ basic_auth = base64.b64encode(auth_string.encode()).decode()
245
+
246
+ headers = {
247
+ "Content-Type": "application/x-www-form-urlencoded",
248
+ "Accept": "application/json",
249
+ "Authorization": f"Basic {basic_auth}"
250
+ }
251
+
252
+ data = {
253
+ "grant_type": "authorization_code",
254
+ "code": code,
255
+ "redirect_uri": redirect_uri,
256
+ "client_id": IFLOW_CLIENT_ID,
257
+ "client_secret": IFLOW_CLIENT_SECRET
258
+ }
259
+
260
+ async with httpx.AsyncClient(timeout=30.0) as client:
261
+ response = await client.post(IFLOW_OAUTH_TOKEN_ENDPOINT, headers=headers, data=data)
262
+
263
+ if response.status_code != 200:
264
+ error_text = response.text
265
+ lib_logger.error(f"iFlow token exchange failed: {response.status_code} {error_text}")
266
+ raise ValueError(f"Token exchange failed: {response.status_code} {error_text}")
267
+
268
+ token_data = response.json()
269
+
270
+ access_token = token_data.get("access_token")
271
+ if not access_token:
272
+ raise ValueError("Missing access_token in token response")
273
+
274
+ refresh_token = token_data.get("refresh_token", "")
275
+ expires_in = token_data.get("expires_in", 3600)
276
+ token_type = token_data.get("token_type", "Bearer")
277
+ scope = token_data.get("scope", "")
278
+
279
+ # Fetch user info to get API key
280
+ user_info = await self._fetch_user_info(access_token)
281
+
282
+ # Calculate expiry date
283
+ from datetime import datetime, timedelta
284
+ expiry_date = (datetime.utcnow() + timedelta(seconds=expires_in)).isoformat() + 'Z'
285
+
286
+ return {
287
+ "access_token": access_token,
288
+ "refresh_token": refresh_token,
289
+ "api_key": user_info["api_key"],
290
+ "email": user_info["email"],
291
+ "expiry_date": expiry_date,
292
+ "token_type": token_type,
293
+ "scope": scope
294
+ }
295
+
296
+ async def _refresh_token(self, path: str, force: bool = False) -> Dict[str, Any]:
297
+ """
298
+ Refreshes the OAuth tokens and re-fetches the API key.
299
+ CRITICAL: Must re-fetch user info to get potentially updated API key.
300
+ """
301
+ async with self._get_lock(path):
302
+ cached_creds = self._credentials_cache.get(path)
303
+ if not force and cached_creds and not self._is_token_expired(cached_creds):
304
+ return cached_creds
305
+
306
+ # If cache is empty, read from file
307
+ if path not in self._credentials_cache:
308
+ await self._read_creds_from_file(path)
309
+
310
+ creds_from_file = self._credentials_cache[path]
311
+
312
+ lib_logger.info(f"Refreshing iFlow OAuth token for '{Path(path).name}'...")
313
+ refresh_token = creds_from_file.get("refresh_token")
314
+ if not refresh_token:
315
+ raise ValueError("No refresh_token found in iFlow credentials file.")
316
+
317
+ # Create Basic Auth header
318
+ auth_string = f"{IFLOW_CLIENT_ID}:{IFLOW_CLIENT_SECRET}"
319
+ basic_auth = base64.b64encode(auth_string.encode()).decode()
320
+
321
+ headers = {
322
+ "Content-Type": "application/x-www-form-urlencoded",
323
+ "Accept": "application/json",
324
+ "Authorization": f"Basic {basic_auth}"
325
+ }
326
+
327
+ data = {
328
+ "grant_type": "refresh_token",
329
+ "refresh_token": refresh_token,
330
+ "client_id": IFLOW_CLIENT_ID,
331
+ "client_secret": IFLOW_CLIENT_SECRET
332
+ }
333
+
334
+ async with httpx.AsyncClient(timeout=30.0) as client:
335
+ response = await client.post(IFLOW_OAUTH_TOKEN_ENDPOINT, headers=headers, data=data)
336
+ response.raise_for_status()
337
+ new_token_data = response.json()
338
+
339
+ # Update tokens
340
+ access_token = new_token_data.get("access_token")
341
+ if not access_token:
342
+ raise ValueError("Missing access_token in refresh response")
343
+
344
+ creds_from_file["access_token"] = access_token
345
+ creds_from_file["refresh_token"] = new_token_data.get("refresh_token", creds_from_file["refresh_token"])
346
+
347
+ expires_in = new_token_data.get("expires_in", 3600)
348
+ from datetime import datetime, timedelta
349
+ creds_from_file["expiry_date"] = (datetime.utcnow() + timedelta(seconds=expires_in)).isoformat() + 'Z'
350
+
351
+ creds_from_file["token_type"] = new_token_data.get("token_type", creds_from_file.get("token_type", "Bearer"))
352
+ creds_from_file["scope"] = new_token_data.get("scope", creds_from_file.get("scope", ""))
353
+
354
+ # CRITICAL: Re-fetch user info to get potentially updated API key
355
+ try:
356
+ user_info = await self._fetch_user_info(access_token)
357
+ if user_info.get("api_key"):
358
+ creds_from_file["api_key"] = user_info["api_key"]
359
+ if user_info.get("email"):
360
+ creds_from_file["email"] = user_info["email"]
361
+ except Exception as e:
362
+ lib_logger.warning(f"Failed to update API key during token refresh: {e}")
363
+
364
+ # Update timestamp in metadata if it exists
365
+ if creds_from_file.get("_proxy_metadata"):
366
+ creds_from_file["_proxy_metadata"]["last_check_timestamp"] = time.time()
367
+
368
+ await self._save_credentials(path, creds_from_file)
369
+ lib_logger.info(f"Successfully refreshed iFlow OAuth token for '{Path(path).name}'.")
370
+ return creds_from_file
371
+
372
+ async def get_api_details(self, credential_path: str) -> Tuple[str, str]:
373
+ """
374
+ Returns the API base URL and API key (NOT access_token).
375
+ CRITICAL: iFlow uses the api_key for API requests, not the OAuth access_token.
376
+ """
377
+ creds = await self._load_credentials(credential_path)
378
+
379
+ # Check if token needs refresh
380
+ if self._is_token_expired(creds):
381
+ creds = await self._refresh_token(credential_path)
382
+
383
+ api_key = creds.get("api_key")
384
+ if not api_key:
385
+ raise ValueError("Missing api_key in iFlow credentials")
386
+
387
+ base_url = "https://apis.iflow.cn/v1"
388
+ return base_url, api_key
389
+
390
+ async def proactively_refresh(self, credential_path: str):
391
+ """Proactively refreshes tokens if they're close to expiry."""
392
+ creds = await self._load_credentials(credential_path)
393
+ if self._is_token_expired(creds):
394
+ await self._refresh_token(credential_path)
395
+
396
+ def _get_lock(self, path: str) -> asyncio.Lock:
397
+ """Gets or creates a lock for the given credential path."""
398
+ if path not in self._refresh_locks:
399
+ self._refresh_locks[path] = asyncio.Lock()
400
+ return self._refresh_locks[path]
401
+
402
+ async def initialize_token(self, creds_or_path: Union[Dict[str, Any], str]) -> Dict[str, Any]:
403
+ """
404
+ Initiates OAuth authorization code flow if tokens are missing or invalid.
405
+ Uses local callback server to receive authorization code.
406
+ """
407
+ path = creds_or_path if isinstance(creds_or_path, str) else None
408
+ file_name = Path(path).name if path else "in-memory object"
409
+ lib_logger.debug(f"Initializing iFlow token for '{file_name}'...")
410
+
411
+ try:
412
+ creds = await self._load_credentials(creds_or_path) if path else creds_or_path
413
+
414
+ reason = ""
415
+ if not creds.get("refresh_token"):
416
+ reason = "refresh token is missing"
417
+ elif self._is_token_expired(creds):
418
+ reason = "token is expired"
419
+
420
+ if reason:
421
+ # Try automatic refresh first if we have a refresh token
422
+ if reason == "token is expired" and creds.get("refresh_token"):
423
+ try:
424
+ return await self._refresh_token(path)
425
+ except Exception as e:
426
+ lib_logger.warning(f"Automatic token refresh for '{file_name}' failed: {e}. Proceeding to interactive login.")
427
+
428
+ # Interactive OAuth flow
429
+ lib_logger.warning(f"iFlow OAuth token for '{file_name}' needs setup: {reason}.")
430
+
431
+ # Generate random state for CSRF protection
432
+ state = secrets.token_urlsafe(32)
433
+
434
+ # Build authorization URL
435
+ redirect_uri = f"http://localhost:{CALLBACK_PORT}/oauth2callback"
436
+ auth_params = {
437
+ "loginMethod": "phone",
438
+ "type": "phone",
439
+ "redirect": redirect_uri,
440
+ "state": state,
441
+ "client_id": IFLOW_CLIENT_ID
442
+ }
443
+ auth_url = f"{IFLOW_OAUTH_AUTHORIZE_ENDPOINT}?{urlencode(auth_params)}"
444
+
445
+ # Start OAuth callback server
446
+ callback_server = OAuthCallbackServer(port=CALLBACK_PORT)
447
+ try:
448
+ await callback_server.start(expected_state=state)
449
+
450
+ # Display instructions to user
451
+ auth_panel_text = Text.from_markup(
452
+ "1. Visit the URL below to sign in with your phone number.\n"
453
+ "2. [bold]Authorize the application[/bold] to access your account.\n"
454
+ "3. You will be automatically redirected after authorization."
455
+ )
456
+ console.print(Panel(auth_panel_text, title=f"iFlow OAuth Setup for [bold yellow]{file_name}[/bold yellow]", style="bold blue"))
457
+ console.print(f"[bold]URL:[/bold] [link={auth_url}]{auth_url}[/link]\n")
458
+
459
+ # Open browser
460
+ webbrowser.open(auth_url)
461
+
462
+ # Wait for callback
463
+ with console.status("[bold green]Waiting for authorization in the browser...[/bold green]", spinner="dots"):
464
+ code = await callback_server.wait_for_callback(timeout=300.0)
465
+
466
+ lib_logger.info("Received authorization code, exchanging for tokens...")
467
+
468
+ # Exchange code for tokens and API key
469
+ token_data = await self._exchange_code_for_tokens(code, redirect_uri)
470
+
471
+ # Update credentials
472
+ creds.update({
473
+ "access_token": token_data["access_token"],
474
+ "refresh_token": token_data["refresh_token"],
475
+ "api_key": token_data["api_key"],
476
+ "email": token_data["email"],
477
+ "expiry_date": token_data["expiry_date"],
478
+ "token_type": token_data["token_type"],
479
+ "scope": token_data["scope"]
480
+ })
481
+
482
+ # Create metadata object
483
+ if not creds.get("_proxy_metadata"):
484
+ creds["_proxy_metadata"] = {
485
+ "email": token_data["email"],
486
+ "last_check_timestamp": time.time()
487
+ }
488
+
489
+ if path:
490
+ await self._save_credentials(path, creds)
491
+
492
+ lib_logger.info(f"iFlow OAuth initialized successfully for '{file_name}'.")
493
+ return creds
494
+
495
+ finally:
496
+ await callback_server.stop()
497
+
498
+ lib_logger.info(f"iFlow OAuth token at '{file_name}' is valid.")
499
+ return creds
500
+
501
+ except Exception as e:
502
+ raise ValueError(f"Failed to initialize iFlow OAuth for '{path}': {e}")
503
+
504
+ async def get_auth_header(self, credential_path: str) -> Dict[str, str]:
505
+ """
506
+ Returns auth header with API key (NOT OAuth access_token).
507
+ CRITICAL: iFlow API requests use the api_key, not the OAuth tokens.
508
+ """
509
+ creds = await self._load_credentials(credential_path)
510
+ if self._is_token_expired(creds):
511
+ creds = await self._refresh_token(credential_path)
512
+
513
+ api_key = creds.get("api_key")
514
+ if not api_key:
515
+ raise ValueError("Missing api_key in iFlow credentials")
516
+
517
+ return {"Authorization": f"Bearer {api_key}"}
518
+
519
+ async def get_user_info(self, creds_or_path: Union[Dict[str, Any], str]) -> Dict[str, Any]:
520
+ """Retrieves user info from the _proxy_metadata in the credential file."""
521
+ try:
522
+ path = creds_or_path if isinstance(creds_or_path, str) else None
523
+ creds = await self._load_credentials(creds_or_path) if path else creds_or_path
524
+
525
+ # Ensure the token is valid
526
+ if path:
527
+ await self.initialize_token(path)
528
+ creds = await self._load_credentials(path)
529
+
530
+ email = creds.get("email") or creds.get("_proxy_metadata", {}).get("email")
531
+
532
+ if not email:
533
+ lib_logger.warning(f"No email found in iFlow credentials for '{path or 'in-memory object'}'.")
534
+
535
+ # Update timestamp on check
536
+ if path and "_proxy_metadata" in creds:
537
+ creds["_proxy_metadata"]["last_check_timestamp"] = time.time()
538
+ await self._save_credentials(path, creds)
539
+
540
+ return {"email": email}
541
+ except Exception as e:
542
+ lib_logger.error(f"Failed to get iFlow user info from credentials: {e}")
543
+ return {"email": None}
src/rotator_library/providers/iflow_provider.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/rotator_library/providers/iflow_provider.py
2
+
3
+ import json
4
+ import time
5
+ import httpx
6
+ import logging
7
+ from typing import Union, AsyncGenerator, List, Dict, Any
8
+ from .provider_interface import ProviderInterface
9
+ from .iflow_auth_base import IFlowAuthBase
10
+ import litellm
11
+ from litellm.exceptions import RateLimitError, AuthenticationError
12
+
13
+ lib_logger = logging.getLogger('rotator_library')
14
+
15
+ # Model list can be expanded as iFlow supports more models
16
+ HARDCODED_MODELS = [
17
+ "deepseek-v3",
18
+ "deepseek-chat",
19
+ "deepseek-coder"
20
+ ]
21
+
22
+ # OpenAI-compatible parameters supported by iFlow API
23
+ SUPPORTED_PARAMS = {
24
+ 'model', 'messages', 'temperature', 'top_p', 'max_tokens',
25
+ 'stream', 'tools', 'tool_choice', 'presence_penalty',
26
+ 'frequency_penalty', 'n', 'stop', 'seed', 'response_format'
27
+ }
28
+
29
+
30
+ class IFlowProvider(IFlowAuthBase, ProviderInterface):
31
+ """
32
+ iFlow provider using OAuth authentication with local callback server.
33
+ API requests use the derived API key (NOT OAuth access_token).
34
+ Based on the Go example implementation.
35
+ """
36
+ skip_cost_calculation = True
37
+
38
+ def __init__(self):
39
+ super().__init__()
40
+
41
+ def has_custom_logic(self) -> bool:
42
+ return True
43
+
44
+ async def get_models(self, credential: str, client: httpx.AsyncClient) -> List[str]:
45
+ """Returns a hardcoded list of known compatible iFlow models."""
46
+ return [f"iflow/{model_id}" for model_id in HARDCODED_MODELS]
47
+
48
+ def _clean_tool_schemas(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
49
+ """
50
+ Removes unsupported properties from tool schemas to prevent API errors.
51
+ Similar to Qwen Code implementation.
52
+ """
53
+ import copy
54
+ cleaned_tools = []
55
+
56
+ for tool in tools:
57
+ cleaned_tool = copy.deepcopy(tool)
58
+
59
+ if "function" in cleaned_tool:
60
+ func = cleaned_tool["function"]
61
+
62
+ # Remove strict mode (may not be supported)
63
+ func.pop("strict", None)
64
+
65
+ # Clean parameter schema if present
66
+ if "parameters" in func and isinstance(func["parameters"], dict):
67
+ params = func["parameters"]
68
+
69
+ # Remove additionalProperties if present
70
+ params.pop("additionalProperties", None)
71
+
72
+ # Recursively clean nested properties
73
+ if "properties" in params:
74
+ self._clean_schema_properties(params["properties"])
75
+
76
+ cleaned_tools.append(cleaned_tool)
77
+
78
+ return cleaned_tools
79
+
80
+ def _clean_schema_properties(self, properties: Dict[str, Any]) -> None:
81
+ """Recursively cleans schema properties."""
82
+ for prop_name, prop_schema in properties.items():
83
+ if isinstance(prop_schema, dict):
84
+ # Remove unsupported fields
85
+ prop_schema.pop("strict", None)
86
+ prop_schema.pop("additionalProperties", None)
87
+
88
+ # Recurse into nested properties
89
+ if "properties" in prop_schema:
90
+ self._clean_schema_properties(prop_schema["properties"])
91
+
92
+ # Recurse into array items
93
+ if "items" in prop_schema and isinstance(prop_schema["items"], dict):
94
+ self._clean_schema_properties({"item": prop_schema["items"]})
95
+
96
+ def _build_request_payload(self, **kwargs) -> Dict[str, Any]:
97
+ """
98
+ Builds a clean request payload with only supported parameters.
99
+ This prevents 400 Bad Request errors from litellm-internal parameters.
100
+ """
101
+ # Extract only supported OpenAI parameters
102
+ payload = {k: v for k, v in kwargs.items() if k in SUPPORTED_PARAMS}
103
+
104
+ # Always force streaming for internal processing
105
+ payload['stream'] = True
106
+
107
+ # Always include usage data in stream
108
+ payload['stream_options'] = {"include_usage": True}
109
+
110
+ # Handle tool schema cleaning
111
+ if "tools" in payload and payload["tools"]:
112
+ payload["tools"] = self._clean_tool_schemas(payload["tools"])
113
+ lib_logger.debug(f"Cleaned {len(payload['tools'])} tool schemas")
114
+
115
+ return payload
116
+
117
+ def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str):
118
+ """
119
+ Converts a raw iFlow SSE chunk to an OpenAI-compatible chunk.
120
+ Since iFlow is OpenAI-compatible, minimal conversion is needed.
121
+ """
122
+ if not isinstance(chunk, dict):
123
+ return
124
+
125
+ # Handle usage data
126
+ if usage_data := chunk.get("usage"):
127
+ yield {
128
+ "choices": [], "model": model_id, "object": "chat.completion.chunk",
129
+ "id": chunk.get("id", f"chatcmpl-iflow-{time.time()}"),
130
+ "created": chunk.get("created", int(time.time())),
131
+ "usage": {
132
+ "prompt_tokens": usage_data.get("prompt_tokens", 0),
133
+ "completion_tokens": usage_data.get("completion_tokens", 0),
134
+ "total_tokens": usage_data.get("total_tokens", 0),
135
+ }
136
+ }
137
+ return
138
+
139
+ # Handle content data
140
+ choices = chunk.get("choices", [])
141
+ if not choices:
142
+ return
143
+
144
+ # iFlow returns OpenAI-compatible format, so we can mostly pass through
145
+ yield {
146
+ "choices": choices,
147
+ "model": model_id,
148
+ "object": "chat.completion.chunk",
149
+ "id": chunk.get("id", f"chatcmpl-iflow-{time.time()}"),
150
+ "created": chunk.get("created", int(time.time()))
151
+ }
152
+
153
+ def _stream_to_completion_response(self, chunks: List[litellm.ModelResponse]) -> litellm.ModelResponse:
154
+ """
155
+ Manually reassembles streaming chunks into a complete response.
156
+ """
157
+ if not chunks:
158
+ raise ValueError("No chunks provided for reassembly")
159
+
160
+ # Initialize the final response structure
161
+ final_message = {"role": "assistant"}
162
+ aggregated_tool_calls = {}
163
+ usage_data = None
164
+ finish_reason = None
165
+
166
+ # Get the first chunk for basic response metadata
167
+ first_chunk = chunks[0]
168
+
169
+ # Process each chunk to aggregate content
170
+ for chunk in chunks:
171
+ if not hasattr(chunk, 'choices') or not chunk.choices:
172
+ continue
173
+
174
+ choice = chunk.choices[0]
175
+ delta = choice.get("delta", {})
176
+
177
+ # Aggregate content
178
+ if "content" in delta and delta["content"] is not None:
179
+ if "content" not in final_message:
180
+ final_message["content"] = ""
181
+ final_message["content"] += delta["content"]
182
+
183
+ # Aggregate reasoning content (if supported by iFlow)
184
+ if "reasoning_content" in delta and delta["reasoning_content"] is not None:
185
+ if "reasoning_content" not in final_message:
186
+ final_message["reasoning_content"] = ""
187
+ final_message["reasoning_content"] += delta["reasoning_content"]
188
+
189
+ # Aggregate tool calls
190
+ if "tool_calls" in delta and delta["tool_calls"]:
191
+ for tc_chunk in delta["tool_calls"]:
192
+ index = tc_chunk["index"]
193
+ if index not in aggregated_tool_calls:
194
+ aggregated_tool_calls[index] = {"function": {"name": "", "arguments": ""}}
195
+ if "id" in tc_chunk:
196
+ aggregated_tool_calls[index]["id"] = tc_chunk["id"]
197
+ if "type" in tc_chunk:
198
+ aggregated_tool_calls[index]["type"] = tc_chunk["type"]
199
+ if "function" in tc_chunk:
200
+ if "name" in tc_chunk["function"] and tc_chunk["function"]["name"] is not None:
201
+ aggregated_tool_calls[index]["function"]["name"] += tc_chunk["function"]["name"]
202
+ if "arguments" in tc_chunk["function"] and tc_chunk["function"]["arguments"] is not None:
203
+ aggregated_tool_calls[index]["function"]["arguments"] += tc_chunk["function"]["arguments"]
204
+
205
+ # Aggregate function calls (legacy format)
206
+ if "function_call" in delta and delta["function_call"] is not None:
207
+ if "function_call" not in final_message:
208
+ final_message["function_call"] = {"name": "", "arguments": ""}
209
+ if "name" in delta["function_call"] and delta["function_call"]["name"] is not None:
210
+ final_message["function_call"]["name"] += delta["function_call"]["name"]
211
+ if "arguments" in delta["function_call"] and delta["function_call"]["arguments"] is not None:
212
+ final_message["function_call"]["arguments"] += delta["function_call"]["arguments"]
213
+
214
+ # Get finish reason from the last chunk that has it
215
+ if choice.get("finish_reason"):
216
+ finish_reason = choice["finish_reason"]
217
+
218
+ # Handle usage data from the last chunk that has it
219
+ for chunk in reversed(chunks):
220
+ if hasattr(chunk, 'usage') and chunk.usage:
221
+ usage_data = chunk.usage
222
+ break
223
+
224
+ # Add tool calls to final message if any
225
+ if aggregated_tool_calls:
226
+ final_message["tool_calls"] = list(aggregated_tool_calls.values())
227
+
228
+ # Ensure standard fields are present for consistent logging
229
+ for field in ["content", "tool_calls", "function_call"]:
230
+ if field not in final_message:
231
+ final_message[field] = None
232
+
233
+ # Construct the final response
234
+ final_choice = {
235
+ "index": 0,
236
+ "message": final_message,
237
+ "finish_reason": finish_reason
238
+ }
239
+
240
+ # Create the final ModelResponse
241
+ final_response_data = {
242
+ "id": first_chunk.id,
243
+ "object": "chat.completion",
244
+ "created": first_chunk.created,
245
+ "model": first_chunk.model,
246
+ "choices": [final_choice],
247
+ "usage": usage_data
248
+ }
249
+
250
+ return litellm.ModelResponse(**final_response_data)
251
+
252
+ async def acompletion(self, client: httpx.AsyncClient, **kwargs) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]:
253
+ credential_path = kwargs.pop("credential_identifier")
254
+ enable_request_logging = kwargs.pop("enable_request_logging", False)
255
+ model = kwargs["model"]
256
+
257
+ async def make_request():
258
+ """Prepares and makes the actual API call."""
259
+ # CRITICAL: get_api_details returns api_key, NOT access_token
260
+ api_base, api_key = await self.get_api_details(credential_path)
261
+
262
+ # Build clean payload with only supported parameters
263
+ payload = self._build_request_payload(**kwargs)
264
+
265
+ headers = {
266
+ "Authorization": f"Bearer {api_key}", # Uses api_key from user info
267
+ "Content-Type": "application/json",
268
+ "Accept": "text/event-stream",
269
+ "User-Agent": "iFlow-Cli"
270
+ }
271
+
272
+ url = f"{api_base.rstrip('/')}/chat/completions"
273
+
274
+ if enable_request_logging:
275
+ lib_logger.info(f"iFlow Request URL: {url}")
276
+ lib_logger.info(f"iFlow Request Payload: {json.dumps(payload, indent=2)}")
277
+ else:
278
+ lib_logger.debug(f"iFlow Request URL: {url}")
279
+
280
+ return client.stream("POST", url, headers=headers, json=payload, timeout=600)
281
+
282
+ async def stream_handler(response_stream, attempt=1):
283
+ """Handles the streaming response and converts chunks."""
284
+ try:
285
+ async with response_stream as response:
286
+ # Check for HTTP errors before processing stream
287
+ if response.status_code >= 400:
288
+ error_text = await response.aread()
289
+ error_text = error_text.decode('utf-8') if isinstance(error_text, bytes) else error_text
290
+
291
+ # Handle 401: Force token refresh and retry once
292
+ if response.status_code == 401 and attempt == 1:
293
+ lib_logger.warning("iFlow returned 401. Forcing token refresh and retrying once.")
294
+ await self._refresh_token(credential_path, force=True)
295
+ retry_stream = await make_request()
296
+ async for chunk in stream_handler(retry_stream, attempt=2):
297
+ yield chunk
298
+ return
299
+
300
+ # Handle 429: Rate limit
301
+ elif response.status_code == 429 or "slow_down" in error_text.lower():
302
+ raise RateLimitError(
303
+ f"iFlow rate limit exceeded: {error_text}",
304
+ llm_provider="iflow",
305
+ model=model,
306
+ response=response
307
+ )
308
+
309
+ # Handle other errors
310
+ else:
311
+ if enable_request_logging:
312
+ lib_logger.error(f"iFlow HTTP {response.status_code} error: {error_text}")
313
+ raise httpx.HTTPStatusError(
314
+ f"HTTP {response.status_code}: {error_text}",
315
+ request=response.request,
316
+ response=response
317
+ )
318
+
319
+ # Process successful streaming response
320
+ async for line in response.aiter_lines():
321
+ if line.startswith('data: '):
322
+ data_str = line[6:]
323
+ if data_str == "[DONE]":
324
+ break
325
+ try:
326
+ chunk = json.loads(data_str)
327
+ for openai_chunk in self._convert_chunk_to_openai(chunk, model):
328
+ yield litellm.ModelResponse(**openai_chunk)
329
+ except json.JSONDecodeError:
330
+ lib_logger.warning(f"Could not decode JSON from iFlow: {line}")
331
+
332
+ except httpx.HTTPStatusError:
333
+ raise # Re-raise HTTP errors we already handled
334
+ except Exception as e:
335
+ if enable_request_logging:
336
+ lib_logger.error(f"Error during iFlow stream processing: {e}", exc_info=True)
337
+ raise
338
+
339
+ http_response_stream = await make_request()
340
+ response_generator = stream_handler(http_response_stream)
341
+
342
+ if kwargs.get("stream"):
343
+ return response_generator
344
+ else:
345
+ async def non_stream_wrapper():
346
+ chunks = [chunk async for chunk in response_generator]
347
+ return self._stream_to_completion_response(chunks)
348
+ return await non_stream_wrapper()
src/rotator_library/providers/qwen_auth_base.py CHANGED
@@ -110,8 +110,8 @@ class QwenAuthBase:
110
  lib_logger.info(f"Successfully refreshed Qwen OAuth token for '{Path(path).name}'.")
111
  return creds_from_file
112
 
113
- def get_api_details(self, credential_path: str) -> Tuple[str, str]:
114
- creds = self._credentials_cache[credential_path]
115
  base_url = creds.get("resource_url", "https://portal.qwen.ai/v1")
116
  if not base_url.startswith("http"):
117
  base_url = f"https://{base_url}"
 
110
  lib_logger.info(f"Successfully refreshed Qwen OAuth token for '{Path(path).name}'.")
111
  return creds_from_file
112
 
113
+ async def get_api_details(self, credential_path: str) -> Tuple[str, str]:
114
+ creds = await self._load_credentials(credential_path)
115
  base_url = creds.get("resource_url", "https://portal.qwen.ai/v1")
116
  if not base_url.startswith("http"):
117
  base_url = f"https://{base_url}"
src/rotator_library/providers/qwen_code_provider.py CHANGED
@@ -17,8 +17,16 @@ HARDCODED_MODELS = [
17
  "qwen3-coder-flash"
18
  ]
19
 
 
 
 
 
 
 
 
20
  class QwenCodeProvider(QwenAuthBase, ProviderInterface):
21
  skip_cost_calculation = True
 
22
 
23
  def __init__(self):
24
  super().__init__()
@@ -30,6 +38,87 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
30
  """Returns a hardcoded list of known compatible Qwen models."""
31
  return [f"qwen_code/{model_id}" for model_id in HARDCODED_MODELS]
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str):
34
  """Converts a raw Qwen SSE chunk to an OpenAI-compatible chunk."""
35
  if not isinstance(chunk, dict):
@@ -60,14 +149,14 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
60
  # Handle <think> tags for reasoning content
61
  content = delta.get("content")
62
  if content and ("<think>" in content or "</think>" in content):
63
- parts = content.replace("<think>", "||THINK||").replace("</think>", "||/THINK||").split("||")
64
  for part in parts:
65
  if not part: continue
66
 
67
  new_delta = {}
68
- if part.startswith("THINK||"):
69
- new_delta['reasoning_content'] = part.replace("THINK||", "")
70
- elif part.startswith("/THINK||"):
71
  continue
72
  else:
73
  new_delta['content'] = part
@@ -85,71 +174,199 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
85
  "id": f"chatcmpl-qwen-{time.time()}", "created": int(time.time())
86
  }
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  async def acompletion(self, client: httpx.AsyncClient, **kwargs) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]:
89
  credential_path = kwargs.pop("credential_identifier")
 
90
  model = kwargs["model"]
91
 
92
- async def do_call():
 
93
  api_base, access_token = await self.get_api_details(credential_path)
94
-
95
- # Prepare payload
96
- payload = kwargs.copy()
97
- payload.pop("litellm_params", None) # Clean up internal params
98
-
99
- # Per Go example, inject dummy tool to prevent stream corruption
100
- if not payload.get("tools"):
101
- payload["tools"] = [{"type": "function", "function": {"name": "do_not_call_me", "description": "Do not call this tool under any circumstances.", "parameters": {"type": "object", "properties": {}}}}]
102
-
103
- # Ensure usage is included in stream
104
- payload["stream_options"] = {"include_usage": True}
105
 
106
  headers = {
107
  "Authorization": f"Bearer {access_token}",
108
  "Content-Type": "application/json",
109
- "Accept": "text/event-stream" if kwargs.get("stream") else "application/json",
110
  "User-Agent": "google-api-nodejs-client/9.15.1",
111
  "X-Goog-Api-Client": "gl-node/22.17.0",
112
  "Client-Metadata": "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI",
113
  }
114
-
115
- url = f"{api_base.rstrip('/')}/chat/completions"
116
- lib_logger.debug(f"Qwen Code Request URL: {url}")
117
- lib_logger.debug(f"Qwen Code Request Payload: {json.dumps(payload, indent=2)}")
118
-
119
- async def stream_handler():
120
- async with client.stream("POST", url, headers=headers, json=payload, timeout=600) as response:
121
- response.raise_for_status()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  async for line in response.aiter_lines():
123
  if line.startswith('data: '):
124
  data_str = line[6:]
125
- if data_str == "[DONE]": break
 
126
  try:
127
  chunk = json.loads(data_str)
128
  for openai_chunk in self._convert_chunk_to_openai(chunk, model):
129
  yield litellm.ModelResponse(**openai_chunk)
130
  except json.JSONDecodeError:
131
  lib_logger.warning(f"Could not decode JSON from Qwen Code: {line}")
132
-
133
- return stream_handler()
134
-
135
- try:
136
- response_gen = await do_call()
137
- except httpx.HTTPStatusError as e:
138
- if e.response.status_code == 401:
139
- lib_logger.warning("Qwen Code returned 401. Forcing token refresh and retrying once.")
140
- await self._refresh_token(credential_path, force=True)
141
- response_gen = await do_call()
142
- elif e.response.status_code == 429 or "slow_down" in e.response.text.lower():
143
- raise RateLimitError(
144
- message=f"Qwen Code rate limit exceeded: {e.response.text}",
145
- llm_provider="qwen_code",
146
- response=e.response
147
- )
148
- else:
149
- raise e
150
 
151
  if kwargs.get("stream"):
152
- return response_gen
153
  else:
154
- chunks = [chunk async for chunk in response_gen]
155
- return litellm.utils.stream_to_completion_response(chunks)
 
 
 
17
  "qwen3-coder-flash"
18
  ]
19
 
20
+ # OpenAI-compatible parameters supported by Qwen Code API
21
+ SUPPORTED_PARAMS = {
22
+ 'model', 'messages', 'temperature', 'top_p', 'max_tokens',
23
+ 'stream', 'tools', 'tool_choice', 'presence_penalty',
24
+ 'frequency_penalty', 'n', 'stop', 'seed', 'response_format'
25
+ }
26
+
27
  class QwenCodeProvider(QwenAuthBase, ProviderInterface):
28
  skip_cost_calculation = True
29
+ REASONING_START_MARKER = 'THINK||'
30
 
31
  def __init__(self):
32
  super().__init__()
 
38
  """Returns a hardcoded list of known compatible Qwen models."""
39
  return [f"qwen_code/{model_id}" for model_id in HARDCODED_MODELS]
40
 
41
+ def _clean_tool_schemas(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
42
+ """
43
+ Removes unsupported properties from tool schemas to prevent API errors.
44
+ Based on Gemini CLI's approach but adapted for Qwen's API requirements.
45
+ """
46
+ import copy
47
+ cleaned_tools = []
48
+
49
+ for tool in tools:
50
+ cleaned_tool = copy.deepcopy(tool)
51
+
52
+ if "function" in cleaned_tool:
53
+ func = cleaned_tool["function"]
54
+
55
+ # Remove strict mode (not supported by Qwen)
56
+ func.pop("strict", None)
57
+
58
+ # Clean parameter schema if present
59
+ if "parameters" in func and isinstance(func["parameters"], dict):
60
+ params = func["parameters"]
61
+
62
+ # Remove additionalProperties if present
63
+ params.pop("additionalProperties", None)
64
+
65
+ # Recursively clean nested properties
66
+ if "properties" in params:
67
+ self._clean_schema_properties(params["properties"])
68
+
69
+ cleaned_tools.append(cleaned_tool)
70
+
71
+ return cleaned_tools
72
+
73
+ def _clean_schema_properties(self, properties: Dict[str, Any]) -> None:
74
+ """Recursively cleans schema properties."""
75
+ for prop_name, prop_schema in properties.items():
76
+ if isinstance(prop_schema, dict):
77
+ # Remove unsupported fields
78
+ prop_schema.pop("strict", None)
79
+ prop_schema.pop("additionalProperties", None)
80
+
81
+ # Recurse into nested properties
82
+ if "properties" in prop_schema:
83
+ self._clean_schema_properties(prop_schema["properties"])
84
+
85
+ # Recurse into array items
86
+ if "items" in prop_schema and isinstance(prop_schema["items"], dict):
87
+ self._clean_schema_properties({"item": prop_schema["items"]})
88
+
89
+ def _build_request_payload(self, **kwargs) -> Dict[str, Any]:
90
+ """
91
+ Builds a clean request payload with only supported parameters.
92
+ This prevents 400 Bad Request errors from litellm-internal parameters.
93
+ """
94
+ # Extract only supported OpenAI parameters
95
+ payload = {k: v for k, v in kwargs.items() if k in SUPPORTED_PARAMS}
96
+
97
+ # Always force streaming for internal processing
98
+ payload['stream'] = True
99
+
100
+ # Always include usage data in stream
101
+ payload['stream_options'] = {"include_usage": True}
102
+
103
+ # Handle tool schema cleaning
104
+ if "tools" in payload and payload["tools"]:
105
+ payload["tools"] = self._clean_tool_schemas(payload["tools"])
106
+ lib_logger.debug(f"Cleaned {len(payload['tools'])} tool schemas")
107
+ elif not payload.get("tools"):
108
+ # Per Qwen Code API bug (see: https://github.com/qianwen-team/flash-dance/issues/2),
109
+ # injecting a dummy tool prevents stream corruption when no tools are provided
110
+ payload["tools"] = [{
111
+ "type": "function",
112
+ "function": {
113
+ "name": "do_not_call_me",
114
+ "description": "Do not call this tool.",
115
+ "parameters": {"type": "object", "properties": {}}
116
+ }
117
+ }]
118
+ lib_logger.debug("Injected dummy tool to prevent Qwen API stream corruption")
119
+
120
+ return payload
121
+
122
  def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str):
123
  """Converts a raw Qwen SSE chunk to an OpenAI-compatible chunk."""
124
  if not isinstance(chunk, dict):
 
149
  # Handle <think> tags for reasoning content
150
  content = delta.get("content")
151
  if content and ("<think>" in content or "</think>" in content):
152
+ parts = content.replace("<think>", f"||{self.REASONING_START_MARKER}").replace("</think>", f"||/{self.REASONING_START_MARKER}").split("||")
153
  for part in parts:
154
  if not part: continue
155
 
156
  new_delta = {}
157
+ if part.startswith(self.REASONING_START_MARKER):
158
+ new_delta['reasoning_content'] = part.replace(self.REASONING_START_MARKER, "")
159
+ elif part.startswith(f"/{self.REASONING_START_MARKER}"):
160
  continue
161
  else:
162
  new_delta['content'] = part
 
174
  "id": f"chatcmpl-qwen-{time.time()}", "created": int(time.time())
175
  }
176
 
177
+ def _stream_to_completion_response(self, chunks: List[litellm.ModelResponse]) -> litellm.ModelResponse:
178
+ """
179
+ Manually reassembles streaming chunks into a complete response.
180
+ This replaces the non-existent litellm.utils.stream_to_completion_response function.
181
+ """
182
+ if not chunks:
183
+ raise ValueError("No chunks provided for reassembly")
184
+
185
+ # Initialize the final response structure
186
+ final_message = {"role": "assistant"}
187
+ aggregated_tool_calls = {}
188
+ usage_data = None
189
+ finish_reason = None
190
+
191
+ # Get the first chunk for basic response metadata
192
+ first_chunk = chunks[0]
193
+
194
+ # Process each chunk to aggregate content
195
+ for chunk in chunks:
196
+ if not hasattr(chunk, 'choices') or not chunk.choices:
197
+ continue
198
+
199
+ choice = chunk.choices[0]
200
+ delta = choice.get("delta", {})
201
+
202
+ # Aggregate content
203
+ if "content" in delta and delta["content"] is not None:
204
+ if "content" not in final_message:
205
+ final_message["content"] = ""
206
+ final_message["content"] += delta["content"]
207
+
208
+ # Aggregate reasoning content
209
+ if "reasoning_content" in delta and delta["reasoning_content"] is not None:
210
+ if "reasoning_content" not in final_message:
211
+ final_message["reasoning_content"] = ""
212
+ final_message["reasoning_content"] += delta["reasoning_content"]
213
+
214
+ # Aggregate tool calls
215
+ if "tool_calls" in delta and delta["tool_calls"]:
216
+ for tc_chunk in delta["tool_calls"]:
217
+ index = tc_chunk["index"]
218
+ if index not in aggregated_tool_calls:
219
+ aggregated_tool_calls[index] = {"function": {"name": "", "arguments": ""}}
220
+ if "id" in tc_chunk:
221
+ aggregated_tool_calls[index]["id"] = tc_chunk["id"]
222
+ if "function" in tc_chunk:
223
+ if "name" in tc_chunk["function"] and tc_chunk["function"]["name"] is not None:
224
+ aggregated_tool_calls[index]["function"]["name"] += tc_chunk["function"]["name"]
225
+ if "arguments" in tc_chunk["function"] and tc_chunk["function"]["arguments"] is not None:
226
+ aggregated_tool_calls[index]["function"]["arguments"] += tc_chunk["function"]["arguments"]
227
+
228
+ # Aggregate function calls (legacy format)
229
+ if "function_call" in delta and delta["function_call"] is not None:
230
+ if "function_call" not in final_message:
231
+ final_message["function_call"] = {"name": "", "arguments": ""}
232
+ if "name" in delta["function_call"] and delta["function_call"]["name"] is not None:
233
+ final_message["function_call"]["name"] += delta["function_call"]["name"]
234
+ if "arguments" in delta["function_call"] and delta["function_call"]["arguments"] is not None:
235
+ final_message["function_call"]["arguments"] += delta["function_call"]["arguments"]
236
+
237
+ # Get finish reason from the last chunk that has it
238
+ if choice.get("finish_reason"):
239
+ finish_reason = choice["finish_reason"]
240
+
241
+ # Handle usage data from the last chunk that has it
242
+ for chunk in reversed(chunks):
243
+ if hasattr(chunk, 'usage') and chunk.usage:
244
+ usage_data = chunk.usage
245
+ break
246
+
247
+ # Add tool calls to final message if any
248
+ if aggregated_tool_calls:
249
+ final_message["tool_calls"] = list(aggregated_tool_calls.values())
250
+
251
+ # Ensure standard fields are present for consistent logging
252
+ for field in ["content", "tool_calls", "function_call"]:
253
+ if field not in final_message:
254
+ final_message[field] = None
255
+
256
+ # Construct the final response
257
+ final_choice = {
258
+ "index": 0,
259
+ "message": final_message,
260
+ "finish_reason": finish_reason
261
+ }
262
+
263
+ # Create the final ModelResponse
264
+ final_response_data = {
265
+ "id": first_chunk.id,
266
+ "object": "chat.completion",
267
+ "created": first_chunk.created,
268
+ "model": first_chunk.model,
269
+ "choices": [final_choice],
270
+ "usage": usage_data
271
+ }
272
+
273
+ return litellm.ModelResponse(**final_response_data)
274
+
275
  async def acompletion(self, client: httpx.AsyncClient, **kwargs) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]:
276
  credential_path = kwargs.pop("credential_identifier")
277
+ enable_request_logging = kwargs.pop("enable_request_logging", False)
278
  model = kwargs["model"]
279
 
280
+ async def make_request():
281
+ """Prepares and makes the actual API call."""
282
  api_base, access_token = await self.get_api_details(credential_path)
283
+
284
+ # Build clean payload with only supported parameters
285
+ payload = self._build_request_payload(**kwargs)
 
 
 
 
 
 
 
 
286
 
287
  headers = {
288
  "Authorization": f"Bearer {access_token}",
289
  "Content-Type": "application/json",
290
+ "Accept": "text/event-stream",
291
  "User-Agent": "google-api-nodejs-client/9.15.1",
292
  "X-Goog-Api-Client": "gl-node/22.17.0",
293
  "Client-Metadata": "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI",
294
  }
295
+
296
+ url = f"{api_base.rstrip('/')}/v1/chat/completions"
297
+
298
+ if enable_request_logging:
299
+ lib_logger.info(f"Qwen Code Request URL: {url}")
300
+ lib_logger.info(f"Qwen Code Request Payload: {json.dumps(payload, indent=2)}")
301
+ else:
302
+ lib_logger.debug(f"Qwen Code Request URL: {url}")
303
+
304
+ return client.stream("POST", url, headers=headers, json=payload, timeout=600)
305
+
306
+ async def stream_handler(response_stream, attempt=1):
307
+ """Handles the streaming response and converts chunks."""
308
+ try:
309
+ async with response_stream as response:
310
+ # Check for HTTP errors before processing stream
311
+ if response.status_code >= 400:
312
+ error_text = await response.aread()
313
+ error_text = error_text.decode('utf-8') if isinstance(error_text, bytes) else error_text
314
+
315
+ # Handle 401: Force token refresh and retry once
316
+ if response.status_code == 401 and attempt == 1:
317
+ lib_logger.warning("Qwen Code returned 401. Forcing token refresh and retrying once.")
318
+ await self._refresh_token(credential_path, force=True)
319
+ retry_stream = await make_request()
320
+ async for chunk in stream_handler(retry_stream, attempt=2):
321
+ yield chunk
322
+ return
323
+
324
+ # Handle 429: Rate limit
325
+ elif response.status_code == 429 or "slow_down" in error_text.lower():
326
+ raise RateLimitError(
327
+ f"Qwen Code rate limit exceeded: {error_text}",
328
+ llm_provider="qwen_code",
329
+ model=model,
330
+ response=response
331
+ )
332
+
333
+ # Handle other errors
334
+ else:
335
+ if enable_request_logging:
336
+ lib_logger.error(f"Qwen Code HTTP {response.status_code} error: {error_text}")
337
+ raise httpx.HTTPStatusError(
338
+ f"HTTP {response.status_code}: {error_text}",
339
+ request=response.request,
340
+ response=response
341
+ )
342
+
343
+ # Process successful streaming response
344
  async for line in response.aiter_lines():
345
  if line.startswith('data: '):
346
  data_str = line[6:]
347
+ if data_str == "[DONE]":
348
+ break
349
  try:
350
  chunk = json.loads(data_str)
351
  for openai_chunk in self._convert_chunk_to_openai(chunk, model):
352
  yield litellm.ModelResponse(**openai_chunk)
353
  except json.JSONDecodeError:
354
  lib_logger.warning(f"Could not decode JSON from Qwen Code: {line}")
355
+
356
+ except httpx.HTTPStatusError:
357
+ raise # Re-raise HTTP errors we already handled
358
+ except Exception as e:
359
+ if enable_request_logging:
360
+ lib_logger.error(f"Error during Qwen Code stream processing: {e}", exc_info=True)
361
+ raise
362
+
363
+ http_response_stream = await make_request()
364
+ response_generator = stream_handler(http_response_stream)
 
 
 
 
 
 
 
 
365
 
366
  if kwargs.get("stream"):
367
+ return response_generator
368
  else:
369
+ async def non_stream_wrapper():
370
+ chunks = [chunk async for chunk in response_generator]
371
+ return self._stream_to_completion_response(chunks)
372
+ return await non_stream_wrapper()