Mirrowel commited on
Commit
ed4dd55
·
2 Parent(s): 0c82aac c745d73

Merge branch 'dev' into main

Browse files
DOCUMENTATION.md CHANGED
@@ -856,6 +856,142 @@ class AntigravityAuthBase(GoogleOAuthBase):
856
  - Headless environment detection
857
  - Sequential refresh queue processing
858
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
859
  ---
860
 
861
 
@@ -877,8 +1013,8 @@ The `GeminiCliProvider` is the most complex implementation, mimicking the Google
877
 
878
  #### Authentication (`gemini_auth_base.py`)
879
 
880
- * **Device Flow**: Uses a standard OAuth 2.0 flow. The `credential_tool` spins up a local web server (`localhost:8085`) to capture the callback from Google's auth page.
881
- * **Token Lifecycle**:
882
  * **Proactive Refresh**: Tokens are refreshed 5 minutes before expiry.
883
  * **Atomic Writes**: Credential files are updated using a temp-file-and-move strategy to prevent corruption during writes.
884
  * **Revocation Handling**: If a `400` or `401` occurs during refresh, the token is marked as revoked, preventing infinite retry loops.
@@ -907,7 +1043,7 @@ The provider employs a sophisticated, cached discovery mechanism to find a valid
907
  ### 3.3. iFlow (`iflow_provider.py`)
908
 
909
  * **Hybrid Auth**: Uses a custom OAuth flow (Authorization Code) to obtain an `access_token`. However, the *actual* API calls use a separate `apiKey` that is retrieved from the user's profile (`/api/oauth/getUserInfo`) using the access token.
910
- * **Callback Server**: The auth flow spins up a local server on port `11451` to capture the redirect.
911
  * **Token Management**: Automatically refreshes the OAuth token and re-fetches the API key if needed.
912
  * **Schema Cleaning**: Similar to Qwen, it aggressively sanitizes tool schemas to prevent 400 errors.
913
  * **Dedicated Logging**: Implements `_IFlowFileLogger` to capture raw chunks for debugging proprietary API behaviors.
@@ -935,4 +1071,177 @@ To facilitate robust debugging, the proxy includes a comprehensive transaction l
935
 
936
  This level of detail allows developers to trace exactly why a request failed or why a specific key was rotated.
937
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
938
 
 
856
  - Headless environment detection
857
  - Sequential refresh queue processing
858
 
859
+ #### OAuth Callback Port Configuration
860
+
861
+ Each OAuth provider uses a local callback server during authentication. The callback port can be customized via environment variables to avoid conflicts with other services.
862
+
863
+ **Default Ports:**
864
+
865
+ | Provider | Default Port | Environment Variable |
866
+ |----------|-------------|---------------------|
867
+ | Gemini CLI | 8085 | `GEMINI_CLI_OAUTH_PORT` |
868
+ | Antigravity | 51121 | `ANTIGRAVITY_OAUTH_PORT` |
869
+ | iFlow | 11451 | `IFLOW_OAUTH_PORT` |
870
+
871
+ **Configuration Methods:**
872
+
873
+ 1. **Via TUI Settings Menu:**
874
+ - Main Menu → `4. View Provider & Advanced Settings` → `1. Launch Settings Tool`
875
+ - Select the provider (Gemini CLI, Antigravity, or iFlow)
876
+ - Modify the `*_OAUTH_PORT` setting
877
+ - Use "Reset to Default" to restore the original port
878
+
879
+ 2. **Via `.env` file:**
880
+ ```env
881
+ # Custom OAuth callback ports (optional)
882
+ GEMINI_CLI_OAUTH_PORT=8085
883
+ ANTIGRAVITY_OAUTH_PORT=51121
884
+ IFLOW_OAUTH_PORT=11451
885
+ ```
886
+
887
+ **When to Change Ports:**
888
+
889
+ - If the default port conflicts with another service on your system
890
+ - If running multiple proxy instances on the same machine
891
+ - If firewall rules require specific port ranges
892
+
893
+ **Note:** Port changes take effect on the next OAuth authentication attempt. Existing tokens are not affected.
894
+
895
+ ---
896
+
897
+ ### 2.14. HTTP Timeout Configuration (`timeout_config.py`)
898
+
899
+ Centralized timeout configuration for all HTTP requests to LLM providers.
900
+
901
+ #### Purpose
902
+
903
+ The `TimeoutConfig` class provides fine-grained control over HTTP timeouts for streaming and non-streaming LLM requests. This addresses the common issue of proxy hangs when upstream providers stall during connection establishment or response generation.
904
+
905
+ #### Timeout Types Explained
906
+
907
+ | Timeout | Description |
908
+ |---------|-------------|
909
+ | **connect** | Maximum time to establish a TCP/TLS connection to the upstream server |
910
+ | **read** | Maximum time to wait between receiving data chunks (resets on each chunk for streaming) |
911
+ | **write** | Maximum time to wait while sending the request body |
912
+ | **pool** | Maximum time to wait for a connection from the connection pool |
913
+
914
+ #### Default Values
915
+
916
+ | Setting | Streaming | Non-Streaming | Rationale |
917
+ |---------|-----------|---------------|-----------|
918
+ | **connect** | 30s | 30s | Fast fail if server is unreachable |
919
+ | **read** | 180s (3 min) | 600s (10 min) | Streaming expects periodic chunks; non-streaming may wait for full generation |
920
+ | **write** | 30s | 30s | Request bodies are typically small |
921
+ | **pool** | 60s | 60s | Reasonable wait for connection pool |
922
+
923
+ #### Environment Variable Overrides
924
+
925
+ All timeout values can be customized via environment variables:
926
+
927
+ ```env
928
+ # Connection establishment timeout (seconds)
929
+ TIMEOUT_CONNECT=30
930
+
931
+ # Request body send timeout (seconds)
932
+ TIMEOUT_WRITE=30
933
+
934
+ # Connection pool acquisition timeout (seconds)
935
+ TIMEOUT_POOL=60
936
+
937
+ # Read timeout between chunks for streaming requests (seconds)
938
+ # If no data arrives for this duration, the connection is considered stalled
939
+ TIMEOUT_READ_STREAMING=180
940
+
941
+ # Read timeout for non-streaming responses (seconds)
942
+ # Longer to accommodate models that take time to generate full responses
943
+ TIMEOUT_READ_NON_STREAMING=600
944
+ ```
945
+
946
+ #### Streaming vs Non-Streaming Behavior
947
+
948
+ **Streaming Requests** (`TimeoutConfig.streaming()`):
949
+ - Uses shorter read timeout (default 3 minutes)
950
+ - Timer resets every time a chunk arrives
951
+ - If no data for 3 minutes → connection considered dead → failover to next credential
952
+ - Appropriate for chat completions where tokens should arrive periodically
953
+
954
+ **Non-Streaming Requests** (`TimeoutConfig.non_streaming()`):
955
+ - Uses longer read timeout (default 10 minutes)
956
+ - Server may take significant time to generate the complete response before sending anything
957
+ - Complex reasoning tasks or large outputs may legitimately take several minutes
958
+ - Only used by Antigravity provider's `_handle_non_streaming()` method
959
+
960
+ #### Provider Usage
961
+
962
+ The following providers use `TimeoutConfig`:
963
+
964
+ | Provider | Method | Timeout Type |
965
+ |----------|--------|--------------|
966
+ | `antigravity_provider.py` | `_handle_non_streaming()` | `non_streaming()` |
967
+ | `antigravity_provider.py` | `_handle_streaming()` | `streaming()` |
968
+ | `gemini_cli_provider.py` | `acompletion()` | `streaming()` |
969
+ | `iflow_provider.py` | `acompletion()` | `streaming()` |
970
+ | `qwen_code_provider.py` | `acompletion()` | `streaming()` |
971
+
972
+ **Note:** iFlow, Qwen Code, and Gemini CLI providers always use streaming internally (even for non-streaming requests), aggregating chunks into a complete response. Only Antigravity has a true non-streaming path.
973
+
974
+ #### Tuning Recommendations
975
+
976
+ | Use Case | Recommendation |
977
+ |----------|----------------|
978
+ | **Long thinking tasks** | Increase `TIMEOUT_READ_STREAMING` to 300-360s |
979
+ | **Unstable network** | Increase `TIMEOUT_CONNECT` to 60s |
980
+ | **High concurrency** | Increase `TIMEOUT_POOL` if seeing pool exhaustion |
981
+ | **Large context/output** | Increase `TIMEOUT_READ_NON_STREAMING` to 900s+ |
982
+
983
+ #### Example Configuration
984
+
985
+ ```env
986
+ # For environments with complex reasoning tasks
987
+ TIMEOUT_READ_STREAMING=300
988
+ TIMEOUT_READ_NON_STREAMING=900
989
+
990
+ # For unstable network conditions
991
+ TIMEOUT_CONNECT=60
992
+ TIMEOUT_POOL=120
993
+ ```
994
+
995
  ---
996
 
997
 
 
1013
 
1014
  #### Authentication (`gemini_auth_base.py`)
1015
 
1016
+ * **Device Flow**: Uses a standard OAuth 2.0 flow. The `credential_tool` spins up a local web server (default: `localhost:8085`, configurable via `GEMINI_CLI_OAUTH_PORT`) to capture the callback from Google's auth page.
1017
+ * **Token Lifecycle**:
1018
  * **Proactive Refresh**: Tokens are refreshed 5 minutes before expiry.
1019
  * **Atomic Writes**: Credential files are updated using a temp-file-and-move strategy to prevent corruption during writes.
1020
  * **Revocation Handling**: If a `400` or `401` occurs during refresh, the token is marked as revoked, preventing infinite retry loops.
 
1043
  ### 3.3. iFlow (`iflow_provider.py`)
1044
 
1045
  * **Hybrid Auth**: Uses a custom OAuth flow (Authorization Code) to obtain an `access_token`. However, the *actual* API calls use a separate `apiKey` that is retrieved from the user's profile (`/api/oauth/getUserInfo`) using the access token.
1046
+ * **Callback Server**: The auth flow spins up a local server (default: port `11451`, configurable via `IFLOW_OAUTH_PORT`) to capture the redirect.
1047
  * **Token Management**: Automatically refreshes the OAuth token and re-fetches the API key if needed.
1048
  * **Schema Cleaning**: Similar to Qwen, it aggressively sanitizes tool schemas to prevent 400 errors.
1049
  * **Dedicated Logging**: Implements `_IFlowFileLogger` to capture raw chunks for debugging proprietary API behaviors.
 
1071
 
1072
  This level of detail allows developers to trace exactly why a request failed or why a specific key was rotated.
1073
 
1074
+ ---
1075
+
1076
+ ## 5. Runtime Resilience
1077
+
1078
+ The proxy is engineered to maintain high availability even in the face of runtime filesystem disruptions. This "Runtime Resilience" capability ensures that the service continues to process API requests even if data files or directories are deleted while the application is running.
1079
+
1080
+ ### 5.1. Centralized Resilient I/O (`resilient_io.py`)
1081
+
1082
+ All file operations are centralized in a single utility module that provides consistent error handling, graceful degradation, and automatic retry with shutdown flush:
1083
+
1084
+ #### `BufferedWriteRegistry` (Singleton)
1085
+
1086
+ Global registry for buffered writes with periodic retry and shutdown flush. Ensures critical data is saved even if disk writes fail temporarily:
1087
+
1088
+ - **Per-file buffering**: Each file path has its own pending write (latest data always wins)
1089
+ - **Periodic retries**: Background thread retries failed writes every 30 seconds
1090
+ - **Shutdown flush**: `atexit` hook ensures final write attempt on app exit (Ctrl+C)
1091
+ - **Thread-safe**: Safe for concurrent access from multiple threads
1092
+
1093
+ ```python
1094
+ # Get the singleton instance
1095
+ registry = BufferedWriteRegistry.get_instance()
1096
+
1097
+ # Check pending writes (for monitoring)
1098
+ pending_count = registry.get_pending_count()
1099
+ pending_files = registry.get_pending_paths()
1100
+
1101
+ # Manual flush (optional - atexit handles this automatically)
1102
+ results = registry.flush_all() # Returns {path: success_bool}
1103
+
1104
+ # Manual shutdown (if needed before atexit)
1105
+ results = registry.shutdown()
1106
+ ```
1107
+
1108
+ #### `ResilientStateWriter`
1109
+
1110
+ For stateful files that must persist (usage stats):
1111
+ - **Memory-first**: Always updates in-memory state before attempting disk write
1112
+ - **Atomic writes**: Uses tempfile + move pattern to prevent corruption
1113
+ - **Automatic retry with backoff**: If disk fails, waits `retry_interval` seconds before trying again
1114
+ - **Shutdown integration**: Registers with `BufferedWriteRegistry` on failure for final flush
1115
+ - **Health monitoring**: Exposes `is_healthy` property for monitoring
1116
+
1117
+ ```python
1118
+ writer = ResilientStateWriter("data.json", logger, retry_interval=30.0)
1119
+ writer.write({"key": "value"}) # Always succeeds (memory update)
1120
+ if not writer.is_healthy:
1121
+ logger.warning("Disk writes failing, data in memory only")
1122
+ # On next write() call after retry_interval, disk write is attempted again
1123
+ # On app exit (Ctrl+C), BufferedWriteRegistry attempts final save
1124
+ ```
1125
+
1126
+ #### `safe_write_json()`
1127
+
1128
+ For JSON writes with configurable options (credentials, cache):
1129
+
1130
+ | Parameter | Default | Description |
1131
+ |-----------|---------|-------------|
1132
+ | `path` | required | File path to write to |
1133
+ | `data` | required | JSON-serializable data |
1134
+ | `logger` | required | Logger for warnings |
1135
+ | `atomic` | `True` | Use atomic write pattern (tempfile + move) |
1136
+ | `indent` | `2` | JSON indentation level |
1137
+ | `ensure_ascii` | `True` | Escape non-ASCII characters |
1138
+ | `secure_permissions` | `False` | Set file permissions to 0o600 |
1139
+ | `buffer_on_failure` | `False` | Register with BufferedWriteRegistry on failure |
1140
+
1141
+ When `buffer_on_failure=True`:
1142
+ - Failed writes are registered with `BufferedWriteRegistry`
1143
+ - Data is retried every 30 seconds in background
1144
+ - On app exit, final write attempt is made automatically
1145
+ - Success unregisters the pending write
1146
+
1147
+ ```python
1148
+ # For critical data (auth tokens) - use buffer_on_failure
1149
+ safe_write_json(path, creds, logger, secure_permissions=True, buffer_on_failure=True)
1150
+
1151
+ # For non-critical data (logs) - no buffering needed
1152
+ safe_write_json(path, data, logger)
1153
+ ```
1154
+
1155
+ #### `safe_log_write()`
1156
+
1157
+ For log files where occasional loss is acceptable:
1158
+ - Fire-and-forget pattern
1159
+ - Creates parent directories if needed
1160
+ - Returns `True`/`False`, never raises
1161
+ - **No buffering** - logs are dropped on failure
1162
+
1163
+ #### `safe_mkdir()`
1164
+
1165
+ For directory creation with error handling.
1166
+
1167
+ ### 5.2. Resilience Hierarchy
1168
+
1169
+ The system follows a strict hierarchy of survival:
1170
+
1171
+ 1. **Core API Handling (Level 1)**: The Python runtime keeps all necessary code in memory. Deleting source code files while the proxy is running will **not** crash active requests.
1172
+
1173
+ 2. **Credential Management (Level 2)**: OAuth tokens are cached in memory first. If credential files are deleted, the proxy continues using cached tokens. If a token refresh succeeds but the file cannot be written, the new token is buffered for retry and saved on shutdown.
1174
+
1175
+ 3. **Usage Tracking (Level 3)**: Usage statistics (`key_usage.json`) are maintained in memory via `ResilientStateWriter`. If the file is deleted, the system tracks usage internally and attempts to recreate the file on the next save interval. Pending writes are flushed on shutdown.
1176
+
1177
+ 4. **Provider Cache (Level 4)**: The provider cache tracks disk health and continues operating in memory-only mode if disk writes fail. Has its own shutdown mechanism.
1178
+
1179
+ 5. **Logging (Level 5)**: Logging is treated as non-critical. If the `logs/` directory is removed, the system attempts to recreate it. If creation fails, logging degrades gracefully without interrupting the request flow. **No buffering or retry**.
1180
+
1181
+ ### 5.3. Component Integration
1182
+
1183
+ | Component | Utility Used | Behavior on Disk Failure | Shutdown Flush |
1184
+ |-----------|--------------|--------------------------|----------------|
1185
+ | `UsageManager` | `ResilientStateWriter` | Continues in memory, retries after 30s | Yes (via registry) |
1186
+ | `GoogleOAuthBase` | `safe_write_json(buffer_on_failure=True)` | Memory cache preserved, buffered for retry | Yes (via registry) |
1187
+ | `QwenAuthBase` | `safe_write_json(buffer_on_failure=True)` | Memory cache preserved, buffered for retry | Yes (via registry) |
1188
+ | `IFlowAuthBase` | `safe_write_json(buffer_on_failure=True)` | Memory cache preserved, buffered for retry | Yes (via registry) |
1189
+ | `ProviderCache` | `safe_write_json` + own shutdown | Retries via own background loop | Yes (own mechanism) |
1190
+ | `DetailedLogger` | `safe_write_json` | Logs dropped, no crash | No |
1191
+ | `failure_logger` | Python `logging.RotatingFileHandler` | Falls back to NullHandler | No |
1192
+
1193
+ ### 5.4. Shutdown Behavior
1194
+
1195
+ When the application exits (including Ctrl+C):
1196
+
1197
+ 1. **atexit handler fires**: `BufferedWriteRegistry._atexit_handler()` is called
1198
+ 2. **Pending writes counted**: Registry checks how many files have pending writes
1199
+ 3. **Flush attempted**: Each pending file gets a final write attempt
1200
+ 4. **Results logged**:
1201
+ - Success: `"Shutdown flush: all N write(s) succeeded"`
1202
+ - Partial: `"Shutdown flush: X succeeded, Y failed"` with failed file names
1203
+
1204
+ **Console output example:**
1205
+ ```
1206
+ INFO:rotator_library.resilient_io:Flushing 2 pending write(s) on shutdown...
1207
+ INFO:rotator_library.resilient_io:Shutdown flush: all 2 write(s) succeeded
1208
+ ```
1209
+
1210
+ ### 5.5. "Develop While Running"
1211
+
1212
+ This architecture supports a robust development workflow:
1213
+
1214
+ - **Log Cleanup**: You can safely run `rm -rf logs/` while the proxy is serving traffic. The system will recreate the directory structure on the next request.
1215
+ - **Config Reset**: Deleting `key_usage.json` resets the persistence layer, but the running instance preserves its current in-memory counts for load balancing consistency.
1216
+ - **File Recovery**: If you delete a critical file, the system attempts directory auto-recreation before every write operation.
1217
+ - **Safe Exit**: Ctrl+C triggers graceful shutdown with final data flush attempt.
1218
+
1219
+ ### 5.6. Graceful Degradation & Data Loss
1220
+
1221
+ While functionality is preserved, persistence may be compromised during filesystem failures:
1222
+
1223
+ - **Logs**: If disk writes fail, detailed request logs may be lost (no buffering).
1224
+ - **Usage Stats**: Buffered in memory and flushed on shutdown. Data loss only if shutdown flush also fails.
1225
+ - **Credentials**: Buffered in memory and flushed on shutdown. Re-authentication only needed if shutdown flush fails.
1226
+ - **Cache**: Provider cache entries may need to be regenerated after restart if its own shutdown mechanism fails.
1227
+
1228
+ ### 5.7. Monitoring Disk Health
1229
+
1230
+ Components expose health information for monitoring:
1231
+
1232
+ ```python
1233
+ # BufferedWriteRegistry
1234
+ registry = BufferedWriteRegistry.get_instance()
1235
+ pending = registry.get_pending_count() # Number of files with pending writes
1236
+ files = registry.get_pending_paths() # List of pending file names
1237
+
1238
+ # UsageManager
1239
+ writer = usage_manager._state_writer
1240
+ health = writer.get_health_info()
1241
+ # Returns: {"healthy": True, "failure_count": 0, "last_success": 1234567890.0, ...}
1242
+
1243
+ # ProviderCache
1244
+ stats = cache.get_stats()
1245
+ # Includes: {"disk_available": True, "disk_errors": 0, ...}
1246
+ ```
1247
 
README.md CHANGED
@@ -1,755 +1,763 @@
1
- # Universal LLM API Proxy & Resilience Library [![ko-fi](https://ko-fi.com/img/githubbutton_sm.svg)](https://ko-fi.com/C0C0UZS4P)
 
2
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/Mirrowel/LLM-API-Key-Proxy) [![zread](https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff)](https://zread.ai/Mirrowel/LLM-API-Key-Proxy)
3
 
 
4
 
5
- ## Detailed Setup and Features
6
 
7
- This project provides a powerful solution for developers building complex applications, such as agentic systems, that interact with multiple Large Language Model (LLM) providers. It consists of two distinct but complementary components:
 
 
8
 
9
- 1. **A Universal API Proxy**: A self-hosted FastAPI application that provides a single, OpenAI-compatible endpoint for all your LLM requests. Powered by `litellm`, it allows you to seamlessly switch between different providers and models without altering your application's code.
10
- 2. **A Resilience & Key Management Library**: The core engine that powers the proxy. This reusable Python library intelligently manages a pool of API keys to ensure your application is highly available and resilient to transient provider errors or performance issues.
11
-
12
- ## Features
13
 
14
- - **Universal API Endpoint**: Simplifies development by providing a single, OpenAI-compatible interface for diverse LLM providers.
15
- - **High Availability**: The underlying library ensures your application remains operational by gracefully handling transient provider errors and API key-specific issues.
16
- - **Resilient Performance**: A global timeout on all requests prevents your application from hanging on unresponsive provider APIs.
17
- - **Advanced Concurrency Control**: A single API key can be used for multiple concurrent requests. By default, it supports concurrent requests to *different* models. With configuration (`MAX_CONCURRENT_REQUESTS_PER_KEY_<PROVIDER>`), it can also support multiple concurrent requests to the *same* model using the same key.
18
- - **Intelligent Key Management**: Optimizes request distribution across your pool of keys by selecting the best available one for each call.
19
- - **Automated OAuth Discovery**: Automatically discovers, validates, and manages OAuth credentials from standard provider directories (e.g., `~/.gemini/`, `~/.qwen/`, `~/.iflow/`).
20
- - **Stateless Deployment Support**: Deploy easily to platforms like Railway, Render, or Vercel. The new export tool converts complex OAuth credentials (Gemini CLI, Qwen, iFlow) into simple environment variables, removing the need for persistent storage or file uploads.
21
- - **Batch Request Processing**: Efficiently aggregates multiple embedding requests into single batch API calls, improving throughput and reducing rate limit hits.
22
- - **New Provider Support**: Full support for **iFlow** (API Key & OAuth), **Qwen Code** (API Key & OAuth), and **NVIDIA NIM** with DeepSeek thinking support, including special handling for their API quirks (tool schema cleaning, reasoning support, dedicated logging).
23
- - **Duplicate Credential Detection**: Intelligently detects if multiple local credential files belong to the same user account and logs a warning, preventing redundancy in your key pool.
24
- - **Escalating Per-Model Cooldowns**: If a key fails for a specific model, it's placed on a temporary, escalating cooldown for that model, allowing it to be used with others.
25
- - **Automatic Daily Resets**: Cooldowns and usage statistics are automatically reset daily, making the system self-maintaining.
26
- - **Detailed Request Logging**: Enable comprehensive logging for debugging. Each request gets its own directory with full request/response details, streaming chunks, and performance metadata.
27
- - **Provider Agnostic**: Compatible with any provider supported by `litellm`.
28
- - **OpenAI-Compatible Proxy**: Offers a familiar API interface with additional endpoints for model and provider discovery.
29
- - **Advanced Model Filtering**: Supports both blacklists and whitelists to give you fine-grained control over which models are available through the proxy.
30
-
31
- - **🆕 Antigravity Provider**: Full support for Google's internal Antigravity API, providing access to Gemini 3 and Claude models with advanced features:
32
- - **🚀 Claude Opus 4.5** - Anthropic's most powerful model (thinking mode only)
33
- - **Claude Sonnet 4.5** - Supports both thinking and non-thinking modes
34
- - **Gemini 3 Pro** - With thinkingLevel support (low/high)
35
- - Credential prioritization with automatic paid/free tier detection
36
- - Thought signature caching for multi-turn conversations
37
- - Tool hallucination prevention via parameter signature injection
38
- - Automatic thinking block sanitization for Claude models (with recovery strategies)
39
- - Note: Claude thinking mode requires careful conversation state management (see [Antigravity documentation](DOCUMENTATION.md#antigravity-claude-extended-thinking-sanitization) for details)
40
- - **🆕 Credential Prioritization**: Automatic tier detection and priority-based credential selection ensures paid-tier credentials are used for premium models that require them.
41
- - **🆕 Sequential Rotation Mode**: Choose between balanced (distribute load evenly) or sequential (use until exhausted) credential rotation strategies. Sequential mode maximizes cache hit rates for providers like Antigravity.
42
- - **🆕 Per-Model Quota Tracking**: Granular per-model usage tracking with authoritative quota reset timestamps from provider error responses. Each model maintains its own window with `window_start_ts` and `quota_reset_ts`.
43
- - **🆕 Model Quota Groups**: Group models that share quota limits (e.g., Claude Sonnet and Opus). When one model in a group hits quota, all receive the same cooldown timestamp.
44
- - **🆕 Priority-Based Concurrency**: Assign credentials to priority tiers (1=highest) with configurable concurrency multipliers. Paid-tier credentials can handle more concurrent requests than free-tier ones.
45
- - **🆕 Provider-Specific Quota Parsing**: Extended provider interface with `parse_quota_error()` method to extract precise retry-after times from provider-specific error formats (e.g., Google RPC format).
46
- - **🆕 Flexible Rolling Windows**: Support for provider-specific quota reset configurations (5-hour, 7-day, etc.) replacing hardcoded daily resets.
47
- - **🆕 Weighted Random Rotation**: Configurable credential rotation strategy - choose between deterministic (perfect balance) or weighted random (unpredictable, harder to fingerprint) selection.
48
- - **🆕 Enhanced Gemini CLI**: Improved project discovery, paid vs free tier detection, and Gemini 3 support with thoughtSignature caching.
49
- - **🆕 Temperature Override**: Global temperature=0 override option to prevent tool hallucination issues with low-temperature settings.
50
- - **🆕 Provider Cache System**: Modular caching system for preserving conversation state (thought signatures, thinking content) across requests.
51
- - **🆕 Refactored OAuth Base**: Shared [`GoogleOAuthBase`](src/rotator_library/providers/google_oauth_base.py) class eliminates code duplication across OAuth providers.
52
-
53
- - **🆕 Interactive Launcher TUI**: Beautiful, cross-platform TUI for configuration and management with an integrated settings tool for advanced configuration.
54
 
 
 
 
 
55
 
56
  ---
57
 
58
- ## 1. Quick Start
59
 
60
- ### Windows (Simplest)
61
 
62
- 1. **Download the latest release** from the [GitHub Releases page](https://github.com/Mirrowel/LLM-API-Key-Proxy/releases/latest).
63
- 2. Unzip the downloaded file.
64
- 3. **Run the executable** (run without arguments). This launches the **interactive TUI launcher** which allows you to:
65
- - 🚀 Run the proxy server with your configured settings
66
- - ⚙️ Configure proxy settings (Host, Port, PROXY_API_KEY, Request Logging)
67
- - 🔑 Manage credentials (add/edit API keys & OAuth credentials)
68
- - 📊 View provider status and advanced settings
69
- - 🔧 Configure advanced settings interactively (custom API bases, model definitions, concurrency limits)
70
- - 🔄 Reload configuration without restarting
71
 
72
- > **Note:** The legacy `launcher.bat` is deprecated.
73
 
74
  ### macOS / Linux
75
 
76
- **Option A: Using the Executable (Recommended)**
77
- If you downloaded the pre-compiled binary for your platform, no Python installation is required.
78
-
79
- 1. **Download the latest release** from the GitHub Releases page.
80
- 2. Open a terminal and make the binary executable:
81
- ```bash
82
- chmod +x proxy_app
83
- ```
84
- 3. **Run the Interactive Launcher**:
85
- ```bash
86
- ./proxy_app
87
- ```
88
- This launches the TUI where you can configure and run the proxy.
89
-
90
- 4. **Or run directly with arguments** to bypass the launcher:
91
- ```bash
92
- ./proxy_app --host 0.0.0.0 --port 8000
93
- ```
94
-
95
- **Option B: Manual Setup (Source Code)**
96
- If you are running from source, use these commands:
97
-
98
- **1. Install Dependencies**
99
  ```bash
100
- # Ensure you have Python 3.10+ installed
101
- python3 -m venv venv
102
- source venv/bin/activate
103
- pip install -r requirements.txt
104
  ```
105
 
106
- **2. Launch the Interactive TUI**
 
107
  ```bash
108
- export PYTHONPATH=$PYTHONPATH:$(pwd)/src
 
 
 
 
109
  python src/proxy_app/main.py
110
  ```
111
 
112
- **3. Or run directly with arguments to bypass the launcher**
113
- ```bash
114
- export PYTHONPATH=$PYTHONPATH:$(pwd)/src
115
- python src/proxy_app/main.py --host 0.0.0.0 --port 8000
116
- ```
117
- *To enable logging, add `--enable-request-logging` to the command.*
118
 
119
  ---
120
 
121
- ## 2. Interactive TUI Launcher
122
-
123
- The proxy now includes a powerful **interactive Text User Interface (TUI)** that makes configuration and management effortless.
124
-
125
- ### Features
126
-
127
- - **🎯 Main Menu**:
128
- - Run proxy server with saved settings
129
- - Configure proxy settings (host, port, API key, logging)
130
- - Manage credentials (API keys & OAuth)
131
- - View provider & advanced settings status
132
- - Reload configuration
133
-
134
- - **🔧 Advanced Settings Tool**:
135
- - Configure custom OpenAI-compatible providers
136
- - Define provider models (simple or advanced JSON format)
137
- - Set concurrency limits per provider
138
- - Configure rotation modes (balanced vs sequential)
139
- - Manage priority-based concurrency multipliers
140
- - Interactive numbered menus for easy selection
141
- - Pending changes system with save/discard options
142
-
143
- - **📊 Status Dashboard**:
144
- - Shows configured providers and credential counts
145
- - Displays custom providers and API bases
146
- - Shows active advanced settings
147
- - Real-time configuration status
148
-
149
- ### How to Use
150
-
151
- **Running without arguments launches the TUI:**
152
- ```bash
153
- # Windows
154
- proxy_app.exe
155
 
156
- # macOS/Linux
157
- ./proxy_app
158
 
159
- # From source
160
- python src/proxy_app/main.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  ```
162
 
163
- **Running with arguments bypasses the TUI:**
 
 
 
 
164
  ```bash
165
- # Direct startup (skips TUI)
166
- proxy_app.exe --host 0.0.0.0 --port 8000
 
 
 
 
 
167
  ```
168
 
169
- ### Configuration Files
170
 
171
- The TUI manages two configuration files:
172
- - **`launcher_config.json`**: Stores launcher-specific settings (host, port, logging preference)
173
- - **`.env`**: Stores all credentials and advanced settings (PROXY_API_KEY, provider credentials, custom settings)
174
 
175
- All advanced settings configured through the TUI are stored in `.env` for compatibility with manual editing and deployment platforms.
 
 
 
 
 
 
176
 
177
- ---
178
 
179
- ## 3. Detailed Setup (From Source)
 
180
 
181
- This guide is for users who want to run the proxy from the source code on any operating system.
182
 
183
- ### Step 1: Clone and Install
 
 
 
 
 
 
 
 
 
 
184
 
185
- First, clone the repository and install the required dependencies into a virtual environment.
186
 
187
- **Linux/macOS:**
188
- ```bash
189
- # Clone the repository
190
- git clone https://github.com/Mirrowel/LLM-API-Key-Proxy.git
191
- cd LLM-API-Key-Proxy
192
 
193
- # Create and activate a virtual environment
194
- python3 -m venv venv
195
- source venv/bin/activate
 
 
 
 
 
 
 
196
 
197
- # Install dependencies
198
- pip install -r requirements.txt
199
- ```
200
 
201
- **Windows:**
202
- ```powershell
203
- # Clone the repository
204
- git clone https://github.com/Mirrowel/LLM-API-Key-Proxy.git
205
- cd LLM-API-Key-Proxy
206
 
207
- # Create and activate a virtual environment
208
- python -m venv venv
209
- .\venv\Scripts\Activate.ps1
210
 
211
- # Install dependencies
212
- pip install -r requirements.txt
213
- ```
 
 
214
 
215
- ### Step 2: Configure API Keys
 
 
216
 
217
- Create a `.env` file to store your secret keys. You can do this by copying the example file.
218
 
219
- **Linux/macOS:**
220
  ```bash
221
- cp .env.example .env
222
  ```
223
 
224
- **Windows:**
225
- ```powershell
226
- copy .env.example .env
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  ```
228
 
229
- Now, open the new `.env` file and add your keys.
230
 
231
- **Refer to the `.env.example` file for the correct format and a full list of supported providers.**
232
 
233
- The proxy supports two types of credentials:
234
 
235
- 1. **API Keys**: Standard secret keys from providers like OpenAI, Anthropic, etc.
236
- 2. **OAuth Credentials**: For services that use OAuth 2.0, like the Gemini CLI.
237
 
238
- #### Automated Credential Discovery (Recommended)
239
 
240
- For many providers, **no configuration is necessary**. The proxy automatically discovers and manages credentials from their default locations:
241
- - **API Keys**: Scans your environment variables for keys matching the format `PROVIDER_API_KEY_1` (e.g., `GEMINI_API_KEY_1`).
242
- - **OAuth Credentials**: Scans default system directories (e.g., `~/.gemini/`, `~/.qwen/`, `~/.iflow/`) for all `*.json` credential files.
 
 
 
243
 
244
- You only need to create a `.env` file to set your `PROXY_API_KEY` and to override or add credentials if the automatic discovery doesn't suit your needs.
245
 
246
- #### Interactive Credential Management Tool
 
247
 
248
- The proxy includes a powerful interactive CLI tool for managing all your credentials. This is the recommended way to set up credentials:
 
 
 
 
249
 
250
- ```bash
251
- python -m rotator_library.credential_tool
 
 
 
252
  ```
253
 
254
- **Or use the TUI Launcher** (recommended):
255
- ```bash
256
- python src/proxy_app/main.py
257
- # Then select "3. 🔑 Manage Credentials"
258
- ```
259
 
260
- **Main Menu Features:**
 
 
 
 
 
261
 
262
- 1. **Add OAuth Credential** - Interactive OAuth flow for Gemini CLI, Antigravity, Qwen Code, and iFlow
263
- - Automatically opens your browser for authentication
264
- - Handles the entire OAuth flow including callbacks
265
- - Saves credentials to the local `oauth_creds/` directory
266
- - For Gemini CLI: Automatically discovers or creates a Google Cloud project
267
- - For Antigravity: Similar to Gemini CLI with Antigravity-specific scopes
268
- - For Qwen Code: Uses Device Code flow (you'll enter a code in your browser)
269
- - For iFlow: Starts a local callback server on port 11451
270
 
271
- 2. **Add API Key** - Add standard API keys for any LiteLLM-supported provider
272
- - Interactive prompts guide you through the process
273
- - Automatically saves to your `.env` file
274
- - Supports multiple keys per provider (numbered automatically)
275
 
276
- 3. **Export Credentials to .env** - The "Stateless Deployment" feature
277
- - Converts file-based OAuth credentials into environment variables
278
- - Essential for platforms without persistent file storage
279
- - Generates a ready-to-paste `.env` block for each credential
280
 
281
- **Stateless Deployment Workflow (Railway, Render, Vercel, etc.):**
282
 
283
- If you're deploying to a platform without persistent file storage:
284
 
285
- 1. **Setup credentials locally first**:
286
- ```bash
287
- python -m rotator_library.credential_tool
288
- # Select "Add OAuth Credential" and complete the flow
289
- ```
290
 
291
- 2. **Export to environment variables**:
292
- ```bash
293
- python -m rotator_library.credential_tool
294
- # Select "Export Gemini CLI to .env" (or Qwen/iFlow)
295
- # Choose your credential file
296
- ```
297
 
298
- 3. **Copy the generated output**:
299
- - The tool creates a file like `gemini_cli_credential_1.env`
300
- - Contains all necessary `GEMINI_CLI_*` variables
 
301
 
302
- 4. **Paste into your hosting platform**:
303
- - Add each variable to your platform's environment settings
304
- - Set `SKIP_OAUTH_INIT_CHECK=true` to skip interactive validation
305
- - No credential files needed; everything loads from environment variables
306
 
307
- **Local-First OAuth Management:**
308
 
309
- The proxy uses a "local-first" approach for OAuth credentials:
310
 
311
- - **Local Storage**: All OAuth credentials are stored in `oauth_creds/` directory
312
- - **Automatic Discovery**: On first run, the proxy scans system paths (`~/.gemini/`, `~/.qwen/`, `~/.iflow/`) and imports found credentials
313
- - **Deduplication**: Intelligently detects duplicate accounts (by email/user ID) and warns you
314
- - **Priority**: Local files take priority over system-wide credentials
315
- - **No System Pollution**: Your project's credentials are isolated from global system credentials
316
 
317
- **Example `.env` configuration:**
318
- ```env
319
- # A secret key for your proxy server to authenticate requests.
320
- # This can be any secret string you choose.
321
- PROXY_API_KEY="a-very-secret-and-unique-key"
322
-
323
- # --- Provider API Keys (Optional) ---
324
- # The proxy automatically finds keys in your environment variables.
325
- # You can also define them here. Add multiple keys by numbering them (_1, _2).
326
- GEMINI_API_KEY_1="YOUR_GEMINI_API_KEY_1"
327
- GEMINI_API_KEY_2="YOUR_GEMINI_API_KEY_2"
328
- OPENROUTER_API_KEY_1="YOUR_OPENROUTER_API_KEY_1"
329
-
330
- # --- OAuth Credentials (Optional) ---
331
- # The proxy automatically finds credentials in standard system paths.
332
- # You can override this by specifying a path to your credential file.
333
- GEMINI_CLI_OAUTH_1="/path/to/your/specific/gemini_creds.json"
334
-
335
- # --- Gemini CLI: Stateless Deployment Support ---
336
- # For hosts without file persistence (Railway, Render, etc.), you can provide
337
- # Gemini CLI credentials directly via environment variables:
338
- GEMINI_CLI_ACCESS_TOKEN="ya29.your-access-token"
339
- GEMINI_CLI_REFRESH_TOKEN="1//your-refresh-token"
340
- GEMINI_CLI_EXPIRY_DATE="1234567890000"
341
- GEMINI_CLI_EMAIL="your-email@gmail.com"
342
- # Optional: GEMINI_CLI_PROJECT_ID, GEMINI_CLI_CLIENT_ID, etc.
343
- # See IMPLEMENTATION_SUMMARY.md for full list of supported variables
344
-
345
- # --- Dual Authentication Support ---
346
- # Some providers (qwen_code, iflow) support BOTH OAuth and direct API keys.
347
- # You can use either method, or mix both for credential rotation:
348
- QWEN_CODE_API_KEY_1="your-qwen-api-key" # Direct API key
349
- # AND/OR use OAuth: oauth_creds/qwen_code_oauth_1.json
350
- IFLOW_API_KEY_1="sk-your-iflow-key" # Direct API key
351
- # AND/OR use OAuth: oauth_creds/iflow_oauth_1.json
352
- ```
353
 
354
- ### 4. Run the Proxy
 
 
 
 
 
355
 
356
- You can run the proxy in two ways:
357
 
358
- **A) Using the Compiled Executable (Recommended)**
 
359
 
360
- A pre-compiled, standalone executable for Windows is available on the [latest GitHub Release](https://github.com/Mirrowel/LLM-API-Key-Proxy/releases/latest). This is the easiest way to get started as it requires no setup.
 
 
 
 
 
361
 
362
- For the simplest experience, follow the **Quick Start** guide at the top of this document.
363
 
364
- **B) Running from Source**
 
365
 
366
- Start the server by running the `main.py` script
 
 
 
 
 
 
367
 
368
- ```bash
369
- python src/proxy_app/main.py
370
- ```
371
- This launches the interactive TUI launcher by default. To run the proxy directly, use:
372
 
373
- ```bash
374
- python src/proxy_app/main.py --host 0.0.0.0 --port 8000
375
- ```
376
 
377
- The proxy is now running and available at `http://127.0.0.1:8000`.
 
 
 
 
378
 
379
- ### 5. Make a Request
 
 
 
 
 
380
 
381
- You can now send requests to the proxy. The endpoint is `http://127.0.0.1:8000/v1/chat/completions`.
 
 
 
382
 
383
- Remember to:
384
- 1. Set the `Authorization` header to `Bearer your-super-secret-proxy-key`.
385
- 2. Specify the `model` in the format `provider/model_name`.
 
386
 
387
- Here is an example using `curl`:
388
- ```bash
389
- curl -X POST http://127.0.0.1:8000/v1/chat/completions \
390
- -H "Content-Type: application/json" \
391
- -H "Authorization: Bearer your-super-secret-proxy-key" \
392
- -d '{
393
- "model": "gemini/gemini-2.5-flash",
394
- "messages": [{"role": "user", "content": "What is the capital of France?"}]
395
- }'
396
- ```
 
 
 
 
 
 
397
 
398
  ---
399
 
400
- ## Advanced Usage
401
 
402
- ### Using with the OpenAI Python Library (Recommended)
 
403
 
404
- The proxy is OpenAI-compatible, so you can use it directly with the `openai` Python client.
405
 
406
- ```python
407
- import openai
 
 
 
408
 
409
- # Point the client to your local proxy
410
- client = openai.OpenAI(
411
- base_url="http://127.0.0.1:8000/v1",
412
- api_key="a-very-secret-and-unique-key" # Use your PROXY_API_KEY here
413
- )
414
 
415
- # Make a request
416
- response = client.chat.completions.create(
417
- model="gemini/gemini-2.5-flash", # Specify provider and model
418
- messages=[
419
- {"role": "user", "content": "Write a short poem about space."}
420
- ]
421
- )
422
 
423
- print(response.choices[0].message.content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
  ```
425
 
426
- ### Using with `curl`
 
 
 
 
 
427
 
428
- ```bash
429
- You can also send requests directly using tools like `curl`.
 
 
 
 
430
 
431
- ```bash
432
- curl -X POST http://127.0.0.1:8000/v1/chat/completions \
433
- -H "Content-Type: application/json" \
434
- -H "Authorization: Bearer a-very-secret-and-unique-key" \
435
- -d '{
436
- "model": "gemini/gemini-2.5-flash",
437
- "messages": [{"role": "user", "content": "What is the capital of France?"}]
438
- }'
 
 
 
 
 
 
 
439
  ```
440
 
441
- ### Available API Endpoints
442
 
443
- - `POST /v1/chat/completions`: The main endpoint for making chat requests.
444
- - `POST /v1/embeddings`: The endpoint for creating embeddings.
445
- - `GET /v1/models`: Returns a list of all available models from your configured providers.
446
- - `GET /v1/providers`: Returns a list of all configured providers.
447
- - `POST /v1/token-count`: Calculates the token count for a given message payload.
448
 
449
- ---
 
 
450
 
451
- ## 4. Advanced Topics
452
 
453
- ### Batch Request Processing
454
 
455
- The proxy includes a `Batch Manager` that optimizes high-volume embedding requests.
456
- - **Automatic Aggregation**: Multiple individual embedding requests are automatically collected into a single batch API call.
457
- - **Configurable**: Works out of the box, but can be tuned for specific needs.
458
- - **Benefits**: Significantly reduces the number of HTTP requests to providers, helping you stay within rate limits while improving throughput.
459
 
460
- ### How It Works
 
 
461
 
462
- The proxy is built on a robust architecture:
463
 
464
- 1. **Intelligent Routing**: The `UsageManager` selects the best available key from your pool. It prioritizes idle keys first, then keys that can handle concurrency, ensuring optimal load balancing.
465
- 2. **Resilience & Deadlines**: Every request has a strict deadline (`global_timeout`). If a provider is slow or fails, the proxy retries with a different key immediately, ensuring your application never hangs.
466
- 3. **Batching**: High-volume embedding requests are automatically aggregated into optimized batches, reducing API calls and staying within rate limits.
467
- 4. **Deep Observability**: (Optional) Detailed logs capture every byte of the transaction, including raw streaming chunks, for precise debugging of complex agentic interactions.
468
 
469
- ### Command-Line Arguments and Scripts
 
 
 
470
 
471
- The proxy server can be configured at runtime using the following command-line arguments:
472
 
473
- - `--host`: The IP address to bind the server to. Defaults to `0.0.0.0` (accessible from your local network).
474
- - `--port`: The port to run the server on. Defaults to `8000`.
475
- - `--enable-request-logging`: A flag to enable detailed, per-request logging. When active, the proxy creates a unique directory for each transaction in the `logs/detailed_logs/` folder, containing the full request, response, streaming chunks, and performance metadata. This is highly recommended for debugging.
476
 
477
- ### New Provider Highlights
478
 
479
- #### **Gemini CLI (Advanced)**
480
- A powerful provider that mimics the Google Cloud Code extension.
481
- - **Zero-Config Project Discovery**: Automatically finds your Google Cloud Project ID or onboards you to a free-tier project if none exists.
482
- - **Internal API Access**: Uses high-limit internal endpoints (`cloudcode-pa.googleapis.com`) rather than the public Vertex AI API.
483
- - **Smart Rate Limiting**: Automatically falls back to preview models (e.g., `gemini-2.5-pro-preview`) if the main model hits a rate limit.
 
 
484
 
485
- #### **Qwen Code**
486
- - **Dual Authentication**: Use either standard API keys or OAuth 2.0 Device Flow credentials.
487
- - **Schema Cleaning**: Automatically removes `strict` and `additionalProperties` from tool schemas to prevent API errors.
488
- - **Stream Stability**: Injects a dummy `do_not_call_me` tool to prevent stream corruption issues when no tools are provided.
489
- - **Reasoning Support**: Parses `<think>` tags in responses and exposes them as `reasoning_content` (similar to OpenAI's o1 format).
490
- - **Dedicated Logging**: Optional per-request file logging to `logs/qwen_code_logs/` for debugging.
491
- - **Custom Models**: Define additional models via `QWEN_CODE_MODELS` environment variable (JSON array format).
492
 
493
- #### **iFlow**
494
- - **Dual Authentication**: Use either standard API keys or OAuth 2.0 Authorization Code Flow.
495
- - **Hybrid Auth**: OAuth flow provides an access token, but actual API calls use a separate `apiKey` retrieved from user profile.
496
- - **Local Callback Server**: OAuth flow runs a temporary server on port 11451 to capture the redirect.
497
- - **Schema Cleaning**: Same as Qwen Code - removes unsupported properties from tool schemas.
498
- - **Stream Stability**: Injects placeholder tools to stabilize streaming for empty tool lists.
499
- - **Dedicated Logging**: Optional per-request file logging to `logs/iflow_logs/` for debugging proprietary API behaviors.
500
- - **Custom Models**: Define additional models via `IFLOW_MODELS` environment variable (JSON array format).
501
 
 
 
 
 
 
 
502
 
503
- ### Advanced Configuration
504
 
505
- The following advanced settings can be added to your `.env` file (or configured interactively via the TUI Settings Tool):
 
 
 
 
506
 
507
- #### OAuth and Refresh Settings
 
 
 
 
508
 
509
- - **`OAUTH_REFRESH_INTERVAL`**: Controls how often (in seconds) the background refresher checks for expired OAuth tokens. Default is `600` (10 minutes).
510
- ```env
511
- OAUTH_REFRESH_INTERVAL=600 # Check every 10 minutes
512
- ```
 
 
 
 
513
 
514
- - **`SKIP_OAUTH_INIT_CHECK`**: Set to `true` to skip the interactive OAuth setup/validation check on startup. Essential for non-interactive environments like Docker containers or CI/CD pipelines.
515
- ```env
516
- SKIP_OAUTH_INIT_CHECK=true
517
 
 
 
518
 
519
- #### **Antigravity (Advanced - Gemini 3 \ Claude Opus 4.5 / Sonnet 4.5 Access)**
520
- The newest and most sophisticated provider, offering access to cutting-edge models via Google's internal Antigravity API.
521
 
522
  **Supported Models:**
523
- - Gemini 2.5 (Pro/Flash) with `thinkingBudget` parameter
524
- - **Gemini 3 Pro (High/Low)** - Latest preview models
525
- - **🆕 Claude Opus 4.5 + Thinking** - Anthropic's most powerful model via Antigravity proxy
526
- - **Claude Sonnet 4.5 + Thinking** via Antigravity proxy
527
 
528
- **Advanced Features:**
529
- - **Thought Signature Caching**: Preserves encrypted signatures for multi-turn Gemini 3 conversations
530
- - **Tool Hallucination Prevention**: Automatic system instruction and parameter signature injection for Gemini 3 to prevent tools from being called with incorrect parameters
531
- - **Thinking Preservation**: Caches Claude thinking content for consistency across conversation turns
532
- - **Automatic Fallback**: Tries sandbox endpoints before falling back to production
533
- - **Schema Cleaning**: Handles Claude-specific tool schema requirements
534
 
535
- **Configuration:**
536
- - **OAuth Setup**: Uses Google OAuth similar to Gemini CLI (separate scopes)
537
- - **Stateless Deployment**: Full environment variable support
538
- - **Paid Tier Recommended**: Gemini 3 models require a paid Google Cloud project
 
539
 
540
  **Environment Variables:**
541
  ```env
542
- # Stateless deployment
543
- ANTIGRAVITY_ACCESS_TOKEN="..."
544
- ANTIGRAVITY_REFRESH_TOKEN="..."
545
- ANTIGRAVITY_EXPIRY_DATE="..."
546
- ANTIGRAVITY_EMAIL="user@gmail.com"
547
 
548
  # Feature toggles
549
- ANTIGRAVITY_ENABLE_SIGNATURE_CACHE=true # Multi-turn conversation support
550
- ANTIGRAVITY_GEMINI3_TOOL_FIX=true # Prevent tool hallucination
551
  ```
552
 
 
553
 
554
- ```
555
-
556
- #### Credential Rotation Modes
557
-
558
- - **`ROTATION_MODE_<PROVIDER>`**: Controls how credentials are rotated when multiple are available. Default: `balanced` (except Antigravity which defaults to `sequential`).
559
- - `balanced`: Rotate credentials evenly across requests to distribute load. Best for per-minute rate limits.
560
- - `sequential`: Use one credential until exhausted (429 error), then switch to next. Best for daily/weekly quotas.
561
- ```env
562
- ROTATION_MODE_GEMINI=sequential # Use Gemini keys until quota exhausted
563
- ROTATION_MODE_OPENAI=balanced # Distribute load across OpenAI keys (default)
564
- ROTATION_MODE_ANTIGRAVITY=balanced # Override Antigravity's sequential default
565
- ```
566
-
567
- #### Priority-Based Concurrency Multipliers
568
-
569
- - **`CONCURRENCY_MULTIPLIER_<PROVIDER>_PRIORITY_<N>`**: Assign concurrency multipliers to priority tiers. Higher-tier credentials handle more concurrent requests.
570
- ```env
571
- # Universal multipliers (apply to all rotation modes)
572
- CONCURRENCY_MULTIPLIER_ANTIGRAVITY_PRIORITY_1=10 # 10x for paid ultra tier
573
- CONCURRENCY_MULTIPLIER_ANTIGRAVITY_PRIORITY_3=1 # 1x for lower tiers
574
-
575
- # Mode-specific overrides
576
- CONCURRENCY_MULTIPLIER_ANTIGRAVITY_PRIORITY_2_BALANCED=1 # P2 = 1x in balanced mode only
577
- ```
578
-
579
- **Provider Defaults** (built into provider classes):
580
- - **Antigravity**: Priority 1: 5x, Priority 2: 3x, Priority 3+: 2x (sequential) or 1x (balanced)
581
- - **Gemini CLI**: Priority 1: 5x, Priority 2: 3x, Others: 1x
582
-
583
- #### Model Quota Groups
584
-
585
- - **`QUOTA_GROUPS_<PROVIDER>_<GROUP>`**: Define models that share quota/cooldown timing. When one model hits quota, all in the group receive the same cooldown timestamp.
586
- ```env
587
- QUOTA_GROUPS_ANTIGRAVITY_CLAUDE="claude-sonnet-4-5,claude-opus-4-5"
588
- QUOTA_GROUPS_ANTIGRAVITY_GEMINI="gemini-3-pro-preview,gemini-3-pro-image-preview"
589
-
590
- # To disable a default group:
591
- QUOTA_GROUPS_ANTIGRAVITY_CLAUDE=""
592
- ```
593
-
594
- **Default Groups**:
595
- - **Antigravity**: Claude group (Sonnet 4.5 + Opus 4.5) with Opus counting 2x vs Sonnet
596
-
597
- #### Concurrency Control
598
-
599
- - **`MAX_CONCURRENT_REQUESTS_PER_KEY_<PROVIDER>`**: Set the maximum number of simultaneous requests allowed per API key for a specific provider. Default is `1` (no concurrency). Useful for high-throughput providers.
600
- ```env
601
- MAX_CONCURRENT_REQUESTS_PER_KEY_OPENAI=3
602
- MAX_CONCURRENT_REQUESTS_PER_KEY_ANTHROPIC=2
603
- MAX_CONCURRENT_REQUESTS_PER_KEY_GEMINI=1
604
- ```
605
-
606
- #### Custom Model Lists
607
-
608
- For providers that support custom model definitions (Qwen Code, iFlow), you can override the default model list:
609
-
610
- - **`QWEN_CODE_MODELS`**: JSON array of custom Qwen Code models. These models take priority over hardcoded defaults.
611
- ```env
612
- QWEN_CODE_MODELS='["qwen3-coder-plus", "qwen3-coder-flash", "custom-model-id"]'
613
- ```
614
-
615
- - **`IFLOW_MODELS`**: JSON array of custom iFlow models. These models take priority over hardcoded defaults.
616
- ```env
617
- IFLOW_MODELS='["glm-4.6", "qwen3-coder-plus", "deepseek-v3.2"]'
618
- ```
619
-
620
- #### Provider-Specific Settings
621
-
622
- - **`GEMINI_CLI_PROJECT_ID`**: Manually specify a Google Cloud Project ID for Gemini CLI OAuth. Only needed if automatic discovery fails.
623
-
624
-
625
- #### Antigravity Provider
626
-
627
- - **`ANTIGRAVITY_OAUTH_1`**: Path to Antigravity OAuth credential file (auto-discovered from `~/.antigravity/` or use the credential tool).
628
- ```env
629
- ANTIGRAVITY_OAUTH_1="/path/to/your/antigravity_creds.json"
630
- ```
631
-
632
- - **Stateless Deployment** (Environment Variables):
633
- ```env
634
- ANTIGRAVITY_ACCESS_TOKEN="ya29.your-access-token"
635
-
636
-
637
- #### Credential Rotation Strategy
638
-
639
- - **`ROTATION_TOLERANCE`**: Controls how credentials are selected for requests. Set via environment variable or programmatically.
640
- - `0.0`: **Deterministic** - Always selects the least-used credential for perfect load balance
641
- - `3.0` (default, recommended): **Weighted Random** - Randomly selects with bias toward less-used credentials. Provides unpredictability (harder to fingerprint/detect) while maintaining good balance
642
- - `5.0+`: **High Randomness** - Maximum unpredictability, even heavily-used credentials can be selected
643
-
644
- ```env
645
- # For maximum security/unpredictability (recommended for production)
646
- ROTATION_TOLERANCE=3.0
647
-
648
- # For perfect load balancing (default)
649
- ROTATION_TOLERANCE=0.0
650
- ```
651
-
652
- **Why use weighted random?**
653
- - Makes traffic patterns less predictable
654
- - Still maintains good load distribution across keys
655
- - Recommended for production environments with multiple credentials
656
-
657
-
658
- ANTIGRAVITY_REFRESH_TOKEN="1//your-refresh-token"
659
- ANTIGRAVITY_EXPIRY_DATE="1234567890000"
660
- ANTIGRAVITY_EMAIL="your-email@gmail.com"
661
- ```
662
-
663
- - **`ANTIGRAVITY_ENABLE_SIGNATURE_CACHE`**: Enable/disable thought signature caching for Gemini 3 multi-turn conversations. Default: `true`.
664
- ```env
665
- ANTIGRAVITY_ENABLE_SIGNATURE_CACHE=true
666
- ```
667
-
668
- - **`ANTIGRAVITY_GEMINI3_TOOL_FIX`**: Enable/disable tool hallucination prevention for Gemini 3 models. Default: `true`.
669
- ```env
670
- ANTIGRAVITY_GEMINI3_TOOL_FIX=true
671
- ```
672
-
673
- #### Temperature Override (Global)
674
-
675
- - **`OVERRIDE_TEMPERATURE_ZERO`**: Prevents tool hallucination caused by temperature=0 settings. Modes:
676
- - `"remove"`: Deletes temperature=0 from requests (lets provider use default)
677
- - `"set"`: Changes temperature=0 to temperature=1.0
678
- - `"false"` or unset: Disabled (default)
679
 
680
- #### Credential Prioritization
 
681
 
682
- - **`GEMINI_CLI_PROJECT_ID`**: Manually specify a Google Cloud Project ID for Gemini CLI OAuth. Auto-discovered unless unexpected failure occurs.
683
- ```env
684
- GEMINI_CLI_PROJECT_ID="your-gcp-project-id"
685
- ```
686
 
 
 
 
 
 
687
 
688
- ```env
689
- GEMINI_CLI_PROJECT_ID="your-gcp-project-id"
690
- ```
 
 
691
 
692
- **Example:**
693
- ```bash
694
- python src/proxy_app/main.py --host 127.0.0.1 --port 9999 --enable-request-logging
695
- ```
696
 
 
 
697
 
698
- #### Windows Batch Scripts
699
 
700
- For convenience on Windows, you can use the provided `.bat` scripts in the root directory:
 
 
 
 
701
 
702
- - **`launcher.bat`** *(deprecated)*: Legacy launcher with manual menu system. Still functional but superseded by the new TUI.
 
 
 
 
703
 
704
- ### Troubleshooting
705
 
706
- - **`401 Unauthorized`**: Ensure your `PROXY_API_KEY` is set correctly in the `.env` file and included in the `Authorization: Bearer <key>` header of your request.
707
- - **`500 Internal Server Error`**: Check the console logs of the `uvicorn` server for detailed error messages. This could indicate an issue with one of your provider API keys (e.g., it's invalid or has been revoked) or a problem with the provider's service. If you have logging enabled (`--enable-request-logging`), inspect the `final_response.json` and `metadata.json` files in the corresponding log directory under `logs/detailed_logs/` for the specific error returned by the upstream provider.
708
- - **All keys on cooldown**: If you see a message that all keys are on cooldown, it means all your keys for a specific provider have recently failed. If you have logging enabled (`--enable-request-logging`), check the `logs/detailed_logs/` directory to find the logs for the failed requests and inspect the `final_response.json` to see the underlying error from the provider.
709
 
710
- ---
711
 
712
- ## Library and Technical Docs
 
 
 
 
713
 
714
- - **Using the Library**: For documentation on how to use the `api-key-manager` library directly in your own Python projects, please refer to its [README.md](src/rotator_library/README.md).
715
- - **Technical Details**: For a more in-depth technical explanation of the library's architecture, components, and internal workings, please refer to the [Technical Documentation](DOCUMENTATION.md).
 
 
 
716
 
717
- ### Advanced Model Filtering (Whitelists & Blacklists)
 
718
 
719
- The proxy provides a powerful way to control which models are available to your applications using environment variables in your `.env` file.
720
 
721
- #### How It Works
722
 
723
- The filtering logic is applied in this order:
 
724
 
725
- 1. **Whitelist Check**: If a provider has a whitelist defined (`WHITELIST_MODELS_<PROVIDER>`), any model on that list will **always be available**, even if it's on the blacklist.
726
- 2. **Blacklist Check**: For any model *not* on the whitelist, the proxy checks the blacklist (`IGNORE_MODELS_<PROVIDER>`). If the model is on the blacklist, it will be hidden.
727
- 3. **Default**: If a model is on neither list, it will be available.
728
 
729
- This allows for two powerful patterns:
 
 
 
 
730
 
731
- #### Use Case 1: Pure Whitelist Mode
732
 
733
- You can expose *only* the specific models you want. To do this, set the blacklist to `*` to block all models by default, and then add the desired models to the whitelist.
734
 
735
- **Example `.env`:**
736
- ```env
737
- # Block all Gemini models by default
738
- IGNORE_MODELS_GEMINI="*"
739
 
740
- # Only allow gemini-1.5-pro and gemini-1.5-flash
741
- WHITELIST_MODELS_GEMINI="gemini-1.5-pro-latest,gemini-1.5-flash-latest"
 
 
 
 
 
 
742
  ```
743
 
744
- #### Use Case 2: Exemption Mode
 
 
 
745
 
746
- You can block a broad category of models and then use the whitelist to make specific exceptions.
 
747
 
748
- **Example `.env`:**
749
- ```env
750
- # Block all preview models from OpenAI
751
- IGNORE_MODELS_OPENAI="*-preview*"
752
 
753
- # But make an exception for a specific preview model you want to test
754
- WHITELIST_MODELS_OPENAI="gpt-4o-2024-08-06-preview"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
755
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Universal LLM API Proxy & Resilience Library
2
+ [![ko-fi](https://ko-fi.com/img/githubbutton_sm.svg)](https://ko-fi.com/C0C0UZS4P)
3
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/Mirrowel/LLM-API-Key-Proxy) [![zread](https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff)](https://zread.ai/Mirrowel/LLM-API-Key-Proxy)
4
 
5
+ **One proxy. Any LLM provider. Zero code changes.**
6
 
7
+ A self-hosted proxy that provides a single, OpenAI-compatible API endpoint for all your LLM providers. Works with any application that supports custom OpenAI base URLs—no code changes required in your existing tools.
8
 
9
+ This project consists of two components:
10
+ 1. **The API Proxy** — A FastAPI application providing a universal `/v1/chat/completions` endpoint
11
+ 2. **The Resilience Library** — A reusable Python library for intelligent API key management, rotation, and failover
12
 
13
+ ---
 
 
 
14
 
15
+ ## Why Use This?
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ - **Universal Compatibility** — Works with any app supporting OpenAI-compatible APIs: Opencode, Continue, Roo/Kilo Code, JanitorAI, SillyTavern, custom applications, and more
18
+ - **One Endpoint, Many Providers** — Configure Gemini, OpenAI, Anthropic, and [any LiteLLM-supported provider](https://docs.litellm.ai/docs/providers) once. Access them all through a single API key
19
+ - **Built-in Resilience** — Automatic key rotation, failover on errors, rate limit handling, and intelligent cooldowns
20
+ - **Exclusive Provider Support** — Includes custom providers not available elsewhere: **Antigravity** (Gemini 3 + Claude Sonnet/Opus 4.5), **Gemini CLI**, **Qwen Code**, and **iFlow**
21
 
22
  ---
23
 
24
+ ## Quick Start
25
 
26
+ ### Windows
27
 
28
+ 1. **Download** the latest release from [GitHub Releases](https://github.com/Mirrowel/LLM-API-Key-Proxy/releases/latest)
29
+ 2. **Unzip** the downloaded file
30
+ 3. **Run** `proxy_app.exe` the interactive TUI launcher opens
 
 
 
 
 
 
31
 
32
+ <!-- TODO: Add TUI main menu screenshot here -->
33
 
34
  ### macOS / Linux
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  ```bash
37
+ # Download and extract the release for your platform
38
+ chmod +x proxy_app
39
+ ./proxy_app
 
40
  ```
41
 
42
+ ### From Source
43
+
44
  ```bash
45
+ git clone https://github.com/Mirrowel/LLM-API-Key-Proxy.git
46
+ cd LLM-API-Key-Proxy
47
+ python3 -m venv venv
48
+ source venv/bin/activate # Windows: venv\Scripts\activate
49
+ pip install -r requirements.txt
50
  python src/proxy_app/main.py
51
  ```
52
 
53
+ > **Tip:** Running with command-line arguments (e.g., `--host 0.0.0.0 --port 8000`) bypasses the TUI and starts the proxy directly.
 
 
 
 
 
54
 
55
  ---
56
 
57
+ ## Connecting to the Proxy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ Once the proxy is running, configure your application with these settings:
 
60
 
61
+ | Setting | Value |
62
+ |---------|-------|
63
+ | **Base URL / API Endpoint** | `http://127.0.0.1:8000/v1` |
64
+ | **API Key** | Your `PROXY_API_KEY` |
65
+
66
+ ### Model Format: `provider/model_name`
67
+
68
+ **Important:** Models must be specified in the format `provider/model_name`. The `provider/` prefix tells the proxy which backend to route the request to.
69
+
70
+ ```
71
+ gemini/gemini-2.5-flash ← Gemini API
72
+ openai/gpt-4o ← OpenAI API
73
+ anthropic/claude-3-5-sonnet ← Anthropic API
74
+ openrouter/anthropic/claude-3-opus ← OpenRouter
75
+ gemini_cli/gemini-2.5-pro ← Gemini CLI (OAuth)
76
+ antigravity/gemini-3-pro-preview ← Antigravity (Gemini 3, Claude Opus 4.5)
77
+ ```
78
+
79
+ ### Usage Examples
80
+
81
+ <details>
82
+ <summary><b>Python (OpenAI Library)</b></summary>
83
+
84
+ ```python
85
+ from openai import OpenAI
86
+
87
+ client = OpenAI(
88
+ base_url="http://127.0.0.1:8000/v1",
89
+ api_key="your-proxy-api-key"
90
+ )
91
+
92
+ response = client.chat.completions.create(
93
+ model="gemini/gemini-2.5-flash", # provider/model format
94
+ messages=[{"role": "user", "content": "Hello!"}]
95
+ )
96
+ print(response.choices[0].message.content)
97
  ```
98
 
99
+ </details>
100
+
101
+ <details>
102
+ <summary><b>curl</b></summary>
103
+
104
  ```bash
105
+ curl -X POST http://127.0.0.1:8000/v1/chat/completions \
106
+ -H "Content-Type: application/json" \
107
+ -H "Authorization: Bearer your-proxy-api-key" \
108
+ -d '{
109
+ "model": "gemini/gemini-2.5-flash",
110
+ "messages": [{"role": "user", "content": "What is the capital of France?"}]
111
+ }'
112
  ```
113
 
114
+ </details>
115
 
116
+ <details>
117
+ <summary><b>JanitorAI / SillyTavern / Other Chat UIs</b></summary>
 
118
 
119
+ 1. Go to **API Settings**
120
+ 2. Select **"Proxy"** or **"Custom OpenAI"** mode
121
+ 3. Configure:
122
+ - **API URL:** `http://127.0.0.1:8000/v1`
123
+ - **API Key:** Your `PROXY_API_KEY`
124
+ - **Model:** `provider/model_name` (e.g., `gemini/gemini-2.5-flash`)
125
+ 4. Save and start chatting
126
 
127
+ </details>
128
 
129
+ <details>
130
+ <summary><b>Continue / Cursor / IDE Extensions</b></summary>
131
 
132
+ In your configuration file (e.g., `config.json`):
133
 
134
+ ```json
135
+ {
136
+ "models": [{
137
+ "title": "Gemini via Proxy",
138
+ "provider": "openai",
139
+ "model": "gemini/gemini-2.5-flash",
140
+ "apiBase": "http://127.0.0.1:8000/v1",
141
+ "apiKey": "your-proxy-api-key"
142
+ }]
143
+ }
144
+ ```
145
 
146
+ </details>
147
 
148
+ ### API Endpoints
 
 
 
 
149
 
150
+ | Endpoint | Description |
151
+ |----------|-------------|
152
+ | `GET /` | Status check — confirms proxy is running |
153
+ | `POST /v1/chat/completions` | Chat completions (main endpoint) |
154
+ | `POST /v1/embeddings` | Text embeddings |
155
+ | `GET /v1/models` | List all available models with pricing & capabilities |
156
+ | `GET /v1/models/{model_id}` | Get details for a specific model |
157
+ | `GET /v1/providers` | List configured providers |
158
+ | `POST /v1/token-count` | Calculate token count for a payload |
159
+ | `POST /v1/cost-estimate` | Estimate cost based on token counts |
160
 
161
+ > **Tip:** The `/v1/models` endpoint is useful for discovering available models in your client. Many apps can fetch this list automatically. Add `?enriched=false` for a minimal response without pricing data.
 
 
162
 
163
+ ---
 
 
 
 
164
 
165
+ ## Managing Credentials
 
 
166
 
167
+ The proxy includes an interactive tool for managing all your API keys and OAuth credentials.
168
+
169
+ ### Using the TUI
170
+
171
+ <!-- TODO: Add TUI credentials menu screenshot here -->
172
 
173
+ 1. Run the proxy without arguments to open the TUI
174
+ 2. Select **"🔑 Manage Credentials"**
175
+ 3. Choose to add API keys or OAuth credentials
176
 
177
+ ### Using the Command Line
178
 
 
179
  ```bash
180
+ python -m rotator_library.credential_tool
181
  ```
182
 
183
+ ### Credential Types
184
+
185
+ | Type | Providers | How to Add |
186
+ |------|-----------|------------|
187
+ | **API Keys** | Gemini, OpenAI, Anthropic, OpenRouter, Groq, Mistral, NVIDIA, Cohere, Chutes | Enter key in TUI or add to `.env` |
188
+ | **OAuth** | Gemini CLI, Antigravity, Qwen Code, iFlow | Interactive browser login via credential tool |
189
+
190
+ ### The `.env` File
191
+
192
+ Credentials are stored in a `.env` file. You can edit it directly or use the TUI:
193
+
194
+ ```env
195
+ # Required: Authentication key for YOUR proxy
196
+ PROXY_API_KEY="your-secret-proxy-key"
197
+
198
+ # Provider API Keys (add multiple with _1, _2, etc.)
199
+ GEMINI_API_KEY_1="your-gemini-key"
200
+ GEMINI_API_KEY_2="another-gemini-key"
201
+ OPENAI_API_KEY_1="your-openai-key"
202
+ ANTHROPIC_API_KEY_1="your-anthropic-key"
203
  ```
204
 
205
+ > Copy `.env.example` to `.env` as a starting point.
206
 
207
+ ---
208
 
209
+ ## The Resilience Library
210
 
211
+ The proxy is powered by a standalone Python library that you can use directly in your own applications.
 
212
 
213
+ ### Key Features
214
 
215
+ - **Async-native** with `asyncio` and `httpx`
216
+ - **Intelligent key selection** with tiered, model-aware locking
217
+ - **Deadline-driven requests** with configurable global timeout
218
+ - **Automatic failover** between keys on errors
219
+ - **OAuth support** for Gemini CLI, Antigravity, Qwen, iFlow
220
+ - **Stateless deployment ready** — load credentials from environment variables
221
 
222
+ ### Basic Usage
223
 
224
+ ```python
225
+ from rotator_library import RotatingClient
226
 
227
+ client = RotatingClient(
228
+ api_keys={"gemini": ["key1", "key2"], "openai": ["key3"]},
229
+ global_timeout=30,
230
+ max_retries=2
231
+ )
232
 
233
+ async with client:
234
+ response = await client.acompletion(
235
+ model="gemini/gemini-2.5-flash",
236
+ messages=[{"role": "user", "content": "Hello!"}]
237
+ )
238
  ```
239
 
240
+ ### Library Documentation
 
 
 
 
241
 
242
+ See the [Library README](src/rotator_library/README.md) for complete documentation including:
243
+ - All initialization parameters
244
+ - Streaming support
245
+ - Error handling and cooldown strategies
246
+ - Provider plugin system
247
+ - Credential prioritization
248
 
249
+ ---
 
 
 
 
 
 
 
250
 
251
+ ## Interactive TUI
 
 
 
252
 
253
+ The proxy includes a powerful text-based UI for configuration and management.
 
 
 
254
 
255
+ <!-- TODO: Add TUI main menu screenshot here -->
256
 
257
+ ### TUI Features
258
 
259
+ - **🚀 Run Proxy** — Start the server with saved settings
260
+ - **⚙️ Configure Settings** — Host, port, API key, request logging
261
+ - **🔑 Manage Credentials** — Add/edit API keys and OAuth credentials
262
+ - **📊 View Status** See configured providers and credential counts
263
+ - **🔧 Advanced Settings** — Custom providers, model definitions, concurrency
264
 
265
+ ### Configuration Files
 
 
 
 
 
266
 
267
+ | File | Contents |
268
+ |------|----------|
269
+ | `.env` | All credentials and advanced settings |
270
+ | `launcher_config.json` | TUI-specific settings (host, port, logging) |
271
 
272
+ ---
 
 
 
273
 
274
+ ## Features
275
 
276
+ ### Core Capabilities
277
 
278
+ - **Universal OpenAI-compatible endpoint** for all providers
279
+ - **Multi-provider support** via [LiteLLM](https://docs.litellm.ai/docs/providers) fallback
280
+ - **Automatic key rotation** and load balancing
281
+ - **Interactive TUI** for easy configuration
282
+ - **Detailed request logging** for debugging
283
 
284
+ <details>
285
+ <summary><b>🛡️ Resilience & High Availability</b></summary>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
+ - **Global timeout** with deadline-driven retries
288
+ - **Escalating cooldowns** per model (10s → 30s → 60s → 120s)
289
+ - **Key-level lockouts** for consistently failing keys
290
+ - **Stream error detection** and graceful recovery
291
+ - **Batch embedding aggregation** for improved throughput
292
+ - **Automatic daily resets** for cooldowns and usage stats
293
 
294
+ </details>
295
 
296
+ <details>
297
+ <summary><b>🔑 Credential Management</b></summary>
298
 
299
+ - **Auto-discovery** of API keys from environment variables
300
+ - **OAuth discovery** from standard paths (`~/.gemini/`, `~/.qwen/`, `~/.iflow/`)
301
+ - **Duplicate detection** warns when same account added multiple times
302
+ - **Credential prioritization** — paid tier used before free tier
303
+ - **Stateless deployment** — export OAuth to environment variables
304
+ - **Local-first storage** — credentials isolated in `oauth_creds/` directory
305
 
306
+ </details>
307
 
308
+ <details>
309
+ <summary><b>⚙️ Advanced Configuration</b></summary>
310
 
311
+ - **Model whitelists/blacklists** with wildcard support
312
+ - **Per-provider concurrency limits** (`MAX_CONCURRENT_REQUESTS_PER_KEY_<PROVIDER>`)
313
+ - **Rotation modes** — balanced (distribute load) or sequential (use until exhausted)
314
+ - **Priority multipliers** — higher concurrency for paid credentials
315
+ - **Model quota groups** — shared cooldowns for related models
316
+ - **Temperature override** — prevent tool hallucination issues
317
+ - **Weighted random rotation** — unpredictable selection patterns
318
 
319
+ </details>
 
 
 
320
 
321
+ <details>
322
+ <summary><b>🔌 Provider-Specific Features</b></summary>
 
323
 
324
+ **Gemini CLI:**
325
+ - Zero-config Google Cloud project discovery
326
+ - Internal API access with higher rate limits
327
+ - Automatic fallback to preview models on rate limit
328
+ - Paid vs free tier detection
329
 
330
+ **Antigravity:**
331
+ - Gemini 3 Pro with `thinkingLevel` support
332
+ - Claude Opus 4.5 (thinking mode)
333
+ - Claude Sonnet 4.5 (thinking and non-thinking)
334
+ - Thought signature caching for multi-turn conversations
335
+ - Tool hallucination prevention
336
 
337
+ **Qwen Code:**
338
+ - Dual auth (API key + OAuth Device Flow)
339
+ - `<think>` tag parsing as `reasoning_content`
340
+ - Tool schema cleaning
341
 
342
+ **iFlow:**
343
+ - Dual auth (API key + OAuth Authorization Code)
344
+ - Hybrid auth with separate API key fetch
345
+ - Tool schema cleaning
346
 
347
+ **NVIDIA NIM:**
348
+ - Dynamic model discovery
349
+ - DeepSeek thinking support
350
+
351
+ </details>
352
+
353
+ <details>
354
+ <summary><b>📝 Logging & Debugging</b></summary>
355
+
356
+ - **Per-request file logging** with `--enable-request-logging`
357
+ - **Unique request directories** with full transaction details
358
+ - **Streaming chunk capture** for debugging
359
+ - **Performance metadata** (duration, tokens, model used)
360
+ - **Provider-specific logs** for Qwen, iFlow, Antigravity
361
+
362
+ </details>
363
 
364
  ---
365
 
366
+ ## Advanced Configuration
367
 
368
+ <details>
369
+ <summary><b>Environment Variables Reference</b></summary>
370
 
371
+ ### Proxy Settings
372
 
373
+ | Variable | Description | Default |
374
+ |----------|-------------|---------|
375
+ | `PROXY_API_KEY` | Authentication key for your proxy | Required |
376
+ | `OAUTH_REFRESH_INTERVAL` | Token refresh check interval (seconds) | `600` |
377
+ | `SKIP_OAUTH_INIT_CHECK` | Skip interactive OAuth setup on startup | `false` |
378
 
379
+ ### Per-Provider Settings
 
 
 
 
380
 
381
+ | Pattern | Description | Example |
382
+ |---------|-------------|---------|
383
+ | `<PROVIDER>_API_KEY_<N>` | API key for provider | `GEMINI_API_KEY_1` |
384
+ | `MAX_CONCURRENT_REQUESTS_PER_KEY_<PROVIDER>` | Concurrent request limit | `MAX_CONCURRENT_REQUESTS_PER_KEY_OPENAI=3` |
385
+ | `ROTATION_MODE_<PROVIDER>` | `balanced` or `sequential` | `ROTATION_MODE_GEMINI=sequential` |
386
+ | `IGNORE_MODELS_<PROVIDER>` | Blacklist (comma-separated, supports `*`) | `IGNORE_MODELS_OPENAI=*-preview*` |
387
+ | `WHITELIST_MODELS_<PROVIDER>` | Whitelist (overrides blacklist) | `WHITELIST_MODELS_GEMINI=gemini-2.5-pro` |
388
 
389
+ ### Advanced Features
390
+
391
+ | Variable | Description |
392
+ |----------|-------------|
393
+ | `ROTATION_TOLERANCE` | `0.0`=deterministic, `3.0`=weighted random (default) |
394
+ | `CONCURRENCY_MULTIPLIER_<PROVIDER>_PRIORITY_<N>` | Concurrency multiplier per priority tier |
395
+ | `QUOTA_GROUPS_<PROVIDER>_<GROUP>` | Models sharing quota limits |
396
+ | `OVERRIDE_TEMPERATURE_ZERO` | `remove` or `set` to prevent tool hallucination |
397
+
398
+ </details>
399
+
400
+ <details>
401
+ <summary><b>Model Filtering (Whitelists & Blacklists)</b></summary>
402
+
403
+ Control which models are exposed through your proxy.
404
+
405
+ ### Blacklist Only
406
+ ```env
407
+ # Hide all preview models
408
+ IGNORE_MODELS_OPENAI="*-preview*"
409
  ```
410
 
411
+ ### Pure Whitelist Mode
412
+ ```env
413
+ # Block all, then allow specific models
414
+ IGNORE_MODELS_GEMINI="*"
415
+ WHITELIST_MODELS_GEMINI="gemini-2.5-pro,gemini-2.5-flash"
416
+ ```
417
 
418
+ ### Exemption Mode
419
+ ```env
420
+ # Block preview models, but allow one specific preview
421
+ IGNORE_MODELS_OPENAI="*-preview*"
422
+ WHITELIST_MODELS_OPENAI="gpt-4o-2024-08-06-preview"
423
+ ```
424
 
425
+ **Logic order:** Whitelist check → Blacklist check → Default allow
426
+
427
+ </details>
428
+
429
+ <details>
430
+ <summary><b>Concurrency & Rotation Settings</b></summary>
431
+
432
+ ### Concurrency Limits
433
+
434
+ ```env
435
+ # Allow 3 concurrent requests per OpenAI key
436
+ MAX_CONCURRENT_REQUESTS_PER_KEY_OPENAI=3
437
+
438
+ # Default is 1 (no concurrency)
439
+ MAX_CONCURRENT_REQUESTS_PER_KEY_GEMINI=1
440
  ```
441
 
442
+ ### Rotation Modes
443
 
444
+ ```env
445
+ # balanced (default): Distribute load evenly - best for per-minute rate limits
446
+ ROTATION_MODE_OPENAI=balanced
 
 
447
 
448
+ # sequential: Use until exhausted - best for daily/weekly quotas
449
+ ROTATION_MODE_GEMINI=sequential
450
+ ```
451
 
452
+ ### Priority Multipliers
453
 
454
+ Paid credentials can handle more concurrent requests:
455
 
456
+ ```env
457
+ # Priority 1 (paid ultra): 10x concurrency
458
+ CONCURRENCY_MULTIPLIER_ANTIGRAVITY_PRIORITY_1=10
 
459
 
460
+ # Priority 2 (standard paid): 3x
461
+ CONCURRENCY_MULTIPLIER_ANTIGRAVITY_PRIORITY_2=3
462
+ ```
463
 
464
+ ### Model Quota Groups
465
 
466
+ Models sharing quota limits:
 
 
 
467
 
468
+ ```env
469
+ # Claude models share quota - when one hits limit, both cool down
470
+ QUOTA_GROUPS_ANTIGRAVITY_CLAUDE="claude-sonnet-4-5,claude-opus-4-5"
471
+ ```
472
 
473
+ </details>
474
 
475
+ <details>
476
+ <summary><b>Timeout Configuration</b></summary>
 
477
 
478
+ Fine-grained control over HTTP timeouts:
479
 
480
+ ```env
481
+ TIMEOUT_CONNECT=30 # Connection establishment
482
+ TIMEOUT_WRITE=30 # Request body send
483
+ TIMEOUT_POOL=60 # Connection pool acquisition
484
+ TIMEOUT_READ_STREAMING=180 # Between streaming chunks (3 min)
485
+ TIMEOUT_READ_NON_STREAMING=600 # Full response wait (10 min)
486
+ ```
487
 
488
+ **Recommendations:**
489
+ - Long thinking tasks: Increase `TIMEOUT_READ_STREAMING` to 300-360s
490
+ - Unstable network: Increase `TIMEOUT_CONNECT` to 60s
491
+ - Large outputs: Increase `TIMEOUT_READ_NON_STREAMING` to 900s+
 
 
 
492
 
493
+ </details>
 
 
 
 
 
 
 
494
 
495
+ ---
496
+
497
+ ## OAuth Providers
498
+
499
+ <details>
500
+ <summary><b>Gemini CLI</b></summary>
501
 
502
+ Uses Google OAuth to access internal Gemini endpoints with higher rate limits.
503
 
504
+ **Setup:**
505
+ 1. Run `python -m rotator_library.credential_tool`
506
+ 2. Select "Add OAuth Credential" → "Gemini CLI"
507
+ 3. Complete browser authentication
508
+ 4. Credentials saved to `oauth_creds/gemini_cli_oauth_1.json`
509
 
510
+ **Features:**
511
+ - Zero-config project discovery
512
+ - Automatic free-tier project onboarding
513
+ - Paid vs free tier detection
514
+ - Smart fallback on rate limits
515
 
516
+ **Environment Variables (for stateless deployment):**
517
+ ```env
518
+ GEMINI_CLI_ACCESS_TOKEN="ya29.your-access-token"
519
+ GEMINI_CLI_REFRESH_TOKEN="1//your-refresh-token"
520
+ GEMINI_CLI_EXPIRY_DATE="1234567890000"
521
+ GEMINI_CLI_EMAIL="your-email@gmail.com"
522
+ GEMINI_CLI_PROJECT_ID="your-gcp-project-id" # Optional
523
+ ```
524
 
525
+ </details>
 
 
526
 
527
+ <details>
528
+ <summary><b>Antigravity (Gemini 3 + Claude Opus 4.5)</b></summary>
529
 
530
+ Access Google's internal Antigravity API for cutting-edge models.
 
531
 
532
  **Supported Models:**
533
+ - **Gemini 3 Pro** with `thinkingLevel` support (low/high)
534
+ - **Claude Opus 4.5** Anthropic's most powerful model (thinking mode only)
535
+ - **Claude Sonnet 4.5** supports both thinking and non-thinking modes
536
+ - Gemini 2.5 Pro/Flash
537
 
538
+ **Setup:**
539
+ 1. Run `python -m rotator_library.credential_tool`
540
+ 2. Select "Add OAuth Credential" "Antigravity"
541
+ 3. Complete browser authentication
 
 
542
 
543
+ **Advanced Features:**
544
+ - Thought signature caching for multi-turn conversations
545
+ - Tool hallucination prevention via parameter signature injection
546
+ - Automatic thinking block sanitization for Claude
547
+ - Credential prioritization (paid resets every 5 hours, free weekly)
548
 
549
  **Environment Variables:**
550
  ```env
551
+ ANTIGRAVITY_ACCESS_TOKEN="ya29.your-access-token"
552
+ ANTIGRAVITY_REFRESH_TOKEN="1//your-refresh-token"
553
+ ANTIGRAVITY_EXPIRY_DATE="1234567890000"
554
+ ANTIGRAVITY_EMAIL="your-email@gmail.com"
 
555
 
556
  # Feature toggles
557
+ ANTIGRAVITY_ENABLE_SIGNATURE_CACHE=true
558
+ ANTIGRAVITY_GEMINI3_TOOL_FIX=true
559
  ```
560
 
561
+ > **Note:** Gemini 3 models require a paid-tier Google Cloud project.
562
 
563
+ </details>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
564
 
565
+ <details>
566
+ <summary><b>Qwen Code</b></summary>
567
 
568
+ Uses OAuth Device Flow for Qwen/Dashscope APIs.
 
 
 
569
 
570
+ **Setup:**
571
+ 1. Run the credential tool
572
+ 2. Select "Add OAuth Credential" → "Qwen Code"
573
+ 3. Enter the code displayed in your browser
574
+ 4. Or add API key directly: `QWEN_CODE_API_KEY_1="your-key"`
575
 
576
+ **Features:**
577
+ - Dual auth (API key or OAuth)
578
+ - `<think>` tag parsing as `reasoning_content`
579
+ - Automatic tool schema cleaning
580
+ - Custom models via `QWEN_CODE_MODELS` env var
581
 
582
+ </details>
 
 
 
583
 
584
+ <details>
585
+ <summary><b>iFlow</b></summary>
586
 
587
+ Uses OAuth Authorization Code flow with local callback server.
588
 
589
+ **Setup:**
590
+ 1. Run the credential tool
591
+ 2. Select "Add OAuth Credential" → "iFlow"
592
+ 3. Complete browser authentication (callback on port 11451)
593
+ 4. Or add API key directly: `IFLOW_API_KEY_1="sk-your-key"`
594
 
595
+ **Features:**
596
+ - Dual auth (API key or OAuth)
597
+ - Hybrid auth (OAuth token fetches separate API key)
598
+ - Automatic tool schema cleaning
599
+ - Custom models via `IFLOW_MODELS` env var
600
 
601
+ </details>
602
 
603
+ <details>
604
+ <summary><b>Stateless Deployment (Export to Environment Variables)</b></summary>
 
605
 
606
+ For platforms without file persistence (Railway, Render, Vercel):
607
 
608
+ 1. **Set up credentials locally:**
609
+ ```bash
610
+ python -m rotator_library.credential_tool
611
+ # Complete OAuth flows
612
+ ```
613
 
614
+ 2. **Export to environment variables:**
615
+ ```bash
616
+ python -m rotator_library.credential_tool
617
+ # Select "Export [Provider] to .env"
618
+ ```
619
 
620
+ 3. **Copy generated variables to your platform:**
621
+ The tool creates files like `gemini_cli_credential_1.env` containing all necessary variables.
622
 
623
+ 4. **Set `SKIP_OAUTH_INIT_CHECK=true`** to skip interactive validation on startup.
624
 
625
+ </details>
626
 
627
+ <details>
628
+ <summary><b>OAuth Callback Port Configuration</b></summary>
629
 
630
+ Customize OAuth callback ports if defaults conflict:
 
 
631
 
632
+ | Provider | Default Port | Environment Variable |
633
+ |----------|-------------|---------------------|
634
+ | Gemini CLI | 8085 | `GEMINI_CLI_OAUTH_PORT` |
635
+ | Antigravity | 51121 | `ANTIGRAVITY_OAUTH_PORT` |
636
+ | iFlow | 11451 | `IFLOW_OAUTH_PORT` |
637
 
638
+ </details>
639
 
640
+ ---
641
 
642
+ ## Deployment
643
+
644
+ <details>
645
+ <summary><b>Command-Line Arguments</b></summary>
646
 
647
+ ```bash
648
+ python src/proxy_app/main.py [OPTIONS]
649
+
650
+ Options:
651
+ --host TEXT Host to bind (default: 0.0.0.0)
652
+ --port INTEGER Port to run on (default: 8000)
653
+ --enable-request-logging Enable detailed per-request logging
654
+ --add-credential Launch interactive credential setup tool
655
  ```
656
 
657
+ **Examples:**
658
+ ```bash
659
+ # Run on custom port
660
+ python src/proxy_app/main.py --host 127.0.0.1 --port 9000
661
 
662
+ # Run with logging
663
+ python src/proxy_app/main.py --enable-request-logging
664
 
665
+ # Add credentials without starting proxy
666
+ python src/proxy_app/main.py --add-credential
667
+ ```
 
668
 
669
+ </details>
670
+
671
+ <details>
672
+ <summary><b>Render / Railway / Vercel</b></summary>
673
+
674
+ See the [Deployment Guide](Deployment%20guide.md) for complete instructions.
675
+
676
+ **Quick Setup:**
677
+ 1. Fork the repository
678
+ 2. Create a `.env` file with your credentials
679
+ 3. Create a new Web Service pointing to your repo
680
+ 4. Set build command: `pip install -r requirements.txt`
681
+ 5. Set start command: `uvicorn src.proxy_app.main:app --host 0.0.0.0 --port $PORT`
682
+ 6. Upload `.env` as a secret file
683
+
684
+ **OAuth Credentials:**
685
+ Export OAuth credentials to environment variables using the credential tool, then add them to your platform's environment settings.
686
+
687
+ </details>
688
+
689
+ <details>
690
+ <summary><b>Custom VPS / Docker</b></summary>
691
+
692
+ **Option 1: Authenticate locally, deploy credentials**
693
+ 1. Complete OAuth flows on your local machine
694
+ 2. Export to environment variables
695
+ 3. Deploy `.env` to your server
696
+
697
+ **Option 2: SSH Port Forwarding**
698
+ ```bash
699
+ # Forward callback ports through SSH
700
+ ssh -L 51121:localhost:51121 -L 8085:localhost:8085 user@your-vps
701
+
702
+ # Then run credential tool on the VPS
703
+ ```
704
+
705
+ **Systemd Service:**
706
+ ```ini
707
+ [Unit]
708
+ Description=LLM API Key Proxy
709
+ After=network.target
710
+
711
+ [Service]
712
+ Type=simple
713
+ WorkingDirectory=/path/to/LLM-API-Key-Proxy
714
+ ExecStart=/path/to/python -m uvicorn src.proxy_app.main:app --host 0.0.0.0 --port 8000
715
+ Restart=always
716
+
717
+ [Install]
718
+ WantedBy=multi-user.target
719
  ```
720
+
721
+ See [VPS Deployment](Deployment%20guide.md#appendix-deploying-to-a-custom-vps) for complete guide.
722
+
723
+ </details>
724
+
725
+ ---
726
+
727
+ ## Troubleshooting
728
+
729
+ | Issue | Solution |
730
+ |-------|----------|
731
+ | `401 Unauthorized` | Verify `PROXY_API_KEY` matches your `Authorization: Bearer` header exactly |
732
+ | `500 Internal Server Error` | Check provider key validity; enable `--enable-request-logging` for details |
733
+ | All keys on cooldown | All keys failed recently; check `logs/detailed_logs/` for upstream errors |
734
+ | Model not found | Verify format is `provider/model_name` (e.g., `gemini/gemini-2.5-flash`) |
735
+ | OAuth callback failed | Ensure callback port (8085, 51121, 11451) isn't blocked by firewall |
736
+ | Streaming hangs | Increase `TIMEOUT_READ_STREAMING`; check provider status |
737
+
738
+ **Detailed Logs:**
739
+
740
+ When `--enable-request-logging` is enabled, check `logs/detailed_logs/` for:
741
+ - `request.json` — Exact request payload
742
+ - `final_response.json` — Complete response or error
743
+ - `streaming_chunks.jsonl` — All SSE chunks received
744
+ - `metadata.json` — Performance metrics
745
+
746
+ ---
747
+
748
+ ## Documentation
749
+
750
+ | Document | Description |
751
+ |----------|-------------|
752
+ | [Technical Documentation](DOCUMENTATION.md) | Architecture, internals, provider implementations |
753
+ | [Library README](src/rotator_library/README.md) | Using the resilience library directly |
754
+ | [Deployment Guide](Deployment%20guide.md) | Hosting on Render, Railway, VPS |
755
+ | [.env.example](.env.example) | Complete environment variable reference |
756
+
757
+ ---
758
+
759
+ ## License
760
+
761
+ This project is dual-licensed:
762
+ - **Proxy Application** (`src/proxy_app/`) — [MIT License](src/proxy_app/LICENSE)
763
+ - **Resilience Library** (`src/rotator_library/`) — [LGPL-3.0](src/rotator_library/COPYING.LESSER)
src/proxy_app/detailed_logger.py CHANGED
@@ -3,16 +3,33 @@ import time
3
  import uuid
4
  from datetime import datetime
5
  from pathlib import Path
6
- from typing import Any, Dict, Optional, List
7
  import logging
8
 
9
- LOGS_DIR = Path(__file__).resolve().parent.parent.parent / "logs"
10
- DETAILED_LOGS_DIR = LOGS_DIR / "detailed_logs"
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  class DetailedLogger:
13
  """
14
  Logs comprehensive details of each API transaction to a unique, timestamped directory.
 
 
 
15
  """
 
16
  def __init__(self):
17
  """
18
  Initializes the logger for a single request, creating a unique directory to store all related log files.
@@ -20,17 +37,26 @@ class DetailedLogger:
20
  self.start_time = time.time()
21
  self.request_id = str(uuid.uuid4())
22
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
23
- self.log_dir = DETAILED_LOGS_DIR / f"{timestamp}_{self.request_id}"
24
- self.log_dir.mkdir(parents=True, exist_ok=True)
25
  self.streaming = False
 
26
 
27
  def _write_json(self, filename: str, data: Dict[str, Any]):
28
  """Helper to write data to a JSON file in the log directory."""
29
- try:
30
- with open(self.log_dir / filename, "w", encoding="utf-8") as f:
31
- json.dump(data, f, indent=4, ensure_ascii=False)
32
- except Exception as e:
33
- logging.error(f"[{self.request_id}] Failed to write to {filename}: {e}")
 
 
 
 
 
 
 
 
 
34
 
35
  def log_request(self, headers: Dict[str, Any], body: Dict[str, Any]):
36
  """Logs the initial request details."""
@@ -39,23 +65,22 @@ class DetailedLogger:
39
  "request_id": self.request_id,
40
  "timestamp_utc": datetime.utcnow().isoformat(),
41
  "headers": dict(headers),
42
- "body": body
43
  }
44
  self._write_json("request.json", request_data)
45
 
46
  def log_stream_chunk(self, chunk: Dict[str, Any]):
47
  """Logs an individual chunk from a streaming response to a JSON Lines file."""
48
- try:
49
- log_entry = {
50
- "timestamp_utc": datetime.utcnow().isoformat(),
51
- "chunk": chunk
52
- }
53
- with open(self.log_dir / "streaming_chunks.jsonl", "a", encoding="utf-8") as f:
54
- f.write(json.dumps(log_entry, ensure_ascii=False) + "\n")
55
- except Exception as e:
56
- logging.error(f"[{self.request_id}] Failed to write stream chunk: {e}")
57
-
58
- def log_final_response(self, status_code: int, headers: Optional[Dict[str, Any]], body: Dict[str, Any]):
59
  """Logs the complete final response, either from a non-streaming call or after reassembling a stream."""
60
  end_time = time.time()
61
  duration_ms = (end_time - self.start_time) * 1000
@@ -66,7 +91,7 @@ class DetailedLogger:
66
  "status_code": status_code,
67
  "duration_ms": round(duration_ms),
68
  "headers": dict(headers) if headers else None,
69
- "body": body
70
  }
71
  self._write_json("final_response.json", response_data)
72
  self._log_metadata(response_data)
@@ -75,10 +100,10 @@ class DetailedLogger:
75
  """Recursively searches for and extracts 'reasoning' fields from the response body."""
76
  if not isinstance(response_body, dict):
77
  return None
78
-
79
  if "reasoning" in response_body:
80
  return response_body["reasoning"]
81
-
82
  if "choices" in response_body and response_body["choices"]:
83
  message = response_body["choices"][0].get("message", {})
84
  if "reasoning" in message:
@@ -93,8 +118,13 @@ class DetailedLogger:
93
  usage = response_data.get("body", {}).get("usage") or {}
94
  model = response_data.get("body", {}).get("model", "N/A")
95
  finish_reason = "N/A"
96
- if "choices" in response_data.get("body", {}) and response_data["body"]["choices"]:
97
- finish_reason = response_data["body"]["choices"][0].get("finish_reason", "N/A")
 
 
 
 
 
98
 
99
  metadata = {
100
  "request_id": self.request_id,
@@ -110,12 +140,12 @@ class DetailedLogger:
110
  },
111
  "finish_reason": finish_reason,
112
  "reasoning_found": False,
113
- "reasoning_content": None
114
  }
115
 
116
  reasoning = self._extract_reasoning(response_data.get("body", {}))
117
  if reasoning:
118
  metadata["reasoning_found"] = True
119
  metadata["reasoning_content"] = reasoning
120
-
121
- self._write_json("metadata.json", metadata)
 
3
  import uuid
4
  from datetime import datetime
5
  from pathlib import Path
6
+ from typing import Any, Dict, Optional
7
  import logging
8
 
9
+ from rotator_library.utils.resilient_io import (
10
+ safe_write_json,
11
+ safe_log_write,
12
+ safe_mkdir,
13
+ )
14
+ from rotator_library.utils.paths import get_logs_dir
15
+
16
+
17
+ def _get_detailed_logs_dir() -> Path:
18
+ """Get the detailed logs directory, creating it if needed."""
19
+ logs_dir = get_logs_dir()
20
+ detailed_dir = logs_dir / "detailed_logs"
21
+ detailed_dir.mkdir(parents=True, exist_ok=True)
22
+ return detailed_dir
23
+
24
 
25
  class DetailedLogger:
26
  """
27
  Logs comprehensive details of each API transaction to a unique, timestamped directory.
28
+
29
+ Uses fire-and-forget logging - if disk writes fail, logs are dropped (not buffered)
30
+ to prevent memory issues, especially with streaming responses.
31
  """
32
+
33
  def __init__(self):
34
  """
35
  Initializes the logger for a single request, creating a unique directory to store all related log files.
 
37
  self.start_time = time.time()
38
  self.request_id = str(uuid.uuid4())
39
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
40
+ self.log_dir = _get_detailed_logs_dir() / f"{timestamp}_{self.request_id}"
 
41
  self.streaming = False
42
+ self._dir_available = safe_mkdir(self.log_dir, logging)
43
 
44
  def _write_json(self, filename: str, data: Dict[str, Any]):
45
  """Helper to write data to a JSON file in the log directory."""
46
+ if not self._dir_available:
47
+ # Try to create directory again in case it was recreated
48
+ self._dir_available = safe_mkdir(self.log_dir, logging)
49
+ if not self._dir_available:
50
+ return
51
+
52
+ safe_write_json(
53
+ self.log_dir / filename,
54
+ data,
55
+ logging,
56
+ atomic=False,
57
+ indent=4,
58
+ ensure_ascii=False,
59
+ )
60
 
61
  def log_request(self, headers: Dict[str, Any], body: Dict[str, Any]):
62
  """Logs the initial request details."""
 
65
  "request_id": self.request_id,
66
  "timestamp_utc": datetime.utcnow().isoformat(),
67
  "headers": dict(headers),
68
+ "body": body,
69
  }
70
  self._write_json("request.json", request_data)
71
 
72
  def log_stream_chunk(self, chunk: Dict[str, Any]):
73
  """Logs an individual chunk from a streaming response to a JSON Lines file."""
74
+ if not self._dir_available:
75
+ return
76
+
77
+ log_entry = {"timestamp_utc": datetime.utcnow().isoformat(), "chunk": chunk}
78
+ content = json.dumps(log_entry, ensure_ascii=False) + "\n"
79
+ safe_log_write(self.log_dir / "streaming_chunks.jsonl", content, logging)
80
+
81
+ def log_final_response(
82
+ self, status_code: int, headers: Optional[Dict[str, Any]], body: Dict[str, Any]
83
+ ):
 
84
  """Logs the complete final response, either from a non-streaming call or after reassembling a stream."""
85
  end_time = time.time()
86
  duration_ms = (end_time - self.start_time) * 1000
 
91
  "status_code": status_code,
92
  "duration_ms": round(duration_ms),
93
  "headers": dict(headers) if headers else None,
94
+ "body": body,
95
  }
96
  self._write_json("final_response.json", response_data)
97
  self._log_metadata(response_data)
 
100
  """Recursively searches for and extracts 'reasoning' fields from the response body."""
101
  if not isinstance(response_body, dict):
102
  return None
103
+
104
  if "reasoning" in response_body:
105
  return response_body["reasoning"]
106
+
107
  if "choices" in response_body and response_body["choices"]:
108
  message = response_body["choices"][0].get("message", {})
109
  if "reasoning" in message:
 
118
  usage = response_data.get("body", {}).get("usage") or {}
119
  model = response_data.get("body", {}).get("model", "N/A")
120
  finish_reason = "N/A"
121
+ if (
122
+ "choices" in response_data.get("body", {})
123
+ and response_data["body"]["choices"]
124
+ ):
125
+ finish_reason = response_data["body"]["choices"][0].get(
126
+ "finish_reason", "N/A"
127
+ )
128
 
129
  metadata = {
130
  "request_id": self.request_id,
 
140
  },
141
  "finish_reason": finish_reason,
142
  "reasoning_found": False,
143
+ "reasoning_content": None,
144
  }
145
 
146
  reasoning = self._extract_reasoning(response_data.get("body", {}))
147
  if reasoning:
148
  metadata["reasoning_found"] = True
149
  metadata["reasoning_content"] = reasoning
150
+
151
+ self._write_json("metadata.json", metadata)
src/proxy_app/launcher_tui.py CHANGED
@@ -16,6 +16,20 @@ from dotenv import load_dotenv, set_key
16
  console = Console()
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def clear_screen():
20
  """
21
  Cross-platform terminal clear that works robustly on both
@@ -74,7 +88,7 @@ class LauncherConfig:
74
  @staticmethod
75
  def update_proxy_api_key(new_key: str):
76
  """Update PROXY_API_KEY in .env only"""
77
- env_file = Path.cwd() / ".env"
78
  set_key(str(env_file), "PROXY_API_KEY", new_key)
79
  load_dotenv(dotenv_path=env_file, override=True)
80
 
@@ -85,7 +99,7 @@ class SettingsDetector:
85
  @staticmethod
86
  def _load_local_env() -> dict:
87
  """Load environment variables from local .env file only"""
88
- env_file = Path.cwd() / ".env"
89
  env_dict = {}
90
  if not env_file.exists():
91
  return env_dict
@@ -107,7 +121,7 @@ class SettingsDetector:
107
 
108
  @staticmethod
109
  def get_all_settings() -> dict:
110
- """Returns comprehensive settings overview"""
111
  return {
112
  "credentials": SettingsDetector.detect_credentials(),
113
  "custom_bases": SettingsDetector.detect_custom_api_bases(),
@@ -117,6 +131,17 @@ class SettingsDetector:
117
  "provider_settings": SettingsDetector.detect_provider_settings(),
118
  }
119
 
 
 
 
 
 
 
 
 
 
 
 
120
  @staticmethod
121
  def detect_credentials() -> dict:
122
  """Detect API keys and OAuth credentials"""
@@ -260,7 +285,7 @@ class LauncherTUI:
260
  self.console = Console()
261
  self.config = LauncherConfig()
262
  self.running = True
263
- self.env_file = Path.cwd() / ".env"
264
  # Load .env file to ensure environment variables are available
265
  load_dotenv(dotenv_path=self.env_file, override=True)
266
 
@@ -277,8 +302,8 @@ class LauncherTUI:
277
  """Display main menu and handle selection"""
278
  clear_screen()
279
 
280
- # Detect all settings
281
- settings = SettingsDetector.get_all_settings()
282
  credentials = settings["credentials"]
283
  custom_bases = settings["custom_bases"]
284
 
@@ -363,18 +388,17 @@ class LauncherTUI:
363
  self.console.print("━" * 70)
364
  provider_count = len(credentials)
365
  custom_count = len(custom_bases)
366
- provider_settings = settings.get("provider_settings", {})
 
 
 
367
  has_advanced = bool(
368
  settings["model_definitions"]
369
  or settings["concurrency_limits"]
370
  or settings["model_filters"]
371
- or provider_settings
372
  )
373
-
374
- self.console.print(f" Providers: {provider_count} configured")
375
- self.console.print(f" Custom Providers: {custom_count} configured")
376
  self.console.print(
377
- f" Advanced Settings: {'Active (view in menu 4)' if has_advanced else 'None'}"
378
  )
379
 
380
  # Show menu
@@ -418,7 +442,7 @@ class LauncherTUI:
418
  elif choice == "4":
419
  self.show_provider_settings_menu()
420
  elif choice == "5":
421
- load_dotenv(dotenv_path=Path.cwd() / ".env", override=True)
422
  self.config = LauncherConfig() # Reload config
423
  self.console.print("\n[green]✅ Configuration reloaded![/green]")
424
  elif choice == "6":
@@ -659,13 +683,14 @@ class LauncherTUI:
659
  """Display provider/advanced settings (read-only + launch tool)"""
660
  clear_screen()
661
 
662
- settings = SettingsDetector.get_all_settings()
 
 
663
  credentials = settings["credentials"]
664
  custom_bases = settings["custom_bases"]
665
  model_defs = settings["model_definitions"]
666
  concurrency = settings["concurrency_limits"]
667
  filters = settings["model_filters"]
668
- provider_settings = settings.get("provider_settings", {})
669
 
670
  self.console.print(
671
  Panel.fit(
@@ -740,23 +765,13 @@ class LauncherTUI:
740
  status = " + ".join(status_parts) if status_parts else "None"
741
  self.console.print(f" • {provider:15} ✅ {status}")
742
 
743
- # Provider-Specific Settings
744
  self.console.print()
745
  self.console.print("[bold]🔬 Provider-Specific Settings[/bold]")
746
  self.console.print("━" * 70)
747
- try:
748
- from proxy_app.settings_tool import PROVIDER_SETTINGS_MAP
749
- except ImportError:
750
- from .settings_tool import PROVIDER_SETTINGS_MAP
751
- for provider in PROVIDER_SETTINGS_MAP.keys():
752
- display_name = provider.replace("_", " ").title()
753
- modified = provider_settings.get(provider, 0)
754
- if modified > 0:
755
- self.console.print(
756
- f" • {display_name:20} [yellow]{modified} setting{'s' if modified > 1 else ''} modified[/yellow]"
757
- )
758
- else:
759
- self.console.print(f" • {display_name:20} [dim]using defaults[/dim]")
760
 
761
  # Actions
762
  self.console.print()
@@ -823,15 +838,31 @@ class LauncherTUI:
823
  # Run the tool with from_launcher=True to skip duplicate loading screen
824
  run_credential_tool(from_launcher=True)
825
  # Reload environment after credential tool
826
- load_dotenv(dotenv_path=Path.cwd() / ".env", override=True)
827
 
828
  def launch_settings_tool(self):
829
  """Launch settings configuration tool"""
830
- from proxy_app.settings_tool import run_settings_tool
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
831
 
832
  run_settings_tool()
833
  # Reload environment after settings tool
834
- load_dotenv(dotenv_path=Path.cwd() / ".env", override=True)
835
 
836
  def show_about(self):
837
  """Display About page with project information"""
@@ -919,9 +950,9 @@ class LauncherTUI:
919
  )
920
 
921
  ensure_env_defaults()
922
- load_dotenv(dotenv_path=Path.cwd() / ".env", override=True)
923
  run_credential_tool()
924
- load_dotenv(dotenv_path=Path.cwd() / ".env", override=True)
925
 
926
  # Check again after credential tool
927
  if not os.getenv("PROXY_API_KEY"):
 
16
  console = Console()
17
 
18
 
19
+ def _get_env_file() -> Path:
20
+ """
21
+ Get .env file path (lightweight - no heavy imports).
22
+
23
+ Returns:
24
+ Path to .env file - EXE directory if frozen, else current working directory
25
+ """
26
+ if getattr(sys, "frozen", False):
27
+ # Running as PyInstaller EXE - use EXE's directory
28
+ return Path(sys.executable).parent / ".env"
29
+ # Running as script - use current working directory
30
+ return Path.cwd() / ".env"
31
+
32
+
33
  def clear_screen():
34
  """
35
  Cross-platform terminal clear that works robustly on both
 
88
  @staticmethod
89
  def update_proxy_api_key(new_key: str):
90
  """Update PROXY_API_KEY in .env only"""
91
+ env_file = _get_env_file()
92
  set_key(str(env_file), "PROXY_API_KEY", new_key)
93
  load_dotenv(dotenv_path=env_file, override=True)
94
 
 
99
  @staticmethod
100
  def _load_local_env() -> dict:
101
  """Load environment variables from local .env file only"""
102
+ env_file = _get_env_file()
103
  env_dict = {}
104
  if not env_file.exists():
105
  return env_dict
 
121
 
122
  @staticmethod
123
  def get_all_settings() -> dict:
124
+ """Returns comprehensive settings overview (includes provider_settings which triggers heavy imports)"""
125
  return {
126
  "credentials": SettingsDetector.detect_credentials(),
127
  "custom_bases": SettingsDetector.detect_custom_api_bases(),
 
131
  "provider_settings": SettingsDetector.detect_provider_settings(),
132
  }
133
 
134
+ @staticmethod
135
+ def get_basic_settings() -> dict:
136
+ """Returns basic settings overview without provider_settings (avoids heavy imports)"""
137
+ return {
138
+ "credentials": SettingsDetector.detect_credentials(),
139
+ "custom_bases": SettingsDetector.detect_custom_api_bases(),
140
+ "model_definitions": SettingsDetector.detect_model_definitions(),
141
+ "concurrency_limits": SettingsDetector.detect_concurrency_limits(),
142
+ "model_filters": SettingsDetector.detect_model_filters(),
143
+ }
144
+
145
  @staticmethod
146
  def detect_credentials() -> dict:
147
  """Detect API keys and OAuth credentials"""
 
285
  self.console = Console()
286
  self.config = LauncherConfig()
287
  self.running = True
288
+ self.env_file = _get_env_file()
289
  # Load .env file to ensure environment variables are available
290
  load_dotenv(dotenv_path=self.env_file, override=True)
291
 
 
302
  """Display main menu and handle selection"""
303
  clear_screen()
304
 
305
+ # Detect basic settings (excludes provider_settings to avoid heavy imports)
306
+ settings = SettingsDetector.get_basic_settings()
307
  credentials = settings["credentials"]
308
  custom_bases = settings["custom_bases"]
309
 
 
388
  self.console.print("━" * 70)
389
  provider_count = len(credentials)
390
  custom_count = len(custom_bases)
391
+
392
+ self.console.print(f" Providers: {provider_count} configured")
393
+ self.console.print(f" Custom Providers: {custom_count} configured")
394
+ # Note: provider_settings detection is deferred to avoid heavy imports on startup
395
  has_advanced = bool(
396
  settings["model_definitions"]
397
  or settings["concurrency_limits"]
398
  or settings["model_filters"]
 
399
  )
 
 
 
400
  self.console.print(
401
+ f" Advanced Settings: {'Active (view in menu 4)' if has_advanced else 'None (view menu 4 for details)'}"
402
  )
403
 
404
  # Show menu
 
442
  elif choice == "4":
443
  self.show_provider_settings_menu()
444
  elif choice == "5":
445
+ load_dotenv(dotenv_path=_get_env_file(), override=True)
446
  self.config = LauncherConfig() # Reload config
447
  self.console.print("\n[green]✅ Configuration reloaded![/green]")
448
  elif choice == "6":
 
683
  """Display provider/advanced settings (read-only + launch tool)"""
684
  clear_screen()
685
 
686
+ # Use basic settings to avoid heavy imports - provider_settings deferred to Settings Tool
687
+ settings = SettingsDetector.get_basic_settings()
688
+
689
  credentials = settings["credentials"]
690
  custom_bases = settings["custom_bases"]
691
  model_defs = settings["model_definitions"]
692
  concurrency = settings["concurrency_limits"]
693
  filters = settings["model_filters"]
 
694
 
695
  self.console.print(
696
  Panel.fit(
 
765
  status = " + ".join(status_parts) if status_parts else "None"
766
  self.console.print(f" • {provider:15} ✅ {status}")
767
 
768
+ # Provider-Specific Settings (deferred to Settings Tool to avoid heavy imports)
769
  self.console.print()
770
  self.console.print("[bold]🔬 Provider-Specific Settings[/bold]")
771
  self.console.print("━" * 70)
772
+ self.console.print(
773
+ " [dim]Launch Settings Tool to view/configure provider-specific settings[/dim]"
774
+ )
 
 
 
 
 
 
 
 
 
 
775
 
776
  # Actions
777
  self.console.print()
 
838
  # Run the tool with from_launcher=True to skip duplicate loading screen
839
  run_credential_tool(from_launcher=True)
840
  # Reload environment after credential tool
841
+ load_dotenv(dotenv_path=_get_env_file(), override=True)
842
 
843
  def launch_settings_tool(self):
844
  """Launch settings configuration tool"""
845
+ import time
846
+
847
+ clear_screen()
848
+
849
+ self.console.print("━" * 70)
850
+ self.console.print("Advanced Settings Configuration Tool")
851
+ self.console.print("━" * 70)
852
+
853
+ _start_time = time.time()
854
+
855
+ with self.console.status("Initializing settings tool...", spinner="dots"):
856
+ from proxy_app.settings_tool import run_settings_tool
857
+
858
+ _elapsed = time.time() - _start_time
859
+ self.console.print(f"✓ Settings tool ready in {_elapsed:.2f}s")
860
+
861
+ time.sleep(0.3)
862
 
863
  run_settings_tool()
864
  # Reload environment after settings tool
865
+ load_dotenv(dotenv_path=_get_env_file(), override=True)
866
 
867
  def show_about(self):
868
  """Display About page with project information"""
 
950
  )
951
 
952
  ensure_env_defaults()
953
+ load_dotenv(dotenv_path=_get_env_file(), override=True)
954
  run_credential_tool()
955
+ load_dotenv(dotenv_path=_get_env_file(), override=True)
956
 
957
  # Check again after credential tool
958
  if not os.getenv("PROXY_API_KEY"):
src/proxy_app/main.py CHANGED
@@ -52,11 +52,17 @@ _start_time = time.time()
52
  from dotenv import load_dotenv
53
  from glob import glob
54
 
 
 
 
 
 
 
 
55
  # Load main .env first
56
- load_dotenv()
57
 
58
  # Load any additional .env files (e.g., antigravity_all_combined.env, gemini_cli_all_combined.env)
59
- _root_dir = Path.cwd()
60
  _env_files_found = list(_root_dir.glob("*.env"))
61
  for _env_file in sorted(_root_dir.glob("*.env")):
62
  if _env_file.name != ".env": # Skip main .env (already loaded)
@@ -234,8 +240,10 @@ print(
234
  # Note: Debug logging will be added after logging configuration below
235
 
236
  # --- Logging Configuration ---
237
- LOG_DIR = Path(__file__).resolve().parent.parent.parent / "logs"
238
- LOG_DIR.mkdir(exist_ok=True)
 
 
239
 
240
  # Configure a console handler with color (INFO and above only, no DEBUG)
241
  console_handler = colorlog.StreamHandler(sys.stdout)
@@ -324,7 +332,7 @@ litellm_logger.propagate = False
324
  logging.debug(f"Modules loaded in {_elapsed:.2f}s")
325
 
326
  # Load environment variables from .env file
327
- load_dotenv()
328
 
329
  # --- Configuration ---
330
  USE_EMBEDDING_BATCHER = False
@@ -570,11 +578,11 @@ async def lifespan(app: FastAPI):
570
  )
571
 
572
  # Log loaded credentials summary (compact, always visible for deployment verification)
573
- #_api_summary = ', '.join([f"{p}:{len(c)}" for p, c in api_keys.items()]) if api_keys else "none"
574
- #_oauth_summary = ', '.join([f"{p}:{len(c)}" for p, c in oauth_credentials.items()]) if oauth_credentials else "none"
575
- #_total_summary = ', '.join([f"{p}:{len(c)}" for p, c in client.all_credentials.items()])
576
- #print(f"🔑 Credentials loaded: {_total_summary} (API: {_api_summary} | OAuth: {_oauth_summary})")
577
- client.background_refresher.start() # Start the background task
578
  app.state.rotating_client = client
579
 
580
  # Warn if no provider credentials are configured
@@ -1263,8 +1271,8 @@ async def cost_estimate(request: Request, _=Depends(verify_api_key)):
1263
 
1264
 
1265
  if __name__ == "__main__":
1266
- # Define ENV_FILE for onboarding checks
1267
- ENV_FILE = Path.cwd() / ".env"
1268
 
1269
  # Check if launcher TUI should be shown (no arguments provided)
1270
  if len(sys.argv) == 1:
@@ -1331,7 +1339,7 @@ if __name__ == "__main__":
1331
 
1332
  ensure_env_defaults()
1333
  # Reload environment variables after ensure_env_defaults creates/updates .env
1334
- load_dotenv(override=True)
1335
  run_credential_tool()
1336
  else:
1337
  # Check if onboarding is needed
@@ -1349,11 +1357,11 @@ if __name__ == "__main__":
1349
  from rotator_library.credential_tool import ensure_env_defaults
1350
 
1351
  ensure_env_defaults()
1352
- load_dotenv(override=True)
1353
  run_credential_tool()
1354
 
1355
  # After credential tool exits, reload and re-check
1356
- load_dotenv(override=True)
1357
  # Re-read PROXY_API_KEY from environment
1358
  PROXY_API_KEY = os.getenv("PROXY_API_KEY")
1359
 
 
52
  from dotenv import load_dotenv
53
  from glob import glob
54
 
55
+ # Get the application root directory (EXE dir if frozen, else CWD)
56
+ # Inlined here to avoid triggering heavy rotator_library imports before loading screen
57
+ if getattr(sys, "frozen", False):
58
+ _root_dir = Path(sys.executable).parent
59
+ else:
60
+ _root_dir = Path.cwd()
61
+
62
  # Load main .env first
63
+ load_dotenv(_root_dir / ".env")
64
 
65
  # Load any additional .env files (e.g., antigravity_all_combined.env, gemini_cli_all_combined.env)
 
66
  _env_files_found = list(_root_dir.glob("*.env"))
67
  for _env_file in sorted(_root_dir.glob("*.env")):
68
  if _env_file.name != ".env": # Skip main .env (already loaded)
 
240
  # Note: Debug logging will be added after logging configuration below
241
 
242
  # --- Logging Configuration ---
243
+ # Import path utilities here (after loading screen) to avoid triggering heavy imports early
244
+ from rotator_library.utils.paths import get_logs_dir, get_data_file
245
+
246
+ LOG_DIR = get_logs_dir(_root_dir)
247
 
248
  # Configure a console handler with color (INFO and above only, no DEBUG)
249
  console_handler = colorlog.StreamHandler(sys.stdout)
 
332
  logging.debug(f"Modules loaded in {_elapsed:.2f}s")
333
 
334
  # Load environment variables from .env file
335
+ load_dotenv(_root_dir / ".env")
336
 
337
  # --- Configuration ---
338
  USE_EMBEDDING_BATCHER = False
 
578
  )
579
 
580
  # Log loaded credentials summary (compact, always visible for deployment verification)
581
+ # _api_summary = ', '.join([f"{p}:{len(c)}" for p, c in api_keys.items()]) if api_keys else "none"
582
+ # _oauth_summary = ', '.join([f"{p}:{len(c)}" for p, c in oauth_credentials.items()]) if oauth_credentials else "none"
583
+ # _total_summary = ', '.join([f"{p}:{len(c)}" for p, c in client.all_credentials.items()])
584
+ # print(f"🔑 Credentials loaded: {_total_summary} (API: {_api_summary} | OAuth: {_oauth_summary})")
585
+ client.background_refresher.start() # Start the background task
586
  app.state.rotating_client = client
587
 
588
  # Warn if no provider credentials are configured
 
1271
 
1272
 
1273
  if __name__ == "__main__":
1274
+ # Define ENV_FILE for onboarding checks using centralized path
1275
+ ENV_FILE = get_data_file(".env")
1276
 
1277
  # Check if launcher TUI should be shown (no arguments provided)
1278
  if len(sys.argv) == 1:
 
1339
 
1340
  ensure_env_defaults()
1341
  # Reload environment variables after ensure_env_defaults creates/updates .env
1342
+ load_dotenv(ENV_FILE, override=True)
1343
  run_credential_tool()
1344
  else:
1345
  # Check if onboarding is needed
 
1357
  from rotator_library.credential_tool import ensure_env_defaults
1358
 
1359
  ensure_env_defaults()
1360
+ load_dotenv(ENV_FILE, override=True)
1361
  run_credential_tool()
1362
 
1363
  # After credential tool exits, reload and re-check
1364
+ load_dotenv(ENV_FILE, override=True)
1365
  # Re-read PROXY_API_KEY from environment
1366
  PROXY_API_KEY = os.getenv("PROXY_API_KEY")
1367
 
src/proxy_app/settings_tool.py CHANGED
@@ -12,8 +12,36 @@ from rich.prompt import Prompt, IntPrompt, Confirm
12
  from rich.panel import Panel
13
  from dotenv import set_key, unset_key
14
 
 
 
15
  console = Console()
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def clear_screen():
19
  """
@@ -31,7 +59,7 @@ class AdvancedSettings:
31
  """Manages pending changes to .env"""
32
 
33
  def __init__(self):
34
- self.env_file = Path.cwd() / ".env"
35
  self.pending_changes = {} # key -> value (None means delete)
36
  self.load_current_settings()
37
 
@@ -39,7 +67,7 @@ class AdvancedSettings:
39
  """Load current .env values into env vars"""
40
  from dotenv import load_dotenv
41
 
42
- load_dotenv(override=True)
43
 
44
  def set(self, key: str, value: str):
45
  """Stage a change"""
@@ -70,6 +98,70 @@ class AdvancedSettings:
70
  """Check if there are pending changes"""
71
  return bool(self.pending_changes)
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  class CustomProviderManager:
75
  """Manages custom provider API bases"""
@@ -383,6 +475,11 @@ ANTIGRAVITY_SETTINGS = {
383
  "default": "\n\nSTRICT PARAMETERS: {params}.",
384
  "description": "Template for Claude strict parameter hints in tool descriptions",
385
  },
 
 
 
 
 
386
  }
387
 
388
  # Gemini CLI provider environment variables
@@ -427,12 +524,27 @@ GEMINI_CLI_SETTINGS = {
427
  "default": "",
428
  "description": "GCP Project ID for paid tier users (required for paid tiers)",
429
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  }
431
 
432
  # Map provider names to their settings definitions
433
  PROVIDER_SETTINGS_MAP = {
434
  "antigravity": ANTIGRAVITY_SETTINGS,
435
  "gemini_cli": GEMINI_CLI_SETTINGS,
 
436
  }
437
 
438
 
@@ -516,9 +628,61 @@ class SettingsTool:
516
  self.provider_settings_mgr = ProviderSettingsManager(self.settings)
517
  self.running = True
518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519
  def get_available_providers(self) -> List[str]:
520
  """Get list of providers that have credentials configured"""
521
- env_file = Path.cwd() / ".env"
522
  providers = set()
523
 
524
  # Scan for providers with API keys from local .env
@@ -541,7 +705,9 @@ class SettingsTool:
541
  pass
542
 
543
  # Also check for OAuth providers from files
544
- oauth_dir = Path("oauth_creds")
 
 
545
  if oauth_dir.exists():
546
  for file in oauth_dir.glob("*_oauth_*.json"):
547
  provider = file.name.split("_oauth_")[0]
@@ -579,12 +745,7 @@ class SettingsTool:
579
  self.console.print()
580
  self.console.print("━" * 70)
581
 
582
- if self.settings.has_pending():
583
- self.console.print(
584
- '[yellow]ℹ️ Changes are pending until you select "Save & Exit"[/yellow]'
585
- )
586
- else:
587
- self.console.print("[dim]ℹ️ No pending changes[/dim]")
588
 
589
  self.console.print()
590
  self.console.print(
@@ -618,6 +779,7 @@ class SettingsTool:
618
  while True:
619
  clear_screen()
620
 
 
621
  providers = self.provider_mgr.get_current_providers()
622
 
623
  self.console.print(
@@ -631,9 +793,48 @@ class SettingsTool:
631
  self.console.print("[bold]📋 Configured Custom Providers[/bold]")
632
  self.console.print("━" * 70)
633
 
634
- if providers:
635
- for name, base in providers.items():
636
- self.console.print(f" • {name:15} {base}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
637
  else:
638
  self.console.print(" [dim]No custom providers configured[/dim]")
639
 
@@ -662,7 +863,7 @@ class SettingsTool:
662
  if api_base:
663
  self.provider_mgr.add_provider(name, api_base)
664
  self.console.print(
665
- f"\n[green]✅ Custom provider '{name}' configured![/green]"
666
  )
667
  self.console.print(
668
  f" To use: set {name.upper()}_API_KEY in credentials"
@@ -670,14 +871,18 @@ class SettingsTool:
670
  input("\nPress Enter to continue...")
671
 
672
  elif choice == "2":
673
- if not providers:
 
 
 
 
674
  self.console.print("\n[yellow]No providers to edit[/yellow]")
675
  input("\nPress Enter to continue...")
676
  continue
677
 
678
  # Show numbered list
679
  self.console.print("\n[bold]Select provider to edit:[/bold]")
680
- providers_list = list(providers.keys())
681
  for idx, prov in enumerate(providers_list, 1):
682
  self.console.print(f" {idx}. {prov}")
683
 
@@ -686,7 +891,9 @@ class SettingsTool:
686
  choices=[str(i) for i in range(1, len(providers_list) + 1)],
687
  )
688
  name = providers_list[choice_idx - 1]
689
- current_base = providers.get(name, "")
 
 
690
 
691
  self.console.print(f"\nCurrent API Base: {current_base}")
692
  new_base = Prompt.ask(
@@ -703,16 +910,33 @@ class SettingsTool:
703
  input("\nPress Enter to continue...")
704
 
705
  elif choice == "3":
706
- if not providers:
 
 
 
 
 
 
 
 
 
 
 
707
  self.console.print("\n[yellow]No providers to remove[/yellow]")
708
  input("\nPress Enter to continue...")
709
  continue
710
 
711
  # Show numbered list
712
  self.console.print("\n[bold]Select provider to remove:[/bold]")
713
- providers_list = list(providers.keys())
 
714
  for idx, prov in enumerate(providers_list, 1):
715
- self.console.print(f" {idx}. {prov}")
 
 
 
 
 
716
 
717
  choice_idx = IntPrompt.ask(
718
  "Select option",
@@ -721,10 +945,18 @@ class SettingsTool:
721
  name = providers_list[choice_idx - 1]
722
 
723
  if Confirm.ask(f"Remove '{name}'?"):
724
- self.provider_mgr.remove_provider(name)
725
- self.console.print(
726
- f"\n[green]✅ Provider '{name}' removed![/green]"
727
- )
 
 
 
 
 
 
 
 
728
  input("\nPress Enter to continue...")
729
 
730
  elif choice == "4":
@@ -735,7 +967,8 @@ class SettingsTool:
735
  while True:
736
  clear_screen()
737
 
738
- all_providers = self.model_mgr.get_all_providers_with_models()
 
739
 
740
  self.console.print(
741
  Panel.fit(
@@ -748,10 +981,69 @@ class SettingsTool:
748
  self.console.print("[bold]📋 Configured Provider Models[/bold]")
749
  self.console.print("━" * 70)
750
 
751
- if all_providers:
752
- for provider, count in all_providers.items():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
753
  self.console.print(
754
- f" • {provider:15} {count} model{'s' if count > 1 else ''}"
 
 
755
  )
756
  else:
757
  self.console.print(" [dim]No model definitions configured[/dim]")
@@ -778,19 +1070,36 @@ class SettingsTool:
778
  if choice == "1":
779
  self.add_model_definitions()
780
  elif choice == "2":
781
- if not all_providers:
 
 
 
 
782
  self.console.print("\n[yellow]No providers to edit[/yellow]")
783
  input("\nPress Enter to continue...")
784
  continue
785
- self.edit_model_definitions(list(all_providers.keys()))
786
  elif choice == "3":
787
- if not all_providers:
 
 
 
788
  self.console.print("\n[yellow]No providers to view[/yellow]")
789
  input("\nPress Enter to continue...")
790
  continue
791
- self.view_model_definitions(list(all_providers.keys()))
792
  elif choice == "4":
793
- if not all_providers:
 
 
 
 
 
 
 
 
 
 
794
  self.console.print("\n[yellow]No providers to remove[/yellow]")
795
  input("\nPress Enter to continue...")
796
  continue
@@ -799,9 +1108,14 @@ class SettingsTool:
799
  self.console.print(
800
  "\n[bold]Select provider to remove models from:[/bold]"
801
  )
802
- providers_list = list(all_providers.keys())
803
  for idx, prov in enumerate(providers_list, 1):
804
- self.console.print(f" {idx}. {prov}")
 
 
 
 
 
805
 
806
  choice_idx = IntPrompt.ask(
807
  "Select option",
@@ -810,10 +1124,18 @@ class SettingsTool:
810
  provider = providers_list[choice_idx - 1]
811
 
812
  if Confirm.ask(f"Remove all model definitions for '{provider}'?"):
813
- self.model_mgr.remove_models(provider)
814
- self.console.print(
815
- f"\n[green]✅ Model definitions removed for '{provider}'![/green]"
816
- )
 
 
 
 
 
 
 
 
817
  input("\nPress Enter to continue...")
818
  elif choice == "5":
819
  break
@@ -1140,7 +1462,7 @@ class SettingsTool:
1140
  self.console.print("[bold]📋 Current Settings[/bold]")
1141
  self.console.print("━" * 70)
1142
 
1143
- # Display all settings with current values
1144
  settings_list = list(definitions.keys())
1145
  for idx, key in enumerate(settings_list, 1):
1146
  definition = definitions[key]
@@ -1149,37 +1471,88 @@ class SettingsTool:
1149
  setting_type = definition.get("type", "str")
1150
  description = definition.get("description", "")
1151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1152
  # Format value display
1153
  if setting_type == "bool":
1154
  value_display = (
1155
  "[green]✓ Enabled[/green]"
1156
- if current
1157
  else "[red]✗ Disabled[/red]"
1158
  )
 
 
 
 
 
 
 
 
 
1159
  elif setting_type == "int":
1160
- value_display = f"[cyan]{current}[/cyan]"
 
1161
  else:
1162
  value_display = (
1163
- f"[cyan]{current or '(not set)'}[/cyan]"
1164
- if current
1165
  else "[dim](not set)[/dim]"
1166
  )
1167
-
1168
- # Check if modified from default
1169
- modified = current != default
1170
- mod_marker = "[yellow]*[/yellow]" if modified else " "
1171
 
1172
  # Short key name for display (strip provider prefix)
1173
  short_key = key.replace(f"{provider.upper()}_", "")
1174
 
1175
- self.console.print(
1176
- f" {mod_marker}{idx:2}. {short_key:35} {value_display}"
1177
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1178
  self.console.print(f" [dim]{description}[/dim]")
1179
 
1180
  self.console.print()
1181
  self.console.print("━" * 70)
1182
- self.console.print("[dim]* = modified from default[/dim]")
 
 
1183
  self.console.print()
1184
  self.console.print("[bold]⚙️ Actions[/bold]")
1185
  self.console.print()
@@ -1299,6 +1672,7 @@ class SettingsTool:
1299
  while True:
1300
  clear_screen()
1301
 
 
1302
  modes = self.rotation_mgr.get_current_modes()
1303
  available_providers = self.get_available_providers()
1304
 
@@ -1322,20 +1696,78 @@ class SettingsTool:
1322
  self.console.print("[bold]📋 Current Rotation Mode Settings[/bold]")
1323
  self.console.print("━" * 70)
1324
 
1325
- if modes:
1326
- for provider, mode in modes.items():
1327
- default_mode = self.rotation_mgr.get_default_mode(provider)
1328
- is_custom = mode != default_mode
1329
- marker = "[yellow]*[/yellow]" if is_custom else " "
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1330
  mode_display = (
1331
  f"[green]{mode}[/green]"
1332
  if mode == "sequential"
1333
  else f"[blue]{mode}[/blue]"
1334
  )
1335
- self.console.print(f" {marker}• {provider:20} {mode_display}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1336
 
1337
  # Show providers with default modes
1338
- providers_with_defaults = [p for p in available_providers if p not in modes]
 
 
1339
  if providers_with_defaults:
1340
  self.console.print()
1341
  self.console.print("[dim]Providers using default modes:[/dim]")
@@ -1423,12 +1855,16 @@ class SettingsTool:
1423
 
1424
  self.rotation_mgr.set_mode(provider, new_mode)
1425
  self.console.print(
1426
- f"\n[green]✅ Rotation mode for '{provider}' set to {new_mode}![/green]"
1427
  )
1428
  input("\nPress Enter to continue...")
1429
 
1430
  elif choice == "2":
1431
- if not modes:
 
 
 
 
1432
  self.console.print(
1433
  "\n[yellow]No custom rotation modes to reset[/yellow]"
1434
  )
@@ -1439,12 +1875,18 @@ class SettingsTool:
1439
  self.console.print(
1440
  "\n[bold]Select provider to reset to default:[/bold]"
1441
  )
1442
- modes_list = list(modes.keys())
1443
  for idx, prov in enumerate(modes_list, 1):
1444
  default_mode = self.rotation_mgr.get_default_mode(prov)
1445
- self.console.print(
1446
- f" {idx}. {prov} (will reset to: {default_mode})"
1447
- )
 
 
 
 
 
 
1448
 
1449
  choice_idx = IntPrompt.ask(
1450
  "Select option",
@@ -1452,12 +1894,21 @@ class SettingsTool:
1452
  )
1453
  provider = modes_list[choice_idx - 1]
1454
  default_mode = self.rotation_mgr.get_default_mode(provider)
 
1455
 
1456
  if Confirm.ask(f"Reset '{provider}' to default mode ({default_mode})?"):
1457
- self.rotation_mgr.remove_mode(provider)
1458
- self.console.print(
1459
- f"\n[green]✅ Rotation mode for '{provider}' reset to default ({default_mode})![/green]"
1460
- )
 
 
 
 
 
 
 
 
1461
  input("\nPress Enter to continue...")
1462
 
1463
  elif choice == "3":
@@ -1630,6 +2081,7 @@ class SettingsTool:
1630
  while True:
1631
  clear_screen()
1632
 
 
1633
  limits = self.concurrency_mgr.get_current_limits()
1634
 
1635
  self.console.print(
@@ -1643,10 +2095,57 @@ class SettingsTool:
1643
  self.console.print("[bold]📋 Current Concurrency Settings[/bold]")
1644
  self.console.print("━" * 70)
1645
 
1646
- if limits:
1647
- for provider, limit in limits.items():
1648
- self.console.print(f" • {provider:15} {limit} requests/key")
1649
- self.console.print(f" • Default: 1 request/key (all others)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1650
  else:
1651
  self.console.print(" • Default: 1 request/key (all providers)")
1652
 
@@ -1704,7 +2203,7 @@ class SettingsTool:
1704
  if 1 <= limit <= 100:
1705
  self.concurrency_mgr.set_limit(provider, limit)
1706
  self.console.print(
1707
- f"\n[green]✅ Concurrency limit set for '{provider}': {limit} requests/key[/green]"
1708
  )
1709
  else:
1710
  self.console.print(
@@ -1713,14 +2212,18 @@ class SettingsTool:
1713
  input("\nPress Enter to continue...")
1714
 
1715
  elif choice == "2":
1716
- if not limits:
 
 
 
 
1717
  self.console.print("\n[yellow]No limits to edit[/yellow]")
1718
  input("\nPress Enter to continue...")
1719
  continue
1720
 
1721
  # Show numbered list
1722
  self.console.print("\n[bold]Select provider to edit:[/bold]")
1723
- limits_list = list(limits.keys())
1724
  for idx, prov in enumerate(limits_list, 1):
1725
  self.console.print(f" {idx}. {prov}")
1726
 
@@ -1729,7 +2232,8 @@ class SettingsTool:
1729
  choices=[str(i) for i in range(1, len(limits_list) + 1)],
1730
  )
1731
  provider = limits_list[choice_idx - 1]
1732
- current_limit = limits.get(provider, 1)
 
1733
 
1734
  self.console.print(f"\nCurrent limit: {current_limit} requests/key")
1735
  new_limit = IntPrompt.ask(
@@ -1750,7 +2254,18 @@ class SettingsTool:
1750
  input("\nPress Enter to continue...")
1751
 
1752
  elif choice == "3":
1753
- if not limits:
 
 
 
 
 
 
 
 
 
 
 
1754
  self.console.print("\n[yellow]No limits to remove[/yellow]")
1755
  input("\nPress Enter to continue...")
1756
  continue
@@ -1759,9 +2274,14 @@ class SettingsTool:
1759
  self.console.print(
1760
  "\n[bold]Select provider to remove limit from:[/bold]"
1761
  )
1762
- limits_list = list(limits.keys())
1763
  for idx, prov in enumerate(limits_list, 1):
1764
- self.console.print(f" {idx}. {prov}")
 
 
 
 
 
1765
 
1766
  choice_idx = IntPrompt.ask(
1767
  "Select option",
@@ -1772,18 +2292,118 @@ class SettingsTool:
1772
  if Confirm.ask(
1773
  f"Remove concurrency limit for '{provider}' (reset to default 1)?"
1774
  ):
1775
- self.concurrency_mgr.remove_limit(provider)
1776
- self.console.print(
1777
- f"\n[green]✅ Limit removed for '{provider}' - using default (1 request/key)[/green]"
1778
- )
 
 
 
 
 
 
 
 
1779
  input("\nPress Enter to continue...")
1780
 
1781
  elif choice == "4":
1782
  break
1783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1784
  def save_and_exit(self):
1785
  """Save pending changes and exit"""
1786
  if self.settings.has_pending():
 
 
 
1787
  if Confirm.ask("\n[bold yellow]Save all pending changes?[/bold yellow]"):
1788
  self.settings.save()
1789
  self.console.print("\n[green]✅ All changes saved to .env![/green]")
@@ -1801,6 +2421,9 @@ class SettingsTool:
1801
  def exit_without_saving(self):
1802
  """Exit without saving"""
1803
  if self.settings.has_pending():
 
 
 
1804
  if Confirm.ask("\n[bold red]Discard all pending changes?[/bold red]"):
1805
  self.settings.discard()
1806
  self.console.print("\n[yellow]Changes discarded[/yellow]")
 
12
  from rich.panel import Panel
13
  from dotenv import set_key, unset_key
14
 
15
+ from rotator_library.utils.paths import get_data_file
16
+
17
  console = Console()
18
 
19
+ # Sentinel value for distinguishing "no pending change" from "pending change to None"
20
+ _NOT_FOUND = object()
21
+
22
+ # Import default OAuth port values from provider modules
23
+ # These serve as the source of truth for default port values
24
+ try:
25
+ from rotator_library.providers.gemini_auth_base import GeminiAuthBase
26
+
27
+ GEMINI_CLI_DEFAULT_OAUTH_PORT = GeminiAuthBase.CALLBACK_PORT
28
+ except ImportError:
29
+ GEMINI_CLI_DEFAULT_OAUTH_PORT = 8085
30
+
31
+ try:
32
+ from rotator_library.providers.antigravity_auth_base import AntigravityAuthBase
33
+
34
+ ANTIGRAVITY_DEFAULT_OAUTH_PORT = AntigravityAuthBase.CALLBACK_PORT
35
+ except ImportError:
36
+ ANTIGRAVITY_DEFAULT_OAUTH_PORT = 51121
37
+
38
+ try:
39
+ from rotator_library.providers.iflow_auth_base import (
40
+ CALLBACK_PORT as IFLOW_DEFAULT_OAUTH_PORT,
41
+ )
42
+ except ImportError:
43
+ IFLOW_DEFAULT_OAUTH_PORT = 11451
44
+
45
 
46
  def clear_screen():
47
  """
 
59
  """Manages pending changes to .env"""
60
 
61
  def __init__(self):
62
+ self.env_file = get_data_file(".env")
63
  self.pending_changes = {} # key -> value (None means delete)
64
  self.load_current_settings()
65
 
 
67
  """Load current .env values into env vars"""
68
  from dotenv import load_dotenv
69
 
70
+ load_dotenv(self.env_file, override=True)
71
 
72
  def set(self, key: str, value: str):
73
  """Stage a change"""
 
98
  """Check if there are pending changes"""
99
  return bool(self.pending_changes)
100
 
101
+ def get_pending_value(self, key: str):
102
+ """Get pending value for a key. Returns sentinel _NOT_FOUND if no pending change."""
103
+ return self.pending_changes.get(key, _NOT_FOUND)
104
+
105
+ def get_original_value(self, key: str) -> Optional[str]:
106
+ """Get the current .env value (before pending changes)"""
107
+ return os.getenv(key)
108
+
109
+ def get_change_type(self, key: str) -> Optional[str]:
110
+ """Returns 'add', 'edit', 'remove', or None if no pending change"""
111
+ if key not in self.pending_changes:
112
+ return None
113
+ if self.pending_changes[key] is None:
114
+ return "remove"
115
+ elif os.getenv(key) is not None:
116
+ return "edit"
117
+ else:
118
+ return "add"
119
+
120
+ def get_pending_keys_by_pattern(
121
+ self, prefix: str = "", suffix: str = ""
122
+ ) -> List[str]:
123
+ """Get all pending change keys that match prefix and/or suffix"""
124
+ return [
125
+ k
126
+ for k in self.pending_changes.keys()
127
+ if k.startswith(prefix) and k.endswith(suffix)
128
+ ]
129
+
130
+ def get_changes_summary(self) -> Dict[str, List[tuple]]:
131
+ """Get categorized summary of all pending changes.
132
+ Returns dict with 'add', 'edit', 'remove' keys,
133
+ each containing list of (key, old_val, new_val) tuples.
134
+ """
135
+ summary: Dict[str, List[tuple]] = {"add": [], "edit": [], "remove": []}
136
+ for key, new_val in self.pending_changes.items():
137
+ old_val = os.getenv(key)
138
+ change_type = self.get_change_type(key)
139
+ if change_type:
140
+ summary[change_type].append((key, old_val, new_val))
141
+ # Sort each list alphabetically by key
142
+ for change_type in summary:
143
+ summary[change_type].sort(key=lambda x: x[0])
144
+ return summary
145
+
146
+ def get_pending_counts(self) -> Dict[str, int]:
147
+ """Get counts of pending changes by type"""
148
+ adds = len(
149
+ [
150
+ k
151
+ for k, v in self.pending_changes.items()
152
+ if v is not None and os.getenv(k) is None
153
+ ]
154
+ )
155
+ edits = len(
156
+ [
157
+ k
158
+ for k, v in self.pending_changes.items()
159
+ if v is not None and os.getenv(k) is not None
160
+ ]
161
+ )
162
+ removes = len([k for k, v in self.pending_changes.items() if v is None])
163
+ return {"add": adds, "edit": edits, "remove": removes}
164
+
165
 
166
  class CustomProviderManager:
167
  """Manages custom provider API bases"""
 
475
  "default": "\n\nSTRICT PARAMETERS: {params}.",
476
  "description": "Template for Claude strict parameter hints in tool descriptions",
477
  },
478
+ "ANTIGRAVITY_OAUTH_PORT": {
479
+ "type": "int",
480
+ "default": ANTIGRAVITY_DEFAULT_OAUTH_PORT,
481
+ "description": "Local port for OAuth callback server during authentication",
482
+ },
483
  }
484
 
485
  # Gemini CLI provider environment variables
 
524
  "default": "",
525
  "description": "GCP Project ID for paid tier users (required for paid tiers)",
526
  },
527
+ "GEMINI_CLI_OAUTH_PORT": {
528
+ "type": "int",
529
+ "default": GEMINI_CLI_DEFAULT_OAUTH_PORT,
530
+ "description": "Local port for OAuth callback server during authentication",
531
+ },
532
+ }
533
+
534
+ # iFlow provider environment variables
535
+ IFLOW_SETTINGS = {
536
+ "IFLOW_OAUTH_PORT": {
537
+ "type": "int",
538
+ "default": IFLOW_DEFAULT_OAUTH_PORT,
539
+ "description": "Local port for OAuth callback server during authentication",
540
+ },
541
  }
542
 
543
  # Map provider names to their settings definitions
544
  PROVIDER_SETTINGS_MAP = {
545
  "antigravity": ANTIGRAVITY_SETTINGS,
546
  "gemini_cli": GEMINI_CLI_SETTINGS,
547
+ "iflow": IFLOW_SETTINGS,
548
  }
549
 
550
 
 
628
  self.provider_settings_mgr = ProviderSettingsManager(self.settings)
629
  self.running = True
630
 
631
+ def _format_item(
632
+ self,
633
+ name: str,
634
+ value: str,
635
+ change_type: Optional[str],
636
+ old_value: Optional[str] = None,
637
+ width: int = 15,
638
+ ) -> str:
639
+ """Format a list item with change indicator.
640
+
641
+ change_type: None, 'add', 'edit', 'remove'
642
+ Returns formatted string like:
643
+ " + myapi https://api.example.com" (green)
644
+ " ~ openai 1 → 5 requests/key" (yellow)
645
+ " - oldapi https://old.api.com" (red)
646
+ " • groq 3 requests/key" (normal)
647
+ """
648
+ if change_type == "add":
649
+ return f" [green]+ {name:{width}} {value}[/green]"
650
+ elif change_type == "edit":
651
+ if old_value is not None:
652
+ return f" [yellow]~ {name:{width}} {old_value} → {value}[/yellow]"
653
+ else:
654
+ return f" [yellow]~ {name:{width}} {value}[/yellow]"
655
+ elif change_type == "remove":
656
+ return f" [red]- {name:{width}} {value}[/red]"
657
+ else:
658
+ return f" • {name:{width}} {value}"
659
+
660
+ def _get_pending_status_text(self) -> str:
661
+ """Get formatted pending changes status text for main menu."""
662
+ if not self.settings.has_pending():
663
+ return "[dim]ℹ️ No pending changes[/dim]"
664
+
665
+ counts = self.settings.get_pending_counts()
666
+ parts = []
667
+ if counts["add"]:
668
+ parts.append(
669
+ f"[green]{counts['add']} addition{'s' if counts['add'] > 1 else ''}[/green]"
670
+ )
671
+ if counts["edit"]:
672
+ parts.append(
673
+ f"[yellow]{counts['edit']} modification{'s' if counts['edit'] > 1 else ''}[/yellow]"
674
+ )
675
+ if counts["remove"]:
676
+ parts.append(
677
+ f"[red]{counts['remove']} removal{'s' if counts['remove'] > 1 else ''}[/red]"
678
+ )
679
+
680
+ return f"[bold]ℹ️ Pending changes: {', '.join(parts)}[/bold]"
681
+ self.running = True
682
+
683
  def get_available_providers(self) -> List[str]:
684
  """Get list of providers that have credentials configured"""
685
+ env_file = get_data_file(".env")
686
  providers = set()
687
 
688
  # Scan for providers with API keys from local .env
 
705
  pass
706
 
707
  # Also check for OAuth providers from files
708
+ from rotator_library.utils.paths import get_oauth_dir
709
+
710
+ oauth_dir = get_oauth_dir()
711
  if oauth_dir.exists():
712
  for file in oauth_dir.glob("*_oauth_*.json"):
713
  provider = file.name.split("_oauth_")[0]
 
745
  self.console.print()
746
  self.console.print("━" * 70)
747
 
748
+ self.console.print(self._get_pending_status_text())
 
 
 
 
 
749
 
750
  self.console.print()
751
  self.console.print(
 
779
  while True:
780
  clear_screen()
781
 
782
+ # Get current providers from env
783
  providers = self.provider_mgr.get_current_providers()
784
 
785
  self.console.print(
 
793
  self.console.print("[bold]📋 Configured Custom Providers[/bold]")
794
  self.console.print("━" * 70)
795
 
796
+ # Build combined view with pending changes
797
+ all_providers: Dict[str, Dict[str, Any]] = {}
798
+
799
+ # Add current providers (from env)
800
+ for name, base in providers.items():
801
+ key = f"{name.upper()}_API_BASE"
802
+ change_type = self.settings.get_change_type(key)
803
+ if change_type == "remove":
804
+ all_providers[name] = {"value": base, "type": "remove", "old": None}
805
+ elif change_type == "edit":
806
+ new_val = self.settings.pending_changes[key]
807
+ all_providers[name] = {
808
+ "value": new_val,
809
+ "type": "edit",
810
+ "old": base,
811
+ }
812
+ else:
813
+ all_providers[name] = {"value": base, "type": None, "old": None}
814
+
815
+ # Add pending new providers (additions)
816
+ for key in self.settings.get_pending_keys_by_pattern(suffix="_API_BASE"):
817
+ if self.settings.get_change_type(key) == "add":
818
+ name = key.replace("_API_BASE", "").lower()
819
+ if name not in all_providers:
820
+ all_providers[name] = {
821
+ "value": self.settings.pending_changes[key],
822
+ "type": "add",
823
+ "old": None,
824
+ }
825
+
826
+ if all_providers:
827
+ # Sort alphabetically
828
+ for name in sorted(all_providers.keys()):
829
+ info = all_providers[name]
830
+ self.console.print(
831
+ self._format_item(
832
+ name,
833
+ info["value"],
834
+ info["type"],
835
+ info["old"],
836
+ )
837
+ )
838
  else:
839
  self.console.print(" [dim]No custom providers configured[/dim]")
840
 
 
863
  if api_base:
864
  self.provider_mgr.add_provider(name, api_base)
865
  self.console.print(
866
+ f"\n[green]✅ Custom provider '{name}' staged![/green]"
867
  )
868
  self.console.print(
869
  f" To use: set {name.upper()}_API_KEY in credentials"
 
871
  input("\nPress Enter to continue...")
872
 
873
  elif choice == "2":
874
+ # Get editable providers (existing + pending additions, excluding pending removals)
875
+ editable = {
876
+ k: v for k, v in all_providers.items() if v["type"] != "remove"
877
+ }
878
+ if not editable:
879
  self.console.print("\n[yellow]No providers to edit[/yellow]")
880
  input("\nPress Enter to continue...")
881
  continue
882
 
883
  # Show numbered list
884
  self.console.print("\n[bold]Select provider to edit:[/bold]")
885
+ providers_list = sorted(editable.keys())
886
  for idx, prov in enumerate(providers_list, 1):
887
  self.console.print(f" {idx}. {prov}")
888
 
 
891
  choices=[str(i) for i in range(1, len(providers_list) + 1)],
892
  )
893
  name = providers_list[choice_idx - 1]
894
+ info = editable[name]
895
+ # Get effective current value (could be pending or from env)
896
+ current_base = info["value"]
897
 
898
  self.console.print(f"\nCurrent API Base: {current_base}")
899
  new_base = Prompt.ask(
 
910
  input("\nPress Enter to continue...")
911
 
912
  elif choice == "3":
913
+ # Get removable providers (existing ones not already pending removal)
914
+ removable = {
915
+ k: v
916
+ for k, v in all_providers.items()
917
+ if v["type"] != "remove" and v["type"] != "add"
918
+ }
919
+ # For pending additions, we can "undo" by removing from pending
920
+ pending_adds = {
921
+ k: v for k, v in all_providers.items() if v["type"] == "add"
922
+ }
923
+
924
+ if not removable and not pending_adds:
925
  self.console.print("\n[yellow]No providers to remove[/yellow]")
926
  input("\nPress Enter to continue...")
927
  continue
928
 
929
  # Show numbered list
930
  self.console.print("\n[bold]Select provider to remove:[/bold]")
931
+ # Show existing providers first, then pending additions
932
+ providers_list = sorted(removable.keys()) + sorted(pending_adds.keys())
933
  for idx, prov in enumerate(providers_list, 1):
934
+ if prov in pending_adds:
935
+ self.console.print(
936
+ f" {idx}. {prov} [green](pending add)[/green]"
937
+ )
938
+ else:
939
+ self.console.print(f" {idx}. {prov}")
940
 
941
  choice_idx = IntPrompt.ask(
942
  "Select option",
 
945
  name = providers_list[choice_idx - 1]
946
 
947
  if Confirm.ask(f"Remove '{name}'?"):
948
+ if name in pending_adds:
949
+ # Undo pending addition - remove from pending_changes
950
+ key = f"{name.upper()}_API_BASE"
951
+ del self.settings.pending_changes[key]
952
+ self.console.print(
953
+ f"\n[green]✅ Pending addition of '{name}' cancelled![/green]"
954
+ )
955
+ else:
956
+ self.provider_mgr.remove_provider(name)
957
+ self.console.print(
958
+ f"\n[green]✅ Provider '{name}' marked for removal![/green]"
959
+ )
960
  input("\nPress Enter to continue...")
961
 
962
  elif choice == "4":
 
967
  while True:
968
  clear_screen()
969
 
970
+ # Get current providers with models from env
971
+ all_providers_env = self.model_mgr.get_all_providers_with_models()
972
 
973
  self.console.print(
974
  Panel.fit(
 
981
  self.console.print("[bold]📋 Configured Provider Models[/bold]")
982
  self.console.print("━" * 70)
983
 
984
+ # Build combined view with pending changes
985
+ all_models: Dict[str, Dict[str, Any]] = {}
986
+ suffix = "_MODELS"
987
+
988
+ # Add current providers (from env)
989
+ for provider, count in all_providers_env.items():
990
+ key = f"{provider.upper()}{suffix}"
991
+ change_type = self.settings.get_change_type(key)
992
+ if change_type == "remove":
993
+ all_models[provider] = {
994
+ "value": f"{count} model{'s' if count > 1 else ''}",
995
+ "type": "remove",
996
+ "old": None,
997
+ }
998
+ elif change_type == "edit":
999
+ # Get new model count from pending
1000
+ new_val = self.settings.pending_changes[key]
1001
+ try:
1002
+ parsed = json.loads(new_val)
1003
+ new_count = (
1004
+ len(parsed) if isinstance(parsed, (dict, list)) else 0
1005
+ )
1006
+ except (json.JSONDecodeError, ValueError):
1007
+ new_count = 0
1008
+ all_models[provider] = {
1009
+ "value": f"{new_count} model{'s' if new_count > 1 else ''}",
1010
+ "type": "edit",
1011
+ "old": f"{count} model{'s' if count > 1 else ''}",
1012
+ }
1013
+ else:
1014
+ all_models[provider] = {
1015
+ "value": f"{count} model{'s' if count > 1 else ''}",
1016
+ "type": None,
1017
+ "old": None,
1018
+ }
1019
+
1020
+ # Add pending new model definitions (additions)
1021
+ for key in self.settings.get_pending_keys_by_pattern(suffix=suffix):
1022
+ if self.settings.get_change_type(key) == "add":
1023
+ provider = key.replace(suffix, "").lower()
1024
+ if provider not in all_models:
1025
+ new_val = self.settings.pending_changes[key]
1026
+ try:
1027
+ parsed = json.loads(new_val)
1028
+ new_count = (
1029
+ len(parsed) if isinstance(parsed, (dict, list)) else 0
1030
+ )
1031
+ except (json.JSONDecodeError, ValueError):
1032
+ new_count = 0
1033
+ all_models[provider] = {
1034
+ "value": f"{new_count} model{'s' if new_count > 1 else ''}",
1035
+ "type": "add",
1036
+ "old": None,
1037
+ }
1038
+
1039
+ if all_models:
1040
+ # Sort alphabetically
1041
+ for provider in sorted(all_models.keys()):
1042
+ info = all_models[provider]
1043
  self.console.print(
1044
+ self._format_item(
1045
+ provider, info["value"], info["type"], info["old"]
1046
+ )
1047
  )
1048
  else:
1049
  self.console.print(" [dim]No model definitions configured[/dim]")
 
1070
  if choice == "1":
1071
  self.add_model_definitions()
1072
  elif choice == "2":
1073
+ # Get editable models (existing + pending additions, excluding pending removals)
1074
+ editable = {
1075
+ k: v for k, v in all_models.items() if v["type"] != "remove"
1076
+ }
1077
+ if not editable:
1078
  self.console.print("\n[yellow]No providers to edit[/yellow]")
1079
  input("\nPress Enter to continue...")
1080
  continue
1081
+ self.edit_model_definitions(sorted(editable.keys()))
1082
  elif choice == "3":
1083
+ viewable = {
1084
+ k: v for k, v in all_models.items() if v["type"] != "remove"
1085
+ }
1086
+ if not viewable:
1087
  self.console.print("\n[yellow]No providers to view[/yellow]")
1088
  input("\nPress Enter to continue...")
1089
  continue
1090
+ self.view_model_definitions(sorted(viewable.keys()))
1091
  elif choice == "4":
1092
+ # Get removable models (existing ones not already pending removal)
1093
+ removable = {
1094
+ k: v
1095
+ for k, v in all_models.items()
1096
+ if v["type"] != "remove" and v["type"] != "add"
1097
+ }
1098
+ pending_adds = {
1099
+ k: v for k, v in all_models.items() if v["type"] == "add"
1100
+ }
1101
+
1102
+ if not removable and not pending_adds:
1103
  self.console.print("\n[yellow]No providers to remove[/yellow]")
1104
  input("\nPress Enter to continue...")
1105
  continue
 
1108
  self.console.print(
1109
  "\n[bold]Select provider to remove models from:[/bold]"
1110
  )
1111
+ providers_list = sorted(removable.keys()) + sorted(pending_adds.keys())
1112
  for idx, prov in enumerate(providers_list, 1):
1113
+ if prov in pending_adds:
1114
+ self.console.print(
1115
+ f" {idx}. {prov} [green](pending add)[/green]"
1116
+ )
1117
+ else:
1118
+ self.console.print(f" {idx}. {prov}")
1119
 
1120
  choice_idx = IntPrompt.ask(
1121
  "Select option",
 
1124
  provider = providers_list[choice_idx - 1]
1125
 
1126
  if Confirm.ask(f"Remove all model definitions for '{provider}'?"):
1127
+ if provider in pending_adds:
1128
+ # Undo pending addition
1129
+ key = f"{provider.upper()}{suffix}"
1130
+ del self.settings.pending_changes[key]
1131
+ self.console.print(
1132
+ f"\n[green]✅ Pending models for '{provider}' cancelled![/green]"
1133
+ )
1134
+ else:
1135
+ self.model_mgr.remove_models(provider)
1136
+ self.console.print(
1137
+ f"\n[green]✅ Model definitions marked for removal for '{provider}'![/green]"
1138
+ )
1139
  input("\nPress Enter to continue...")
1140
  elif choice == "5":
1141
  break
 
1462
  self.console.print("[bold]📋 Current Settings[/bold]")
1463
  self.console.print("━" * 70)
1464
 
1465
+ # Display all settings with current values and pending changes
1466
  settings_list = list(definitions.keys())
1467
  for idx, key in enumerate(settings_list, 1):
1468
  definition = definitions[key]
 
1471
  setting_type = definition.get("type", "str")
1472
  description = definition.get("description", "")
1473
 
1474
+ # Check for pending changes
1475
+ change_type = self.settings.get_change_type(key)
1476
+ pending_val = self.settings.get_pending_value(key)
1477
+
1478
+ # Determine effective value to display
1479
+ if pending_val is not _NOT_FOUND and pending_val is not None:
1480
+ # Has pending change - convert to proper type for display
1481
+ if setting_type == "bool":
1482
+ effective = pending_val.lower() in ("true", "1", "yes")
1483
+ elif setting_type == "int":
1484
+ try:
1485
+ effective = int(pending_val)
1486
+ except (ValueError, TypeError):
1487
+ effective = pending_val
1488
+ else:
1489
+ effective = pending_val
1490
+ elif pending_val is None and change_type == "remove":
1491
+ # Pending removal - will revert to default
1492
+ effective = default
1493
+ else:
1494
+ effective = current
1495
+
1496
  # Format value display
1497
  if setting_type == "bool":
1498
  value_display = (
1499
  "[green]✓ Enabled[/green]"
1500
+ if effective
1501
  else "[red]✗ Disabled[/red]"
1502
  )
1503
+ old_display = (
1504
+ (
1505
+ "[green]✓ Enabled[/green]"
1506
+ if current
1507
+ else "[red]✗ Disabled[/red]"
1508
+ )
1509
+ if change_type
1510
+ else None
1511
+ )
1512
  elif setting_type == "int":
1513
+ value_display = f"[cyan]{effective}[/cyan]"
1514
+ old_display = f"[cyan]{current}[/cyan]" if change_type else None
1515
  else:
1516
  value_display = (
1517
+ f"[cyan]{effective or '(not set)'}[/cyan]"
1518
+ if effective
1519
  else "[dim](not set)[/dim]"
1520
  )
1521
+ old_display = (
1522
+ f"[cyan]{current}[/cyan]" if change_type and current else None
1523
+ )
 
1524
 
1525
  # Short key name for display (strip provider prefix)
1526
  short_key = key.replace(f"{provider.upper()}_", "")
1527
 
1528
+ # Determine display marker based on pending change type
1529
+ if change_type == "add":
1530
+ self.console.print(
1531
+ f" [green]+{idx:2}. {short_key:35} {value_display}[/green]"
1532
+ )
1533
+ elif change_type == "edit":
1534
+ self.console.print(
1535
+ f" [yellow]~{idx:2}. {short_key:35} {old_display} → {value_display}[/yellow]"
1536
+ )
1537
+ elif change_type == "remove":
1538
+ self.console.print(
1539
+ f" [red]-{idx:2}. {short_key:35} {old_display} → [dim](default: {default})[/dim][/red]"
1540
+ )
1541
+ else:
1542
+ # Check if modified from default (in env, not pending)
1543
+ modified = current != default
1544
+ mod_marker = "[yellow]*[/yellow]" if modified else " "
1545
+ self.console.print(
1546
+ f" {mod_marker}{idx:2}. {short_key:35} {value_display}"
1547
+ )
1548
+
1549
  self.console.print(f" [dim]{description}[/dim]")
1550
 
1551
  self.console.print()
1552
  self.console.print("━" * 70)
1553
+ self.console.print(
1554
+ "[dim]* = modified from default, + = pending add, ~ = pending edit, - = pending reset[/dim]"
1555
+ )
1556
  self.console.print()
1557
  self.console.print("[bold]⚙️ Actions[/bold]")
1558
  self.console.print()
 
1672
  while True:
1673
  clear_screen()
1674
 
1675
+ # Get current modes from env
1676
  modes = self.rotation_mgr.get_current_modes()
1677
  available_providers = self.get_available_providers()
1678
 
 
1696
  self.console.print("[bold]📋 Current Rotation Mode Settings[/bold]")
1697
  self.console.print("━" * 70)
1698
 
1699
+ # Build combined view with pending changes
1700
+ all_modes: Dict[str, Dict[str, Any]] = {}
1701
+ prefix = "ROTATION_MODE_"
1702
+
1703
+ # Add current modes (from env)
1704
+ for provider, mode in modes.items():
1705
+ key = f"{prefix}{provider.upper()}"
1706
+ change_type = self.settings.get_change_type(key)
1707
+ default_mode = self.rotation_mgr.get_default_mode(provider)
1708
+ if change_type == "remove":
1709
+ all_modes[provider] = {"value": mode, "type": "remove", "old": None}
1710
+ elif change_type == "edit":
1711
+ new_val = self.settings.pending_changes[key]
1712
+ all_modes[provider] = {
1713
+ "value": new_val,
1714
+ "type": "edit",
1715
+ "old": mode,
1716
+ }
1717
+ else:
1718
+ all_modes[provider] = {"value": mode, "type": None, "old": None}
1719
+
1720
+ # Add pending new modes (additions)
1721
+ for key in self.settings.get_pending_keys_by_pattern(prefix=prefix):
1722
+ if self.settings.get_change_type(key) == "add":
1723
+ provider = key.replace(prefix, "").lower()
1724
+ if provider not in all_modes:
1725
+ all_modes[provider] = {
1726
+ "value": self.settings.pending_changes[key],
1727
+ "type": "add",
1728
+ "old": None,
1729
+ }
1730
+
1731
+ if all_modes:
1732
+ # Sort alphabetically
1733
+ for provider in sorted(all_modes.keys()):
1734
+ info = all_modes[provider]
1735
+ mode = info["value"]
1736
  mode_display = (
1737
  f"[green]{mode}[/green]"
1738
  if mode == "sequential"
1739
  else f"[blue]{mode}[/blue]"
1740
  )
1741
+ old_display = None
1742
+ if info["old"]:
1743
+ old_display = (
1744
+ f"[green]{info['old']}[/green]"
1745
+ if info["old"] == "sequential"
1746
+ else f"[blue]{info['old']}[/blue]"
1747
+ )
1748
+
1749
+ if info["type"] == "add":
1750
+ self.console.print(
1751
+ f" [green]+ {provider:20} {mode_display}[/green]"
1752
+ )
1753
+ elif info["type"] == "edit":
1754
+ self.console.print(
1755
+ f" [yellow]~ {provider:20} {old_display} → {mode_display}[/yellow]"
1756
+ )
1757
+ elif info["type"] == "remove":
1758
+ self.console.print(
1759
+ f" [red]- {provider:20} {mode_display}[/red]"
1760
+ )
1761
+ else:
1762
+ default_mode = self.rotation_mgr.get_default_mode(provider)
1763
+ is_custom = mode != default_mode
1764
+ marker = "[yellow]*[/yellow]" if is_custom else " "
1765
+ self.console.print(f" {marker}• {provider:20} {mode_display}")
1766
 
1767
  # Show providers with default modes
1768
+ providers_with_defaults = [
1769
+ p for p in available_providers if p not in modes and p not in all_modes
1770
+ ]
1771
  if providers_with_defaults:
1772
  self.console.print()
1773
  self.console.print("[dim]Providers using default modes:[/dim]")
 
1855
 
1856
  self.rotation_mgr.set_mode(provider, new_mode)
1857
  self.console.print(
1858
+ f"\n[green]✅ Rotation mode for '{provider}' staged as {new_mode}![/green]"
1859
  )
1860
  input("\nPress Enter to continue...")
1861
 
1862
  elif choice == "2":
1863
+ # Get resettable modes (existing + pending adds, excluding pending removes)
1864
+ resettable = {
1865
+ k: v for k, v in all_modes.items() if v["type"] != "remove"
1866
+ }
1867
+ if not resettable:
1868
  self.console.print(
1869
  "\n[yellow]No custom rotation modes to reset[/yellow]"
1870
  )
 
1875
  self.console.print(
1876
  "\n[bold]Select provider to reset to default:[/bold]"
1877
  )
1878
+ modes_list = sorted(resettable.keys())
1879
  for idx, prov in enumerate(modes_list, 1):
1880
  default_mode = self.rotation_mgr.get_default_mode(prov)
1881
+ info = resettable[prov]
1882
+ if info["type"] == "add":
1883
+ self.console.print(
1884
+ f" {idx}. {prov} [green](pending add)[/green] - will cancel"
1885
+ )
1886
+ else:
1887
+ self.console.print(
1888
+ f" {idx}. {prov} (will reset to: {default_mode})"
1889
+ )
1890
 
1891
  choice_idx = IntPrompt.ask(
1892
  "Select option",
 
1894
  )
1895
  provider = modes_list[choice_idx - 1]
1896
  default_mode = self.rotation_mgr.get_default_mode(provider)
1897
+ info = resettable[provider]
1898
 
1899
  if Confirm.ask(f"Reset '{provider}' to default mode ({default_mode})?"):
1900
+ if info["type"] == "add":
1901
+ # Undo pending addition
1902
+ key = f"{prefix}{provider.upper()}"
1903
+ del self.settings.pending_changes[key]
1904
+ self.console.print(
1905
+ f"\n[green]✅ Pending mode for '{provider}' cancelled![/green]"
1906
+ )
1907
+ else:
1908
+ self.rotation_mgr.remove_mode(provider)
1909
+ self.console.print(
1910
+ f"\n[green]✅ Rotation mode for '{provider}' marked for reset to default ({default_mode})![/green]"
1911
+ )
1912
  input("\nPress Enter to continue...")
1913
 
1914
  elif choice == "3":
 
2081
  while True:
2082
  clear_screen()
2083
 
2084
+ # Get current limits from env
2085
  limits = self.concurrency_mgr.get_current_limits()
2086
 
2087
  self.console.print(
 
2095
  self.console.print("[bold]📋 Current Concurrency Settings[/bold]")
2096
  self.console.print("━" * 70)
2097
 
2098
+ # Build combined view with pending changes
2099
+ all_limits: Dict[str, Dict[str, Any]] = {}
2100
+ prefix = "MAX_CONCURRENT_REQUESTS_PER_KEY_"
2101
+
2102
+ # Add current limits (from env)
2103
+ for provider, limit in limits.items():
2104
+ key = f"{prefix}{provider.upper()}"
2105
+ change_type = self.settings.get_change_type(key)
2106
+ if change_type == "remove":
2107
+ all_limits[provider] = {
2108
+ "value": str(limit),
2109
+ "type": "remove",
2110
+ "old": None,
2111
+ }
2112
+ elif change_type == "edit":
2113
+ new_val = self.settings.pending_changes[key]
2114
+ all_limits[provider] = {
2115
+ "value": new_val,
2116
+ "type": "edit",
2117
+ "old": str(limit),
2118
+ }
2119
+ else:
2120
+ all_limits[provider] = {
2121
+ "value": str(limit),
2122
+ "type": None,
2123
+ "old": None,
2124
+ }
2125
+
2126
+ # Add pending new limits (additions)
2127
+ for key in self.settings.get_pending_keys_by_pattern(prefix=prefix):
2128
+ if self.settings.get_change_type(key) == "add":
2129
+ provider = key.replace(prefix, "").lower()
2130
+ if provider not in all_limits:
2131
+ all_limits[provider] = {
2132
+ "value": self.settings.pending_changes[key],
2133
+ "type": "add",
2134
+ "old": None,
2135
+ }
2136
+
2137
+ if all_limits:
2138
+ # Sort alphabetically
2139
+ for provider in sorted(all_limits.keys()):
2140
+ info = all_limits[provider]
2141
+ value_display = f"{info['value']} requests/key"
2142
+ old_display = f"{info['old']} requests/key" if info["old"] else None
2143
+ self.console.print(
2144
+ self._format_item(
2145
+ provider, value_display, info["type"], old_display
2146
+ )
2147
+ )
2148
+ self.console.print(" • Default: 1 request/key (all others)")
2149
  else:
2150
  self.console.print(" • Default: 1 request/key (all providers)")
2151
 
 
2203
  if 1 <= limit <= 100:
2204
  self.concurrency_mgr.set_limit(provider, limit)
2205
  self.console.print(
2206
+ f"\n[green]✅ Concurrency limit staged for '{provider}': {limit} requests/key[/green]"
2207
  )
2208
  else:
2209
  self.console.print(
 
2212
  input("\nPress Enter to continue...")
2213
 
2214
  elif choice == "2":
2215
+ # Get editable limits (existing + pending additions, excluding pending removals)
2216
+ editable = {
2217
+ k: v for k, v in all_limits.items() if v["type"] != "remove"
2218
+ }
2219
+ if not editable:
2220
  self.console.print("\n[yellow]No limits to edit[/yellow]")
2221
  input("\nPress Enter to continue...")
2222
  continue
2223
 
2224
  # Show numbered list
2225
  self.console.print("\n[bold]Select provider to edit:[/bold]")
2226
+ limits_list = sorted(editable.keys())
2227
  for idx, prov in enumerate(limits_list, 1):
2228
  self.console.print(f" {idx}. {prov}")
2229
 
 
2232
  choices=[str(i) for i in range(1, len(limits_list) + 1)],
2233
  )
2234
  provider = limits_list[choice_idx - 1]
2235
+ info = editable[provider]
2236
+ current_limit = int(info["value"])
2237
 
2238
  self.console.print(f"\nCurrent limit: {current_limit} requests/key")
2239
  new_limit = IntPrompt.ask(
 
2254
  input("\nPress Enter to continue...")
2255
 
2256
  elif choice == "3":
2257
+ # Get removable limits (existing ones not already pending removal)
2258
+ removable = {
2259
+ k: v
2260
+ for k, v in all_limits.items()
2261
+ if v["type"] != "remove" and v["type"] != "add"
2262
+ }
2263
+ # For pending additions, we can "undo" by removing from pending
2264
+ pending_adds = {
2265
+ k: v for k, v in all_limits.items() if v["type"] == "add"
2266
+ }
2267
+
2268
+ if not removable and not pending_adds:
2269
  self.console.print("\n[yellow]No limits to remove[/yellow]")
2270
  input("\nPress Enter to continue...")
2271
  continue
 
2274
  self.console.print(
2275
  "\n[bold]Select provider to remove limit from:[/bold]"
2276
  )
2277
+ limits_list = sorted(removable.keys()) + sorted(pending_adds.keys())
2278
  for idx, prov in enumerate(limits_list, 1):
2279
+ if prov in pending_adds:
2280
+ self.console.print(
2281
+ f" {idx}. {prov} [green](pending add)[/green]"
2282
+ )
2283
+ else:
2284
+ self.console.print(f" {idx}. {prov}")
2285
 
2286
  choice_idx = IntPrompt.ask(
2287
  "Select option",
 
2292
  if Confirm.ask(
2293
  f"Remove concurrency limit for '{provider}' (reset to default 1)?"
2294
  ):
2295
+ if provider in pending_adds:
2296
+ # Undo pending addition
2297
+ key = f"{prefix}{provider.upper()}"
2298
+ del self.settings.pending_changes[key]
2299
+ self.console.print(
2300
+ f"\n[green]✅ Pending limit for '{provider}' cancelled![/green]"
2301
+ )
2302
+ else:
2303
+ self.concurrency_mgr.remove_limit(provider)
2304
+ self.console.print(
2305
+ f"\n[green]✅ Limit marked for removal for '{provider}'[/green]"
2306
+ )
2307
  input("\nPress Enter to continue...")
2308
 
2309
  elif choice == "4":
2310
  break
2311
 
2312
+ def _show_changes_summary(self):
2313
+ """Display categorized summary of all pending changes."""
2314
+ self.console.print(
2315
+ Panel.fit(
2316
+ "[bold cyan]📋 Pending Changes Summary[/bold cyan]",
2317
+ border_style="cyan",
2318
+ )
2319
+ )
2320
+ self.console.print()
2321
+
2322
+ # Define categories with their key patterns
2323
+ categories = [
2324
+ ("Custom Provider API Bases", "_API_BASE", "suffix"),
2325
+ ("Model Definitions", "_MODELS", "suffix"),
2326
+ ("Concurrency Limits", "MAX_CONCURRENT_REQUESTS_PER_KEY_", "prefix"),
2327
+ ("Rotation Modes", "ROTATION_MODE_", "prefix"),
2328
+ ("Priority Multipliers", "CONCURRENCY_MULTIPLIER_", "prefix"),
2329
+ ]
2330
+
2331
+ # Get provider-specific settings keys
2332
+ provider_settings_keys = set()
2333
+ for provider_settings in PROVIDER_SETTINGS_MAP.values():
2334
+ provider_settings_keys.update(provider_settings.keys())
2335
+
2336
+ changes = self.settings.get_changes_summary()
2337
+ displayed_keys = set()
2338
+
2339
+ for category_name, pattern, pattern_type in categories:
2340
+ category_changes = {"add": [], "edit": [], "remove": []}
2341
+
2342
+ for change_type in ["add", "edit", "remove"]:
2343
+ for key, old_val, new_val in changes[change_type]:
2344
+ matches = False
2345
+ if pattern_type == "suffix" and key.endswith(pattern):
2346
+ matches = True
2347
+ elif pattern_type == "prefix" and key.startswith(pattern):
2348
+ matches = True
2349
+
2350
+ if matches:
2351
+ category_changes[change_type].append((key, old_val, new_val))
2352
+ displayed_keys.add(key)
2353
+
2354
+ # Check if this category has any changes
2355
+ has_changes = any(category_changes[t] for t in ["add", "edit", "remove"])
2356
+ if has_changes:
2357
+ self.console.print(f"[bold]{category_name}:[/bold]")
2358
+ # Sort: additions, modifications, removals (alphabetically within each)
2359
+ for change_type in ["add", "edit", "remove"]:
2360
+ for key, old_val, new_val in sorted(
2361
+ category_changes[change_type], key=lambda x: x[0]
2362
+ ):
2363
+ if change_type == "add":
2364
+ self.console.print(f" [green]+ {key} = {new_val}[/green]")
2365
+ elif change_type == "edit":
2366
+ self.console.print(
2367
+ f" [yellow]~ {key}: {old_val} → {new_val}[/yellow]"
2368
+ )
2369
+ else:
2370
+ self.console.print(f" [red]- {key}[/red]")
2371
+ self.console.print()
2372
+
2373
+ # Handle provider-specific settings that don't match the patterns above
2374
+ provider_changes = {"add": [], "edit": [], "remove": []}
2375
+ for change_type in ["add", "edit", "remove"]:
2376
+ for key, old_val, new_val in changes[change_type]:
2377
+ if key not in displayed_keys and key in provider_settings_keys:
2378
+ provider_changes[change_type].append((key, old_val, new_val))
2379
+
2380
+ has_provider_changes = any(
2381
+ provider_changes[t] for t in ["add", "edit", "remove"]
2382
+ )
2383
+ if has_provider_changes:
2384
+ self.console.print("[bold]Provider-Specific Settings:[/bold]")
2385
+ for change_type in ["add", "edit", "remove"]:
2386
+ for key, old_val, new_val in sorted(
2387
+ provider_changes[change_type], key=lambda x: x[0]
2388
+ ):
2389
+ if change_type == "add":
2390
+ self.console.print(f" [green]+ {key} = {new_val}[/green]")
2391
+ elif change_type == "edit":
2392
+ self.console.print(
2393
+ f" [yellow]~ {key}: {old_val} → {new_val}[/yellow]"
2394
+ )
2395
+ else:
2396
+ self.console.print(f" [red]- {key}[/red]")
2397
+ self.console.print()
2398
+
2399
+ self.console.print("━" * 70)
2400
+
2401
  def save_and_exit(self):
2402
  """Save pending changes and exit"""
2403
  if self.settings.has_pending():
2404
+ clear_screen()
2405
+ self._show_changes_summary()
2406
+
2407
  if Confirm.ask("\n[bold yellow]Save all pending changes?[/bold yellow]"):
2408
  self.settings.save()
2409
  self.console.print("\n[green]✅ All changes saved to .env![/green]")
 
2421
  def exit_without_saving(self):
2422
  """Exit without saving"""
2423
  if self.settings.has_pending():
2424
+ clear_screen()
2425
+ self._show_changes_summary()
2426
+
2427
  if Confirm.ask("\n[bold red]Discard all pending changes?[/bold red]"):
2428
  self.settings.discard()
2429
  self.console.print("\n[yellow]Changes discarded[/yellow]")
src/rotator_library/client.py CHANGED
@@ -10,6 +10,7 @@ import litellm
10
  from litellm.exceptions import APIConnectionError
11
  from litellm.litellm_core_utils.token_counter import token_counter
12
  import logging
 
13
  from typing import List, Dict, Any, AsyncGenerator, Optional, Union
14
 
15
  lib_logger = logging.getLogger("rotator_library")
@@ -19,7 +20,7 @@ lib_logger = logging.getLogger("rotator_library")
19
  lib_logger.propagate = False
20
 
21
  from .usage_manager import UsageManager
22
- from .failure_logger import log_failure
23
  from .error_handler import (
24
  PreRequestCallbackError,
25
  classify_error,
@@ -37,6 +38,7 @@ from .cooldown_manager import CooldownManager
37
  from .credential_manager import CredentialManager
38
  from .background_refresher import BackgroundRefresher
39
  from .model_definitions import ModelDefinitions
 
40
 
41
 
42
  class StreamedAPIError(Exception):
@@ -58,7 +60,7 @@ class RotatingClient:
58
  api_keys: Optional[Dict[str, List[str]]] = None,
59
  oauth_credentials: Optional[Dict[str, List[str]]] = None,
60
  max_retries: int = 2,
61
- usage_file_path: str = "key_usage.json",
62
  configure_logging: bool = True,
63
  global_timeout: int = 30,
64
  abort_on_callback_error: bool = True,
@@ -68,6 +70,7 @@ class RotatingClient:
68
  enable_request_logging: bool = False,
69
  max_concurrent_requests_per_key: Optional[Dict[str, int]] = None,
70
  rotation_tolerance: float = 3.0,
 
71
  ):
72
  """
73
  Initialize the RotatingClient with intelligent credential rotation.
@@ -76,7 +79,7 @@ class RotatingClient:
76
  api_keys: Dictionary mapping provider names to lists of API keys
77
  oauth_credentials: Dictionary mapping provider names to OAuth credential paths
78
  max_retries: Maximum number of retry attempts per credential
79
- usage_file_path: Path to store usage statistics
80
  configure_logging: Whether to configure library logging
81
  global_timeout: Global timeout for requests in seconds
82
  abort_on_callback_error: Whether to abort on pre-request callback errors
@@ -89,7 +92,18 @@ class RotatingClient:
89
  - 0.0: Deterministic, least-used credential always selected
90
  - 2.0 - 4.0 (default, recommended): Balanced randomness, can pick credentials within 2 uses of max
91
  - 5.0+: High randomness, more unpredictable selection patterns
 
 
92
  """
 
 
 
 
 
 
 
 
 
93
  os.environ["LITELLM_LOG"] = "ERROR"
94
  litellm.set_verbose = False
95
  litellm.drop_params = True
@@ -124,7 +138,9 @@ class RotatingClient:
124
  if oauth_credentials:
125
  self.oauth_credentials = oauth_credentials
126
  else:
127
- self.credential_manager = CredentialManager(os.environ)
 
 
128
  self.oauth_credentials = self.credential_manager.discover_and_prepare()
129
  self.background_refresher = BackgroundRefresher(self)
130
  self.oauth_providers = set(self.oauth_credentials.keys())
@@ -242,8 +258,14 @@ class RotatingClient:
242
  f"Provider '{provider}' sequential fallback multiplier: {fallback}x"
243
  )
244
 
 
 
 
 
 
 
245
  self.usage_manager = UsageManager(
246
- file_path=usage_file_path,
247
  rotation_tolerance=rotation_tolerance,
248
  provider_rotation_modes=provider_rotation_modes,
249
  provider_plugins=PROVIDER_PLUGINS,
 
10
  from litellm.exceptions import APIConnectionError
11
  from litellm.litellm_core_utils.token_counter import token_counter
12
  import logging
13
+ from pathlib import Path
14
  from typing import List, Dict, Any, AsyncGenerator, Optional, Union
15
 
16
  lib_logger = logging.getLogger("rotator_library")
 
20
  lib_logger.propagate = False
21
 
22
  from .usage_manager import UsageManager
23
+ from .failure_logger import log_failure, configure_failure_logger
24
  from .error_handler import (
25
  PreRequestCallbackError,
26
  classify_error,
 
38
  from .credential_manager import CredentialManager
39
  from .background_refresher import BackgroundRefresher
40
  from .model_definitions import ModelDefinitions
41
+ from .utils.paths import get_default_root, get_logs_dir, get_oauth_dir, get_data_file
42
 
43
 
44
  class StreamedAPIError(Exception):
 
60
  api_keys: Optional[Dict[str, List[str]]] = None,
61
  oauth_credentials: Optional[Dict[str, List[str]]] = None,
62
  max_retries: int = 2,
63
+ usage_file_path: Optional[Union[str, Path]] = None,
64
  configure_logging: bool = True,
65
  global_timeout: int = 30,
66
  abort_on_callback_error: bool = True,
 
70
  enable_request_logging: bool = False,
71
  max_concurrent_requests_per_key: Optional[Dict[str, int]] = None,
72
  rotation_tolerance: float = 3.0,
73
+ data_dir: Optional[Union[str, Path]] = None,
74
  ):
75
  """
76
  Initialize the RotatingClient with intelligent credential rotation.
 
79
  api_keys: Dictionary mapping provider names to lists of API keys
80
  oauth_credentials: Dictionary mapping provider names to OAuth credential paths
81
  max_retries: Maximum number of retry attempts per credential
82
+ usage_file_path: Path to store usage statistics. If None, uses data_dir/key_usage.json
83
  configure_logging: Whether to configure library logging
84
  global_timeout: Global timeout for requests in seconds
85
  abort_on_callback_error: Whether to abort on pre-request callback errors
 
92
  - 0.0: Deterministic, least-used credential always selected
93
  - 2.0 - 4.0 (default, recommended): Balanced randomness, can pick credentials within 2 uses of max
94
  - 5.0+: High randomness, more unpredictable selection patterns
95
+ data_dir: Root directory for all data files (logs, cache, oauth_creds, key_usage.json).
96
+ If None, auto-detects: EXE directory if frozen, else current working directory.
97
  """
98
+ # Resolve data_dir early - this becomes the root for all file operations
99
+ if data_dir is not None:
100
+ self.data_dir = Path(data_dir).resolve()
101
+ else:
102
+ self.data_dir = get_default_root()
103
+
104
+ # Configure failure logger to use correct logs directory
105
+ configure_failure_logger(get_logs_dir(self.data_dir))
106
+
107
  os.environ["LITELLM_LOG"] = "ERROR"
108
  litellm.set_verbose = False
109
  litellm.drop_params = True
 
138
  if oauth_credentials:
139
  self.oauth_credentials = oauth_credentials
140
  else:
141
+ self.credential_manager = CredentialManager(
142
+ os.environ, oauth_dir=get_oauth_dir(self.data_dir)
143
+ )
144
  self.oauth_credentials = self.credential_manager.discover_and_prepare()
145
  self.background_refresher = BackgroundRefresher(self)
146
  self.oauth_providers = set(self.oauth_credentials.keys())
 
258
  f"Provider '{provider}' sequential fallback multiplier: {fallback}x"
259
  )
260
 
261
+ # Resolve usage file path - use provided path or default to data_dir
262
+ if usage_file_path is not None:
263
+ resolved_usage_path = Path(usage_file_path)
264
+ else:
265
+ resolved_usage_path = self.data_dir / "key_usage.json"
266
+
267
  self.usage_manager = UsageManager(
268
+ file_path=resolved_usage_path,
269
  rotation_tolerance=rotation_tolerance,
270
  provider_rotation_modes=provider_rotation_modes,
271
  provider_plugins=PROVIDER_PLUGINS,
src/rotator_library/credential_manager.py CHANGED
@@ -3,12 +3,11 @@ import re
3
  import shutil
4
  import logging
5
  from pathlib import Path
6
- from typing import Dict, List, Optional, Set
7
 
8
- lib_logger = logging.getLogger('rotator_library')
9
 
10
- OAUTH_BASE_DIR = Path.cwd() / "oauth_creds"
11
- OAUTH_BASE_DIR.mkdir(exist_ok=True)
12
 
13
  # Standard directories where tools like `gemini login` store credentials.
14
  DEFAULT_OAUTH_DIRS = {
@@ -33,38 +32,53 @@ class CredentialManager:
33
  """
34
  Discovers OAuth credential files from standard locations, copies them locally,
35
  and updates the configuration to use the local paths.
36
-
37
  Also discovers environment variable-based OAuth credentials for stateless deployments.
38
  Supports two env var formats:
39
-
40
  1. Single credential (legacy): PROVIDER_ACCESS_TOKEN, PROVIDER_REFRESH_TOKEN
41
  2. Multiple credentials (numbered): PROVIDER_1_ACCESS_TOKEN, PROVIDER_2_ACCESS_TOKEN, etc.
42
-
43
  When env-based credentials are detected, virtual paths like "env://provider/1" are created.
44
  """
45
- def __init__(self, env_vars: Dict[str, str]):
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  self.env_vars = env_vars
 
 
47
 
48
  def _discover_env_oauth_credentials(self) -> Dict[str, List[str]]:
49
  """
50
  Discover OAuth credentials defined via environment variables.
51
-
52
  Supports two formats:
53
  1. Single credential: ANTIGRAVITY_ACCESS_TOKEN + ANTIGRAVITY_REFRESH_TOKEN
54
  2. Multiple credentials: ANTIGRAVITY_1_ACCESS_TOKEN + ANTIGRAVITY_1_REFRESH_TOKEN, etc.
55
-
56
  Returns:
57
  Dict mapping provider name to list of virtual paths (e.g., "env://antigravity/1")
58
  """
59
  env_credentials: Dict[str, Set[str]] = {}
60
-
61
  for provider, env_prefix in ENV_OAUTH_PROVIDERS.items():
62
  found_indices: Set[str] = set()
63
-
64
  # Check for numbered credentials (PROVIDER_N_ACCESS_TOKEN pattern)
65
  # Pattern: ANTIGRAVITY_1_ACCESS_TOKEN, ANTIGRAVITY_2_ACCESS_TOKEN, etc.
66
  numbered_pattern = re.compile(rf"^{env_prefix}_(\d+)_ACCESS_TOKEN$")
67
-
68
  for key in self.env_vars.keys():
69
  match = numbered_pattern.match(key)
70
  if match:
@@ -73,28 +87,34 @@ class CredentialManager:
73
  refresh_key = f"{env_prefix}_{index}_REFRESH_TOKEN"
74
  if refresh_key in self.env_vars and self.env_vars[refresh_key]:
75
  found_indices.add(index)
76
-
77
  # Check for legacy single credential (PROVIDER_ACCESS_TOKEN pattern)
78
  # Only use this if no numbered credentials exist
79
  if not found_indices:
80
  access_key = f"{env_prefix}_ACCESS_TOKEN"
81
  refresh_key = f"{env_prefix}_REFRESH_TOKEN"
82
- if (access_key in self.env_vars and self.env_vars[access_key] and
83
- refresh_key in self.env_vars and self.env_vars[refresh_key]):
 
 
 
 
84
  # Use "0" as the index for legacy single credential
85
  found_indices.add("0")
86
-
87
  if found_indices:
88
  env_credentials[provider] = found_indices
89
- lib_logger.info(f"Found {len(found_indices)} env-based credential(s) for {provider}")
90
-
 
 
91
  # Convert to virtual paths
92
  result: Dict[str, List[str]] = {}
93
  for provider, indices in env_credentials.items():
94
  # Sort indices numerically for consistent ordering
95
  sorted_indices = sorted(indices, key=lambda x: int(x))
96
  result[provider] = [f"env://{provider}/{idx}" for idx in sorted_indices]
97
-
98
  return result
99
 
100
  def discover_and_prepare(self) -> Dict[str, List[str]]:
@@ -105,7 +125,9 @@ class CredentialManager:
105
  # These take priority for stateless deployments
106
  env_oauth_creds = self._discover_env_oauth_credentials()
107
  for provider, virtual_paths in env_oauth_creds.items():
108
- lib_logger.info(f"Using {len(virtual_paths)} env-based credential(s) for {provider}")
 
 
109
  final_config[provider] = virtual_paths
110
 
111
  # Extract OAuth file paths from environment variables
@@ -115,21 +137,29 @@ class CredentialManager:
115
  provider = key.split("_OAUTH_")[0].lower()
116
  if provider not in env_oauth_paths:
117
  env_oauth_paths[provider] = []
118
- if value: # Only consider non-empty values
119
  env_oauth_paths[provider].append(value)
120
 
121
  # PHASE 2: Discover file-based OAuth credentials
122
  for provider, default_dir in DEFAULT_OAUTH_DIRS.items():
123
  # Skip if already discovered from environment variables
124
  if provider in final_config:
125
- lib_logger.debug(f"Skipping file discovery for {provider} - using env-based credentials")
 
 
126
  continue
127
-
128
  # Check for existing local credentials first. If found, use them and skip discovery.
129
- local_provider_creds = sorted(list(OAUTH_BASE_DIR.glob(f"{provider}_oauth_*.json")))
 
 
130
  if local_provider_creds:
131
- lib_logger.info(f"Found {len(local_provider_creds)} existing local credential(s) for {provider}. Skipping discovery.")
132
- final_config[provider] = [str(p.resolve()) for p in local_provider_creds]
 
 
 
 
133
  continue
134
 
135
  # If no local credentials exist, proceed with a one-time discovery and copy.
@@ -140,13 +170,13 @@ class CredentialManager:
140
  path = Path(path_str).expanduser()
141
  if path.exists():
142
  discovered_paths.add(path)
143
-
144
  # 2. If no overrides are provided via .env, scan the default directory
145
  # [MODIFIED] This logic is now disabled to prefer local-first credential management.
146
  # if not discovered_paths and default_dir.exists():
147
  # for json_file in default_dir.glob('*.json'):
148
  # discovered_paths.add(json_file)
149
-
150
  if not discovered_paths:
151
  lib_logger.debug(f"No credential files found for provider: {provider}")
152
  continue
@@ -156,18 +186,24 @@ class CredentialManager:
156
  for i, source_path in enumerate(sorted(list(discovered_paths))):
157
  account_id = i + 1
158
  local_filename = f"{provider}_oauth_{account_id}.json"
159
- local_path = OAUTH_BASE_DIR / local_filename
160
 
161
  try:
162
  # Since we've established no local files exist, we can copy directly.
163
  shutil.copy(source_path, local_path)
164
- lib_logger.info(f"Copied '{source_path.name}' to local pool at '{local_path}'.")
 
 
165
  prepared_paths.append(str(local_path.resolve()))
166
  except Exception as e:
167
- lib_logger.error(f"Failed to process OAuth file from '{source_path}': {e}")
168
-
 
 
169
  if prepared_paths:
170
- lib_logger.info(f"Discovered and prepared {len(prepared_paths)} credential(s) for provider: {provider}")
 
 
171
  final_config[provider] = prepared_paths
172
 
173
  lib_logger.info("OAuth credential discovery complete.")
 
3
  import shutil
4
  import logging
5
  from pathlib import Path
6
+ from typing import Dict, List, Optional, Set, Union
7
 
8
+ from .utils.paths import get_oauth_dir
9
 
10
+ lib_logger = logging.getLogger("rotator_library")
 
11
 
12
  # Standard directories where tools like `gemini login` store credentials.
13
  DEFAULT_OAUTH_DIRS = {
 
32
  """
33
  Discovers OAuth credential files from standard locations, copies them locally,
34
  and updates the configuration to use the local paths.
35
+
36
  Also discovers environment variable-based OAuth credentials for stateless deployments.
37
  Supports two env var formats:
38
+
39
  1. Single credential (legacy): PROVIDER_ACCESS_TOKEN, PROVIDER_REFRESH_TOKEN
40
  2. Multiple credentials (numbered): PROVIDER_1_ACCESS_TOKEN, PROVIDER_2_ACCESS_TOKEN, etc.
41
+
42
  When env-based credentials are detected, virtual paths like "env://provider/1" are created.
43
  """
44
+
45
+ def __init__(
46
+ self,
47
+ env_vars: Dict[str, str],
48
+ oauth_dir: Optional[Union[Path, str]] = None,
49
+ ):
50
+ """
51
+ Initialize the CredentialManager.
52
+
53
+ Args:
54
+ env_vars: Dictionary of environment variables (typically os.environ).
55
+ oauth_dir: Directory for storing OAuth credentials.
56
+ If None, uses get_oauth_dir() which respects EXE vs script mode.
57
+ """
58
  self.env_vars = env_vars
59
+ self.oauth_base_dir = Path(oauth_dir) if oauth_dir else get_oauth_dir()
60
+ self.oauth_base_dir.mkdir(parents=True, exist_ok=True)
61
 
62
  def _discover_env_oauth_credentials(self) -> Dict[str, List[str]]:
63
  """
64
  Discover OAuth credentials defined via environment variables.
65
+
66
  Supports two formats:
67
  1. Single credential: ANTIGRAVITY_ACCESS_TOKEN + ANTIGRAVITY_REFRESH_TOKEN
68
  2. Multiple credentials: ANTIGRAVITY_1_ACCESS_TOKEN + ANTIGRAVITY_1_REFRESH_TOKEN, etc.
69
+
70
  Returns:
71
  Dict mapping provider name to list of virtual paths (e.g., "env://antigravity/1")
72
  """
73
  env_credentials: Dict[str, Set[str]] = {}
74
+
75
  for provider, env_prefix in ENV_OAUTH_PROVIDERS.items():
76
  found_indices: Set[str] = set()
77
+
78
  # Check for numbered credentials (PROVIDER_N_ACCESS_TOKEN pattern)
79
  # Pattern: ANTIGRAVITY_1_ACCESS_TOKEN, ANTIGRAVITY_2_ACCESS_TOKEN, etc.
80
  numbered_pattern = re.compile(rf"^{env_prefix}_(\d+)_ACCESS_TOKEN$")
81
+
82
  for key in self.env_vars.keys():
83
  match = numbered_pattern.match(key)
84
  if match:
 
87
  refresh_key = f"{env_prefix}_{index}_REFRESH_TOKEN"
88
  if refresh_key in self.env_vars and self.env_vars[refresh_key]:
89
  found_indices.add(index)
90
+
91
  # Check for legacy single credential (PROVIDER_ACCESS_TOKEN pattern)
92
  # Only use this if no numbered credentials exist
93
  if not found_indices:
94
  access_key = f"{env_prefix}_ACCESS_TOKEN"
95
  refresh_key = f"{env_prefix}_REFRESH_TOKEN"
96
+ if (
97
+ access_key in self.env_vars
98
+ and self.env_vars[access_key]
99
+ and refresh_key in self.env_vars
100
+ and self.env_vars[refresh_key]
101
+ ):
102
  # Use "0" as the index for legacy single credential
103
  found_indices.add("0")
104
+
105
  if found_indices:
106
  env_credentials[provider] = found_indices
107
+ lib_logger.info(
108
+ f"Found {len(found_indices)} env-based credential(s) for {provider}"
109
+ )
110
+
111
  # Convert to virtual paths
112
  result: Dict[str, List[str]] = {}
113
  for provider, indices in env_credentials.items():
114
  # Sort indices numerically for consistent ordering
115
  sorted_indices = sorted(indices, key=lambda x: int(x))
116
  result[provider] = [f"env://{provider}/{idx}" for idx in sorted_indices]
117
+
118
  return result
119
 
120
  def discover_and_prepare(self) -> Dict[str, List[str]]:
 
125
  # These take priority for stateless deployments
126
  env_oauth_creds = self._discover_env_oauth_credentials()
127
  for provider, virtual_paths in env_oauth_creds.items():
128
+ lib_logger.info(
129
+ f"Using {len(virtual_paths)} env-based credential(s) for {provider}"
130
+ )
131
  final_config[provider] = virtual_paths
132
 
133
  # Extract OAuth file paths from environment variables
 
137
  provider = key.split("_OAUTH_")[0].lower()
138
  if provider not in env_oauth_paths:
139
  env_oauth_paths[provider] = []
140
+ if value: # Only consider non-empty values
141
  env_oauth_paths[provider].append(value)
142
 
143
  # PHASE 2: Discover file-based OAuth credentials
144
  for provider, default_dir in DEFAULT_OAUTH_DIRS.items():
145
  # Skip if already discovered from environment variables
146
  if provider in final_config:
147
+ lib_logger.debug(
148
+ f"Skipping file discovery for {provider} - using env-based credentials"
149
+ )
150
  continue
151
+
152
  # Check for existing local credentials first. If found, use them and skip discovery.
153
+ local_provider_creds = sorted(
154
+ list(self.oauth_base_dir.glob(f"{provider}_oauth_*.json"))
155
+ )
156
  if local_provider_creds:
157
+ lib_logger.info(
158
+ f"Found {len(local_provider_creds)} existing local credential(s) for {provider}. Skipping discovery."
159
+ )
160
+ final_config[provider] = [
161
+ str(p.resolve()) for p in local_provider_creds
162
+ ]
163
  continue
164
 
165
  # If no local credentials exist, proceed with a one-time discovery and copy.
 
170
  path = Path(path_str).expanduser()
171
  if path.exists():
172
  discovered_paths.add(path)
173
+
174
  # 2. If no overrides are provided via .env, scan the default directory
175
  # [MODIFIED] This logic is now disabled to prefer local-first credential management.
176
  # if not discovered_paths and default_dir.exists():
177
  # for json_file in default_dir.glob('*.json'):
178
  # discovered_paths.add(json_file)
179
+
180
  if not discovered_paths:
181
  lib_logger.debug(f"No credential files found for provider: {provider}")
182
  continue
 
186
  for i, source_path in enumerate(sorted(list(discovered_paths))):
187
  account_id = i + 1
188
  local_filename = f"{provider}_oauth_{account_id}.json"
189
+ local_path = self.oauth_base_dir / local_filename
190
 
191
  try:
192
  # Since we've established no local files exist, we can copy directly.
193
  shutil.copy(source_path, local_path)
194
+ lib_logger.info(
195
+ f"Copied '{source_path.name}' to local pool at '{local_path}'."
196
+ )
197
  prepared_paths.append(str(local_path.resolve()))
198
  except Exception as e:
199
+ lib_logger.error(
200
+ f"Failed to process OAuth file from '{source_path}': {e}"
201
+ )
202
+
203
  if prepared_paths:
204
+ lib_logger.info(
205
+ f"Discovered and prepared {len(prepared_paths)} credential(s) for provider: {provider}"
206
+ )
207
  final_config[provider] = prepared_paths
208
 
209
  lib_logger.info("OAuth credential discovery complete.")
src/rotator_library/credential_tool.py CHANGED
@@ -3,22 +3,31 @@
3
  import asyncio
4
  import json
5
  import os
6
- import re
7
  import time
8
  from pathlib import Path
9
  from dotenv import set_key, get_key
10
 
11
- # NOTE: Heavy imports (provider_factory, PROVIDER_PLUGINS) are deferred
12
  # to avoid 6-7 second delay before showing loading screen
13
  from rich.console import Console
14
  from rich.panel import Panel
15
  from rich.prompt import Prompt
16
  from rich.text import Text
17
 
18
- OAUTH_BASE_DIR = Path.cwd() / "oauth_creds"
19
- OAUTH_BASE_DIR.mkdir(exist_ok=True)
20
- # Use a direct path to the .env file in the project root
21
- ENV_FILE = Path.cwd() / ".env"
 
 
 
 
 
 
 
 
 
 
22
 
23
  console = Console()
24
 
@@ -26,12 +35,14 @@ console = Console()
26
  _provider_factory = None
27
  _provider_plugins = None
28
 
 
29
  def _ensure_providers_loaded():
30
  """Lazy load provider modules only when needed"""
31
  global _provider_factory, _provider_plugins
32
  if _provider_factory is None:
33
  from . import provider_factory as pf
34
  from .providers import PROVIDER_PLUGINS as pp
 
35
  _provider_factory = pf
36
  _provider_plugins = pp
37
  return _provider_factory, _provider_plugins
@@ -39,99 +50,34 @@ def _ensure_providers_loaded():
39
 
40
  def clear_screen():
41
  """
42
- Cross-platform terminal clear that works robustly on both
43
  classic Windows conhost and modern terminals (Windows Terminal, Linux, Mac).
44
-
45
  Uses native OS commands instead of ANSI escape sequences:
46
  - Windows (conhost & Windows Terminal): cls
47
  - Unix-like systems (Linux, Mac): clear
48
  """
49
- os.system('cls' if os.name == 'nt' else 'clear')
50
-
51
 
52
- def _get_credential_number_from_filename(filename: str) -> int:
53
- """
54
- Extract credential number from filename like 'provider_oauth_1.json' -> 1
55
- """
56
- match = re.search(r'_oauth_(\d+)\.json$', filename)
57
- if match:
58
- return int(match.group(1))
59
- return 1
60
-
61
-
62
- def _build_env_export_content(
63
- provider_prefix: str,
64
- cred_number: int,
65
- creds: dict,
66
- email: str,
67
- extra_fields: dict = None,
68
- include_client_creds: bool = True
69
- ) -> tuple[list[str], str]:
70
- """
71
- Build .env content for OAuth credential export with numbered format.
72
- Exports all fields from the JSON file as a 1-to-1 mirror.
73
-
74
- Args:
75
- provider_prefix: Environment variable prefix (e.g., "ANTIGRAVITY", "GEMINI_CLI")
76
- cred_number: Credential number for this export (1, 2, 3, etc.)
77
- creds: The credential dictionary loaded from JSON
78
- email: User email for comments
79
- extra_fields: Optional dict of additional fields to include
80
- include_client_creds: Whether to include client_id/secret (Google OAuth providers)
81
-
82
- Returns:
83
- Tuple of (env_lines list, numbered_prefix string for display)
84
- """
85
- # Use numbered format: PROVIDER_N_ACCESS_TOKEN
86
- numbered_prefix = f"{provider_prefix}_{cred_number}"
87
-
88
- env_lines = [
89
- f"# {provider_prefix} Credential #{cred_number} for: {email}",
90
- f"# Exported from: {provider_prefix.lower()}_oauth_{cred_number}.json",
91
- f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
92
- f"# ",
93
- f"# To combine multiple credentials into one .env file, copy these lines",
94
- f"# and ensure each credential has a unique number (1, 2, 3, etc.)",
95
- "",
96
- f"{numbered_prefix}_ACCESS_TOKEN={creds.get('access_token', '')}",
97
- f"{numbered_prefix}_REFRESH_TOKEN={creds.get('refresh_token', '')}",
98
- f"{numbered_prefix}_SCOPE={creds.get('scope', '')}",
99
- f"{numbered_prefix}_TOKEN_TYPE={creds.get('token_type', 'Bearer')}",
100
- f"{numbered_prefix}_ID_TOKEN={creds.get('id_token', '')}",
101
- f"{numbered_prefix}_EXPIRY_DATE={creds.get('expiry_date', 0)}",
102
- ]
103
-
104
- if include_client_creds:
105
- env_lines.extend([
106
- f"{numbered_prefix}_CLIENT_ID={creds.get('client_id', '')}",
107
- f"{numbered_prefix}_CLIENT_SECRET={creds.get('client_secret', '')}",
108
- f"{numbered_prefix}_TOKEN_URI={creds.get('token_uri', 'https://oauth2.googleapis.com/token')}",
109
- f"{numbered_prefix}_UNIVERSE_DOMAIN={creds.get('universe_domain', 'googleapis.com')}",
110
- ])
111
-
112
- env_lines.append(f"{numbered_prefix}_EMAIL={email}")
113
-
114
- # Add extra provider-specific fields
115
- if extra_fields:
116
- for key, value in extra_fields.items():
117
- if value: # Only add non-empty values
118
- env_lines.append(f"{numbered_prefix}_{key}={value}")
119
-
120
- return env_lines, numbered_prefix
121
 
122
  def ensure_env_defaults():
123
  """
124
  Ensures the .env file exists and contains essential default values like PROXY_API_KEY.
125
  """
126
- if not ENV_FILE.is_file():
127
- ENV_FILE.touch()
128
- console.print(f"Creating a new [bold yellow]{ENV_FILE.name}[/bold yellow] file...")
 
 
129
 
130
  # Check for PROXY_API_KEY, similar to setup_env.bat
131
- if get_key(str(ENV_FILE), "PROXY_API_KEY") is None:
132
  default_key = "VerysecretKey"
133
- console.print(f"Adding default [bold cyan]PROXY_API_KEY[/bold cyan] to [bold yellow]{ENV_FILE.name}[/bold yellow]...")
134
- set_key(str(ENV_FILE), "PROXY_API_KEY", default_key)
 
 
 
135
 
136
  async def setup_api_key():
137
  """
@@ -144,41 +90,74 @@ async def setup_api_key():
144
 
145
  # Verified list of LiteLLM providers with their friendly names and API key variables
146
  LITELLM_PROVIDERS = {
147
- "OpenAI": "OPENAI_API_KEY", "Anthropic": "ANTHROPIC_API_KEY",
148
- "Google AI Studio (Gemini)": "GEMINI_API_KEY", "Azure OpenAI": "AZURE_API_KEY",
149
- "Vertex AI": "GOOGLE_API_KEY", "AWS Bedrock": "AWS_ACCESS_KEY_ID",
150
- "Cohere": "COHERE_API_KEY", "Chutes": "CHUTES_API_KEY",
 
 
 
 
151
  "Mistral AI": "MISTRAL_API_KEY",
152
- "Codestral (Mistral)": "CODESTRAL_API_KEY", "Groq": "GROQ_API_KEY",
153
- "Perplexity": "PERPLEXITYAI_API_KEY", "xAI": "XAI_API_KEY",
154
- "Together AI": "TOGETHERAI_API_KEY", "Fireworks AI": "FIREWORKS_AI_API_KEY",
155
- "Replicate": "REPLICATE_API_KEY", "Hugging Face": "HUGGINGFACE_API_KEY",
156
- "Anyscale": "ANYSCALE_API_KEY", "NVIDIA NIM": "NVIDIA_NIM_API_KEY",
157
- "Deepseek": "DEEPSEEK_API_KEY", "AI21": "AI21_API_KEY",
158
- "Cerebras": "CEREBRAS_API_KEY", "Moonshot": "MOONSHOT_API_KEY",
159
- "Ollama": "OLLAMA_API_KEY", "Xinference": "XINFERENCE_API_KEY",
160
- "Infinity": "INFINITY_API_KEY", "OpenRouter": "OPENROUTER_API_KEY",
161
- "Deepinfra": "DEEPINFRA_API_KEY", "Cloudflare": "CLOUDFLARE_API_KEY",
162
- "Baseten": "BASETEN_API_KEY", "Modal": "MODAL_API_KEY",
163
- "Databricks": "DATABRICKS_API_KEY", "AWS SageMaker": "AWS_ACCESS_KEY_ID",
164
- "IBM watsonx.ai": "WATSONX_APIKEY", "Predibase": "PREDIBASE_API_KEY",
165
- "Clarifai": "CLARIFAI_API_KEY", "NLP Cloud": "NLP_CLOUD_API_KEY",
166
- "Voyage AI": "VOYAGE_API_KEY", "Jina AI": "JINA_API_KEY",
167
- "Hyperbolic": "HYPERBOLIC_API_KEY", "Morph": "MORPH_API_KEY",
168
- "Lambda AI": "LAMBDA_API_KEY", "Novita AI": "NOVITA_API_KEY",
169
- "Aleph Alpha": "ALEPH_ALPHA_API_KEY", "SambaNova": "SAMBANOVA_API_KEY",
170
- "FriendliAI": "FRIENDLI_TOKEN", "Galadriel": "GALADRIEL_API_KEY",
171
- "CompactifAI": "COMPACTIFAI_API_KEY", "Lemonade": "LEMONADE_API_KEY",
172
- "GradientAI": "GRADIENTAI_API_KEY", "Featherless AI": "FEATHERLESS_AI_API_KEY",
173
- "Nebius AI Studio": "NEBIUS_API_KEY", "Dashscope (Qwen)": "DASHSCOPE_API_KEY",
174
- "Bytez": "BYTEZ_API_KEY", "Oracle OCI": "OCI_API_KEY",
175
- "DataRobot": "DATAROBOT_API_KEY", "OVHCloud": "OVHCLOUD_API_KEY",
176
- "Volcengine": "VOLCENGINE_API_KEY", "Snowflake": "SNOWFLAKE_API_KEY",
177
- "Nscale": "NSCALE_API_KEY", "Recraft": "RECRAFT_API_KEY",
178
- "v0": "V0_API_KEY", "Vercel": "VERCEL_AI_GATEWAY_API_KEY",
179
- "Topaz": "TOPAZ_API_KEY", "ElevenLabs": "ELEVENLABS_API_KEY",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  "Deepgram": "DEEPGRAM_API_KEY",
181
- "GitHub Models": "GITHUB_TOKEN", "GitHub Copilot": "GITHUB_COPILOT_API_KEY",
 
182
  }
183
 
184
  # Discover custom providers and add them to the list
@@ -186,37 +165,37 @@ async def setup_api_key():
186
  # qwen_code API key support is a fallback
187
  # iflow API key support is a feature
188
  _, PROVIDER_PLUGINS = _ensure_providers_loaded()
189
-
190
  # Build a set of environment variables already in LITELLM_PROVIDERS
191
  # to avoid duplicates based on the actual API key names
192
  litellm_env_vars = set(LITELLM_PROVIDERS.values())
193
-
194
  # Providers to exclude from API key list
195
  exclude_providers = {
196
- 'gemini_cli', # OAuth-only
197
- 'antigravity', # OAuth-only
198
- 'qwen_code', # API key is fallback, OAuth is primary - don't advertise
199
- 'openai_compatible', # Base class, not a real provider
200
  }
201
-
202
  discovered_providers = {}
203
  for provider_key in PROVIDER_PLUGINS.keys():
204
  if provider_key in exclude_providers:
205
  continue
206
-
207
  # Create environment variable name
208
  env_var = provider_key.upper() + "_API_KEY"
209
-
210
  # Check if this env var already exists in LITELLM_PROVIDERS
211
  # This catches duplicates like GEMINI_API_KEY, MISTRAL_API_KEY, etc.
212
  if env_var in litellm_env_vars:
213
  # Already in LITELLM_PROVIDERS with better name, skip this one
214
  continue
215
-
216
  # Create display name for this custom provider
217
- display_name = provider_key.replace('_', ' ').title()
218
  discovered_providers[display_name] = env_var
219
-
220
  # LITELLM_PROVIDERS takes precedence (comes first in merge)
221
  combined_providers = {**LITELLM_PROVIDERS, **discovered_providers}
222
  provider_display_list = sorted(combined_providers.keys())
@@ -231,15 +210,19 @@ async def setup_api_key():
231
  else:
232
  provider_text.append(f" {i + 1}. {provider_name}\n")
233
 
234
- console.print(Panel(provider_text, title="Available Providers for API Key", style="bold blue"))
 
 
235
 
236
  choice = Prompt.ask(
237
- Text.from_markup("[bold]Please select a provider or type [red]'b'[/red] to go back[/bold]"),
 
 
238
  choices=[str(i + 1) for i in range(len(provider_display_list))] + ["b"],
239
- show_choices=False
240
  )
241
 
242
- if choice.lower() == 'b':
243
  return
244
 
245
  try:
@@ -251,59 +234,88 @@ async def setup_api_key():
251
  api_key = Prompt.ask(f"Enter the API key for {display_name}")
252
 
253
  # Check for duplicate API key value
254
- if ENV_FILE.is_file():
255
- with open(ENV_FILE, "r") as f:
256
  for line in f:
257
  line = line.strip()
258
  if line.startswith(api_var_base) and "=" in line:
259
- existing_key_name, _, existing_key_value = line.partition("=")
 
 
260
  if existing_key_value == api_key:
261
- warning_text = Text.from_markup(f"This API key already exists as [bold yellow]'{existing_key_name}'[/bold yellow]. Overwriting...")
262
- console.print(Panel(warning_text, style="bold yellow", title="Updating API Key"))
263
-
264
- set_key(str(ENV_FILE), existing_key_name, api_key)
265
-
266
- success_text = Text.from_markup(f"Successfully updated existing key [bold yellow]'{existing_key_name}'[/bold yellow].")
267
- console.print(Panel(success_text, style="bold green", title="Success"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  return
269
 
270
  # Special handling for AWS
271
  if display_name in ["AWS Bedrock", "AWS SageMaker"]:
272
- console.print(Panel(
273
- Text.from_markup(
274
- "This provider requires both an Access Key ID and a Secret Access Key.\n"
275
- f"The key you entered will be saved as [bold yellow]{api_var_base}_1[/bold yellow].\n"
276
- "Please manually add the [bold cyan]AWS_SECRET_ACCESS_KEY_1[/bold cyan] to your .env file."
277
- ),
278
- title="[bold yellow]Additional Step Required[/bold yellow]",
279
- border_style="yellow"
280
- ))
 
 
281
 
282
  key_index = 1
283
  while True:
284
  key_name = f"{api_var_base}_{key_index}"
285
- if ENV_FILE.is_file():
286
- with open(ENV_FILE, "r") as f:
287
  if not any(line.startswith(f"{key_name}=") for line in f):
288
  break
289
  else:
290
  break
291
  key_index += 1
292
-
293
  key_name = f"{api_var_base}_{key_index}"
294
- set_key(str(ENV_FILE), key_name, api_key)
295
-
296
- success_text = Text.from_markup(f"Successfully added {display_name} API key as [bold yellow]'{key_name}'[/bold yellow].")
 
 
297
  console.print(Panel(success_text, style="bold green", title="Success"))
298
 
299
  else:
300
  console.print("[bold red]Invalid choice. Please try again.[/bold red]")
301
  except ValueError:
302
- console.print("[bold red]Invalid input. Please enter a number or 'b'.[/bold red]")
 
 
 
303
 
304
  async def setup_new_credential(provider_name: str):
305
  """
306
  Interactively sets up a new OAuth credential for a given provider.
 
 
307
  """
308
  try:
309
  provider_factory, _ = _ensure_providers_loaded()
@@ -315,668 +327,602 @@ async def setup_new_credential(provider_name: str):
315
  "gemini_cli": "Gemini CLI (OAuth)",
316
  "qwen_code": "Qwen Code (OAuth - also supports API keys)",
317
  "iflow": "iFlow (OAuth - also supports API keys)",
318
- "antigravity": "Antigravity (OAuth)"
319
  }
320
- display_name = oauth_friendly_names.get(provider_name, provider_name.replace('_', ' ').title())
321
-
322
- # Pass provider metadata to auth classes for better display
323
- temp_creds = {
324
- "_proxy_metadata": {
325
- "provider_name": provider_name,
326
- "display_name": display_name
327
- }
328
- }
329
- initialized_creds = await auth_instance.initialize_token(temp_creds)
330
-
331
- user_info = await auth_instance.get_user_info(initialized_creds)
332
- email = user_info.get("email")
333
 
334
- if not email:
335
- console.print(Panel(f"Could not retrieve a unique identifier for {provider_name}. Aborting.", style="bold red", title="Error"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  return
337
 
338
- for cred_file in OAUTH_BASE_DIR.glob(f"{provider_name}_oauth_*.json"):
339
- with open(cred_file, 'r') as f:
340
- existing_creds = json.load(f)
341
-
342
- metadata = existing_creds.get("_proxy_metadata", {})
343
- if metadata.get("email") == email:
344
- warning_text = Text.from_markup(f"Found existing credential for [bold cyan]'{email}'[/bold cyan] at [bold yellow]'{cred_file.name}'[/bold yellow]. Overwriting...")
345
- console.print(Panel(warning_text, style="bold yellow", title="Updating Credential"))
 
 
 
346
 
347
- # Overwrite the existing file in-place
348
- with open(cred_file, 'w') as f:
349
- json.dump(initialized_creds, f, indent=2)
 
 
350
 
351
- success_text = Text.from_markup(f"Successfully updated credential at [bold yellow]'{cred_file.name}'[/bold yellow] for user [bold cyan]'{email}'[/bold cyan].")
352
- console.print(Panel(success_text, style="bold green", title="Success"))
353
- return
354
-
355
- existing_files = list(OAUTH_BASE_DIR.glob(f"{provider_name}_oauth_*.json"))
356
- next_num = 1
357
- if existing_files:
358
- nums = [int(re.search(r'_(\d+)\.json$', f.name).group(1)) for f in existing_files if re.search(r'_(\d+)\.json$', f.name)]
359
- if nums:
360
- next_num = max(nums) + 1
361
-
362
- new_filename = f"{provider_name}_oauth_{next_num}.json"
363
- new_filepath = OAUTH_BASE_DIR / new_filename
364
-
365
- with open(new_filepath, 'w') as f:
366
- json.dump(initialized_creds, f, indent=2)
367
-
368
- success_text = Text.from_markup(f"Successfully created new credential at [bold yellow]'{new_filepath.name}'[/bold yellow] for user [bold cyan]'{email}'[/bold cyan].")
369
  console.print(Panel(success_text, style="bold green", title="Success"))
370
 
371
  except Exception as e:
372
- console.print(Panel(f"An error occurred during setup for {provider_name}: {e}", style="bold red", title="Error"))
 
 
 
 
 
 
373
 
374
 
375
  async def export_gemini_cli_to_env():
376
  """
377
  Export a Gemini CLI credential JSON file to .env format.
378
- Uses numbered format (GEMINI_CLI_1_*, GEMINI_CLI_2_*) for multiple credential support.
379
  """
380
- console.print(Panel("[bold cyan]Export Gemini CLI Credential to .env[/bold cyan]", expand=False))
 
 
 
 
381
 
382
- # Find all gemini_cli credentials
383
- gemini_cli_files = sorted(list(OAUTH_BASE_DIR.glob("gemini_cli_oauth_*.json")))
 
 
384
 
385
- if not gemini_cli_files:
386
- console.print(Panel("No Gemini CLI credentials found. Please add one first using 'Add OAuth Credential'.",
387
- style="bold red", title="No Credentials"))
 
 
 
 
 
 
 
 
388
  return
389
 
390
  # Display available credentials
391
  cred_text = Text()
392
- for i, cred_file in enumerate(gemini_cli_files):
393
- try:
394
- with open(cred_file, 'r') as f:
395
- creds = json.load(f)
396
- email = creds.get("_proxy_metadata", {}).get("email", "unknown")
397
- cred_text.append(f" {i + 1}. {cred_file.name} ({email})\n")
398
- except Exception as e:
399
- cred_text.append(f" {i + 1}. {cred_file.name} (error reading: {e})\n")
400
 
401
- console.print(Panel(cred_text, title="Available Gemini CLI Credentials", style="bold blue"))
 
 
402
 
403
  choice = Prompt.ask(
404
- Text.from_markup("[bold]Please select a credential to export or type [red]'b'[/red] to go back[/bold]"),
405
- choices=[str(i + 1) for i in range(len(gemini_cli_files))] + ["b"],
406
- show_choices=False
 
 
407
  )
408
 
409
- if choice.lower() == 'b':
410
  return
411
 
412
  try:
413
  choice_index = int(choice) - 1
414
- if 0 <= choice_index < len(gemini_cli_files):
415
- cred_file = gemini_cli_files[choice_index]
416
-
417
- # Load the credential
418
- with open(cred_file, 'r') as f:
419
- creds = json.load(f)
420
 
421
- # Extract metadata
422
- email = creds.get("_proxy_metadata", {}).get("email", "unknown")
423
- project_id = creds.get("_proxy_metadata", {}).get("project_id", "")
424
- tier = creds.get("_proxy_metadata", {}).get("tier", "")
425
-
426
- # Get credential number from filename
427
- cred_number = _get_credential_number_from_filename(cred_file.name)
428
-
429
- # Generate .env file name with credential number
430
- safe_email = email.replace("@", "_at_").replace(".", "_")
431
- env_filename = f"gemini_cli_{cred_number}_{safe_email}.env"
432
- env_filepath = OAUTH_BASE_DIR / env_filename
433
-
434
- # Build extra fields
435
- extra_fields = {}
436
- if project_id:
437
- extra_fields["PROJECT_ID"] = project_id
438
- if tier:
439
- extra_fields["TIER"] = tier
440
-
441
- # Build .env content using helper
442
- env_lines, numbered_prefix = _build_env_export_content(
443
- provider_prefix="GEMINI_CLI",
444
- cred_number=cred_number,
445
- creds=creds,
446
- email=email,
447
- extra_fields=extra_fields,
448
- include_client_creds=True
449
  )
450
 
451
- # Write to .env file
452
- with open(env_filepath, 'w') as f:
453
- f.write('\n'.join(env_lines))
454
-
455
- success_text = Text.from_markup(
456
- f"Successfully exported credential to [bold yellow]'{env_filepath}'[/bold yellow]\n\n"
457
- f"[bold]Environment variable prefix:[/bold] [cyan]{numbered_prefix}_*[/cyan]\n\n"
458
- f"[bold]To use this credential:[/bold]\n"
459
- f"1. Copy the contents to your main .env file, OR\n"
460
- f"2. Source it: [bold cyan]source {env_filepath.name}[/bold cyan] (Linux/Mac)\n"
461
- f"3. Or on Windows: [bold cyan]Get-Content {env_filepath.name} | ForEach-Object {{ $_ -replace '^([^#].*)$', 'set $1' }} | cmd[/bold cyan]\n\n"
462
- f"[bold]To combine multiple credentials:[/bold]\n"
463
- f"Copy lines from multiple .env files into one file.\n"
464
- f"Each credential uses a unique number ({numbered_prefix}_*)."
465
- )
466
- console.print(Panel(success_text, style="bold green", title="Success"))
 
 
 
 
467
  else:
468
  console.print("[bold red]Invalid choice. Please try again.[/bold red]")
469
  except ValueError:
470
- console.print("[bold red]Invalid input. Please enter a number or 'b'.[/bold red]")
 
 
471
  except Exception as e:
472
- console.print(Panel(f"An error occurred during export: {e}", style="bold red", title="Error"))
 
 
 
 
473
 
474
 
475
  async def export_qwen_code_to_env():
476
  """
477
  Export a Qwen Code credential JSON file to .env format.
478
- Generates one .env file per credential.
479
  """
480
- console.print(Panel("[bold cyan]Export Qwen Code Credential to .env[/bold cyan]", expand=False))
 
 
 
 
 
 
 
 
 
481
 
482
- # Find all qwen_code credentials
483
- qwen_code_files = list(OAUTH_BASE_DIR.glob("qwen_code_oauth_*.json"))
484
 
485
- if not qwen_code_files:
486
- console.print(Panel("No Qwen Code credentials found. Please add one first using 'Add OAuth Credential'.",
487
- style="bold red", title="No Credentials"))
 
 
 
 
 
488
  return
489
 
490
  # Display available credentials
491
  cred_text = Text()
492
- for i, cred_file in enumerate(qwen_code_files):
493
- try:
494
- with open(cred_file, 'r') as f:
495
- creds = json.load(f)
496
- email = creds.get("_proxy_metadata", {}).get("email", "unknown")
497
- cred_text.append(f" {i + 1}. {cred_file.name} ({email})\n")
498
- except Exception as e:
499
- cred_text.append(f" {i + 1}. {cred_file.name} (error reading: {e})\n")
500
 
501
- console.print(Panel(cred_text, title="Available Qwen Code Credentials", style="bold blue"))
 
 
502
 
503
  choice = Prompt.ask(
504
- Text.from_markup("[bold]Please select a credential to export or type [red]'b'[/red] to go back[/bold]"),
505
- choices=[str(i + 1) for i in range(len(qwen_code_files))] + ["b"],
506
- show_choices=False
 
 
507
  )
508
 
509
- if choice.lower() == 'b':
510
  return
511
 
512
  try:
513
  choice_index = int(choice) - 1
514
- if 0 <= choice_index < len(qwen_code_files):
515
- cred_file = qwen_code_files[choice_index]
516
-
517
- # Load the credential
518
- with open(cred_file, 'r') as f:
519
- creds = json.load(f)
520
-
521
- # Extract metadata
522
- email = creds.get("_proxy_metadata", {}).get("email", "unknown")
523
-
524
- # Get credential number from filename
525
- cred_number = _get_credential_number_from_filename(cred_file.name)
526
-
527
- # Generate .env file name with credential number
528
- safe_email = email.replace("@", "_at_").replace(".", "_")
529
- env_filename = f"qwen_code_{cred_number}_{safe_email}.env"
530
- env_filepath = OAUTH_BASE_DIR / env_filename
531
-
532
- # Use numbered format: QWEN_CODE_N_*
533
- numbered_prefix = f"QWEN_CODE_{cred_number}"
534
-
535
- # Build .env content (Qwen has different structure)
536
- env_lines = [
537
- f"# QWEN_CODE Credential #{cred_number} for: {email}",
538
- f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
539
- f"# ",
540
- f"# To combine multiple credentials into one .env file, copy these lines",
541
- f"# and ensure each credential has a unique number (1, 2, 3, etc.)",
542
- "",
543
- f"{numbered_prefix}_ACCESS_TOKEN={creds.get('access_token', '')}",
544
- f"{numbered_prefix}_REFRESH_TOKEN={creds.get('refresh_token', '')}",
545
- f"{numbered_prefix}_EXPIRY_DATE={creds.get('expiry_date', 0)}",
546
- f"{numbered_prefix}_RESOURCE_URL={creds.get('resource_url', 'https://portal.qwen.ai/v1')}",
547
- f"{numbered_prefix}_EMAIL={email}",
548
- ]
549
-
550
- # Write to .env file
551
- with open(env_filepath, 'w') as f:
552
- f.write('\n'.join(env_lines))
553
 
554
- success_text = Text.from_markup(
555
- f"Successfully exported credential to [bold yellow]'{env_filepath}'[/bold yellow]\n\n"
556
- f"[bold]Environment variable prefix:[/bold] [cyan]{numbered_prefix}_*[/cyan]\n\n"
557
- f"[bold]To use this credential:[/bold]\n"
558
- f"1. Copy the contents to your main .env file, OR\n"
559
- f"2. Source it: [bold cyan]source {env_filepath.name}[/bold cyan] (Linux/Mac)\n\n"
560
- f"[bold]To combine multiple credentials:[/bold]\n"
561
- f"Copy lines from multiple .env files into one file.\n"
562
- f"Each credential uses a unique number ({numbered_prefix}_*)."
563
  )
564
- console.print(Panel(success_text, style="bold green", title="Success"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
565
  else:
566
  console.print("[bold red]Invalid choice. Please try again.[/bold red]")
567
  except ValueError:
568
- console.print("[bold red]Invalid input. Please enter a number or 'b'.[/bold red]")
 
 
569
  except Exception as e:
570
- console.print(Panel(f"An error occurred during export: {e}", style="bold red", title="Error"))
 
 
 
 
571
 
572
 
573
  async def export_iflow_to_env():
574
  """
575
  Export an iFlow credential JSON file to .env format.
576
- Uses numbered format (IFLOW_1_*, IFLOW_2_*) for multiple credential support.
577
  """
578
- console.print(Panel("[bold cyan]Export iFlow Credential to .env[/bold cyan]", expand=False))
 
 
579
 
580
- # Find all iflow credentials
581
- iflow_files = sorted(list(OAUTH_BASE_DIR.glob("iflow_oauth_*.json")))
 
 
582
 
583
- if not iflow_files:
584
- console.print(Panel("No iFlow credentials found. Please add one first using 'Add OAuth Credential'.",
585
- style="bold red", title="No Credentials"))
 
 
 
 
 
 
 
 
586
  return
587
 
588
  # Display available credentials
589
  cred_text = Text()
590
- for i, cred_file in enumerate(iflow_files):
591
- try:
592
- with open(cred_file, 'r') as f:
593
- creds = json.load(f)
594
- email = creds.get("_proxy_metadata", {}).get("email", "unknown")
595
- cred_text.append(f" {i + 1}. {cred_file.name} ({email})\n")
596
- except Exception as e:
597
- cred_text.append(f" {i + 1}. {cred_file.name} (error reading: {e})\n")
598
 
599
- console.print(Panel(cred_text, title="Available iFlow Credentials", style="bold blue"))
 
 
600
 
601
  choice = Prompt.ask(
602
- Text.from_markup("[bold]Please select a credential to export or type [red]'b'[/red] to go back[/bold]"),
603
- choices=[str(i + 1) for i in range(len(iflow_files))] + ["b"],
604
- show_choices=False
 
 
605
  )
606
 
607
- if choice.lower() == 'b':
608
  return
609
 
610
  try:
611
  choice_index = int(choice) - 1
612
- if 0 <= choice_index < len(iflow_files):
613
- cred_file = iflow_files[choice_index]
614
 
615
- # Load the credential
616
- with open(cred_file, 'r') as f:
617
- creds = json.load(f)
618
-
619
- # Extract metadata
620
- email = creds.get("_proxy_metadata", {}).get("email", "unknown")
621
-
622
- # Get credential number from filename
623
- cred_number = _get_credential_number_from_filename(cred_file.name)
624
-
625
- # Generate .env file name with credential number
626
- safe_email = email.replace("@", "_at_").replace(".", "_")
627
- env_filename = f"iflow_{cred_number}_{safe_email}.env"
628
- env_filepath = OAUTH_BASE_DIR / env_filename
629
-
630
- # Use numbered format: IFLOW_N_*
631
- numbered_prefix = f"IFLOW_{cred_number}"
632
-
633
- # Build .env content (iFlow has different structure with API key)
634
- env_lines = [
635
- f"# IFLOW Credential #{cred_number} for: {email}",
636
- f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
637
- f"# ",
638
- f"# To combine multiple credentials into one .env file, copy these lines",
639
- f"# and ensure each credential has a unique number (1, 2, 3, etc.)",
640
- "",
641
- f"{numbered_prefix}_ACCESS_TOKEN={creds.get('access_token', '')}",
642
- f"{numbered_prefix}_REFRESH_TOKEN={creds.get('refresh_token', '')}",
643
- f"{numbered_prefix}_API_KEY={creds.get('api_key', '')}",
644
- f"{numbered_prefix}_EXPIRY_DATE={creds.get('expiry_date', '')}",
645
- f"{numbered_prefix}_EMAIL={email}",
646
- f"{numbered_prefix}_TOKEN_TYPE={creds.get('token_type', 'Bearer')}",
647
- f"{numbered_prefix}_SCOPE={creds.get('scope', 'read write')}",
648
- ]
649
-
650
- # Write to .env file
651
- with open(env_filepath, 'w') as f:
652
- f.write('\n'.join(env_lines))
653
-
654
- success_text = Text.from_markup(
655
- f"Successfully exported credential to [bold yellow]'{env_filepath}'[/bold yellow]\n\n"
656
- f"[bold]Environment variable prefix:[/bold] [cyan]{numbered_prefix}_*[/cyan]\n\n"
657
- f"[bold]To use this credential:[/bold]\n"
658
- f"1. Copy the contents to your main .env file, OR\n"
659
- f"2. Source it: [bold cyan]source {env_filepath.name}[/bold cyan] (Linux/Mac)\n\n"
660
- f"[bold]To combine multiple credentials:[/bold]\n"
661
- f"Copy lines from multiple .env files into one file.\n"
662
- f"Each credential uses a unique number ({numbered_prefix}_*)."
663
  )
664
- console.print(Panel(success_text, style="bold green", title="Success"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
665
  else:
666
  console.print("[bold red]Invalid choice. Please try again.[/bold red]")
667
  except ValueError:
668
- console.print("[bold red]Invalid input. Please enter a number or 'b'.[/bold red]")
 
 
669
  except Exception as e:
670
- console.print(Panel(f"An error occurred during export: {e}", style="bold red", title="Error"))
 
 
 
 
671
 
672
 
673
  async def export_antigravity_to_env():
674
  """
675
  Export an Antigravity credential JSON file to .env format.
676
- Uses numbered format (ANTIGRAVITY_1_*, ANTIGRAVITY_2_*) for multiple credential support.
677
  """
678
- console.print(Panel("[bold cyan]Export Antigravity Credential to .env[/bold cyan]", expand=False))
 
 
 
 
 
 
 
 
 
679
 
680
- # Find all antigravity credentials
681
- antigravity_files = sorted(list(OAUTH_BASE_DIR.glob("antigravity_oauth_*.json")))
682
 
683
- if not antigravity_files:
684
- console.print(Panel("No Antigravity credentials found. Please add one first using 'Add OAuth Credential'.",
685
- style="bold red", title="No Credentials"))
 
 
 
 
 
686
  return
687
 
688
  # Display available credentials
689
  cred_text = Text()
690
- for i, cred_file in enumerate(antigravity_files):
691
- try:
692
- with open(cred_file, 'r') as f:
693
- creds = json.load(f)
694
- email = creds.get("_proxy_metadata", {}).get("email", "unknown")
695
- cred_text.append(f" {i + 1}. {cred_file.name} ({email})\n")
696
- except Exception as e:
697
- cred_text.append(f" {i + 1}. {cred_file.name} (error reading: {e})\n")
698
 
699
- console.print(Panel(cred_text, title="Available Antigravity Credentials", style="bold blue"))
 
 
700
 
701
  choice = Prompt.ask(
702
- Text.from_markup("[bold]Please select a credential to export or type [red]'b'[/red] to go back[/bold]"),
703
- choices=[str(i + 1) for i in range(len(antigravity_files))] + ["b"],
704
- show_choices=False
 
 
705
  )
706
 
707
- if choice.lower() == 'b':
708
  return
709
 
710
  try:
711
  choice_index = int(choice) - 1
712
- if 0 <= choice_index < len(antigravity_files):
713
- cred_file = antigravity_files[choice_index]
714
-
715
- # Load the credential
716
- with open(cred_file, 'r') as f:
717
- creds = json.load(f)
718
 
719
- # Extract metadata
720
- email = creds.get("_proxy_metadata", {}).get("email", "unknown")
721
-
722
- # Get credential number from filename
723
- cred_number = _get_credential_number_from_filename(cred_file.name)
724
-
725
- # Generate .env file name with credential number
726
- safe_email = email.replace("@", "_at_").replace(".", "_")
727
- env_filename = f"antigravity_{cred_number}_{safe_email}.env"
728
- env_filepath = OAUTH_BASE_DIR / env_filename
729
-
730
- # Build .env content using helper
731
- env_lines, numbered_prefix = _build_env_export_content(
732
- provider_prefix="ANTIGRAVITY",
733
- cred_number=cred_number,
734
- creds=creds,
735
- email=email,
736
- extra_fields=None,
737
- include_client_creds=True
738
  )
739
 
740
- # Write to .env file
741
- with open(env_filepath, 'w') as f:
742
- f.write('\n'.join(env_lines))
743
-
744
- success_text = Text.from_markup(
745
- f"Successfully exported credential to [bold yellow]'{env_filepath}'[/bold yellow]\n\n"
746
- f"[bold]Environment variable prefix:[/bold] [cyan]{numbered_prefix}_*[/cyan]\n\n"
747
- f"[bold]To use this credential:[/bold]\n"
748
- f"1. Copy the contents to your main .env file, OR\n"
749
- f"2. Source it: [bold cyan]source {env_filepath.name}[/bold cyan] (Linux/Mac)\n"
750
- f"3. Or on Windows: [bold cyan]Get-Content {env_filepath.name} | ForEach-Object {{ $_ -replace '^([^#].*)$', 'set $1' }} | cmd[/bold cyan]\n\n"
751
- f"[bold]To combine multiple credentials:[/bold]\n"
752
- f"Copy lines from multiple .env files into one file.\n"
753
- f"Each credential uses a unique number ({numbered_prefix}_*)."
754
- )
755
- console.print(Panel(success_text, style="bold green", title="Success"))
 
 
 
 
756
  else:
757
  console.print("[bold red]Invalid choice. Please try again.[/bold red]")
758
  except ValueError:
759
- console.print("[bold red]Invalid input. Please enter a number or 'b'.[/bold red]")
 
 
760
  except Exception as e:
761
- console.print(Panel(f"An error occurred during export: {e}", style="bold red", title="Error"))
762
-
763
-
764
- def _build_gemini_cli_env_lines(creds: dict, cred_number: int) -> list[str]:
765
- """Build .env lines for a Gemini CLI credential."""
766
- email = creds.get("_proxy_metadata", {}).get("email", "unknown")
767
- project_id = creds.get("_proxy_metadata", {}).get("project_id", "")
768
- tier = creds.get("_proxy_metadata", {}).get("tier", "")
769
-
770
- extra_fields = {}
771
- if project_id:
772
- extra_fields["PROJECT_ID"] = project_id
773
- if tier:
774
- extra_fields["TIER"] = tier
775
-
776
- env_lines, _ = _build_env_export_content(
777
- provider_prefix="GEMINI_CLI",
778
- cred_number=cred_number,
779
- creds=creds,
780
- email=email,
781
- extra_fields=extra_fields,
782
- include_client_creds=True
783
- )
784
- return env_lines
785
-
786
-
787
- def _build_qwen_code_env_lines(creds: dict, cred_number: int) -> list[str]:
788
- """Build .env lines for a Qwen Code credential."""
789
- email = creds.get("_proxy_metadata", {}).get("email", "unknown")
790
- numbered_prefix = f"QWEN_CODE_{cred_number}"
791
-
792
- env_lines = [
793
- f"# QWEN_CODE Credential #{cred_number} for: {email}",
794
- f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
795
- "",
796
- f"{numbered_prefix}_ACCESS_TOKEN={creds.get('access_token', '')}",
797
- f"{numbered_prefix}_REFRESH_TOKEN={creds.get('refresh_token', '')}",
798
- f"{numbered_prefix}_EXPIRY_DATE={creds.get('expiry_date', 0)}",
799
- f"{numbered_prefix}_RESOURCE_URL={creds.get('resource_url', 'https://portal.qwen.ai/v1')}",
800
- f"{numbered_prefix}_EMAIL={email}",
801
- ]
802
- return env_lines
803
-
804
-
805
- def _build_iflow_env_lines(creds: dict, cred_number: int) -> list[str]:
806
- """Build .env lines for an iFlow credential."""
807
- email = creds.get("_proxy_metadata", {}).get("email", "unknown")
808
- numbered_prefix = f"IFLOW_{cred_number}"
809
-
810
- env_lines = [
811
- f"# IFLOW Credential #{cred_number} for: {email}",
812
- f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
813
- "",
814
- f"{numbered_prefix}_ACCESS_TOKEN={creds.get('access_token', '')}",
815
- f"{numbered_prefix}_REFRESH_TOKEN={creds.get('refresh_token', '')}",
816
- f"{numbered_prefix}_API_KEY={creds.get('api_key', '')}",
817
- f"{numbered_prefix}_EXPIRY_DATE={creds.get('expiry_date', '')}",
818
- f"{numbered_prefix}_EMAIL={email}",
819
- f"{numbered_prefix}_TOKEN_TYPE={creds.get('token_type', 'Bearer')}",
820
- f"{numbered_prefix}_SCOPE={creds.get('scope', 'read write')}",
821
- ]
822
- return env_lines
823
-
824
-
825
- def _build_antigravity_env_lines(creds: dict, cred_number: int) -> list[str]:
826
- """Build .env lines for an Antigravity credential."""
827
- email = creds.get("_proxy_metadata", {}).get("email", "unknown")
828
-
829
- env_lines, _ = _build_env_export_content(
830
- provider_prefix="ANTIGRAVITY",
831
- cred_number=cred_number,
832
- creds=creds,
833
- email=email,
834
- extra_fields=None,
835
- include_client_creds=True
836
- )
837
- return env_lines
838
 
839
 
840
  async def export_all_provider_credentials(provider_name: str):
841
  """
842
  Export all credentials for a specific provider to individual .env files.
 
843
  """
844
- provider_config = {
845
- "gemini_cli": ("GEMINI_CLI", _build_gemini_cli_env_lines),
846
- "qwen_code": ("QWEN_CODE", _build_qwen_code_env_lines),
847
- "iflow": ("IFLOW", _build_iflow_env_lines),
848
- "antigravity": ("ANTIGRAVITY", _build_antigravity_env_lines),
849
- }
850
-
851
- if provider_name not in provider_config:
852
  console.print(f"[bold red]Unknown provider: {provider_name}[/bold red]")
853
  return
854
-
855
- prefix, build_func = provider_config[provider_name]
856
- display_name = prefix.replace("_", " ").title()
857
-
858
- console.print(Panel(f"[bold cyan]Export All {display_name} Credentials[/bold cyan]", expand=False))
859
-
860
- # Find all credentials for this provider
861
- cred_files = sorted(list(OAUTH_BASE_DIR.glob(f"{provider_name}_oauth_*.json")))
862
-
863
- if not cred_files:
864
- console.print(Panel(f"No {display_name} credentials found.", style="bold red", title="No Credentials"))
 
 
 
 
 
 
 
 
 
 
865
  return
866
-
867
  exported_count = 0
868
- for cred_file in cred_files:
869
  try:
870
- with open(cred_file, 'r') as f:
871
- creds = json.load(f)
872
-
873
- email = creds.get("_proxy_metadata", {}).get("email", "unknown")
874
- cred_number = _get_credential_number_from_filename(cred_file.name)
875
-
876
- # Generate .env file name
877
- safe_email = email.replace("@", "_at_").replace(".", "_")
878
- env_filename = f"{provider_name}_{cred_number}_{safe_email}.env"
879
- env_filepath = OAUTH_BASE_DIR / env_filename
880
-
881
- # Build and write .env content
882
- env_lines = build_func(creds, cred_number)
883
- with open(env_filepath, 'w') as f:
884
- f.write('\n'.join(env_lines))
885
-
886
- console.print(f" ✓ Exported [cyan]{cred_file.name}[/cyan] → [yellow]{env_filename}[/yellow]")
887
- exported_count += 1
888
-
889
  except Exception as e:
890
- console.print(f" ✗ Failed to export {cred_file.name}: {e}")
891
-
892
- console.print(Panel(
893
- f"Successfully exported {exported_count}/{len(cred_files)} {display_name} credentials to individual .env files.",
894
- style="bold green", title="Export Complete"
895
- ))
 
 
 
 
 
896
 
897
 
898
  async def combine_provider_credentials(provider_name: str):
899
  """
900
  Combine all credentials for a specific provider into a single .env file.
 
901
  """
902
- provider_config = {
903
- "gemini_cli": ("GEMINI_CLI", _build_gemini_cli_env_lines),
904
- "qwen_code": ("QWEN_CODE", _build_qwen_code_env_lines),
905
- "iflow": ("IFLOW", _build_iflow_env_lines),
906
- "antigravity": ("ANTIGRAVITY", _build_antigravity_env_lines),
907
- }
908
-
909
- if provider_name not in provider_config:
910
  console.print(f"[bold red]Unknown provider: {provider_name}[/bold red]")
911
  return
912
-
913
- prefix, build_func = provider_config[provider_name]
914
- display_name = prefix.replace("_", " ").title()
915
-
916
- console.print(Panel(f"[bold cyan]Combine All {display_name} Credentials[/bold cyan]", expand=False))
917
-
918
- # Find all credentials for this provider
919
- cred_files = sorted(list(OAUTH_BASE_DIR.glob(f"{provider_name}_oauth_*.json")))
920
-
921
- if not cred_files:
922
- console.print(Panel(f"No {display_name} credentials found.", style="bold red", title="No Credentials"))
 
 
 
 
 
 
 
 
 
 
923
  return
924
-
925
  combined_lines = [
926
  f"# Combined {display_name} Credentials",
927
  f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
928
- f"# Total credentials: {len(cred_files)}",
929
  "#",
930
  "# Copy all lines below into your main .env file",
931
  "",
932
  ]
933
-
934
  combined_count = 0
935
- for cred_file in cred_files:
936
  try:
937
- with open(cred_file, 'r') as f:
 
938
  creds = json.load(f)
939
-
940
- cred_number = _get_credential_number_from_filename(cred_file.name)
941
- env_lines = build_func(creds, cred_number)
942
-
943
  combined_lines.extend(env_lines)
944
  combined_lines.append("") # Blank line between credentials
945
  combined_count += 1
946
-
947
  except Exception as e:
948
- console.print(f" ✗ Failed to process {cred_file.name}: {e}")
949
-
 
 
950
  # Write combined file
951
  combined_filename = f"{provider_name}_all_combined.env"
952
- combined_filepath = OAUTH_BASE_DIR / combined_filename
953
-
954
- with open(combined_filepath, 'w') as f:
955
- f.write('\n'.join(combined_lines))
956
-
957
- console.print(Panel(
958
- Text.from_markup(
959
- f"Successfully combined {combined_count} {display_name} credentials into:\n"
960
- f"[bold yellow]{combined_filepath}[/bold yellow]\n\n"
961
- f"[bold]To use:[/bold] Copy the contents into your main .env file."
962
- ),
963
- style="bold green", title="Combine Complete"
964
- ))
 
 
 
965
 
966
 
967
  async def combine_all_credentials():
968
  """
969
  Combine ALL credentials from ALL providers into a single .env file.
 
970
  """
971
- console.print(Panel("[bold cyan]Combine All Provider Credentials[/bold cyan]", expand=False))
972
-
973
- provider_config = {
974
- "gemini_cli": ("GEMINI_CLI", _build_gemini_cli_env_lines),
975
- "qwen_code": ("QWEN_CODE", _build_qwen_code_env_lines),
976
- "iflow": ("IFLOW", _build_iflow_env_lines),
977
- "antigravity": ("ANTIGRAVITY", _build_antigravity_env_lines),
978
- }
979
-
980
  combined_lines = [
981
  "# Combined All Provider Credentials",
982
  f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
@@ -984,63 +930,83 @@ async def combine_all_credentials():
984
  "# Copy all lines below into your main .env file",
985
  "",
986
  ]
987
-
988
  total_count = 0
989
  provider_counts = {}
990
-
991
- for provider_name, (prefix, build_func) in provider_config.items():
992
- cred_files = sorted(list(OAUTH_BASE_DIR.glob(f"{provider_name}_oauth_*.json")))
993
-
994
- if not cred_files:
 
 
 
 
 
 
995
  continue
996
-
997
- display_name = prefix.replace("_", " ").title()
998
  combined_lines.append(f"# ===== {display_name} Credentials =====")
999
  combined_lines.append("")
1000
-
1001
  provider_count = 0
1002
- for cred_file in cred_files:
1003
  try:
1004
- with open(cred_file, 'r') as f:
 
1005
  creds = json.load(f)
1006
-
1007
- cred_number = _get_credential_number_from_filename(cred_file.name)
1008
- env_lines = build_func(creds, cred_number)
1009
-
1010
  combined_lines.extend(env_lines)
1011
  combined_lines.append("")
1012
  provider_count += 1
1013
  total_count += 1
1014
-
1015
  except Exception as e:
1016
- console.print(f" ✗ Failed to process {cred_file.name}: {e}")
1017
-
 
 
1018
  provider_counts[display_name] = provider_count
1019
-
1020
  if total_count == 0:
1021
- console.print(Panel("No credentials found to combine.", style="bold red", title="No Credentials"))
 
 
 
 
 
 
1022
  return
1023
-
1024
  # Write combined file
1025
  combined_filename = "all_providers_combined.env"
1026
- combined_filepath = OAUTH_BASE_DIR / combined_filename
1027
-
1028
- with open(combined_filepath, 'w') as f:
1029
- f.write('\n'.join(combined_lines))
1030
-
1031
  # Build summary
1032
- summary_lines = [f" • {name}: {count} credential(s)" for name, count in provider_counts.items()]
 
 
1033
  summary = "\n".join(summary_lines)
1034
-
1035
- console.print(Panel(
1036
- Text.from_markup(
1037
- f"Successfully combined {total_count} credentials from {len(provider_counts)} providers:\n"
1038
- f"{summary}\n\n"
1039
- f"[bold]Output file:[/bold] [yellow]{combined_filepath}[/yellow]\n\n"
1040
- f"[bold]To use:[/bold] Copy the contents into your main .env file."
1041
- ),
1042
- style="bold green", title="Combine Complete"
1043
- ))
 
 
 
1044
 
1045
 
1046
  async def export_credentials_submenu():
@@ -1049,40 +1015,65 @@ async def export_credentials_submenu():
1049
  """
1050
  while True:
1051
  clear_screen()
1052
- console.print(Panel("[bold cyan]Export Credentials to .env[/bold cyan]", title="--- API Key Proxy ---", expand=False))
1053
-
1054
- console.print(Panel(
1055
- Text.from_markup(
1056
- "[bold]Individual Exports:[/bold]\n"
1057
- "1. Export Gemini CLI credential\n"
1058
- "2. Export Qwen Code credential\n"
1059
- "3. Export iFlow credential\n"
1060
- "4. Export Antigravity credential\n"
1061
- "\n"
1062
- "[bold]Bulk Exports (per provider):[/bold]\n"
1063
- "5. Export ALL Gemini CLI credentials\n"
1064
- "6. Export ALL Qwen Code credentials\n"
1065
- "7. Export ALL iFlow credentials\n"
1066
- "8. Export ALL Antigravity credentials\n"
1067
- "\n"
1068
- "[bold]Combine Credentials:[/bold]\n"
1069
- "9. Combine all Gemini CLI into one file\n"
1070
- "10. Combine all Qwen Code into one file\n"
1071
- "11. Combine all iFlow into one file\n"
1072
- "12. Combine all Antigravity into one file\n"
1073
- "13. Combine ALL providers into one file"
1074
- ),
1075
- title="Choose export option",
1076
- style="bold blue"
1077
- ))
 
 
 
 
 
 
 
 
1078
 
1079
  export_choice = Prompt.ask(
1080
- Text.from_markup("[bold]Please select an option or type [red]'b'[/red] to go back[/bold]"),
1081
- choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "b"],
1082
- show_choices=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1083
  )
1084
 
1085
- if export_choice.lower() == 'b':
1086
  break
1087
 
1088
  # Individual exports
@@ -1146,39 +1137,53 @@ async def export_credentials_submenu():
1146
  async def main(clear_on_start=True):
1147
  """
1148
  An interactive CLI tool to add new credentials.
1149
-
1150
  Args:
1151
- clear_on_start: If False, skip initial screen clear (used when called from launcher
1152
  to preserve the loading screen)
1153
  """
1154
  ensure_env_defaults()
1155
-
1156
  # Only show header if we're clearing (standalone mode)
1157
  if clear_on_start:
1158
- console.print(Panel("[bold cyan]Interactive Credential Setup[/bold cyan]", title="--- API Key Proxy ---", expand=False))
1159
-
 
 
 
 
 
 
1160
  while True:
1161
  # Clear screen between menu selections for cleaner UX
1162
  clear_screen()
1163
- console.print(Panel("[bold cyan]Interactive Credential Setup[/bold cyan]", title="--- API Key Proxy ---", expand=False))
1164
-
1165
- console.print(Panel(
1166
- Text.from_markup(
1167
- "1. Add OAuth Credential\n"
1168
- "2. Add API Key\n"
1169
- "3. Export Credentials"
1170
- ),
1171
- title="Choose credential type",
1172
- style="bold blue"
1173
- ))
 
 
 
 
 
 
1174
 
1175
  setup_type = Prompt.ask(
1176
- Text.from_markup("[bold]Please select an option or type [red]'q'[/red] to quit[/bold]"),
 
 
1177
  choices=["1", "2", "3", "q"],
1178
- show_choices=False
1179
  )
1180
 
1181
- if setup_type.lower() == 'q':
1182
  break
1183
 
1184
  if setup_type == "1":
@@ -1190,69 +1195,88 @@ async def main(clear_on_start=True):
1190
  "iflow": "iFlow (OAuth - also supports API keys)",
1191
  "antigravity": "Antigravity (OAuth)",
1192
  }
1193
-
1194
  provider_text = Text()
1195
  for i, provider in enumerate(available_providers):
1196
- display_name = oauth_friendly_names.get(provider, provider.replace('_', ' ').title())
 
 
1197
  provider_text.append(f" {i + 1}. {display_name}\n")
1198
-
1199
- console.print(Panel(provider_text, title="Available Providers for OAuth", style="bold blue"))
 
 
 
 
 
 
1200
 
1201
  choice = Prompt.ask(
1202
- Text.from_markup("[bold]Please select a provider or type [red]'b'[/red] to go back[/bold]"),
 
 
1203
  choices=[str(i + 1) for i in range(len(available_providers))] + ["b"],
1204
- show_choices=False
1205
  )
1206
 
1207
- if choice.lower() == 'b':
1208
  continue
1209
-
1210
  try:
1211
  choice_index = int(choice) - 1
1212
  if 0 <= choice_index < len(available_providers):
1213
  provider_name = available_providers[choice_index]
1214
- display_name = oauth_friendly_names.get(provider_name, provider_name.replace('_', ' ').title())
1215
- console.print(f"\nStarting OAuth setup for [bold cyan]{display_name}[/bold cyan]...")
 
 
 
 
1216
  await setup_new_credential(provider_name)
1217
  # Don't clear after OAuth - user needs to see full flow
1218
  console.print("\n[dim]Press Enter to return to main menu...[/dim]")
1219
  input()
1220
  else:
1221
- console.print("[bold red]Invalid choice. Please try again.[/bold red]")
 
 
1222
  await asyncio.sleep(1.5)
1223
  except ValueError:
1224
- console.print("[bold red]Invalid input. Please enter a number or 'b'.[/bold red]")
 
 
1225
  await asyncio.sleep(1.5)
1226
 
1227
  elif setup_type == "2":
1228
  await setup_api_key()
1229
- #console.print("\n[dim]Press Enter to return to main menu...[/dim]")
1230
- #input()
1231
 
1232
  elif setup_type == "3":
1233
  await export_credentials_submenu()
1234
 
 
1235
  def run_credential_tool(from_launcher=False):
1236
  """
1237
  Entry point for credential tool.
1238
-
1239
  Args:
1240
  from_launcher: If True, skip loading screen (launcher already showed it)
1241
  """
1242
  # Check if we need to show loading screen
1243
  if not from_launcher:
1244
  # Standalone mode - show full loading UI
1245
- os.system('cls' if os.name == 'nt' else 'clear')
1246
-
1247
  _start_time = time.time()
1248
-
1249
  # Phase 1: Show initial message
1250
  print("━" * 70)
1251
  print("Interactive Credential Setup Tool")
1252
  print("GitHub: https://github.com/Mirrowel/LLM-API-Key-Proxy")
1253
  print("━" * 70)
1254
  print("Loading credential management components...")
1255
-
1256
  # Phase 2: Load dependencies with spinner
1257
  with console.status("Loading authentication providers...", spinner="dots"):
1258
  _ensure_providers_loaded()
@@ -1261,14 +1285,16 @@ def run_credential_tool(from_launcher=False):
1261
  with console.status("Initializing credential tool...", spinner="dots"):
1262
  time.sleep(0.2) # Brief pause for UI consistency
1263
  console.print("✓ Credential tool initialized")
1264
-
1265
  _elapsed = time.time() - _start_time
1266
  _, PROVIDER_PLUGINS = _ensure_providers_loaded()
1267
- print(f"✓ Tool ready in {_elapsed:.2f}s ({len(PROVIDER_PLUGINS)} providers available)")
1268
-
 
 
1269
  # Small delay to let user see the ready message
1270
  time.sleep(0.5)
1271
-
1272
  # Run the main async event loop
1273
  # If from launcher, don't clear screen at start to preserve loading messages
1274
  try:
 
3
  import asyncio
4
  import json
5
  import os
 
6
  import time
7
  from pathlib import Path
8
  from dotenv import set_key, get_key
9
 
10
+ # NOTE: Heavy imports (provider_factory, PROVIDER_PLUGINS) are deferred
11
  # to avoid 6-7 second delay before showing loading screen
12
  from rich.console import Console
13
  from rich.panel import Panel
14
  from rich.prompt import Prompt
15
  from rich.text import Text
16
 
17
+ from .utils.paths import get_oauth_dir, get_data_file
18
+
19
+
20
+ def _get_oauth_base_dir() -> Path:
21
+ """Get the OAuth base directory (lazy, respects EXE vs script mode)."""
22
+ oauth_dir = get_oauth_dir()
23
+ oauth_dir.mkdir(parents=True, exist_ok=True)
24
+ return oauth_dir
25
+
26
+
27
+ def _get_env_file() -> Path:
28
+ """Get the .env file path (lazy, respects EXE vs script mode)."""
29
+ return get_data_file(".env")
30
+
31
 
32
  console = Console()
33
 
 
35
  _provider_factory = None
36
  _provider_plugins = None
37
 
38
+
39
  def _ensure_providers_loaded():
40
  """Lazy load provider modules only when needed"""
41
  global _provider_factory, _provider_plugins
42
  if _provider_factory is None:
43
  from . import provider_factory as pf
44
  from .providers import PROVIDER_PLUGINS as pp
45
+
46
  _provider_factory = pf
47
  _provider_plugins = pp
48
  return _provider_factory, _provider_plugins
 
50
 
51
  def clear_screen():
52
  """
53
+ Cross-platform terminal clear that works robustly on both
54
  classic Windows conhost and modern terminals (Windows Terminal, Linux, Mac).
55
+
56
  Uses native OS commands instead of ANSI escape sequences:
57
  - Windows (conhost & Windows Terminal): cls
58
  - Unix-like systems (Linux, Mac): clear
59
  """
60
+ os.system("cls" if os.name == "nt" else "clear")
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  def ensure_env_defaults():
64
  """
65
  Ensures the .env file exists and contains essential default values like PROXY_API_KEY.
66
  """
67
+ if not _get_env_file().is_file():
68
+ _get_env_file().touch()
69
+ console.print(
70
+ f"Creating a new [bold yellow]{_get_env_file().name}[/bold yellow] file..."
71
+ )
72
 
73
  # Check for PROXY_API_KEY, similar to setup_env.bat
74
+ if get_key(str(_get_env_file()), "PROXY_API_KEY") is None:
75
  default_key = "VerysecretKey"
76
+ console.print(
77
+ f"Adding default [bold cyan]PROXY_API_KEY[/bold cyan] to [bold yellow]{_get_env_file().name}[/bold yellow]..."
78
+ )
79
+ set_key(str(_get_env_file()), "PROXY_API_KEY", default_key)
80
+
81
 
82
  async def setup_api_key():
83
  """
 
90
 
91
  # Verified list of LiteLLM providers with their friendly names and API key variables
92
  LITELLM_PROVIDERS = {
93
+ "OpenAI": "OPENAI_API_KEY",
94
+ "Anthropic": "ANTHROPIC_API_KEY",
95
+ "Google AI Studio (Gemini)": "GEMINI_API_KEY",
96
+ "Azure OpenAI": "AZURE_API_KEY",
97
+ "Vertex AI": "GOOGLE_API_KEY",
98
+ "AWS Bedrock": "AWS_ACCESS_KEY_ID",
99
+ "Cohere": "COHERE_API_KEY",
100
+ "Chutes": "CHUTES_API_KEY",
101
  "Mistral AI": "MISTRAL_API_KEY",
102
+ "Codestral (Mistral)": "CODESTRAL_API_KEY",
103
+ "Groq": "GROQ_API_KEY",
104
+ "Perplexity": "PERPLEXITYAI_API_KEY",
105
+ "xAI": "XAI_API_KEY",
106
+ "Together AI": "TOGETHERAI_API_KEY",
107
+ "Fireworks AI": "FIREWORKS_AI_API_KEY",
108
+ "Replicate": "REPLICATE_API_KEY",
109
+ "Hugging Face": "HUGGINGFACE_API_KEY",
110
+ "Anyscale": "ANYSCALE_API_KEY",
111
+ "NVIDIA NIM": "NVIDIA_NIM_API_KEY",
112
+ "Deepseek": "DEEPSEEK_API_KEY",
113
+ "AI21": "AI21_API_KEY",
114
+ "Cerebras": "CEREBRAS_API_KEY",
115
+ "Moonshot": "MOONSHOT_API_KEY",
116
+ "Ollama": "OLLAMA_API_KEY",
117
+ "Xinference": "XINFERENCE_API_KEY",
118
+ "Infinity": "INFINITY_API_KEY",
119
+ "OpenRouter": "OPENROUTER_API_KEY",
120
+ "Deepinfra": "DEEPINFRA_API_KEY",
121
+ "Cloudflare": "CLOUDFLARE_API_KEY",
122
+ "Baseten": "BASETEN_API_KEY",
123
+ "Modal": "MODAL_API_KEY",
124
+ "Databricks": "DATABRICKS_API_KEY",
125
+ "AWS SageMaker": "AWS_ACCESS_KEY_ID",
126
+ "IBM watsonx.ai": "WATSONX_APIKEY",
127
+ "Predibase": "PREDIBASE_API_KEY",
128
+ "Clarifai": "CLARIFAI_API_KEY",
129
+ "NLP Cloud": "NLP_CLOUD_API_KEY",
130
+ "Voyage AI": "VOYAGE_API_KEY",
131
+ "Jina AI": "JINA_API_KEY",
132
+ "Hyperbolic": "HYPERBOLIC_API_KEY",
133
+ "Morph": "MORPH_API_KEY",
134
+ "Lambda AI": "LAMBDA_API_KEY",
135
+ "Novita AI": "NOVITA_API_KEY",
136
+ "Aleph Alpha": "ALEPH_ALPHA_API_KEY",
137
+ "SambaNova": "SAMBANOVA_API_KEY",
138
+ "FriendliAI": "FRIENDLI_TOKEN",
139
+ "Galadriel": "GALADRIEL_API_KEY",
140
+ "CompactifAI": "COMPACTIFAI_API_KEY",
141
+ "Lemonade": "LEMONADE_API_KEY",
142
+ "GradientAI": "GRADIENTAI_API_KEY",
143
+ "Featherless AI": "FEATHERLESS_AI_API_KEY",
144
+ "Nebius AI Studio": "NEBIUS_API_KEY",
145
+ "Dashscope (Qwen)": "DASHSCOPE_API_KEY",
146
+ "Bytez": "BYTEZ_API_KEY",
147
+ "Oracle OCI": "OCI_API_KEY",
148
+ "DataRobot": "DATAROBOT_API_KEY",
149
+ "OVHCloud": "OVHCLOUD_API_KEY",
150
+ "Volcengine": "VOLCENGINE_API_KEY",
151
+ "Snowflake": "SNOWFLAKE_API_KEY",
152
+ "Nscale": "NSCALE_API_KEY",
153
+ "Recraft": "RECRAFT_API_KEY",
154
+ "v0": "V0_API_KEY",
155
+ "Vercel": "VERCEL_AI_GATEWAY_API_KEY",
156
+ "Topaz": "TOPAZ_API_KEY",
157
+ "ElevenLabs": "ELEVENLABS_API_KEY",
158
  "Deepgram": "DEEPGRAM_API_KEY",
159
+ "GitHub Models": "GITHUB_TOKEN",
160
+ "GitHub Copilot": "GITHUB_COPILOT_API_KEY",
161
  }
162
 
163
  # Discover custom providers and add them to the list
 
165
  # qwen_code API key support is a fallback
166
  # iflow API key support is a feature
167
  _, PROVIDER_PLUGINS = _ensure_providers_loaded()
168
+
169
  # Build a set of environment variables already in LITELLM_PROVIDERS
170
  # to avoid duplicates based on the actual API key names
171
  litellm_env_vars = set(LITELLM_PROVIDERS.values())
172
+
173
  # Providers to exclude from API key list
174
  exclude_providers = {
175
+ "gemini_cli", # OAuth-only
176
+ "antigravity", # OAuth-only
177
+ "qwen_code", # API key is fallback, OAuth is primary - don't advertise
178
+ "openai_compatible", # Base class, not a real provider
179
  }
180
+
181
  discovered_providers = {}
182
  for provider_key in PROVIDER_PLUGINS.keys():
183
  if provider_key in exclude_providers:
184
  continue
185
+
186
  # Create environment variable name
187
  env_var = provider_key.upper() + "_API_KEY"
188
+
189
  # Check if this env var already exists in LITELLM_PROVIDERS
190
  # This catches duplicates like GEMINI_API_KEY, MISTRAL_API_KEY, etc.
191
  if env_var in litellm_env_vars:
192
  # Already in LITELLM_PROVIDERS with better name, skip this one
193
  continue
194
+
195
  # Create display name for this custom provider
196
+ display_name = provider_key.replace("_", " ").title()
197
  discovered_providers[display_name] = env_var
198
+
199
  # LITELLM_PROVIDERS takes precedence (comes first in merge)
200
  combined_providers = {**LITELLM_PROVIDERS, **discovered_providers}
201
  provider_display_list = sorted(combined_providers.keys())
 
210
  else:
211
  provider_text.append(f" {i + 1}. {provider_name}\n")
212
 
213
+ console.print(
214
+ Panel(provider_text, title="Available Providers for API Key", style="bold blue")
215
+ )
216
 
217
  choice = Prompt.ask(
218
+ Text.from_markup(
219
+ "[bold]Please select a provider or type [red]'b'[/red] to go back[/bold]"
220
+ ),
221
  choices=[str(i + 1) for i in range(len(provider_display_list))] + ["b"],
222
+ show_choices=False,
223
  )
224
 
225
+ if choice.lower() == "b":
226
  return
227
 
228
  try:
 
234
  api_key = Prompt.ask(f"Enter the API key for {display_name}")
235
 
236
  # Check for duplicate API key value
237
+ if _get_env_file().is_file():
238
+ with open(_get_env_file(), "r") as f:
239
  for line in f:
240
  line = line.strip()
241
  if line.startswith(api_var_base) and "=" in line:
242
+ existing_key_name, _, existing_key_value = line.partition(
243
+ "="
244
+ )
245
  if existing_key_value == api_key:
246
+ warning_text = Text.from_markup(
247
+ f"This API key already exists as [bold yellow]'{existing_key_name}'[/bold yellow]. Overwriting..."
248
+ )
249
+ console.print(
250
+ Panel(
251
+ warning_text,
252
+ style="bold yellow",
253
+ title="Updating API Key",
254
+ )
255
+ )
256
+
257
+ set_key(
258
+ str(_get_env_file()), existing_key_name, api_key
259
+ )
260
+
261
+ success_text = Text.from_markup(
262
+ f"Successfully updated existing key [bold yellow]'{existing_key_name}'[/bold yellow]."
263
+ )
264
+ console.print(
265
+ Panel(
266
+ success_text,
267
+ style="bold green",
268
+ title="Success",
269
+ )
270
+ )
271
  return
272
 
273
  # Special handling for AWS
274
  if display_name in ["AWS Bedrock", "AWS SageMaker"]:
275
+ console.print(
276
+ Panel(
277
+ Text.from_markup(
278
+ "This provider requires both an Access Key ID and a Secret Access Key.\n"
279
+ f"The key you entered will be saved as [bold yellow]{api_var_base}_1[/bold yellow].\n"
280
+ "Please manually add the [bold cyan]AWS_SECRET_ACCESS_KEY_1[/bold cyan] to your .env file."
281
+ ),
282
+ title="[bold yellow]Additional Step Required[/bold yellow]",
283
+ border_style="yellow",
284
+ )
285
+ )
286
 
287
  key_index = 1
288
  while True:
289
  key_name = f"{api_var_base}_{key_index}"
290
+ if _get_env_file().is_file():
291
+ with open(_get_env_file(), "r") as f:
292
  if not any(line.startswith(f"{key_name}=") for line in f):
293
  break
294
  else:
295
  break
296
  key_index += 1
297
+
298
  key_name = f"{api_var_base}_{key_index}"
299
+ set_key(str(_get_env_file()), key_name, api_key)
300
+
301
+ success_text = Text.from_markup(
302
+ f"Successfully added {display_name} API key as [bold yellow]'{key_name}'[/bold yellow]."
303
+ )
304
  console.print(Panel(success_text, style="bold green", title="Success"))
305
 
306
  else:
307
  console.print("[bold red]Invalid choice. Please try again.[/bold red]")
308
  except ValueError:
309
+ console.print(
310
+ "[bold red]Invalid input. Please enter a number or 'b'.[/bold red]"
311
+ )
312
+
313
 
314
  async def setup_new_credential(provider_name: str):
315
  """
316
  Interactively sets up a new OAuth credential for a given provider.
317
+
318
+ Delegates all credential management logic to the auth class's setup_credential() method.
319
  """
320
  try:
321
  provider_factory, _ = _ensure_providers_loaded()
 
327
  "gemini_cli": "Gemini CLI (OAuth)",
328
  "qwen_code": "Qwen Code (OAuth - also supports API keys)",
329
  "iflow": "iFlow (OAuth - also supports API keys)",
330
+ "antigravity": "Antigravity (OAuth)",
331
  }
332
+ display_name = oauth_friendly_names.get(
333
+ provider_name, provider_name.replace("_", " ").title()
334
+ )
 
 
 
 
 
 
 
 
 
 
335
 
336
+ # Call the auth class's setup_credential() method which handles the entire flow:
337
+ # - OAuth authentication
338
+ # - Email extraction for deduplication
339
+ # - File path determination (new or existing)
340
+ # - Credential file saving
341
+ # - Post-auth discovery (tier/project for Google OAuth providers)
342
+ result = await auth_instance.setup_credential(_get_oauth_base_dir())
343
+
344
+ if not result.success:
345
+ console.print(
346
+ Panel(
347
+ f"Credential setup failed: {result.error}",
348
+ style="bold red",
349
+ title="Error",
350
+ )
351
+ )
352
  return
353
 
354
+ # Display success message with details
355
+ if result.is_update:
356
+ success_text = Text.from_markup(
357
+ f"Successfully updated credential at [bold yellow]'{Path(result.file_path).name}'[/bold yellow] "
358
+ f"for user [bold cyan]'{result.email}'[/bold cyan]."
359
+ )
360
+ else:
361
+ success_text = Text.from_markup(
362
+ f"Successfully created new credential at [bold yellow]'{Path(result.file_path).name}'[/bold yellow] "
363
+ f"for user [bold cyan]'{result.email}'[/bold cyan]."
364
+ )
365
 
366
+ # Add tier/project info if available (Google OAuth providers)
367
+ if hasattr(result, "tier") and result.tier:
368
+ success_text.append(f"\nTier: {result.tier}")
369
+ if hasattr(result, "project_id") and result.project_id:
370
+ success_text.append(f"\nProject: {result.project_id}")
371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  console.print(Panel(success_text, style="bold green", title="Success"))
373
 
374
  except Exception as e:
375
+ console.print(
376
+ Panel(
377
+ f"An error occurred during setup for {provider_name}: {e}",
378
+ style="bold red",
379
+ title="Error",
380
+ )
381
+ )
382
 
383
 
384
  async def export_gemini_cli_to_env():
385
  """
386
  Export a Gemini CLI credential JSON file to .env format.
387
+ Uses the auth class's build_env_lines() and list_credentials() methods.
388
  """
389
+ console.print(
390
+ Panel(
391
+ "[bold cyan]Export Gemini CLI Credential to .env[/bold cyan]", expand=False
392
+ )
393
+ )
394
 
395
+ # Get auth instance for this provider
396
+ provider_factory, _ = _ensure_providers_loaded()
397
+ auth_class = provider_factory.get_provider_auth_class("gemini_cli")
398
+ auth_instance = auth_class()
399
 
400
+ # List available credentials using auth class
401
+ credentials = auth_instance.list_credentials(_get_oauth_base_dir())
402
+
403
+ if not credentials:
404
+ console.print(
405
+ Panel(
406
+ "No Gemini CLI credentials found. Please add one first using 'Add OAuth Credential'.",
407
+ style="bold red",
408
+ title="No Credentials",
409
+ )
410
+ )
411
  return
412
 
413
  # Display available credentials
414
  cred_text = Text()
415
+ for i, cred_info in enumerate(credentials):
416
+ cred_text.append(
417
+ f" {i + 1}. {Path(cred_info['file_path']).name} ({cred_info['email']})\n"
418
+ )
 
 
 
 
419
 
420
+ console.print(
421
+ Panel(cred_text, title="Available Gemini CLI Credentials", style="bold blue")
422
+ )
423
 
424
  choice = Prompt.ask(
425
+ Text.from_markup(
426
+ "[bold]Please select a credential to export or type [red]'b'[/red] to go back[/bold]"
427
+ ),
428
+ choices=[str(i + 1) for i in range(len(credentials))] + ["b"],
429
+ show_choices=False,
430
  )
431
 
432
+ if choice.lower() == "b":
433
  return
434
 
435
  try:
436
  choice_index = int(choice) - 1
437
+ if 0 <= choice_index < len(credentials):
438
+ cred_info = credentials[choice_index]
 
 
 
 
439
 
440
+ # Use auth class to export
441
+ env_path = auth_instance.export_credential_to_env(
442
+ cred_info["file_path"], _get_oauth_base_dir()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  )
444
 
445
+ if env_path:
446
+ numbered_prefix = f"GEMINI_CLI_{cred_info['number']}"
447
+ success_text = Text.from_markup(
448
+ f"Successfully exported credential to [bold yellow]'{Path(env_path).name}'[/bold yellow]\n\n"
449
+ f"[bold]Environment variable prefix:[/bold] [cyan]{numbered_prefix}_*[/cyan]\n\n"
450
+ f"[bold]To use this credential:[/bold]\n"
451
+ f"1. Copy the contents to your main .env file, OR\n"
452
+ f"2. Source it: [bold cyan]source {Path(env_path).name}[/bold cyan] (Linux/Mac)\n"
453
+ f"3. Or on Windows: [bold cyan]Get-Content {Path(env_path).name} | ForEach-Object {{ $_ -replace '^([^#].*)$', 'set $1' }} | cmd[/bold cyan]\n\n"
454
+ f"[bold]To combine multiple credentials:[/bold]\n"
455
+ f"Copy lines from multiple .env files into one file.\n"
456
+ f"Each credential uses a unique number ({numbered_prefix}_*)."
457
+ )
458
+ console.print(Panel(success_text, style="bold green", title="Success"))
459
+ else:
460
+ console.print(
461
+ Panel(
462
+ "Failed to export credential", style="bold red", title="Error"
463
+ )
464
+ )
465
  else:
466
  console.print("[bold red]Invalid choice. Please try again.[/bold red]")
467
  except ValueError:
468
+ console.print(
469
+ "[bold red]Invalid input. Please enter a number or 'b'.[/bold red]"
470
+ )
471
  except Exception as e:
472
+ console.print(
473
+ Panel(
474
+ f"An error occurred during export: {e}", style="bold red", title="Error"
475
+ )
476
+ )
477
 
478
 
479
  async def export_qwen_code_to_env():
480
  """
481
  Export a Qwen Code credential JSON file to .env format.
482
+ Uses the auth class's build_env_lines() and list_credentials() methods.
483
  """
484
+ console.print(
485
+ Panel(
486
+ "[bold cyan]Export Qwen Code Credential to .env[/bold cyan]", expand=False
487
+ )
488
+ )
489
+
490
+ # Get auth instance for this provider
491
+ provider_factory, _ = _ensure_providers_loaded()
492
+ auth_class = provider_factory.get_provider_auth_class("qwen_code")
493
+ auth_instance = auth_class()
494
 
495
+ # List available credentials using auth class
496
+ credentials = auth_instance.list_credentials(_get_oauth_base_dir())
497
 
498
+ if not credentials:
499
+ console.print(
500
+ Panel(
501
+ "No Qwen Code credentials found. Please add one first using 'Add OAuth Credential'.",
502
+ style="bold red",
503
+ title="No Credentials",
504
+ )
505
+ )
506
  return
507
 
508
  # Display available credentials
509
  cred_text = Text()
510
+ for i, cred_info in enumerate(credentials):
511
+ cred_text.append(
512
+ f" {i + 1}. {Path(cred_info['file_path']).name} ({cred_info['email']})\n"
513
+ )
 
 
 
 
514
 
515
+ console.print(
516
+ Panel(cred_text, title="Available Qwen Code Credentials", style="bold blue")
517
+ )
518
 
519
  choice = Prompt.ask(
520
+ Text.from_markup(
521
+ "[bold]Please select a credential to export or type [red]'b'[/red] to go back[/bold]"
522
+ ),
523
+ choices=[str(i + 1) for i in range(len(credentials))] + ["b"],
524
+ show_choices=False,
525
  )
526
 
527
+ if choice.lower() == "b":
528
  return
529
 
530
  try:
531
  choice_index = int(choice) - 1
532
+ if 0 <= choice_index < len(credentials):
533
+ cred_info = credentials[choice_index]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
534
 
535
+ # Use auth class to export
536
+ env_path = auth_instance.export_credential_to_env(
537
+ cred_info["file_path"], _get_oauth_base_dir()
 
 
 
 
 
 
538
  )
539
+
540
+ if env_path:
541
+ numbered_prefix = f"QWEN_CODE_{cred_info['number']}"
542
+ success_text = Text.from_markup(
543
+ f"Successfully exported credential to [bold yellow]'{Path(env_path).name}'[/bold yellow]\n\n"
544
+ f"[bold]Environment variable prefix:[/bold] [cyan]{numbered_prefix}_*[/cyan]\n\n"
545
+ f"[bold]To use this credential:[/bold]\n"
546
+ f"1. Copy the contents to your main .env file, OR\n"
547
+ f"2. Source it: [bold cyan]source {Path(env_path).name}[/bold cyan] (Linux/Mac)\n\n"
548
+ f"[bold]To combine multiple credentials:[/bold]\n"
549
+ f"Copy lines from multiple .env files into one file.\n"
550
+ f"Each credential uses a unique number ({numbered_prefix}_*)."
551
+ )
552
+ console.print(Panel(success_text, style="bold green", title="Success"))
553
+ else:
554
+ console.print(
555
+ Panel(
556
+ "Failed to export credential", style="bold red", title="Error"
557
+ )
558
+ )
559
  else:
560
  console.print("[bold red]Invalid choice. Please try again.[/bold red]")
561
  except ValueError:
562
+ console.print(
563
+ "[bold red]Invalid input. Please enter a number or 'b'.[/bold red]"
564
+ )
565
  except Exception as e:
566
+ console.print(
567
+ Panel(
568
+ f"An error occurred during export: {e}", style="bold red", title="Error"
569
+ )
570
+ )
571
 
572
 
573
  async def export_iflow_to_env():
574
  """
575
  Export an iFlow credential JSON file to .env format.
576
+ Uses the auth class's build_env_lines() and list_credentials() methods.
577
  """
578
+ console.print(
579
+ Panel("[bold cyan]Export iFlow Credential to .env[/bold cyan]", expand=False)
580
+ )
581
 
582
+ # Get auth instance for this provider
583
+ provider_factory, _ = _ensure_providers_loaded()
584
+ auth_class = provider_factory.get_provider_auth_class("iflow")
585
+ auth_instance = auth_class()
586
 
587
+ # List available credentials using auth class
588
+ credentials = auth_instance.list_credentials(_get_oauth_base_dir())
589
+
590
+ if not credentials:
591
+ console.print(
592
+ Panel(
593
+ "No iFlow credentials found. Please add one first using 'Add OAuth Credential'.",
594
+ style="bold red",
595
+ title="No Credentials",
596
+ )
597
+ )
598
  return
599
 
600
  # Display available credentials
601
  cred_text = Text()
602
+ for i, cred_info in enumerate(credentials):
603
+ cred_text.append(
604
+ f" {i + 1}. {Path(cred_info['file_path']).name} ({cred_info['email']})\n"
605
+ )
 
 
 
 
606
 
607
+ console.print(
608
+ Panel(cred_text, title="Available iFlow Credentials", style="bold blue")
609
+ )
610
 
611
  choice = Prompt.ask(
612
+ Text.from_markup(
613
+ "[bold]Please select a credential to export or type [red]'b'[/red] to go back[/bold]"
614
+ ),
615
+ choices=[str(i + 1) for i in range(len(credentials))] + ["b"],
616
+ show_choices=False,
617
  )
618
 
619
+ if choice.lower() == "b":
620
  return
621
 
622
  try:
623
  choice_index = int(choice) - 1
624
+ if 0 <= choice_index < len(credentials):
625
+ cred_info = credentials[choice_index]
626
 
627
+ # Use auth class to export
628
+ env_path = auth_instance.export_credential_to_env(
629
+ cred_info["file_path"], _get_oauth_base_dir()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
630
  )
631
+
632
+ if env_path:
633
+ numbered_prefix = f"IFLOW_{cred_info['number']}"
634
+ success_text = Text.from_markup(
635
+ f"Successfully exported credential to [bold yellow]'{Path(env_path).name}'[/bold yellow]\n\n"
636
+ f"[bold]Environment variable prefix:[/bold] [cyan]{numbered_prefix}_*[/cyan]\n\n"
637
+ f"[bold]To use this credential:[/bold]\n"
638
+ f"1. Copy the contents to your main .env file, OR\n"
639
+ f"2. Source it: [bold cyan]source {Path(env_path).name}[/bold cyan] (Linux/Mac)\n\n"
640
+ f"[bold]To combine multiple credentials:[/bold]\n"
641
+ f"Copy lines from multiple .env files into one file.\n"
642
+ f"Each credential uses a unique number ({numbered_prefix}_*)."
643
+ )
644
+ console.print(Panel(success_text, style="bold green", title="Success"))
645
+ else:
646
+ console.print(
647
+ Panel(
648
+ "Failed to export credential", style="bold red", title="Error"
649
+ )
650
+ )
651
  else:
652
  console.print("[bold red]Invalid choice. Please try again.[/bold red]")
653
  except ValueError:
654
+ console.print(
655
+ "[bold red]Invalid input. Please enter a number or 'b'.[/bold red]"
656
+ )
657
  except Exception as e:
658
+ console.print(
659
+ Panel(
660
+ f"An error occurred during export: {e}", style="bold red", title="Error"
661
+ )
662
+ )
663
 
664
 
665
  async def export_antigravity_to_env():
666
  """
667
  Export an Antigravity credential JSON file to .env format.
668
+ Uses the auth class's build_env_lines() and list_credentials() methods.
669
  """
670
+ console.print(
671
+ Panel(
672
+ "[bold cyan]Export Antigravity Credential to .env[/bold cyan]", expand=False
673
+ )
674
+ )
675
+
676
+ # Get auth instance for this provider
677
+ provider_factory, _ = _ensure_providers_loaded()
678
+ auth_class = provider_factory.get_provider_auth_class("antigravity")
679
+ auth_instance = auth_class()
680
 
681
+ # List available credentials using auth class
682
+ credentials = auth_instance.list_credentials(_get_oauth_base_dir())
683
 
684
+ if not credentials:
685
+ console.print(
686
+ Panel(
687
+ "No Antigravity credentials found. Please add one first using 'Add OAuth Credential'.",
688
+ style="bold red",
689
+ title="No Credentials",
690
+ )
691
+ )
692
  return
693
 
694
  # Display available credentials
695
  cred_text = Text()
696
+ for i, cred_info in enumerate(credentials):
697
+ cred_text.append(
698
+ f" {i + 1}. {Path(cred_info['file_path']).name} ({cred_info['email']})\n"
699
+ )
 
 
 
 
700
 
701
+ console.print(
702
+ Panel(cred_text, title="Available Antigravity Credentials", style="bold blue")
703
+ )
704
 
705
  choice = Prompt.ask(
706
+ Text.from_markup(
707
+ "[bold]Please select a credential to export or type [red]'b'[/red] to go back[/bold]"
708
+ ),
709
+ choices=[str(i + 1) for i in range(len(credentials))] + ["b"],
710
+ show_choices=False,
711
  )
712
 
713
+ if choice.lower() == "b":
714
  return
715
 
716
  try:
717
  choice_index = int(choice) - 1
718
+ if 0 <= choice_index < len(credentials):
719
+ cred_info = credentials[choice_index]
 
 
 
 
720
 
721
+ # Use auth class to export
722
+ env_path = auth_instance.export_credential_to_env(
723
+ cred_info["file_path"], _get_oauth_base_dir()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
724
  )
725
 
726
+ if env_path:
727
+ numbered_prefix = f"ANTIGRAVITY_{cred_info['number']}"
728
+ success_text = Text.from_markup(
729
+ f"Successfully exported credential to [bold yellow]'{Path(env_path).name}'[/bold yellow]\n\n"
730
+ f"[bold]Environment variable prefix:[/bold] [cyan]{numbered_prefix}_*[/cyan]\n\n"
731
+ f"[bold]To use this credential:[/bold]\n"
732
+ f"1. Copy the contents to your main .env file, OR\n"
733
+ f"2. Source it: [bold cyan]source {Path(env_path).name}[/bold cyan] (Linux/Mac)\n"
734
+ f"3. Or on Windows: [bold cyan]Get-Content {Path(env_path).name} | ForEach-Object {{ $_ -replace '^([^#].*)$', 'set $1' }} | cmd[/bold cyan]\n\n"
735
+ f"[bold]To combine multiple credentials:[/bold]\n"
736
+ f"Copy lines from multiple .env files into one file.\n"
737
+ f"Each credential uses a unique number ({numbered_prefix}_*)."
738
+ )
739
+ console.print(Panel(success_text, style="bold green", title="Success"))
740
+ else:
741
+ console.print(
742
+ Panel(
743
+ "Failed to export credential", style="bold red", title="Error"
744
+ )
745
+ )
746
  else:
747
  console.print("[bold red]Invalid choice. Please try again.[/bold red]")
748
  except ValueError:
749
+ console.print(
750
+ "[bold red]Invalid input. Please enter a number or 'b'.[/bold red]"
751
+ )
752
  except Exception as e:
753
+ console.print(
754
+ Panel(
755
+ f"An error occurred during export: {e}", style="bold red", title="Error"
756
+ )
757
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
758
 
759
 
760
  async def export_all_provider_credentials(provider_name: str):
761
  """
762
  Export all credentials for a specific provider to individual .env files.
763
+ Uses the auth class's list_credentials() and export_credential_to_env() methods.
764
  """
765
+ # Get auth instance for this provider
766
+ provider_factory, _ = _ensure_providers_loaded()
767
+ try:
768
+ auth_class = provider_factory.get_provider_auth_class(provider_name)
769
+ auth_instance = auth_class()
770
+ except Exception:
 
 
771
  console.print(f"[bold red]Unknown provider: {provider_name}[/bold red]")
772
  return
773
+
774
+ display_name = provider_name.replace("_", " ").title()
775
+
776
+ console.print(
777
+ Panel(
778
+ f"[bold cyan]Export All {display_name} Credentials[/bold cyan]",
779
+ expand=False,
780
+ )
781
+ )
782
+
783
+ # List all credentials using auth class
784
+ credentials = auth_instance.list_credentials(_get_oauth_base_dir())
785
+
786
+ if not credentials:
787
+ console.print(
788
+ Panel(
789
+ f"No {display_name} credentials found.",
790
+ style="bold red",
791
+ title="No Credentials",
792
+ )
793
+ )
794
  return
795
+
796
  exported_count = 0
797
+ for cred_info in credentials:
798
  try:
799
+ # Use auth class to export
800
+ env_path = auth_instance.export_credential_to_env(
801
+ cred_info["file_path"], _get_oauth_base_dir()
802
+ )
803
+
804
+ if env_path:
805
+ console.print(
806
+ f" ✓ Exported [cyan]{Path(cred_info['file_path']).name}[/cyan] → [yellow]{Path(env_path).name}[/yellow]"
807
+ )
808
+ exported_count += 1
809
+ else:
810
+ console.print(
811
+ f" ✗ Failed to export {Path(cred_info['file_path']).name}"
812
+ )
813
+
 
 
 
 
814
  except Exception as e:
815
+ console.print(
816
+ f" ✗ Failed to export {Path(cred_info['file_path']).name}: {e}"
817
+ )
818
+
819
+ console.print(
820
+ Panel(
821
+ f"Successfully exported {exported_count}/{len(credentials)} {display_name} credentials to individual .env files.",
822
+ style="bold green",
823
+ title="Export Complete",
824
+ )
825
+ )
826
 
827
 
828
  async def combine_provider_credentials(provider_name: str):
829
  """
830
  Combine all credentials for a specific provider into a single .env file.
831
+ Uses the auth class's list_credentials() and build_env_lines() methods.
832
  """
833
+ # Get auth instance for this provider
834
+ provider_factory, _ = _ensure_providers_loaded()
835
+ try:
836
+ auth_class = provider_factory.get_provider_auth_class(provider_name)
837
+ auth_instance = auth_class()
838
+ except Exception:
 
 
839
  console.print(f"[bold red]Unknown provider: {provider_name}[/bold red]")
840
  return
841
+
842
+ display_name = provider_name.replace("_", " ").title()
843
+
844
+ console.print(
845
+ Panel(
846
+ f"[bold cyan]Combine All {display_name} Credentials[/bold cyan]",
847
+ expand=False,
848
+ )
849
+ )
850
+
851
+ # List all credentials using auth class
852
+ credentials = auth_instance.list_credentials(_get_oauth_base_dir())
853
+
854
+ if not credentials:
855
+ console.print(
856
+ Panel(
857
+ f"No {display_name} credentials found.",
858
+ style="bold red",
859
+ title="No Credentials",
860
+ )
861
+ )
862
  return
863
+
864
  combined_lines = [
865
  f"# Combined {display_name} Credentials",
866
  f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
867
+ f"# Total credentials: {len(credentials)}",
868
  "#",
869
  "# Copy all lines below into your main .env file",
870
  "",
871
  ]
872
+
873
  combined_count = 0
874
+ for cred_info in credentials:
875
  try:
876
+ # Load credential file
877
+ with open(cred_info["file_path"], "r") as f:
878
  creds = json.load(f)
879
+
880
+ # Use auth class to build env lines
881
+ env_lines = auth_instance.build_env_lines(creds, cred_info["number"])
882
+
883
  combined_lines.extend(env_lines)
884
  combined_lines.append("") # Blank line between credentials
885
  combined_count += 1
886
+
887
  except Exception as e:
888
+ console.print(
889
+ f" ✗ Failed to process {Path(cred_info['file_path']).name}: {e}"
890
+ )
891
+
892
  # Write combined file
893
  combined_filename = f"{provider_name}_all_combined.env"
894
+ combined_filepath = _get_oauth_base_dir() / combined_filename
895
+
896
+ with open(combined_filepath, "w") as f:
897
+ f.write("\n".join(combined_lines))
898
+
899
+ console.print(
900
+ Panel(
901
+ Text.from_markup(
902
+ f"Successfully combined {combined_count} {display_name} credentials into:\n"
903
+ f"[bold yellow]{combined_filepath}[/bold yellow]\n\n"
904
+ f"[bold]To use:[/bold] Copy the contents into your main .env file."
905
+ ),
906
+ style="bold green",
907
+ title="Combine Complete",
908
+ )
909
+ )
910
 
911
 
912
  async def combine_all_credentials():
913
  """
914
  Combine ALL credentials from ALL providers into a single .env file.
915
+ Uses auth class list_credentials() and build_env_lines() methods.
916
  """
917
+ console.print(
918
+ Panel("[bold cyan]Combine All Provider Credentials[/bold cyan]", expand=False)
919
+ )
920
+
921
+ # List of providers that support OAuth credentials
922
+ oauth_providers = ["gemini_cli", "qwen_code", "iflow", "antigravity"]
923
+
924
+ provider_factory, _ = _ensure_providers_loaded()
925
+
926
  combined_lines = [
927
  "# Combined All Provider Credentials",
928
  f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
 
930
  "# Copy all lines below into your main .env file",
931
  "",
932
  ]
933
+
934
  total_count = 0
935
  provider_counts = {}
936
+
937
+ for provider_name in oauth_providers:
938
+ try:
939
+ auth_class = provider_factory.get_provider_auth_class(provider_name)
940
+ auth_instance = auth_class()
941
+ except Exception:
942
+ continue # Skip providers that don't have auth classes
943
+
944
+ credentials = auth_instance.list_credentials(_get_oauth_base_dir())
945
+
946
+ if not credentials:
947
  continue
948
+
949
+ display_name = provider_name.replace("_", " ").title()
950
  combined_lines.append(f"# ===== {display_name} Credentials =====")
951
  combined_lines.append("")
952
+
953
  provider_count = 0
954
+ for cred_info in credentials:
955
  try:
956
+ # Load credential file
957
+ with open(cred_info["file_path"], "r") as f:
958
  creds = json.load(f)
959
+
960
+ # Use auth class to build env lines
961
+ env_lines = auth_instance.build_env_lines(creds, cred_info["number"])
962
+
963
  combined_lines.extend(env_lines)
964
  combined_lines.append("")
965
  provider_count += 1
966
  total_count += 1
967
+
968
  except Exception as e:
969
+ console.print(
970
+ f" ✗ Failed to process {Path(cred_info['file_path']).name}: {e}"
971
+ )
972
+
973
  provider_counts[display_name] = provider_count
974
+
975
  if total_count == 0:
976
+ console.print(
977
+ Panel(
978
+ "No credentials found to combine.",
979
+ style="bold red",
980
+ title="No Credentials",
981
+ )
982
+ )
983
  return
984
+
985
  # Write combined file
986
  combined_filename = "all_providers_combined.env"
987
+ combined_filepath = _get_oauth_base_dir() / combined_filename
988
+
989
+ with open(combined_filepath, "w") as f:
990
+ f.write("\n".join(combined_lines))
991
+
992
  # Build summary
993
+ summary_lines = [
994
+ f" • {name}: {count} credential(s)" for name, count in provider_counts.items()
995
+ ]
996
  summary = "\n".join(summary_lines)
997
+
998
+ console.print(
999
+ Panel(
1000
+ Text.from_markup(
1001
+ f"Successfully combined {total_count} credentials from {len(provider_counts)} providers:\n"
1002
+ f"{summary}\n\n"
1003
+ f"[bold]Output file:[/bold] [yellow]{combined_filepath}[/yellow]\n\n"
1004
+ f"[bold]To use:[/bold] Copy the contents into your main .env file."
1005
+ ),
1006
+ style="bold green",
1007
+ title="Combine Complete",
1008
+ )
1009
+ )
1010
 
1011
 
1012
  async def export_credentials_submenu():
 
1015
  """
1016
  while True:
1017
  clear_screen()
1018
+ console.print(
1019
+ Panel(
1020
+ "[bold cyan]Export Credentials to .env[/bold cyan]",
1021
+ title="--- API Key Proxy ---",
1022
+ expand=False,
1023
+ )
1024
+ )
1025
+
1026
+ console.print(
1027
+ Panel(
1028
+ Text.from_markup(
1029
+ "[bold]Individual Exports:[/bold]\n"
1030
+ "1. Export Gemini CLI credential\n"
1031
+ "2. Export Qwen Code credential\n"
1032
+ "3. Export iFlow credential\n"
1033
+ "4. Export Antigravity credential\n"
1034
+ "\n"
1035
+ "[bold]Bulk Exports (per provider):[/bold]\n"
1036
+ "5. Export ALL Gemini CLI credentials\n"
1037
+ "6. Export ALL Qwen Code credentials\n"
1038
+ "7. Export ALL iFlow credentials\n"
1039
+ "8. Export ALL Antigravity credentials\n"
1040
+ "\n"
1041
+ "[bold]Combine Credentials:[/bold]\n"
1042
+ "9. Combine all Gemini CLI into one file\n"
1043
+ "10. Combine all Qwen Code into one file\n"
1044
+ "11. Combine all iFlow into one file\n"
1045
+ "12. Combine all Antigravity into one file\n"
1046
+ "13. Combine ALL providers into one file"
1047
+ ),
1048
+ title="Choose export option",
1049
+ style="bold blue",
1050
+ )
1051
+ )
1052
 
1053
  export_choice = Prompt.ask(
1054
+ Text.from_markup(
1055
+ "[bold]Please select an option or type [red]'b'[/red] to go back[/bold]"
1056
+ ),
1057
+ choices=[
1058
+ "1",
1059
+ "2",
1060
+ "3",
1061
+ "4",
1062
+ "5",
1063
+ "6",
1064
+ "7",
1065
+ "8",
1066
+ "9",
1067
+ "10",
1068
+ "11",
1069
+ "12",
1070
+ "13",
1071
+ "b",
1072
+ ],
1073
+ show_choices=False,
1074
  )
1075
 
1076
+ if export_choice.lower() == "b":
1077
  break
1078
 
1079
  # Individual exports
 
1137
  async def main(clear_on_start=True):
1138
  """
1139
  An interactive CLI tool to add new credentials.
1140
+
1141
  Args:
1142
+ clear_on_start: If False, skip initial screen clear (used when called from launcher
1143
  to preserve the loading screen)
1144
  """
1145
  ensure_env_defaults()
1146
+
1147
  # Only show header if we're clearing (standalone mode)
1148
  if clear_on_start:
1149
+ console.print(
1150
+ Panel(
1151
+ "[bold cyan]Interactive Credential Setup[/bold cyan]",
1152
+ title="--- API Key Proxy ---",
1153
+ expand=False,
1154
+ )
1155
+ )
1156
+
1157
  while True:
1158
  # Clear screen between menu selections for cleaner UX
1159
  clear_screen()
1160
+ console.print(
1161
+ Panel(
1162
+ "[bold cyan]Interactive Credential Setup[/bold cyan]",
1163
+ title="--- API Key Proxy ---",
1164
+ expand=False,
1165
+ )
1166
+ )
1167
+
1168
+ console.print(
1169
+ Panel(
1170
+ Text.from_markup(
1171
+ "1. Add OAuth Credential\n2. Add API Key\n3. Export Credentials"
1172
+ ),
1173
+ title="Choose credential type",
1174
+ style="bold blue",
1175
+ )
1176
+ )
1177
 
1178
  setup_type = Prompt.ask(
1179
+ Text.from_markup(
1180
+ "[bold]Please select an option or type [red]'q'[/red] to quit[/bold]"
1181
+ ),
1182
  choices=["1", "2", "3", "q"],
1183
+ show_choices=False,
1184
  )
1185
 
1186
+ if setup_type.lower() == "q":
1187
  break
1188
 
1189
  if setup_type == "1":
 
1195
  "iflow": "iFlow (OAuth - also supports API keys)",
1196
  "antigravity": "Antigravity (OAuth)",
1197
  }
1198
+
1199
  provider_text = Text()
1200
  for i, provider in enumerate(available_providers):
1201
+ display_name = oauth_friendly_names.get(
1202
+ provider, provider.replace("_", " ").title()
1203
+ )
1204
  provider_text.append(f" {i + 1}. {display_name}\n")
1205
+
1206
+ console.print(
1207
+ Panel(
1208
+ provider_text,
1209
+ title="Available Providers for OAuth",
1210
+ style="bold blue",
1211
+ )
1212
+ )
1213
 
1214
  choice = Prompt.ask(
1215
+ Text.from_markup(
1216
+ "[bold]Please select a provider or type [red]'b'[/red] to go back[/bold]"
1217
+ ),
1218
  choices=[str(i + 1) for i in range(len(available_providers))] + ["b"],
1219
+ show_choices=False,
1220
  )
1221
 
1222
+ if choice.lower() == "b":
1223
  continue
1224
+
1225
  try:
1226
  choice_index = int(choice) - 1
1227
  if 0 <= choice_index < len(available_providers):
1228
  provider_name = available_providers[choice_index]
1229
+ display_name = oauth_friendly_names.get(
1230
+ provider_name, provider_name.replace("_", " ").title()
1231
+ )
1232
+ console.print(
1233
+ f"\nStarting OAuth setup for [bold cyan]{display_name}[/bold cyan]..."
1234
+ )
1235
  await setup_new_credential(provider_name)
1236
  # Don't clear after OAuth - user needs to see full flow
1237
  console.print("\n[dim]Press Enter to return to main menu...[/dim]")
1238
  input()
1239
  else:
1240
+ console.print(
1241
+ "[bold red]Invalid choice. Please try again.[/bold red]"
1242
+ )
1243
  await asyncio.sleep(1.5)
1244
  except ValueError:
1245
+ console.print(
1246
+ "[bold red]Invalid input. Please enter a number or 'b'.[/bold red]"
1247
+ )
1248
  await asyncio.sleep(1.5)
1249
 
1250
  elif setup_type == "2":
1251
  await setup_api_key()
1252
+ # console.print("\n[dim]Press Enter to return to main menu...[/dim]")
1253
+ # input()
1254
 
1255
  elif setup_type == "3":
1256
  await export_credentials_submenu()
1257
 
1258
+
1259
  def run_credential_tool(from_launcher=False):
1260
  """
1261
  Entry point for credential tool.
1262
+
1263
  Args:
1264
  from_launcher: If True, skip loading screen (launcher already showed it)
1265
  """
1266
  # Check if we need to show loading screen
1267
  if not from_launcher:
1268
  # Standalone mode - show full loading UI
1269
+ os.system("cls" if os.name == "nt" else "clear")
1270
+
1271
  _start_time = time.time()
1272
+
1273
  # Phase 1: Show initial message
1274
  print("━" * 70)
1275
  print("Interactive Credential Setup Tool")
1276
  print("GitHub: https://github.com/Mirrowel/LLM-API-Key-Proxy")
1277
  print("━" * 70)
1278
  print("Loading credential management components...")
1279
+
1280
  # Phase 2: Load dependencies with spinner
1281
  with console.status("Loading authentication providers...", spinner="dots"):
1282
  _ensure_providers_loaded()
 
1285
  with console.status("Initializing credential tool...", spinner="dots"):
1286
  time.sleep(0.2) # Brief pause for UI consistency
1287
  console.print("✓ Credential tool initialized")
1288
+
1289
  _elapsed = time.time() - _start_time
1290
  _, PROVIDER_PLUGINS = _ensure_providers_loaded()
1291
+ print(
1292
+ f"✓ Tool ready in {_elapsed:.2f}s ({len(PROVIDER_PLUGINS)} providers available)"
1293
+ )
1294
+
1295
  # Small delay to let user see the ready message
1296
  time.sleep(0.5)
1297
+
1298
  # Run the main async event loop
1299
  # If from launcher, don't clear screen at start to preserve loading messages
1300
  try:
src/rotator_library/failure_logger.py CHANGED
@@ -1,47 +1,93 @@
1
  import logging
2
  import json
3
  from logging.handlers import RotatingFileHandler
4
- import os
5
  from datetime import datetime
 
 
6
  from .error_handler import mask_credential
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
 
 
 
8
 
9
- def setup_failure_logger():
10
- """Sets up a dedicated JSON logger for writing detailed failure logs to a file."""
11
- log_dir = "logs"
12
- if not os.path.exists(log_dir):
13
- os.makedirs(log_dir)
14
 
15
- # Create a logger specifically for failures.
16
- # This logger will NOT propagate to the root logger.
 
17
  logger = logging.getLogger("failure_logger")
18
  logger.setLevel(logging.INFO)
19
  logger.propagate = False
20
 
21
- # Use a rotating file handler
22
- handler = RotatingFileHandler(
23
- os.path.join(log_dir, "failures.log"),
24
- maxBytes=5 * 1024 * 1024, # 5 MB
25
- backupCount=2,
26
- )
27
-
28
- # Custom JSON formatter for structured logs
29
- class JsonFormatter(logging.Formatter):
30
- def format(self, record):
31
- # The message is already a dict, so we just format it as a JSON string
32
- return json.dumps(record.msg)
33
 
34
- handler.setFormatter(JsonFormatter())
 
35
 
36
- # Add handler only if it hasn't been added before
37
- if not logger.handlers:
 
 
 
 
38
  logger.addHandler(handler)
 
 
 
 
39
 
40
  return logger
41
 
42
 
43
- # Initialize the dedicated logger for detailed failure logs
44
- failure_logger = setup_failure_logger()
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  # Get the main library logger for concise, propagated messages
47
  main_lib_logger = logging.getLogger("rotator_library")
@@ -52,10 +98,27 @@ def _extract_response_body(error: Exception) -> str:
52
  Extract the full response body from various error types.
53
 
54
  Handles:
 
55
  - httpx.HTTPStatusError: response.text or response.content
56
  - litellm exceptions: various response attributes
57
  - Other exceptions: str(error)
58
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  # Try to get response body from httpx errors
60
  if hasattr(error, "response") and error.response is not None:
61
  response = error.response
@@ -145,11 +208,19 @@ def log_failure(
145
  "request_headers": request_headers,
146
  "error_chain": error_chain if len(error_chain) > 1 else None,
147
  }
148
- failure_logger.error(detailed_log_data)
149
 
150
  # 2. Log a concise summary to the main library logger, which will propagate
151
  summary_message = (
152
  f"API call failed for model {model} with key {mask_credential(api_key)}. "
153
  f"Error: {type(error).__name__}. See failures.log for details."
154
  )
 
 
 
 
 
 
 
 
 
155
  main_lib_logger.error(summary_message)
 
1
  import logging
2
  import json
3
  from logging.handlers import RotatingFileHandler
4
+ from pathlib import Path
5
  from datetime import datetime
6
+ from typing import Optional, Union
7
+
8
  from .error_handler import mask_credential
9
+ from .utils.paths import get_logs_dir
10
+
11
+
12
+ class JsonFormatter(logging.Formatter):
13
+ """Custom JSON formatter for structured logs."""
14
+
15
+ def format(self, record):
16
+ # The message is already a dict, so we just format it as a JSON string
17
+ return json.dumps(record.msg)
18
+
19
+
20
+ # Module-level state for lazy initialization
21
+ _failure_logger: Optional[logging.Logger] = None
22
+ _configured_logs_dir: Optional[Path] = None
23
+
24
+
25
+ def configure_failure_logger(logs_dir: Optional[Union[Path, str]] = None) -> None:
26
+ """
27
+ Configure the failure logger to use a specific logs directory.
28
+
29
+ Call this before first use if you want to override the default location.
30
+ If not called, the logger will use get_logs_dir() on first use.
31
+
32
+ Args:
33
+ logs_dir: Path to the logs directory. If None, uses get_logs_dir().
34
+ """
35
+ global _configured_logs_dir, _failure_logger
36
+ _configured_logs_dir = Path(logs_dir) if logs_dir else None
37
+ # Reset logger so it gets reconfigured on next use
38
+ _failure_logger = None
39
+
40
 
41
+ def _setup_failure_logger(logs_dir: Path) -> logging.Logger:
42
+ """
43
+ Sets up a dedicated JSON logger for writing detailed failure logs to a file.
44
 
45
+ Args:
46
+ logs_dir: Path to the logs directory.
 
 
 
47
 
48
+ Returns:
49
+ Configured logger instance.
50
+ """
51
  logger = logging.getLogger("failure_logger")
52
  logger.setLevel(logging.INFO)
53
  logger.propagate = False
54
 
55
+ # Clear existing handlers to prevent duplicates on re-setup
56
+ logger.handlers.clear()
 
 
 
 
 
 
 
 
 
 
57
 
58
+ try:
59
+ logs_dir.mkdir(parents=True, exist_ok=True)
60
 
61
+ handler = RotatingFileHandler(
62
+ logs_dir / "failures.log",
63
+ maxBytes=5 * 1024 * 1024, # 5 MB
64
+ backupCount=2,
65
+ )
66
+ handler.setFormatter(JsonFormatter())
67
  logger.addHandler(handler)
68
+ except (OSError, PermissionError, IOError) as e:
69
+ logging.warning(f"Cannot create failure log file handler: {e}")
70
+ # Add NullHandler to prevent "no handlers" warning
71
+ logger.addHandler(logging.NullHandler())
72
 
73
  return logger
74
 
75
 
76
+ def get_failure_logger() -> logging.Logger:
77
+ """
78
+ Get the failure logger, initializing it lazily if needed.
79
+
80
+ Returns:
81
+ The configured failure logger.
82
+ """
83
+ global _failure_logger, _configured_logs_dir
84
+
85
+ if _failure_logger is None:
86
+ logs_dir = _configured_logs_dir if _configured_logs_dir else get_logs_dir()
87
+ _failure_logger = _setup_failure_logger(logs_dir)
88
+
89
+ return _failure_logger
90
+
91
 
92
  # Get the main library logger for concise, propagated messages
93
  main_lib_logger = logging.getLogger("rotator_library")
 
98
  Extract the full response body from various error types.
99
 
100
  Handles:
101
+ - StreamedAPIError: wraps original exception in .data attribute
102
  - httpx.HTTPStatusError: response.text or response.content
103
  - litellm exceptions: various response attributes
104
  - Other exceptions: str(error)
105
  """
106
+ # Handle StreamedAPIError which wraps the original exception in .data
107
+ # This is used by our streaming wrapper when catching provider errors
108
+ if hasattr(error, "data") and error.data is not None:
109
+ inner = error.data
110
+ # If data is a dict (parsed JSON error), return it as JSON
111
+ if isinstance(inner, dict):
112
+ try:
113
+ return json.dumps(inner, indent=2)
114
+ except Exception:
115
+ return str(inner)
116
+ # If data is an exception, recurse to extract from it
117
+ if isinstance(inner, Exception):
118
+ result = _extract_response_body(inner)
119
+ if result:
120
+ return result
121
+
122
  # Try to get response body from httpx errors
123
  if hasattr(error, "response") and error.response is not None:
124
  response = error.response
 
208
  "request_headers": request_headers,
209
  "error_chain": error_chain if len(error_chain) > 1 else None,
210
  }
 
211
 
212
  # 2. Log a concise summary to the main library logger, which will propagate
213
  summary_message = (
214
  f"API call failed for model {model} with key {mask_credential(api_key)}. "
215
  f"Error: {type(error).__name__}. See failures.log for details."
216
  )
217
+
218
+ # Log to failure logger with resilience - if it fails, just continue
219
+ try:
220
+ get_failure_logger().error(detailed_log_data)
221
+ except (OSError, IOError) as e:
222
+ # Log file write failed - log to console instead
223
+ logging.warning(f"Failed to write to failures.log: {e}")
224
+
225
+ # Console log always succeeds
226
  main_lib_logger.error(summary_message)
src/rotator_library/providers/antigravity_auth_base.py CHANGED
@@ -1,16 +1,36 @@
1
  # src/rotator_library/providers/antigravity_auth_base.py
2
 
 
 
 
 
 
 
 
 
 
3
  from .google_oauth_base import GoogleOAuthBase
4
 
 
 
 
 
 
 
5
  class AntigravityAuthBase(GoogleOAuthBase):
6
  """
7
  Antigravity OAuth2 authentication implementation.
8
-
9
  Inherits all OAuth functionality from GoogleOAuthBase with Antigravity-specific configuration.
10
  Uses Antigravity's OAuth credentials and includes additional scopes for cclog and experimentsandconfigs.
 
 
 
11
  """
12
-
13
- CLIENT_ID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
 
 
14
  CLIENT_SECRET = "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"
15
  OAUTH_SCOPES = [
16
  "https://www.googleapis.com/auth/cloud-platform",
@@ -22,3 +42,600 @@ class AntigravityAuthBase(GoogleOAuthBase):
22
  ENV_PREFIX = "ANTIGRAVITY"
23
  CALLBACK_PORT = 51121
24
  CALLBACK_PATH = "/oauthcallback"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # src/rotator_library/providers/antigravity_auth_base.py
2
 
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ import os
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Optional, List
9
+
10
+ import httpx
11
+
12
  from .google_oauth_base import GoogleOAuthBase
13
 
14
+ lib_logger = logging.getLogger("rotator_library")
15
+
16
+ # Code Assist endpoint for project discovery
17
+ CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com/v1internal"
18
+
19
+
20
  class AntigravityAuthBase(GoogleOAuthBase):
21
  """
22
  Antigravity OAuth2 authentication implementation.
23
+
24
  Inherits all OAuth functionality from GoogleOAuthBase with Antigravity-specific configuration.
25
  Uses Antigravity's OAuth credentials and includes additional scopes for cclog and experimentsandconfigs.
26
+
27
+ Also provides project/tier discovery functionality that runs during authentication,
28
+ ensuring credentials have their tier and project_id cached before any API requests.
29
  """
30
+
31
+ CLIENT_ID = (
32
+ "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
33
+ )
34
  CLIENT_SECRET = "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"
35
  OAUTH_SCOPES = [
36
  "https://www.googleapis.com/auth/cloud-platform",
 
42
  ENV_PREFIX = "ANTIGRAVITY"
43
  CALLBACK_PORT = 51121
44
  CALLBACK_PATH = "/oauthcallback"
45
+
46
+ def __init__(self):
47
+ super().__init__()
48
+ # Project and tier caches - shared between auth base and provider
49
+ self.project_id_cache: Dict[str, str] = {}
50
+ self.project_tier_cache: Dict[str, str] = {}
51
+
52
+ # =========================================================================
53
+ # POST-AUTH DISCOVERY HOOK
54
+ # =========================================================================
55
+
56
+ async def _post_auth_discovery(
57
+ self, credential_path: str, access_token: str
58
+ ) -> None:
59
+ """
60
+ Discover and cache tier/project information immediately after OAuth authentication.
61
+
62
+ This is called by GoogleOAuthBase._perform_interactive_oauth() after successful auth,
63
+ ensuring tier and project_id are cached during the authentication flow rather than
64
+ waiting for the first API request.
65
+
66
+ Args:
67
+ credential_path: Path to the credential file
68
+ access_token: The newly obtained access token
69
+ """
70
+ lib_logger.debug(
71
+ f"Starting post-auth discovery for Antigravity credential: {Path(credential_path).name}"
72
+ )
73
+
74
+ # Skip if already discovered (shouldn't happen during fresh auth, but be defensive)
75
+ if (
76
+ credential_path in self.project_id_cache
77
+ and credential_path in self.project_tier_cache
78
+ ):
79
+ lib_logger.debug(
80
+ f"Tier and project already cached for {Path(credential_path).name}, skipping discovery"
81
+ )
82
+ return
83
+
84
+ # Call _discover_project_id which handles tier/project discovery and persistence
85
+ # Pass empty litellm_params since we're in auth context (no model-specific overrides)
86
+ project_id = await self._discover_project_id(
87
+ credential_path, access_token, litellm_params={}
88
+ )
89
+
90
+ tier = self.project_tier_cache.get(credential_path, "unknown")
91
+ lib_logger.info(
92
+ f"Post-auth discovery complete for {Path(credential_path).name}: "
93
+ f"tier={tier}, project={project_id}"
94
+ )
95
+
96
+ # =========================================================================
97
+ # PROJECT ID DISCOVERY
98
+ # =========================================================================
99
+
100
+ async def _discover_project_id(
101
+ self, credential_path: str, access_token: str, litellm_params: Dict[str, Any]
102
+ ) -> str:
103
+ """
104
+ Discovers the Google Cloud Project ID, with caching and onboarding for new accounts.
105
+
106
+ This follows the official Gemini CLI discovery flow adapted for Antigravity:
107
+ 1. Check in-memory cache
108
+ 2. Check configured project_id override (litellm_params or env var)
109
+ 3. Check persisted project_id in credential file
110
+ 4. Call loadCodeAssist to check if user is already known (has currentTier)
111
+ - If currentTier exists AND cloudaicompanionProject returned: use server's project
112
+ - If no currentTier: user needs onboarding
113
+ 5. Onboard user (FREE tier: pass cloudaicompanionProject=None for server-managed)
114
+ 6. Fallback to GCP Resource Manager project listing
115
+
116
+ Note: Unlike GeminiCli, Antigravity doesn't use tier-based credential prioritization,
117
+ but we still cache tier info for debugging and consistency.
118
+ """
119
+ lib_logger.debug(
120
+ f"Starting Antigravity project discovery for credential: {credential_path}"
121
+ )
122
+
123
+ # Check in-memory cache first
124
+ if credential_path in self.project_id_cache:
125
+ cached_project = self.project_id_cache[credential_path]
126
+ lib_logger.debug(f"Using cached project ID: {cached_project}")
127
+ return cached_project
128
+
129
+ # Check for configured project ID override (from litellm_params or env var)
130
+ configured_project_id = (
131
+ litellm_params.get("project_id")
132
+ or os.getenv("ANTIGRAVITY_PROJECT_ID")
133
+ or os.getenv("GOOGLE_CLOUD_PROJECT")
134
+ )
135
+ if configured_project_id:
136
+ lib_logger.debug(
137
+ f"Found configured project_id override: {configured_project_id}"
138
+ )
139
+
140
+ # Load credentials from file to check for persisted project_id and tier
141
+ # Skip for env:// paths (environment-based credentials don't persist to files)
142
+ credential_index = self._parse_env_credential_path(credential_path)
143
+ if credential_index is None:
144
+ # Only try to load from file if it's not an env:// path
145
+ try:
146
+ with open(credential_path, "r") as f:
147
+ creds = json.load(f)
148
+
149
+ metadata = creds.get("_proxy_metadata", {})
150
+ persisted_project_id = metadata.get("project_id")
151
+ persisted_tier = metadata.get("tier")
152
+
153
+ if persisted_project_id:
154
+ lib_logger.info(
155
+ f"Loaded persisted project ID from credential file: {persisted_project_id}"
156
+ )
157
+ self.project_id_cache[credential_path] = persisted_project_id
158
+
159
+ # Also load tier if available (for debugging/logging purposes)
160
+ if persisted_tier:
161
+ self.project_tier_cache[credential_path] = persisted_tier
162
+ lib_logger.debug(f"Loaded persisted tier: {persisted_tier}")
163
+
164
+ return persisted_project_id
165
+ except (FileNotFoundError, json.JSONDecodeError, KeyError) as e:
166
+ lib_logger.debug(f"Could not load persisted project ID from file: {e}")
167
+
168
+ lib_logger.debug(
169
+ "No cached or configured project ID found, initiating discovery..."
170
+ )
171
+ headers = {
172
+ "Authorization": f"Bearer {access_token}",
173
+ "Content-Type": "application/json",
174
+ }
175
+
176
+ discovered_project_id = None
177
+ discovered_tier = None
178
+
179
+ async with httpx.AsyncClient() as client:
180
+ # 1. Try discovery endpoint with loadCodeAssist
181
+ lib_logger.debug(
182
+ "Attempting project discovery via Code Assist loadCodeAssist endpoint..."
183
+ )
184
+ try:
185
+ # Build metadata - include duetProject only if we have a configured project
186
+ core_client_metadata = {
187
+ "ideType": "IDE_UNSPECIFIED",
188
+ "platform": "PLATFORM_UNSPECIFIED",
189
+ "pluginType": "GEMINI",
190
+ }
191
+ if configured_project_id:
192
+ core_client_metadata["duetProject"] = configured_project_id
193
+
194
+ # Build load request - pass configured_project_id if available, otherwise None
195
+ load_request = {
196
+ "cloudaicompanionProject": configured_project_id, # Can be None
197
+ "metadata": core_client_metadata,
198
+ }
199
+
200
+ lib_logger.debug(
201
+ f"Sending loadCodeAssist request with cloudaicompanionProject={configured_project_id}"
202
+ )
203
+ response = await client.post(
204
+ f"{CODE_ASSIST_ENDPOINT}:loadCodeAssist",
205
+ headers=headers,
206
+ json=load_request,
207
+ timeout=20,
208
+ )
209
+ response.raise_for_status()
210
+ data = response.json()
211
+
212
+ # Log full response for debugging
213
+ lib_logger.debug(
214
+ f"loadCodeAssist full response keys: {list(data.keys())}"
215
+ )
216
+
217
+ # Extract tier information
218
+ allowed_tiers = data.get("allowedTiers", [])
219
+ current_tier = data.get("currentTier")
220
+
221
+ lib_logger.debug(f"=== Tier Information ===")
222
+ lib_logger.debug(f"currentTier: {current_tier}")
223
+ lib_logger.debug(f"allowedTiers count: {len(allowed_tiers)}")
224
+ for i, tier in enumerate(allowed_tiers):
225
+ tier_id = tier.get("id", "unknown")
226
+ is_default = tier.get("isDefault", False)
227
+ user_defined = tier.get("userDefinedCloudaicompanionProject", False)
228
+ lib_logger.debug(
229
+ f" Tier {i + 1}: id={tier_id}, isDefault={is_default}, userDefinedProject={user_defined}"
230
+ )
231
+ lib_logger.debug(f"========================")
232
+
233
+ # Determine the current tier ID
234
+ current_tier_id = None
235
+ if current_tier:
236
+ current_tier_id = current_tier.get("id")
237
+ lib_logger.debug(f"User has currentTier: {current_tier_id}")
238
+
239
+ # Check if user is already known to server (has currentTier)
240
+ if current_tier_id:
241
+ # User is already onboarded - check for project from server
242
+ server_project = data.get("cloudaicompanionProject")
243
+
244
+ # Check if this tier requires user-defined project (paid tiers)
245
+ requires_user_project = any(
246
+ t.get("id") == current_tier_id
247
+ and t.get("userDefinedCloudaicompanionProject", False)
248
+ for t in allowed_tiers
249
+ )
250
+ is_free_tier = current_tier_id == "free-tier"
251
+
252
+ if server_project:
253
+ # Server returned a project - use it (server wins)
254
+ project_id = server_project
255
+ lib_logger.debug(f"Server returned project: {project_id}")
256
+ elif configured_project_id:
257
+ # No server project but we have configured one - use it
258
+ project_id = configured_project_id
259
+ lib_logger.debug(
260
+ f"No server project, using configured: {project_id}"
261
+ )
262
+ elif is_free_tier:
263
+ # Free tier user without server project - try onboarding
264
+ lib_logger.debug(
265
+ "Free tier user with currentTier but no project - will try onboarding"
266
+ )
267
+ project_id = None
268
+ elif requires_user_project:
269
+ # Paid tier requires a project ID to be set
270
+ raise ValueError(
271
+ f"Paid tier '{current_tier_id}' requires setting ANTIGRAVITY_PROJECT_ID environment variable."
272
+ )
273
+ else:
274
+ # Unknown tier without project - proceed to onboarding
275
+ lib_logger.warning(
276
+ f"Tier '{current_tier_id}' has no project and none configured - will try onboarding"
277
+ )
278
+ project_id = None
279
+
280
+ if project_id:
281
+ # Cache tier info
282
+ self.project_tier_cache[credential_path] = current_tier_id
283
+ discovered_tier = current_tier_id
284
+
285
+ # Log appropriately based on tier
286
+ is_paid = current_tier_id and current_tier_id not in [
287
+ "free-tier",
288
+ "legacy-tier",
289
+ "unknown",
290
+ ]
291
+ if is_paid:
292
+ lib_logger.info(
293
+ f"Using Antigravity paid tier '{current_tier_id}' with project: {project_id}"
294
+ )
295
+ else:
296
+ lib_logger.info(
297
+ f"Discovered Antigravity project ID via loadCodeAssist: {project_id}"
298
+ )
299
+
300
+ self.project_id_cache[credential_path] = project_id
301
+ discovered_project_id = project_id
302
+
303
+ # Persist to credential file
304
+ await self._persist_project_metadata(
305
+ credential_path, project_id, discovered_tier
306
+ )
307
+
308
+ return project_id
309
+
310
+ # 2. User needs onboarding - no currentTier or no project found
311
+ lib_logger.info(
312
+ "No existing Antigravity session found (no currentTier), attempting to onboard user..."
313
+ )
314
+
315
+ # Determine which tier to onboard with
316
+ onboard_tier = None
317
+ for tier in allowed_tiers:
318
+ if tier.get("isDefault"):
319
+ onboard_tier = tier
320
+ break
321
+
322
+ # Fallback to legacy tier if no default
323
+ if not onboard_tier and allowed_tiers:
324
+ for tier in allowed_tiers:
325
+ if tier.get("id") == "legacy-tier":
326
+ onboard_tier = tier
327
+ break
328
+ if not onboard_tier:
329
+ onboard_tier = allowed_tiers[0]
330
+
331
+ if not onboard_tier:
332
+ raise ValueError("No onboarding tiers available from server")
333
+
334
+ tier_id = onboard_tier.get("id", "free-tier")
335
+ requires_user_project = onboard_tier.get(
336
+ "userDefinedCloudaicompanionProject", False
337
+ )
338
+
339
+ lib_logger.debug(
340
+ f"Onboarding with tier: {tier_id}, requiresUserProject: {requires_user_project}"
341
+ )
342
+
343
+ # Build onboard request based on tier type
344
+ # FREE tier: cloudaicompanionProject = None (server-managed)
345
+ # PAID tier: cloudaicompanionProject = configured_project_id
346
+ is_free_tier = tier_id == "free-tier"
347
+
348
+ if is_free_tier:
349
+ # Free tier uses server-managed project
350
+ onboard_request = {
351
+ "tierId": tier_id,
352
+ "cloudaicompanionProject": None, # Server will create/manage
353
+ "metadata": core_client_metadata,
354
+ }
355
+ lib_logger.debug(
356
+ "Free tier onboarding: using server-managed project"
357
+ )
358
+ else:
359
+ # Paid/legacy tier requires user-provided project
360
+ if not configured_project_id and requires_user_project:
361
+ raise ValueError(
362
+ f"Tier '{tier_id}' requires setting ANTIGRAVITY_PROJECT_ID environment variable."
363
+ )
364
+ onboard_request = {
365
+ "tierId": tier_id,
366
+ "cloudaicompanionProject": configured_project_id,
367
+ "metadata": {
368
+ **core_client_metadata,
369
+ "duetProject": configured_project_id,
370
+ }
371
+ if configured_project_id
372
+ else core_client_metadata,
373
+ }
374
+ lib_logger.debug(
375
+ f"Paid tier onboarding: using project {configured_project_id}"
376
+ )
377
+
378
+ lib_logger.debug("Initiating onboardUser request...")
379
+ lro_response = await client.post(
380
+ f"{CODE_ASSIST_ENDPOINT}:onboardUser",
381
+ headers=headers,
382
+ json=onboard_request,
383
+ timeout=30,
384
+ )
385
+ lro_response.raise_for_status()
386
+ lro_data = lro_response.json()
387
+ lib_logger.debug(
388
+ f"Initial onboarding response: done={lro_data.get('done')}"
389
+ )
390
+
391
+ # Poll for onboarding completion (up to 5 minutes)
392
+ for i in range(150): # 150 × 2s = 5 minutes
393
+ if lro_data.get("done"):
394
+ lib_logger.debug(
395
+ f"Onboarding completed after {i} polling attempts"
396
+ )
397
+ break
398
+ await asyncio.sleep(2)
399
+ if (i + 1) % 15 == 0: # Log every 30 seconds
400
+ lib_logger.info(
401
+ f"Still waiting for onboarding completion... ({(i + 1) * 2}s elapsed)"
402
+ )
403
+ lib_logger.debug(
404
+ f"Polling onboarding status... (Attempt {i + 1}/150)"
405
+ )
406
+ lro_response = await client.post(
407
+ f"{CODE_ASSIST_ENDPOINT}:onboardUser",
408
+ headers=headers,
409
+ json=onboard_request,
410
+ timeout=30,
411
+ )
412
+ lro_response.raise_for_status()
413
+ lro_data = lro_response.json()
414
+
415
+ if not lro_data.get("done"):
416
+ lib_logger.error("Onboarding process timed out after 5 minutes")
417
+ raise ValueError(
418
+ "Onboarding process timed out after 5 minutes. Please try again or contact support."
419
+ )
420
+
421
+ # Extract project ID from LRO response
422
+ # Note: onboardUser returns response.cloudaicompanionProject as an object with .id
423
+ lro_response_data = lro_data.get("response", {})
424
+ lro_project_obj = lro_response_data.get("cloudaicompanionProject", {})
425
+ project_id = (
426
+ lro_project_obj.get("id")
427
+ if isinstance(lro_project_obj, dict)
428
+ else None
429
+ )
430
+
431
+ # Fallback to configured project if LRO didn't return one
432
+ if not project_id and configured_project_id:
433
+ project_id = configured_project_id
434
+ lib_logger.debug(
435
+ f"LRO didn't return project, using configured: {project_id}"
436
+ )
437
+
438
+ if not project_id:
439
+ lib_logger.error(
440
+ "Onboarding completed but no project ID in response and none configured"
441
+ )
442
+ raise ValueError(
443
+ "Onboarding completed, but no project ID was returned. "
444
+ "For paid tiers, set ANTIGRAVITY_PROJECT_ID environment variable."
445
+ )
446
+
447
+ lib_logger.debug(
448
+ f"Successfully extracted project ID from onboarding response: {project_id}"
449
+ )
450
+
451
+ # Cache tier info
452
+ self.project_tier_cache[credential_path] = tier_id
453
+ discovered_tier = tier_id
454
+ lib_logger.debug(f"Cached tier information: {tier_id}")
455
+
456
+ # Log concise message based on tier
457
+ is_paid = tier_id and tier_id not in ["free-tier", "legacy-tier"]
458
+ if is_paid:
459
+ lib_logger.info(
460
+ f"Using Antigravity paid tier '{tier_id}' with project: {project_id}"
461
+ )
462
+ else:
463
+ lib_logger.info(
464
+ f"Successfully onboarded user and discovered project ID: {project_id}"
465
+ )
466
+
467
+ self.project_id_cache[credential_path] = project_id
468
+ discovered_project_id = project_id
469
+
470
+ # Persist to credential file
471
+ await self._persist_project_metadata(
472
+ credential_path, project_id, discovered_tier
473
+ )
474
+
475
+ return project_id
476
+
477
+ except httpx.HTTPStatusError as e:
478
+ error_body = ""
479
+ try:
480
+ error_body = e.response.text
481
+ except Exception:
482
+ pass
483
+ if e.response.status_code == 403:
484
+ lib_logger.error(
485
+ f"Antigravity Code Assist API access denied (403). Response: {error_body}"
486
+ )
487
+ lib_logger.error(
488
+ "Possible causes: 1) cloudaicompanion.googleapis.com API not enabled, 2) Wrong project ID for paid tier, 3) Account lacks permissions"
489
+ )
490
+ elif e.response.status_code == 404:
491
+ lib_logger.warning(
492
+ f"Antigravity Code Assist endpoint not found (404). Falling back to project listing."
493
+ )
494
+ elif e.response.status_code == 412:
495
+ # Precondition Failed - often means wrong project for free tier onboarding
496
+ lib_logger.error(
497
+ f"Precondition failed (412): {error_body}. This may mean the project ID is incompatible with the selected tier."
498
+ )
499
+ else:
500
+ lib_logger.warning(
501
+ f"Antigravity onboarding/discovery failed with status {e.response.status_code}: {error_body}. Falling back to project listing."
502
+ )
503
+ except httpx.RequestError as e:
504
+ lib_logger.warning(
505
+ f"Antigravity onboarding/discovery network error: {e}. Falling back to project listing."
506
+ )
507
+
508
+ # 3. Fallback to listing all available GCP projects (last resort)
509
+ lib_logger.debug(
510
+ "Attempting to discover project via GCP Resource Manager API..."
511
+ )
512
+ try:
513
+ async with httpx.AsyncClient() as client:
514
+ lib_logger.debug(
515
+ "Querying Cloud Resource Manager for available projects..."
516
+ )
517
+ response = await client.get(
518
+ "https://cloudresourcemanager.googleapis.com/v1/projects",
519
+ headers=headers,
520
+ timeout=20,
521
+ )
522
+ response.raise_for_status()
523
+ projects = response.json().get("projects", [])
524
+ lib_logger.debug(f"Found {len(projects)} total projects")
525
+ active_projects = [
526
+ p for p in projects if p.get("lifecycleState") == "ACTIVE"
527
+ ]
528
+ lib_logger.debug(f"Found {len(active_projects)} active projects")
529
+
530
+ if not projects:
531
+ lib_logger.error(
532
+ "No GCP projects found for this account. Please create a project in Google Cloud Console."
533
+ )
534
+ elif not active_projects:
535
+ lib_logger.error(
536
+ "No active GCP projects found. Please activate a project in Google Cloud Console."
537
+ )
538
+ else:
539
+ project_id = active_projects[0]["projectId"]
540
+ lib_logger.info(
541
+ f"Discovered Antigravity project ID from active projects list: {project_id}"
542
+ )
543
+ lib_logger.debug(
544
+ f"Selected first active project: {project_id} (out of {len(active_projects)} active projects)"
545
+ )
546
+ self.project_id_cache[credential_path] = project_id
547
+ discovered_project_id = project_id
548
+
549
+ # Persist to credential file (no tier info from resource manager)
550
+ await self._persist_project_metadata(
551
+ credential_path, project_id, None
552
+ )
553
+
554
+ return project_id
555
+ except httpx.HTTPStatusError as e:
556
+ if e.response.status_code == 403:
557
+ lib_logger.error(
558
+ "Failed to list GCP projects due to a 403 Forbidden error. The Cloud Resource Manager API may not be enabled, or your account lacks the 'resourcemanager.projects.list' permission."
559
+ )
560
+ else:
561
+ lib_logger.error(
562
+ f"Failed to list GCP projects with status {e.response.status_code}: {e}"
563
+ )
564
+ except httpx.RequestError as e:
565
+ lib_logger.error(f"Network error while listing GCP projects: {e}")
566
+
567
+ raise ValueError(
568
+ "Could not auto-discover Antigravity project ID. Possible causes:\n"
569
+ " 1. The cloudaicompanion.googleapis.com API is not enabled (enable it in Google Cloud Console)\n"
570
+ " 2. No active GCP projects exist for this account (create one in Google Cloud Console)\n"
571
+ " 3. Account lacks necessary permissions\n"
572
+ "To manually specify a project, set ANTIGRAVITY_PROJECT_ID in your .env file."
573
+ )
574
+
575
+ async def _persist_project_metadata(
576
+ self, credential_path: str, project_id: str, tier: Optional[str]
577
+ ):
578
+ """Persists project ID and tier to the credential file for faster future startups."""
579
+ # Skip persistence for env:// paths (environment-based credentials)
580
+ credential_index = self._parse_env_credential_path(credential_path)
581
+ if credential_index is not None:
582
+ lib_logger.debug(
583
+ f"Skipping project metadata persistence for env:// credential path: {credential_path}"
584
+ )
585
+ return
586
+
587
+ try:
588
+ # Load current credentials
589
+ with open(credential_path, "r") as f:
590
+ creds = json.load(f)
591
+
592
+ # Update metadata
593
+ if "_proxy_metadata" not in creds:
594
+ creds["_proxy_metadata"] = {}
595
+
596
+ creds["_proxy_metadata"]["project_id"] = project_id
597
+ if tier:
598
+ creds["_proxy_metadata"]["tier"] = tier
599
+
600
+ # Save back using the existing save method (handles atomic writes and permissions)
601
+ await self._save_credentials(credential_path, creds)
602
+
603
+ lib_logger.debug(
604
+ f"Persisted project_id and tier to credential file: {credential_path}"
605
+ )
606
+ except Exception as e:
607
+ lib_logger.warning(
608
+ f"Failed to persist project metadata to credential file: {e}"
609
+ )
610
+ # Non-fatal - just means slower startup next time
611
+
612
+ # =========================================================================
613
+ # CREDENTIAL MANAGEMENT OVERRIDES
614
+ # =========================================================================
615
+
616
+ def _get_provider_file_prefix(self) -> str:
617
+ """Return the file prefix for Antigravity credentials."""
618
+ return "antigravity"
619
+
620
+ def build_env_lines(self, creds: Dict[str, Any], cred_number: int) -> List[str]:
621
+ """
622
+ Generate .env file lines for an Antigravity credential.
623
+
624
+ Includes tier and project_id from _proxy_metadata.
625
+ """
626
+ # Get base lines from parent class
627
+ lines = super().build_env_lines(creds, cred_number)
628
+
629
+ # Add Antigravity-specific fields (tier and project_id)
630
+ metadata = creds.get("_proxy_metadata", {})
631
+ prefix = f"{self.ENV_PREFIX}_{cred_number}"
632
+
633
+ project_id = metadata.get("project_id", "")
634
+ tier = metadata.get("tier", "")
635
+
636
+ if project_id:
637
+ lines.append(f"{prefix}_PROJECT_ID={project_id}")
638
+ if tier:
639
+ lines.append(f"{prefix}_TIER={tier}")
640
+
641
+ return lines
src/rotator_library/providers/antigravity_provider.py CHANGED
@@ -38,6 +38,8 @@ from .provider_interface import ProviderInterface, UsageResetConfigDef, QuotaGro
38
  from .antigravity_auth_base import AntigravityAuthBase
39
  from .provider_cache import ProviderCache
40
  from ..model_definitions import ModelDefinitions
 
 
41
 
42
 
43
  # =============================================================================
@@ -105,12 +107,23 @@ DEFAULT_SAFETY_SETTINGS = [
105
  {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
106
  ]
107
 
108
- # Directory paths
109
- _BASE_DIR = Path(__file__).resolve().parent.parent.parent.parent
110
- LOGS_DIR = _BASE_DIR / "logs" / "antigravity_logs"
111
- CACHE_DIR = _BASE_DIR / "cache" / "antigravity"
112
- GEMINI3_SIGNATURE_CACHE_FILE = CACHE_DIR / "gemini3_signatures.json"
113
- CLAUDE_THINKING_CACHE_FILE = CACHE_DIR / "claude_thinking.json"
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  # Gemini 3 tool fix system instruction (prevents hallucination)
116
  DEFAULT_GEMINI3_SYSTEM_INSTRUCTION = """<CRITICAL_TOOL_USAGE_INSTRUCTIONS>
@@ -327,6 +340,33 @@ def _recursively_parse_json_strings(obj: Any) -> Any:
327
  return obj
328
 
329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  def _clean_claude_schema(schema: Any) -> Any:
331
  """
332
  Recursively clean JSON Schema for Antigravity/Google's Proto-based API.
@@ -384,7 +424,6 @@ def _clean_claude_schema(schema: Any) -> Any:
384
  return first_option
385
 
386
  cleaned = {}
387
-
388
  # Handle 'const' by converting to 'enum' with single value
389
  if "const" in schema:
390
  const_value = schema["const"]
@@ -425,7 +464,9 @@ class AntigravityFileLogger:
425
 
426
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
427
  safe_model = model_name.replace("/", "_").replace(":", "_")
428
- self.log_dir = LOGS_DIR / f"{timestamp}_{safe_model}_{uuid.uuid4()}"
 
 
429
 
430
  try:
431
  self.log_dir.mkdir(parents=True, exist_ok=True)
@@ -658,9 +699,6 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
658
  error_obj = data.get("error", data)
659
  details = error_obj.get("details", [])
660
 
661
- if not details:
662
- return None
663
-
664
  result = {
665
  "retry_after": None,
666
  "reason": None,
@@ -711,6 +749,15 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
711
 
712
  # Return None if we couldn't extract retry_after
713
  if not result["retry_after"]:
 
 
 
 
 
 
 
 
 
714
  return None
715
 
716
  return result
@@ -718,12 +765,7 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
718
  def __init__(self):
719
  super().__init__()
720
  self.model_definitions = ModelDefinitions()
721
- self.project_id_cache: Dict[
722
- str, str
723
- ] = {} # Cache project ID per credential path
724
- self.project_tier_cache: Dict[
725
- str, str
726
- ] = {} # Cache project tier per credential path (for debugging)
727
 
728
  # Base URL management
729
  self._base_url_index = 0
@@ -735,13 +777,13 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
735
 
736
  # Initialize caches using shared ProviderCache
737
  self._signature_cache = ProviderCache(
738
- GEMINI3_SIGNATURE_CACHE_FILE,
739
  memory_ttl,
740
  disk_ttl,
741
  env_prefix="ANTIGRAVITY_SIGNATURE",
742
  )
743
  self._thinking_cache = ProviderCache(
744
- CLAUDE_THINKING_CACHE_FILE,
745
  memory_ttl,
746
  disk_ttl,
747
  env_prefix="ANTIGRAVITY_THINKING",
@@ -871,9 +913,48 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
871
 
872
  This ensures all credential priorities are known before any API calls,
873
  preventing unknown credentials from getting priority 999.
 
 
 
874
  """
 
875
  await self._load_persisted_tiers(credential_paths)
876
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
877
  async def _load_persisted_tiers(
878
  self, credential_paths: List[str]
879
  ) -> Dict[str, str]:
@@ -931,6 +1012,8 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
931
 
932
  return loaded
933
 
 
 
934
  # =========================================================================
935
  # MODEL UTILITIES
936
  # =========================================================================
@@ -1007,524 +1090,7 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
1007
 
1008
  return "thinking_" + "_".join(key_parts) if key_parts else None
1009
 
1010
- # =========================================================================
1011
- # PROJECT ID DISCOVERY
1012
- # =========================================================================
1013
-
1014
- async def _discover_project_id(
1015
- self, credential_path: str, access_token: str, litellm_params: Dict[str, Any]
1016
- ) -> str:
1017
- """
1018
- Discovers the Google Cloud Project ID, with caching and onboarding for new accounts.
1019
-
1020
- This follows the official Gemini CLI discovery flow adapted for Antigravity:
1021
- 1. Check in-memory cache
1022
- 2. Check configured project_id override (litellm_params or env var)
1023
- 3. Check persisted project_id in credential file
1024
- 4. Call loadCodeAssist to check if user is already known (has currentTier)
1025
- - If currentTier exists AND cloudaicompanionProject returned: use server's project
1026
- - If no currentTier: user needs onboarding
1027
- 5. Onboard user (FREE tier: pass cloudaicompanionProject=None for server-managed)
1028
- 6. Fallback to GCP Resource Manager project listing
1029
-
1030
- Note: Unlike GeminiCli, Antigravity doesn't use tier-based credential prioritization,
1031
- but we still cache tier info for debugging and consistency.
1032
- """
1033
- lib_logger.debug(
1034
- f"Starting Antigravity project discovery for credential: {credential_path}"
1035
- )
1036
-
1037
- # Check in-memory cache first
1038
- if credential_path in self.project_id_cache:
1039
- cached_project = self.project_id_cache[credential_path]
1040
- lib_logger.debug(f"Using cached project ID: {cached_project}")
1041
- return cached_project
1042
-
1043
- # Check for configured project ID override (from litellm_params or env var)
1044
- configured_project_id = (
1045
- litellm_params.get("project_id")
1046
- or os.getenv("ANTIGRAVITY_PROJECT_ID")
1047
- or os.getenv("GOOGLE_CLOUD_PROJECT")
1048
- )
1049
- if configured_project_id:
1050
- lib_logger.debug(
1051
- f"Found configured project_id override: {configured_project_id}"
1052
- )
1053
-
1054
- # Load credentials from file to check for persisted project_id and tier
1055
- # Skip for env:// paths (environment-based credentials don't persist to files)
1056
- credential_index = self._parse_env_credential_path(credential_path)
1057
- if credential_index is None:
1058
- # Only try to load from file if it's not an env:// path
1059
- try:
1060
- with open(credential_path, "r") as f:
1061
- creds = json.load(f)
1062
-
1063
- metadata = creds.get("_proxy_metadata", {})
1064
- persisted_project_id = metadata.get("project_id")
1065
- persisted_tier = metadata.get("tier")
1066
-
1067
- if persisted_project_id:
1068
- lib_logger.info(
1069
- f"Loaded persisted project ID from credential file: {persisted_project_id}"
1070
- )
1071
- self.project_id_cache[credential_path] = persisted_project_id
1072
-
1073
- # Also load tier if available (for debugging/logging purposes)
1074
- if persisted_tier:
1075
- self.project_tier_cache[credential_path] = persisted_tier
1076
- lib_logger.debug(f"Loaded persisted tier: {persisted_tier}")
1077
-
1078
- return persisted_project_id
1079
- except (FileNotFoundError, json.JSONDecodeError, KeyError) as e:
1080
- lib_logger.debug(f"Could not load persisted project ID from file: {e}")
1081
-
1082
- lib_logger.debug(
1083
- "No cached or configured project ID found, initiating discovery..."
1084
- )
1085
- headers = {
1086
- "Authorization": f"Bearer {access_token}",
1087
- "Content-Type": "application/json",
1088
- }
1089
-
1090
- discovered_project_id = None
1091
- discovered_tier = None
1092
-
1093
- # Use production endpoint for loadCodeAssist (more reliable than sandbox URLs)
1094
- code_assist_endpoint = "https://cloudcode-pa.googleapis.com/v1internal"
1095
-
1096
- async with httpx.AsyncClient() as client:
1097
- # 1. Try discovery endpoint with loadCodeAssist
1098
- lib_logger.debug(
1099
- "Attempting project discovery via Code Assist loadCodeAssist endpoint..."
1100
- )
1101
- try:
1102
- # Build metadata - include duetProject only if we have a configured project
1103
- core_client_metadata = {
1104
- "ideType": "IDE_UNSPECIFIED",
1105
- "platform": "PLATFORM_UNSPECIFIED",
1106
- "pluginType": "GEMINI",
1107
- }
1108
- if configured_project_id:
1109
- core_client_metadata["duetProject"] = configured_project_id
1110
-
1111
- # Build load request - pass configured_project_id if available, otherwise None
1112
- load_request = {
1113
- "cloudaicompanionProject": configured_project_id, # Can be None
1114
- "metadata": core_client_metadata,
1115
- }
1116
-
1117
- lib_logger.debug(
1118
- f"Sending loadCodeAssist request with cloudaicompanionProject={configured_project_id}"
1119
- )
1120
- response = await client.post(
1121
- f"{code_assist_endpoint}:loadCodeAssist",
1122
- headers=headers,
1123
- json=load_request,
1124
- timeout=20,
1125
- )
1126
- response.raise_for_status()
1127
- data = response.json()
1128
-
1129
- # Log full response for debugging
1130
- lib_logger.debug(
1131
- f"loadCodeAssist full response keys: {list(data.keys())}"
1132
- )
1133
-
1134
- # Extract tier information
1135
- allowed_tiers = data.get("allowedTiers", [])
1136
- current_tier = data.get("currentTier")
1137
-
1138
- lib_logger.debug(f"=== Tier Information ===")
1139
- lib_logger.debug(f"currentTier: {current_tier}")
1140
- lib_logger.debug(f"allowedTiers count: {len(allowed_tiers)}")
1141
- for i, tier in enumerate(allowed_tiers):
1142
- tier_id = tier.get("id", "unknown")
1143
- is_default = tier.get("isDefault", False)
1144
- user_defined = tier.get("userDefinedCloudaicompanionProject", False)
1145
- lib_logger.debug(
1146
- f" Tier {i + 1}: id={tier_id}, isDefault={is_default}, userDefinedProject={user_defined}"
1147
- )
1148
- lib_logger.debug(f"========================")
1149
-
1150
- # Determine the current tier ID
1151
- current_tier_id = None
1152
- if current_tier:
1153
- current_tier_id = current_tier.get("id")
1154
- lib_logger.debug(f"User has currentTier: {current_tier_id}")
1155
-
1156
- # Check if user is already known to server (has currentTier)
1157
- if current_tier_id:
1158
- # User is already onboarded - check for project from server
1159
- server_project = data.get("cloudaicompanionProject")
1160
-
1161
- # Check if this tier requires user-defined project (paid tiers)
1162
- requires_user_project = any(
1163
- t.get("id") == current_tier_id
1164
- and t.get("userDefinedCloudaicompanionProject", False)
1165
- for t in allowed_tiers
1166
- )
1167
- is_free_tier = current_tier_id == "free-tier"
1168
-
1169
- if server_project:
1170
- # Server returned a project - use it (server wins)
1171
- project_id = server_project
1172
- lib_logger.debug(f"Server returned project: {project_id}")
1173
- elif configured_project_id:
1174
- # No server project but we have configured one - use it
1175
- project_id = configured_project_id
1176
- lib_logger.debug(
1177
- f"No server project, using configured: {project_id}"
1178
- )
1179
- elif is_free_tier:
1180
- # Free tier user without server project - try onboarding
1181
- lib_logger.debug(
1182
- "Free tier user with currentTier but no project - will try onboarding"
1183
- )
1184
- project_id = None
1185
- elif requires_user_project:
1186
- # Paid tier requires a project ID to be set
1187
- raise ValueError(
1188
- f"Paid tier '{current_tier_id}' requires setting ANTIGRAVITY_PROJECT_ID environment variable."
1189
- )
1190
- else:
1191
- # Unknown tier without project - proceed to onboarding
1192
- lib_logger.warning(
1193
- f"Tier '{current_tier_id}' has no project and none configured - will try onboarding"
1194
- )
1195
- project_id = None
1196
-
1197
- if project_id:
1198
- # Cache tier info
1199
- self.project_tier_cache[credential_path] = current_tier_id
1200
- discovered_tier = current_tier_id
1201
-
1202
- # Log appropriately based on tier
1203
- is_paid = current_tier_id and current_tier_id not in [
1204
- "free-tier",
1205
- "legacy-tier",
1206
- "unknown",
1207
- ]
1208
- if is_paid:
1209
- lib_logger.info(
1210
- f"Using Antigravity paid tier '{current_tier_id}' with project: {project_id}"
1211
- )
1212
- else:
1213
- lib_logger.info(
1214
- f"Discovered Antigravity project ID via loadCodeAssist: {project_id}"
1215
- )
1216
-
1217
- self.project_id_cache[credential_path] = project_id
1218
- discovered_project_id = project_id
1219
-
1220
- # Persist to credential file
1221
- await self._persist_project_metadata(
1222
- credential_path, project_id, discovered_tier
1223
- )
1224
-
1225
- return project_id
1226
-
1227
- # 2. User needs onboarding - no currentTier or no project found
1228
- lib_logger.info(
1229
- "No existing Antigravity session found (no currentTier), attempting to onboard user..."
1230
- )
1231
-
1232
- # Determine which tier to onboard with
1233
- onboard_tier = None
1234
- for tier in allowed_tiers:
1235
- if tier.get("isDefault"):
1236
- onboard_tier = tier
1237
- break
1238
-
1239
- # Fallback to legacy tier if no default
1240
- if not onboard_tier and allowed_tiers:
1241
- for tier in allowed_tiers:
1242
- if tier.get("id") == "legacy-tier":
1243
- onboard_tier = tier
1244
- break
1245
- if not onboard_tier:
1246
- onboard_tier = allowed_tiers[0]
1247
-
1248
- if not onboard_tier:
1249
- raise ValueError("No onboarding tiers available from server")
1250
-
1251
- tier_id = onboard_tier.get("id", "free-tier")
1252
- requires_user_project = onboard_tier.get(
1253
- "userDefinedCloudaicompanionProject", False
1254
- )
1255
-
1256
- lib_logger.debug(
1257
- f"Onboarding with tier: {tier_id}, requiresUserProject: {requires_user_project}"
1258
- )
1259
-
1260
- # Build onboard request based on tier type
1261
- # FREE tier: cloudaicompanionProject = None (server-managed)
1262
- # PAID tier: cloudaicompanionProject = configured_project_id
1263
- is_free_tier = tier_id == "free-tier"
1264
-
1265
- if is_free_tier:
1266
- # Free tier uses server-managed project
1267
- onboard_request = {
1268
- "tierId": tier_id,
1269
- "cloudaicompanionProject": None, # Server will create/manage
1270
- "metadata": core_client_metadata,
1271
- }
1272
- lib_logger.debug(
1273
- "Free tier onboarding: using server-managed project"
1274
- )
1275
- else:
1276
- # Paid/legacy tier requires user-provided project
1277
- if not configured_project_id and requires_user_project:
1278
- raise ValueError(
1279
- f"Tier '{tier_id}' requires setting ANTIGRAVITY_PROJECT_ID environment variable."
1280
- )
1281
- onboard_request = {
1282
- "tierId": tier_id,
1283
- "cloudaicompanionProject": configured_project_id,
1284
- "metadata": {
1285
- **core_client_metadata,
1286
- "duetProject": configured_project_id,
1287
- }
1288
- if configured_project_id
1289
- else core_client_metadata,
1290
- }
1291
- lib_logger.debug(
1292
- f"Paid tier onboarding: using project {configured_project_id}"
1293
- )
1294
-
1295
- lib_logger.debug("Initiating onboardUser request...")
1296
- lro_response = await client.post(
1297
- f"{code_assist_endpoint}:onboardUser",
1298
- headers=headers,
1299
- json=onboard_request,
1300
- timeout=30,
1301
- )
1302
- lro_response.raise_for_status()
1303
- lro_data = lro_response.json()
1304
- lib_logger.debug(
1305
- f"Initial onboarding response: done={lro_data.get('done')}"
1306
- )
1307
-
1308
- # Poll for onboarding completion (up to 5 minutes)
1309
- for i in range(150): # 150 × 2s = 5 minutes
1310
- if lro_data.get("done"):
1311
- lib_logger.debug(
1312
- f"Onboarding completed after {i} polling attempts"
1313
- )
1314
- break
1315
- await asyncio.sleep(2)
1316
- if (i + 1) % 15 == 0: # Log every 30 seconds
1317
- lib_logger.info(
1318
- f"Still waiting for onboarding completion... ({(i + 1) * 2}s elapsed)"
1319
- )
1320
- lib_logger.debug(
1321
- f"Polling onboarding status... (Attempt {i + 1}/150)"
1322
- )
1323
- lro_response = await client.post(
1324
- f"{code_assist_endpoint}:onboardUser",
1325
- headers=headers,
1326
- json=onboard_request,
1327
- timeout=30,
1328
- )
1329
- lro_response.raise_for_status()
1330
- lro_data = lro_response.json()
1331
-
1332
- if not lro_data.get("done"):
1333
- lib_logger.error("Onboarding process timed out after 5 minutes")
1334
- raise ValueError(
1335
- "Onboarding process timed out after 5 minutes. Please try again or contact support."
1336
- )
1337
-
1338
- # Extract project ID from LRO response
1339
- # Note: onboardUser returns response.cloudaicompanionProject as an object with .id
1340
- lro_response_data = lro_data.get("response", {})
1341
- lro_project_obj = lro_response_data.get("cloudaicompanionProject", {})
1342
- project_id = (
1343
- lro_project_obj.get("id")
1344
- if isinstance(lro_project_obj, dict)
1345
- else None
1346
- )
1347
-
1348
- # Fallback to configured project if LRO didn't return one
1349
- if not project_id and configured_project_id:
1350
- project_id = configured_project_id
1351
- lib_logger.debug(
1352
- f"LRO didn't return project, using configured: {project_id}"
1353
- )
1354
-
1355
- if not project_id:
1356
- lib_logger.error(
1357
- "Onboarding completed but no project ID in response and none configured"
1358
- )
1359
- raise ValueError(
1360
- "Onboarding completed, but no project ID was returned. "
1361
- "For paid tiers, set ANTIGRAVITY_PROJECT_ID environment variable."
1362
- )
1363
-
1364
- lib_logger.debug(
1365
- f"Successfully extracted project ID from onboarding response: {project_id}"
1366
- )
1367
-
1368
- # Cache tier info
1369
- self.project_tier_cache[credential_path] = tier_id
1370
- discovered_tier = tier_id
1371
- lib_logger.debug(f"Cached tier information: {tier_id}")
1372
-
1373
- # Log concise message based on tier
1374
- is_paid = tier_id and tier_id not in ["free-tier", "legacy-tier"]
1375
- if is_paid:
1376
- lib_logger.info(
1377
- f"Using Antigravity paid tier '{tier_id}' with project: {project_id}"
1378
- )
1379
- else:
1380
- lib_logger.info(
1381
- f"Successfully onboarded user and discovered project ID: {project_id}"
1382
- )
1383
-
1384
- self.project_id_cache[credential_path] = project_id
1385
- discovered_project_id = project_id
1386
-
1387
- # Persist to credential file
1388
- await self._persist_project_metadata(
1389
- credential_path, project_id, discovered_tier
1390
- )
1391
-
1392
- return project_id
1393
-
1394
- except httpx.HTTPStatusError as e:
1395
- error_body = ""
1396
- try:
1397
- error_body = e.response.text
1398
- except Exception:
1399
- pass
1400
- if e.response.status_code == 403:
1401
- lib_logger.error(
1402
- f"Antigravity Code Assist API access denied (403). Response: {error_body}"
1403
- )
1404
- lib_logger.error(
1405
- "Possible causes: 1) cloudaicompanion.googleapis.com API not enabled, 2) Wrong project ID for paid tier, 3) Account lacks permissions"
1406
- )
1407
- elif e.response.status_code == 404:
1408
- lib_logger.warning(
1409
- f"Antigravity Code Assist endpoint not found (404). Falling back to project listing."
1410
- )
1411
- elif e.response.status_code == 412:
1412
- # Precondition Failed - often means wrong project for free tier onboarding
1413
- lib_logger.error(
1414
- f"Precondition failed (412): {error_body}. This may mean the project ID is incompatible with the selected tier."
1415
- )
1416
- else:
1417
- lib_logger.warning(
1418
- f"Antigravity onboarding/discovery failed with status {e.response.status_code}: {error_body}. Falling back to project listing."
1419
- )
1420
- except httpx.RequestError as e:
1421
- lib_logger.warning(
1422
- f"Antigravity onboarding/discovery network error: {e}. Falling back to project listing."
1423
- )
1424
-
1425
- # 3. Fallback to listing all available GCP projects (last resort)
1426
- lib_logger.debug(
1427
- "Attempting to discover project via GCP Resource Manager API..."
1428
- )
1429
- try:
1430
- async with httpx.AsyncClient() as client:
1431
- lib_logger.debug(
1432
- "Querying Cloud Resource Manager for available projects..."
1433
- )
1434
- response = await client.get(
1435
- "https://cloudresourcemanager.googleapis.com/v1/projects",
1436
- headers=headers,
1437
- timeout=20,
1438
- )
1439
- response.raise_for_status()
1440
- projects = response.json().get("projects", [])
1441
- lib_logger.debug(f"Found {len(projects)} total projects")
1442
- active_projects = [
1443
- p for p in projects if p.get("lifecycleState") == "ACTIVE"
1444
- ]
1445
- lib_logger.debug(f"Found {len(active_projects)} active projects")
1446
-
1447
- if not projects:
1448
- lib_logger.error(
1449
- "No GCP projects found for this account. Please create a project in Google Cloud Console."
1450
- )
1451
- elif not active_projects:
1452
- lib_logger.error(
1453
- "No active GCP projects found. Please activate a project in Google Cloud Console."
1454
- )
1455
- else:
1456
- project_id = active_projects[0]["projectId"]
1457
- lib_logger.info(
1458
- f"Discovered Antigravity project ID from active projects list: {project_id}"
1459
- )
1460
- lib_logger.debug(
1461
- f"Selected first active project: {project_id} (out of {len(active_projects)} active projects)"
1462
- )
1463
- self.project_id_cache[credential_path] = project_id
1464
- discovered_project_id = project_id
1465
-
1466
- # Persist to credential file (no tier info from resource manager)
1467
- await self._persist_project_metadata(
1468
- credential_path, project_id, None
1469
- )
1470
-
1471
- return project_id
1472
- except httpx.HTTPStatusError as e:
1473
- if e.response.status_code == 403:
1474
- lib_logger.error(
1475
- "Failed to list GCP projects due to a 403 Forbidden error. The Cloud Resource Manager API may not be enabled, or your account lacks the 'resourcemanager.projects.list' permission."
1476
- )
1477
- else:
1478
- lib_logger.error(
1479
- f"Failed to list GCP projects with status {e.response.status_code}: {e}"
1480
- )
1481
- except httpx.RequestError as e:
1482
- lib_logger.error(f"Network error while listing GCP projects: {e}")
1483
-
1484
- raise ValueError(
1485
- "Could not auto-discover Antigravity project ID. Possible causes:\n"
1486
- " 1. The cloudaicompanion.googleapis.com API is not enabled (enable it in Google Cloud Console)\n"
1487
- " 2. No active GCP projects exist for this account (create one in Google Cloud Console)\n"
1488
- " 3. Account lacks necessary permissions\n"
1489
- "To manually specify a project, set ANTIGRAVITY_PROJECT_ID in your .env file."
1490
- )
1491
-
1492
- async def _persist_project_metadata(
1493
- self, credential_path: str, project_id: str, tier: Optional[str]
1494
- ):
1495
- """Persists project ID and tier to the credential file for faster future startups."""
1496
- # Skip persistence for env:// paths (environment-based credentials)
1497
- credential_index = self._parse_env_credential_path(credential_path)
1498
- if credential_index is not None:
1499
- lib_logger.debug(
1500
- f"Skipping project metadata persistence for env:// credential path: {credential_path}"
1501
- )
1502
- return
1503
-
1504
- try:
1505
- # Load current credentials
1506
- with open(credential_path, "r") as f:
1507
- creds = json.load(f)
1508
-
1509
- # Update metadata
1510
- if "_proxy_metadata" not in creds:
1511
- creds["_proxy_metadata"] = {}
1512
-
1513
- creds["_proxy_metadata"]["project_id"] = project_id
1514
- if tier:
1515
- creds["_proxy_metadata"]["tier"] = tier
1516
-
1517
- # Save back using the existing save method (handles atomic writes and permissions)
1518
- await self._save_credentials(credential_path, creds)
1519
-
1520
- lib_logger.debug(
1521
- f"Persisted project_id and tier to credential file: {credential_path}"
1522
- )
1523
- except Exception as e:
1524
- lib_logger.warning(
1525
- f"Failed to persist project metadata to credential file: {e}"
1526
- )
1527
- # Non-fatal - just means slower startup next time
1528
 
1529
  # =========================================================================
1530
  # THINKING MODE SANITIZATION
@@ -2424,7 +1990,7 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
2424
  elif first_func_in_msg:
2425
  # Only add bypass to the first function call if no sig available
2426
  func_part["thoughtSignature"] = "skip_thought_signature_validator"
2427
- lib_logger.warning(
2428
  f"Missing thoughtSignature for first func call {tool_id}, using bypass"
2429
  )
2430
  # Subsequent parallel calls: no signature field at all
@@ -2559,9 +2125,9 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
2559
  f"Ignoring duplicate - this may indicate malformed conversation history."
2560
  )
2561
  continue
2562
- #lib_logger.debug(
2563
  # f"[Grouping] Collected response for ID: {resp_id}"
2564
- #)
2565
  collected_responses[resp_id] = resp
2566
 
2567
  # Try to satisfy pending groups (newest first)
@@ -2576,10 +2142,10 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
2576
  collected_responses.pop(gid) for gid in group_ids
2577
  ]
2578
  new_contents.append({"parts": group_responses, "role": "user"})
2579
- #lib_logger.debug(
2580
  # f"[Grouping] Satisfied group with {len(group_responses)} responses: "
2581
  # f"ids={group_ids}"
2582
- #)
2583
  pending_groups.pop(i)
2584
  break
2585
  continue
@@ -2599,10 +2165,10 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
2599
  ]
2600
 
2601
  if call_ids:
2602
- #lib_logger.debug(
2603
  # f"[Grouping] Created pending group expecting {len(call_ids)} responses: "
2604
  # f"ids={call_ids}, names={func_names}"
2605
- #)
2606
  pending_groups.append(
2607
  {
2608
  "ids": call_ids,
@@ -2967,12 +2533,41 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
2967
 
2968
  if params and isinstance(params, dict):
2969
  schema = dict(params)
2970
- schema.pop("$schema", None)
2971
  schema.pop("strict", None)
 
 
 
2972
  schema = _normalize_type_arrays(schema)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2973
  func_decl["parametersJsonSchema"] = schema
2974
  else:
2975
- func_decl["parametersJsonSchema"] = {"type": "object", "properties": {}}
 
 
 
 
 
 
 
 
 
 
 
2976
 
2977
  gemini_tools.append({"functionDeclarations": [func_decl]})
2978
 
@@ -3097,17 +2692,19 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
3097
  return antigravity_payload
3098
 
3099
  def _apply_claude_tool_transform(self, payload: Dict[str, Any]) -> None:
3100
- """Apply Claude-specific tool schema transformations."""
 
 
 
 
3101
  tools = payload["request"].get("tools", [])
3102
  for tool in tools:
3103
  for func_decl in tool.get("functionDeclarations", []):
3104
  if "parametersJsonSchema" in func_decl:
3105
  params = func_decl["parametersJsonSchema"]
3106
- params = (
3107
- _clean_claude_schema(params)
3108
- if isinstance(params, dict)
3109
- else params
3110
- )
3111
  func_decl["parameters"] = params
3112
  del func_decl["parametersJsonSchema"]
3113
 
@@ -3336,6 +2933,13 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
3336
  raw_args = func_call.get("args", {})
3337
  parsed_args = _recursively_parse_json_strings(raw_args)
3338
 
 
 
 
 
 
 
 
3339
  tool_call = {
3340
  "id": tool_id,
3341
  "type": "function",
@@ -3405,7 +3009,7 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
3405
  }
3406
 
3407
  self._thinking_cache.store(cache_key, json.dumps(data))
3408
- lib_logger.info(f"Cached thinking: {cache_key[:50]}...")
3409
 
3410
  # =========================================================================
3411
  # PROVIDER INTERFACE IMPLEMENTATION
@@ -3703,7 +3307,12 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
3703
  file_logger: Optional[AntigravityFileLogger] = None,
3704
  ) -> litellm.ModelResponse:
3705
  """Handle non-streaming completion."""
3706
- response = await client.post(url, headers=headers, json=payload, timeout=600.0)
 
 
 
 
 
3707
  response.raise_for_status()
3708
 
3709
  data = response.json()
@@ -3736,11 +3345,15 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
3736
  }
3737
 
3738
  async with client.stream(
3739
- "POST", url, headers=headers, json=payload, timeout=600.0
 
 
 
 
3740
  ) as response:
3741
  if response.status_code >= 400:
3742
- # Read error body for raise_for_status to include in exception
3743
- # Terminal logging commented out - errors are logged in failures.log
3744
  try:
3745
  await response.aread()
3746
  # lib_logger.error(
 
38
  from .antigravity_auth_base import AntigravityAuthBase
39
  from .provider_cache import ProviderCache
40
  from ..model_definitions import ModelDefinitions
41
+ from ..timeout_config import TimeoutConfig
42
+ from ..utils.paths import get_logs_dir, get_cache_dir
43
 
44
 
45
  # =============================================================================
 
107
  {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
108
  ]
109
 
110
+
111
+ # Directory paths - use centralized path management
112
+ def _get_antigravity_logs_dir():
113
+ return get_logs_dir() / "antigravity_logs"
114
+
115
+
116
+ def _get_antigravity_cache_dir():
117
+ return get_cache_dir(subdir="antigravity")
118
+
119
+
120
+ def _get_gemini3_signature_cache_file():
121
+ return _get_antigravity_cache_dir() / "gemini3_signatures.json"
122
+
123
+
124
+ def _get_claude_thinking_cache_file():
125
+ return _get_antigravity_cache_dir() / "claude_thinking.json"
126
+
127
 
128
  # Gemini 3 tool fix system instruction (prevents hallucination)
129
  DEFAULT_GEMINI3_SYSTEM_INSTRUCTION = """<CRITICAL_TOOL_USAGE_INSTRUCTIONS>
 
340
  return obj
341
 
342
 
343
+ def _inline_schema_refs(schema: Dict[str, Any]) -> Dict[str, Any]:
344
+ """Inline local $ref definitions before sanitization."""
345
+ if not isinstance(schema, dict):
346
+ return schema
347
+
348
+ defs = schema.get("$defs", schema.get("definitions", {}))
349
+ if not defs:
350
+ return schema
351
+
352
+ def resolve(node, seen=()):
353
+ if not isinstance(node, dict):
354
+ return [resolve(x, seen) for x in node] if isinstance(node, list) else node
355
+ if "$ref" in node:
356
+ ref = node["$ref"]
357
+ if ref in seen: # Circular - drop it
358
+ return {k: resolve(v, seen) for k, v in node.items() if k != "$ref"}
359
+ for prefix in ("#/$defs/", "#/definitions/"):
360
+ if isinstance(ref, str) and ref.startswith(prefix):
361
+ name = ref[len(prefix) :]
362
+ if name in defs:
363
+ return resolve(copy.deepcopy(defs[name]), seen + (ref,))
364
+ return {k: resolve(v, seen) for k, v in node.items() if k != "$ref"}
365
+ return {k: resolve(v, seen) for k, v in node.items()}
366
+
367
+ return resolve(schema)
368
+
369
+
370
  def _clean_claude_schema(schema: Any) -> Any:
371
  """
372
  Recursively clean JSON Schema for Antigravity/Google's Proto-based API.
 
424
  return first_option
425
 
426
  cleaned = {}
 
427
  # Handle 'const' by converting to 'enum' with single value
428
  if "const" in schema:
429
  const_value = schema["const"]
 
464
 
465
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
466
  safe_model = model_name.replace("/", "_").replace(":", "_")
467
+ self.log_dir = (
468
+ _get_antigravity_logs_dir() / f"{timestamp}_{safe_model}_{uuid.uuid4()}"
469
+ )
470
 
471
  try:
472
  self.log_dir.mkdir(parents=True, exist_ok=True)
 
699
  error_obj = data.get("error", data)
700
  details = error_obj.get("details", [])
701
 
 
 
 
702
  result = {
703
  "retry_after": None,
704
  "reason": None,
 
749
 
750
  # Return None if we couldn't extract retry_after
751
  if not result["retry_after"]:
752
+ # Handle bare RESOURCE_EXHAUSTED without timing details
753
+ error_status = error_obj.get("status", "")
754
+ error_code = error_obj.get("code")
755
+
756
+ if error_status == "RESOURCE_EXHAUSTED" or error_code == 429:
757
+ result["retry_after"] = 60 # Default fallback
758
+ result["reason"] = result.get("reason") or "RESOURCE_EXHAUSTED"
759
+ return result
760
+
761
  return None
762
 
763
  return result
 
765
  def __init__(self):
766
  super().__init__()
767
  self.model_definitions = ModelDefinitions()
768
+ # NOTE: project_id_cache and project_tier_cache are inherited from AntigravityAuthBase
 
 
 
 
 
769
 
770
  # Base URL management
771
  self._base_url_index = 0
 
777
 
778
  # Initialize caches using shared ProviderCache
779
  self._signature_cache = ProviderCache(
780
+ _get_gemini3_signature_cache_file(),
781
  memory_ttl,
782
  disk_ttl,
783
  env_prefix="ANTIGRAVITY_SIGNATURE",
784
  )
785
  self._thinking_cache = ProviderCache(
786
+ _get_claude_thinking_cache_file(),
787
  memory_ttl,
788
  disk_ttl,
789
  env_prefix="ANTIGRAVITY_THINKING",
 
913
 
914
  This ensures all credential priorities are known before any API calls,
915
  preventing unknown credentials from getting priority 999.
916
+
917
+ For credentials without persisted tier info (new or corrupted), performs
918
+ full discovery to ensure proper prioritization in sequential rotation mode.
919
  """
920
+ # Step 1: Load persisted tiers from files
921
  await self._load_persisted_tiers(credential_paths)
922
 
923
+ # Step 2: Identify credentials still missing tier info
924
+ credentials_needing_discovery = [
925
+ path
926
+ for path in credential_paths
927
+ if path not in self.project_tier_cache
928
+ and self._parse_env_credential_path(path) is None # Skip env:// paths
929
+ ]
930
+
931
+ if not credentials_needing_discovery:
932
+ return # All credentials have tier info
933
+
934
+ lib_logger.info(
935
+ f"Antigravity: Discovering tier info for {len(credentials_needing_discovery)} credential(s)..."
936
+ )
937
+
938
+ # Step 3: Perform discovery for each missing credential (sequential to avoid rate limits)
939
+ for credential_path in credentials_needing_discovery:
940
+ try:
941
+ auth_header = await self.get_auth_header(credential_path)
942
+ access_token = auth_header["Authorization"].split(" ")[1]
943
+ await self._discover_project_id(
944
+ credential_path, access_token, litellm_params={}
945
+ )
946
+ discovered_tier = self.project_tier_cache.get(
947
+ credential_path, "unknown"
948
+ )
949
+ lib_logger.debug(
950
+ f"Discovered tier '{discovered_tier}' for {Path(credential_path).name}"
951
+ )
952
+ except Exception as e:
953
+ lib_logger.warning(
954
+ f"Failed to discover tier for {Path(credential_path).name}: {e}. "
955
+ f"Credential will use default priority."
956
+ )
957
+
958
  async def _load_persisted_tiers(
959
  self, credential_paths: List[str]
960
  ) -> Dict[str, str]:
 
1012
 
1013
  return loaded
1014
 
1015
+ # NOTE: _post_auth_discovery() is inherited from AntigravityAuthBase
1016
+
1017
  # =========================================================================
1018
  # MODEL UTILITIES
1019
  # =========================================================================
 
1090
 
1091
  return "thinking_" + "_".join(key_parts) if key_parts else None
1092
 
1093
+ # NOTE: _discover_project_id() and _persist_project_metadata() are inherited from AntigravityAuthBase
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1094
 
1095
  # =========================================================================
1096
  # THINKING MODE SANITIZATION
 
1990
  elif first_func_in_msg:
1991
  # Only add bypass to the first function call if no sig available
1992
  func_part["thoughtSignature"] = "skip_thought_signature_validator"
1993
+ lib_logger.debug(
1994
  f"Missing thoughtSignature for first func call {tool_id}, using bypass"
1995
  )
1996
  # Subsequent parallel calls: no signature field at all
 
2125
  f"Ignoring duplicate - this may indicate malformed conversation history."
2126
  )
2127
  continue
2128
+ # lib_logger.debug(
2129
  # f"[Grouping] Collected response for ID: {resp_id}"
2130
+ # )
2131
  collected_responses[resp_id] = resp
2132
 
2133
  # Try to satisfy pending groups (newest first)
 
2142
  collected_responses.pop(gid) for gid in group_ids
2143
  ]
2144
  new_contents.append({"parts": group_responses, "role": "user"})
2145
+ # lib_logger.debug(
2146
  # f"[Grouping] Satisfied group with {len(group_responses)} responses: "
2147
  # f"ids={group_ids}"
2148
+ # )
2149
  pending_groups.pop(i)
2150
  break
2151
  continue
 
2165
  ]
2166
 
2167
  if call_ids:
2168
+ # lib_logger.debug(
2169
  # f"[Grouping] Created pending group expecting {len(call_ids)} responses: "
2170
  # f"ids={call_ids}, names={func_names}"
2171
+ # )
2172
  pending_groups.append(
2173
  {
2174
  "ids": call_ids,
 
2533
 
2534
  if params and isinstance(params, dict):
2535
  schema = dict(params)
 
2536
  schema.pop("strict", None)
2537
+ # Inline $ref definitions, then strip unsupported keywords
2538
+ schema = _inline_schema_refs(schema)
2539
+ schema = _clean_claude_schema(schema)
2540
  schema = _normalize_type_arrays(schema)
2541
+
2542
+ # Workaround: Antigravity/Gemini fails to emit functionCall
2543
+ # when tool has empty properties {}. Inject a dummy optional
2544
+ # parameter to ensure the tool call is emitted.
2545
+ # Using a required confirmation parameter forces the model to
2546
+ # commit to the tool call rather than just thinking about it.
2547
+ props = schema.get("properties", {})
2548
+ if not props:
2549
+ schema["properties"] = {
2550
+ "_confirm": {
2551
+ "type": "string",
2552
+ "description": "Enter 'yes' to proceed",
2553
+ }
2554
+ }
2555
+ schema["required"] = ["_confirm"]
2556
+
2557
  func_decl["parametersJsonSchema"] = schema
2558
  else:
2559
+ # No parameters provided - use default with required confirm param
2560
+ # to ensure the tool call is emitted properly
2561
+ func_decl["parametersJsonSchema"] = {
2562
+ "type": "object",
2563
+ "properties": {
2564
+ "_confirm": {
2565
+ "type": "string",
2566
+ "description": "Enter 'yes' to proceed",
2567
+ }
2568
+ },
2569
+ "required": ["_confirm"],
2570
+ }
2571
 
2572
  gemini_tools.append({"functionDeclarations": [func_decl]})
2573
 
 
2692
  return antigravity_payload
2693
 
2694
  def _apply_claude_tool_transform(self, payload: Dict[str, Any]) -> None:
2695
+ """Apply Claude-specific tool schema transformations.
2696
+
2697
+ Converts parametersJsonSchema to parameters and applies Claude-specific
2698
+ schema sanitization (inlines $ref, removes unsupported JSON Schema fields).
2699
+ """
2700
  tools = payload["request"].get("tools", [])
2701
  for tool in tools:
2702
  for func_decl in tool.get("functionDeclarations", []):
2703
  if "parametersJsonSchema" in func_decl:
2704
  params = func_decl["parametersJsonSchema"]
2705
+ if isinstance(params, dict):
2706
+ params = _inline_schema_refs(params)
2707
+ params = _clean_claude_schema(params)
 
 
2708
  func_decl["parameters"] = params
2709
  del func_decl["parametersJsonSchema"]
2710
 
 
2933
  raw_args = func_call.get("args", {})
2934
  parsed_args = _recursively_parse_json_strings(raw_args)
2935
 
2936
+ # Strip the injected _confirm parameter ONLY if it's the sole parameter
2937
+ # This ensures we only strip our injection, not legitimate user params
2938
+ if isinstance(parsed_args, dict) and "_confirm" in parsed_args:
2939
+ if len(parsed_args) == 1:
2940
+ # _confirm is the only param - this was our injection
2941
+ parsed_args.pop("_confirm")
2942
+
2943
  tool_call = {
2944
  "id": tool_id,
2945
  "type": "function",
 
3009
  }
3010
 
3011
  self._thinking_cache.store(cache_key, json.dumps(data))
3012
+ lib_logger.debug(f"Cached thinking: {cache_key[:50]}...")
3013
 
3014
  # =========================================================================
3015
  # PROVIDER INTERFACE IMPLEMENTATION
 
3307
  file_logger: Optional[AntigravityFileLogger] = None,
3308
  ) -> litellm.ModelResponse:
3309
  """Handle non-streaming completion."""
3310
+ response = await client.post(
3311
+ url,
3312
+ headers=headers,
3313
+ json=payload,
3314
+ timeout=TimeoutConfig.non_streaming(),
3315
+ )
3316
  response.raise_for_status()
3317
 
3318
  data = response.json()
 
3345
  }
3346
 
3347
  async with client.stream(
3348
+ "POST",
3349
+ url,
3350
+ headers=headers,
3351
+ json=payload,
3352
+ timeout=TimeoutConfig.streaming(),
3353
  ) as response:
3354
  if response.status_code >= 400:
3355
+ # Read error body so it's available in response.text for logging
3356
+ # The actual logging happens in failure_logger via _extract_response_body
3357
  try:
3358
  await response.aread()
3359
  # lib_logger.error(
src/rotator_library/providers/gemini_auth_base.py CHANGED
@@ -1,15 +1,35 @@
1
  # src/rotator_library/providers/gemini_auth_base.py
2
 
 
 
 
 
 
 
 
 
 
3
  from .google_oauth_base import GoogleOAuthBase
4
 
 
 
 
 
 
 
5
  class GeminiAuthBase(GoogleOAuthBase):
6
  """
7
  Gemini CLI OAuth2 authentication implementation.
8
-
9
  Inherits all OAuth functionality from GoogleOAuthBase with Gemini-specific configuration.
 
 
 
10
  """
11
-
12
- CLIENT_ID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
 
 
13
  CLIENT_SECRET = "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"
14
  OAUTH_SCOPES = [
15
  "https://www.googleapis.com/auth/cloud-platform",
@@ -18,4 +38,606 @@ class GeminiAuthBase(GoogleOAuthBase):
18
  ]
19
  ENV_PREFIX = "GEMINI_CLI"
20
  CALLBACK_PORT = 8085
21
- CALLBACK_PATH = "/oauth2callback"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # src/rotator_library/providers/gemini_auth_base.py
2
 
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ import os
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Optional, List
9
+
10
+ import httpx
11
+
12
  from .google_oauth_base import GoogleOAuthBase
13
 
14
+ lib_logger = logging.getLogger("rotator_library")
15
+
16
+ # Code Assist endpoint for project discovery
17
+ CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com/v1internal"
18
+
19
+
20
  class GeminiAuthBase(GoogleOAuthBase):
21
  """
22
  Gemini CLI OAuth2 authentication implementation.
23
+
24
  Inherits all OAuth functionality from GoogleOAuthBase with Gemini-specific configuration.
25
+
26
+ Also provides project/tier discovery functionality that runs during authentication,
27
+ ensuring credentials have their tier and project_id cached before any API requests.
28
  """
29
+
30
+ CLIENT_ID = (
31
+ "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
32
+ )
33
  CLIENT_SECRET = "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"
34
  OAUTH_SCOPES = [
35
  "https://www.googleapis.com/auth/cloud-platform",
 
38
  ]
39
  ENV_PREFIX = "GEMINI_CLI"
40
  CALLBACK_PORT = 8085
41
+ CALLBACK_PATH = "/oauth2callback"
42
+
43
+ def __init__(self):
44
+ super().__init__()
45
+ # Project and tier caches - shared between auth base and provider
46
+ self.project_id_cache: Dict[str, str] = {}
47
+ self.project_tier_cache: Dict[str, str] = {}
48
+
49
+ # =========================================================================
50
+ # POST-AUTH DISCOVERY HOOK
51
+ # =========================================================================
52
+
53
+ async def _post_auth_discovery(
54
+ self, credential_path: str, access_token: str
55
+ ) -> None:
56
+ """
57
+ Discover and cache tier/project information immediately after OAuth authentication.
58
+
59
+ This is called by GoogleOAuthBase._perform_interactive_oauth() after successful auth,
60
+ ensuring tier and project_id are cached during the authentication flow rather than
61
+ waiting for the first API request.
62
+
63
+ Args:
64
+ credential_path: Path to the credential file
65
+ access_token: The newly obtained access token
66
+ """
67
+ lib_logger.debug(
68
+ f"Starting post-auth discovery for GeminiCli credential: {Path(credential_path).name}"
69
+ )
70
+
71
+ # Skip if already discovered (shouldn't happen during fresh auth, but be defensive)
72
+ if (
73
+ credential_path in self.project_id_cache
74
+ and credential_path in self.project_tier_cache
75
+ ):
76
+ lib_logger.debug(
77
+ f"Tier and project already cached for {Path(credential_path).name}, skipping discovery"
78
+ )
79
+ return
80
+
81
+ # Call _discover_project_id which handles tier/project discovery and persistence
82
+ # Pass empty litellm_params since we're in auth context (no model-specific overrides)
83
+ project_id = await self._discover_project_id(
84
+ credential_path, access_token, litellm_params={}
85
+ )
86
+
87
+ tier = self.project_tier_cache.get(credential_path, "unknown")
88
+ lib_logger.info(
89
+ f"Post-auth discovery complete for {Path(credential_path).name}: "
90
+ f"tier={tier}, project={project_id}"
91
+ )
92
+
93
+ # =========================================================================
94
+ # PROJECT ID DISCOVERY
95
+ # =========================================================================
96
+
97
+ async def _discover_project_id(
98
+ self, credential_path: str, access_token: str, litellm_params: Dict[str, Any]
99
+ ) -> str:
100
+ """
101
+ Discovers the Google Cloud Project ID, with caching and onboarding for new accounts.
102
+
103
+ This follows the official Gemini CLI discovery flow:
104
+ 1. Check in-memory cache
105
+ 2. Check configured project_id override (litellm_params or env var)
106
+ 3. Check persisted project_id in credential file
107
+ 4. Call loadCodeAssist to check if user is already known (has currentTier)
108
+ - If currentTier exists AND cloudaicompanionProject returned: use server's project
109
+ - If currentTier exists but NO cloudaicompanionProject: use configured project_id (paid tier requires this)
110
+ - If no currentTier: user needs onboarding
111
+ 5. Onboard user based on tier:
112
+ - FREE tier: pass cloudaicompanionProject=None (server-managed)
113
+ - PAID tier: pass cloudaicompanionProject=configured_project_id
114
+ 6. Fallback to GCP Resource Manager project listing
115
+ """
116
+ lib_logger.debug(
117
+ f"Starting project discovery for credential: {credential_path}"
118
+ )
119
+
120
+ # Check in-memory cache first
121
+ if credential_path in self.project_id_cache:
122
+ cached_project = self.project_id_cache[credential_path]
123
+ lib_logger.debug(f"Using cached project ID: {cached_project}")
124
+ return cached_project
125
+
126
+ # Check for configured project ID override (from litellm_params or env var)
127
+ # This is REQUIRED for paid tier users per the official CLI behavior
128
+ configured_project_id = litellm_params.get("project_id") or os.getenv(
129
+ "GEMINI_CLI_PROJECT_ID"
130
+ )
131
+ if configured_project_id:
132
+ lib_logger.debug(
133
+ f"Found configured project_id override: {configured_project_id}"
134
+ )
135
+
136
+ # Load credentials from file to check for persisted project_id and tier
137
+ # Skip for env:// paths (environment-based credentials don't persist to files)
138
+ credential_index = self._parse_env_credential_path(credential_path)
139
+ if credential_index is None:
140
+ # Only try to load from file if it's not an env:// path
141
+ try:
142
+ with open(credential_path, "r") as f:
143
+ creds = json.load(f)
144
+
145
+ metadata = creds.get("_proxy_metadata", {})
146
+ persisted_project_id = metadata.get("project_id")
147
+ persisted_tier = metadata.get("tier")
148
+
149
+ if persisted_project_id:
150
+ lib_logger.info(
151
+ f"Loaded persisted project ID from credential file: {persisted_project_id}"
152
+ )
153
+ self.project_id_cache[credential_path] = persisted_project_id
154
+
155
+ # Also load tier if available
156
+ if persisted_tier:
157
+ self.project_tier_cache[credential_path] = persisted_tier
158
+ lib_logger.debug(f"Loaded persisted tier: {persisted_tier}")
159
+
160
+ return persisted_project_id
161
+ except (FileNotFoundError, json.JSONDecodeError, KeyError) as e:
162
+ lib_logger.debug(f"Could not load persisted project ID from file: {e}")
163
+
164
+ lib_logger.debug(
165
+ "No cached or configured project ID found, initiating discovery..."
166
+ )
167
+ headers = {
168
+ "Authorization": f"Bearer {access_token}",
169
+ "Content-Type": "application/json",
170
+ }
171
+
172
+ discovered_project_id = None
173
+ discovered_tier = None
174
+
175
+ async with httpx.AsyncClient() as client:
176
+ # 1. Try discovery endpoint with loadCodeAssist
177
+ lib_logger.debug(
178
+ "Attempting project discovery via Code Assist loadCodeAssist endpoint..."
179
+ )
180
+ try:
181
+ # Build metadata - include duetProject only if we have a configured project
182
+ core_client_metadata = {
183
+ "ideType": "IDE_UNSPECIFIED",
184
+ "platform": "PLATFORM_UNSPECIFIED",
185
+ "pluginType": "GEMINI",
186
+ }
187
+ if configured_project_id:
188
+ core_client_metadata["duetProject"] = configured_project_id
189
+
190
+ # Build load request - pass configured_project_id if available, otherwise None
191
+ load_request = {
192
+ "cloudaicompanionProject": configured_project_id, # Can be None
193
+ "metadata": core_client_metadata,
194
+ }
195
+
196
+ lib_logger.debug(
197
+ f"Sending loadCodeAssist request with cloudaicompanionProject={configured_project_id}"
198
+ )
199
+ response = await client.post(
200
+ f"{CODE_ASSIST_ENDPOINT}:loadCodeAssist",
201
+ headers=headers,
202
+ json=load_request,
203
+ timeout=20,
204
+ )
205
+ response.raise_for_status()
206
+ data = response.json()
207
+
208
+ # Log full response for debugging
209
+ lib_logger.debug(
210
+ f"loadCodeAssist full response keys: {list(data.keys())}"
211
+ )
212
+
213
+ # Extract and log ALL tier information for debugging
214
+ allowed_tiers = data.get("allowedTiers", [])
215
+ current_tier = data.get("currentTier")
216
+
217
+ lib_logger.debug(f"=== Tier Information ===")
218
+ lib_logger.debug(f"currentTier: {current_tier}")
219
+ lib_logger.debug(f"allowedTiers count: {len(allowed_tiers)}")
220
+ for i, tier in enumerate(allowed_tiers):
221
+ tier_id = tier.get("id", "unknown")
222
+ is_default = tier.get("isDefault", False)
223
+ user_defined = tier.get("userDefinedCloudaicompanionProject", False)
224
+ lib_logger.debug(
225
+ f" Tier {i + 1}: id={tier_id}, isDefault={is_default}, userDefinedProject={user_defined}"
226
+ )
227
+ lib_logger.debug(f"========================")
228
+
229
+ # Determine the current tier ID
230
+ current_tier_id = None
231
+ if current_tier:
232
+ current_tier_id = current_tier.get("id")
233
+ lib_logger.debug(f"User has currentTier: {current_tier_id}")
234
+
235
+ # Check if user is already known to server (has currentTier)
236
+ if current_tier_id:
237
+ # User is already onboarded - check for project from server
238
+ server_project = data.get("cloudaicompanionProject")
239
+
240
+ # Check if this tier requires user-defined project (paid tiers)
241
+ requires_user_project = any(
242
+ t.get("id") == current_tier_id
243
+ and t.get("userDefinedCloudaicompanionProject", False)
244
+ for t in allowed_tiers
245
+ )
246
+ is_free_tier = current_tier_id == "free-tier"
247
+
248
+ if server_project:
249
+ # Server returned a project - use it (server wins)
250
+ # This is the normal case for FREE tier users
251
+ project_id = server_project
252
+ lib_logger.debug(f"Server returned project: {project_id}")
253
+ elif configured_project_id:
254
+ # No server project but we have configured one - use it
255
+ # This is the PAID TIER case where server doesn't return a project
256
+ project_id = configured_project_id
257
+ lib_logger.debug(
258
+ f"No server project, using configured: {project_id}"
259
+ )
260
+ elif is_free_tier:
261
+ # Free tier user without server project - this shouldn't happen normally
262
+ # but let's not fail, just proceed to onboarding
263
+ lib_logger.debug(
264
+ "Free tier user with currentTier but no project - will try onboarding"
265
+ )
266
+ project_id = None
267
+ elif requires_user_project:
268
+ # Paid tier requires a project ID to be set
269
+ raise ValueError(
270
+ f"Paid tier '{current_tier_id}' requires setting GEMINI_CLI_PROJECT_ID environment variable. "
271
+ "See https://goo.gle/gemini-cli-auth-docs#workspace-gca"
272
+ )
273
+ else:
274
+ # Unknown tier without project - proceed carefully
275
+ lib_logger.warning(
276
+ f"Tier '{current_tier_id}' has no project and none configured - will try onboarding"
277
+ )
278
+ project_id = None
279
+
280
+ if project_id:
281
+ # Cache tier info
282
+ self.project_tier_cache[credential_path] = current_tier_id
283
+ discovered_tier = current_tier_id
284
+
285
+ # Log appropriately based on tier
286
+ is_paid = current_tier_id and current_tier_id not in [
287
+ "free-tier",
288
+ "legacy-tier",
289
+ "unknown",
290
+ ]
291
+ if is_paid:
292
+ lib_logger.info(
293
+ f"Using Gemini paid tier '{current_tier_id}' with project: {project_id}"
294
+ )
295
+ else:
296
+ lib_logger.info(
297
+ f"Discovered Gemini project ID via loadCodeAssist: {project_id}"
298
+ )
299
+
300
+ self.project_id_cache[credential_path] = project_id
301
+ discovered_project_id = project_id
302
+
303
+ # Persist to credential file
304
+ await self._persist_project_metadata(
305
+ credential_path, project_id, discovered_tier
306
+ )
307
+
308
+ return project_id
309
+
310
+ # 2. User needs onboarding - no currentTier
311
+ lib_logger.info(
312
+ "No existing Gemini session found (no currentTier), attempting to onboard user..."
313
+ )
314
+
315
+ # Determine which tier to onboard with
316
+ onboard_tier = None
317
+ for tier in allowed_tiers:
318
+ if tier.get("isDefault"):
319
+ onboard_tier = tier
320
+ break
321
+
322
+ # Fallback to LEGACY tier if no default (requires user project)
323
+ if not onboard_tier and allowed_tiers:
324
+ # Look for legacy-tier as fallback
325
+ for tier in allowed_tiers:
326
+ if tier.get("id") == "legacy-tier":
327
+ onboard_tier = tier
328
+ break
329
+ # If still no tier, use first available
330
+ if not onboard_tier:
331
+ onboard_tier = allowed_tiers[0]
332
+
333
+ if not onboard_tier:
334
+ raise ValueError("No onboarding tiers available from server")
335
+
336
+ tier_id = onboard_tier.get("id", "free-tier")
337
+ requires_user_project = onboard_tier.get(
338
+ "userDefinedCloudaicompanionProject", False
339
+ )
340
+
341
+ lib_logger.debug(
342
+ f"Onboarding with tier: {tier_id}, requiresUserProject: {requires_user_project}"
343
+ )
344
+
345
+ # Build onboard request based on tier type (following official CLI logic)
346
+ # FREE tier: cloudaicompanionProject = None (server-managed)
347
+ # PAID tier: cloudaicompanionProject = configured_project_id (user must provide)
348
+ is_free_tier = tier_id == "free-tier"
349
+
350
+ if is_free_tier:
351
+ # Free tier uses server-managed project
352
+ onboard_request = {
353
+ "tierId": tier_id,
354
+ "cloudaicompanionProject": None, # Server will create/manage
355
+ "metadata": core_client_metadata,
356
+ }
357
+ lib_logger.debug(
358
+ "Free tier onboarding: using server-managed project"
359
+ )
360
+ else:
361
+ # Paid/legacy tier requires user-provided project
362
+ if not configured_project_id and requires_user_project:
363
+ raise ValueError(
364
+ f"Tier '{tier_id}' requires setting GEMINI_CLI_PROJECT_ID environment variable. "
365
+ "See https://goo.gle/gemini-cli-auth-docs#workspace-gca"
366
+ )
367
+ onboard_request = {
368
+ "tierId": tier_id,
369
+ "cloudaicompanionProject": configured_project_id,
370
+ "metadata": {
371
+ **core_client_metadata,
372
+ "duetProject": configured_project_id,
373
+ }
374
+ if configured_project_id
375
+ else core_client_metadata,
376
+ }
377
+ lib_logger.debug(
378
+ f"Paid tier onboarding: using project {configured_project_id}"
379
+ )
380
+
381
+ lib_logger.debug("Initiating onboardUser request...")
382
+ lro_response = await client.post(
383
+ f"{CODE_ASSIST_ENDPOINT}:onboardUser",
384
+ headers=headers,
385
+ json=onboard_request,
386
+ timeout=30,
387
+ )
388
+ lro_response.raise_for_status()
389
+ lro_data = lro_response.json()
390
+ lib_logger.debug(
391
+ f"Initial onboarding response: done={lro_data.get('done')}"
392
+ )
393
+
394
+ for i in range(150): # Poll for up to 5 minutes (150 × 2s)
395
+ if lro_data.get("done"):
396
+ lib_logger.debug(
397
+ f"Onboarding completed after {i} polling attempts"
398
+ )
399
+ break
400
+ await asyncio.sleep(2)
401
+ if (i + 1) % 15 == 0: # Log every 30 seconds
402
+ lib_logger.info(
403
+ f"Still waiting for onboarding completion... ({(i + 1) * 2}s elapsed)"
404
+ )
405
+ lib_logger.debug(
406
+ f"Polling onboarding status... (Attempt {i + 1}/150)"
407
+ )
408
+ lro_response = await client.post(
409
+ f"{CODE_ASSIST_ENDPOINT}:onboardUser",
410
+ headers=headers,
411
+ json=onboard_request,
412
+ timeout=30,
413
+ )
414
+ lro_response.raise_for_status()
415
+ lro_data = lro_response.json()
416
+
417
+ if not lro_data.get("done"):
418
+ lib_logger.error("Onboarding process timed out after 5 minutes")
419
+ raise ValueError(
420
+ "Onboarding process timed out after 5 minutes. Please try again or contact support."
421
+ )
422
+
423
+ # Extract project ID from LRO response
424
+ # Note: onboardUser returns response.cloudaicompanionProject as an object with .id
425
+ lro_response_data = lro_data.get("response", {})
426
+ lro_project_obj = lro_response_data.get("cloudaicompanionProject", {})
427
+ project_id = (
428
+ lro_project_obj.get("id")
429
+ if isinstance(lro_project_obj, dict)
430
+ else None
431
+ )
432
+
433
+ # Fallback to configured project if LRO didn't return one
434
+ if not project_id and configured_project_id:
435
+ project_id = configured_project_id
436
+ lib_logger.debug(
437
+ f"LRO didn't return project, using configured: {project_id}"
438
+ )
439
+
440
+ if not project_id:
441
+ lib_logger.error(
442
+ "Onboarding completed but no project ID in response and none configured"
443
+ )
444
+ raise ValueError(
445
+ "Onboarding completed, but no project ID was returned. "
446
+ "For paid tiers, set GEMINI_CLI_PROJECT_ID environment variable."
447
+ )
448
+
449
+ lib_logger.debug(
450
+ f"Successfully extracted project ID from onboarding response: {project_id}"
451
+ )
452
+
453
+ # Cache tier info
454
+ self.project_tier_cache[credential_path] = tier_id
455
+ discovered_tier = tier_id
456
+ lib_logger.debug(f"Cached tier information: {tier_id}")
457
+
458
+ # Log concise message for paid projects
459
+ is_paid = tier_id and tier_id not in ["free-tier", "legacy-tier"]
460
+ if is_paid:
461
+ lib_logger.info(
462
+ f"Using Gemini paid tier '{tier_id}' with project: {project_id}"
463
+ )
464
+ else:
465
+ lib_logger.info(
466
+ f"Successfully onboarded user and discovered project ID: {project_id}"
467
+ )
468
+
469
+ self.project_id_cache[credential_path] = project_id
470
+ discovered_project_id = project_id
471
+
472
+ # Persist to credential file
473
+ await self._persist_project_metadata(
474
+ credential_path, project_id, discovered_tier
475
+ )
476
+
477
+ return project_id
478
+
479
+ except httpx.HTTPStatusError as e:
480
+ error_body = ""
481
+ try:
482
+ error_body = e.response.text
483
+ except Exception:
484
+ pass
485
+ if e.response.status_code == 403:
486
+ lib_logger.error(
487
+ f"Gemini Code Assist API access denied (403). Response: {error_body}"
488
+ )
489
+ lib_logger.error(
490
+ "Possible causes: 1) cloudaicompanion.googleapis.com API not enabled, 2) Wrong project ID for paid tier, 3) Account lacks permissions"
491
+ )
492
+ elif e.response.status_code == 404:
493
+ lib_logger.warning(
494
+ f"Gemini Code Assist endpoint not found (404). Falling back to project listing."
495
+ )
496
+ elif e.response.status_code == 412:
497
+ # Precondition Failed - often means wrong project for free tier onboarding
498
+ lib_logger.error(
499
+ f"Precondition failed (412): {error_body}. This may mean the project ID is incompatible with the selected tier."
500
+ )
501
+ else:
502
+ lib_logger.warning(
503
+ f"Gemini onboarding/discovery failed with status {e.response.status_code}: {error_body}. Falling back to project listing."
504
+ )
505
+ except httpx.RequestError as e:
506
+ lib_logger.warning(
507
+ f"Gemini onboarding/discovery network error: {e}. Falling back to project listing."
508
+ )
509
+
510
+ # 3. Fallback to listing all available GCP projects (last resort)
511
+ lib_logger.debug(
512
+ "Attempting to discover project via GCP Resource Manager API..."
513
+ )
514
+ try:
515
+ async with httpx.AsyncClient() as client:
516
+ lib_logger.debug(
517
+ "Querying Cloud Resource Manager for available projects..."
518
+ )
519
+ response = await client.get(
520
+ "https://cloudresourcemanager.googleapis.com/v1/projects",
521
+ headers=headers,
522
+ timeout=20,
523
+ )
524
+ response.raise_for_status()
525
+ projects = response.json().get("projects", [])
526
+ lib_logger.debug(f"Found {len(projects)} total projects")
527
+ active_projects = [
528
+ p for p in projects if p.get("lifecycleState") == "ACTIVE"
529
+ ]
530
+ lib_logger.debug(f"Found {len(active_projects)} active projects")
531
+
532
+ if not projects:
533
+ lib_logger.error(
534
+ "No GCP projects found for this account. Please create a project in Google Cloud Console."
535
+ )
536
+ elif not active_projects:
537
+ lib_logger.error(
538
+ "No active GCP projects found. Please activate a project in Google Cloud Console."
539
+ )
540
+ else:
541
+ project_id = active_projects[0]["projectId"]
542
+ lib_logger.info(
543
+ f"Discovered Gemini project ID from active projects list: {project_id}"
544
+ )
545
+ lib_logger.debug(
546
+ f"Selected first active project: {project_id} (out of {len(active_projects)} active projects)"
547
+ )
548
+ self.project_id_cache[credential_path] = project_id
549
+ discovered_project_id = project_id
550
+
551
+ # Persist to credential file (no tier info from resource manager)
552
+ await self._persist_project_metadata(
553
+ credential_path, project_id, None
554
+ )
555
+
556
+ return project_id
557
+ except httpx.HTTPStatusError as e:
558
+ if e.response.status_code == 403:
559
+ lib_logger.error(
560
+ "Failed to list GCP projects due to a 403 Forbidden error. The Cloud Resource Manager API may not be enabled, or your account lacks the 'resourcemanager.projects.list' permission."
561
+ )
562
+ else:
563
+ lib_logger.error(
564
+ f"Failed to list GCP projects with status {e.response.status_code}: {e}"
565
+ )
566
+ except httpx.RequestError as e:
567
+ lib_logger.error(f"Network error while listing GCP projects: {e}")
568
+
569
+ raise ValueError(
570
+ "Could not auto-discover Gemini project ID. Possible causes:\n"
571
+ " 1. The cloudaicompanion.googleapis.com API is not enabled (enable it in Google Cloud Console)\n"
572
+ " 2. No active GCP projects exist for this account (create one in Google Cloud Console)\n"
573
+ " 3. Account lacks necessary permissions\n"
574
+ "To manually specify a project, set GEMINI_CLI_PROJECT_ID in your .env file."
575
+ )
576
+
577
+ async def _persist_project_metadata(
578
+ self, credential_path: str, project_id: str, tier: Optional[str]
579
+ ):
580
+ """Persists project ID and tier to the credential file for faster future startups."""
581
+ # Skip persistence for env:// paths (environment-based credentials)
582
+ credential_index = self._parse_env_credential_path(credential_path)
583
+ if credential_index is not None:
584
+ lib_logger.debug(
585
+ f"Skipping project metadata persistence for env:// credential path: {credential_path}"
586
+ )
587
+ return
588
+
589
+ try:
590
+ # Load current credentials
591
+ with open(credential_path, "r") as f:
592
+ creds = json.load(f)
593
+
594
+ # Update metadata
595
+ if "_proxy_metadata" not in creds:
596
+ creds["_proxy_metadata"] = {}
597
+
598
+ creds["_proxy_metadata"]["project_id"] = project_id
599
+ if tier:
600
+ creds["_proxy_metadata"]["tier"] = tier
601
+
602
+ # Save back using the existing save method (handles atomic writes and permissions)
603
+ await self._save_credentials(credential_path, creds)
604
+
605
+ lib_logger.debug(
606
+ f"Persisted project_id and tier to credential file: {credential_path}"
607
+ )
608
+ except Exception as e:
609
+ lib_logger.warning(
610
+ f"Failed to persist project metadata to credential file: {e}"
611
+ )
612
+ # Non-fatal - just means slower startup next time
613
+
614
+ # =========================================================================
615
+ # CREDENTIAL MANAGEMENT OVERRIDES
616
+ # =========================================================================
617
+
618
+ def _get_provider_file_prefix(self) -> str:
619
+ """Return the file prefix for Gemini CLI credentials."""
620
+ return "gemini_cli"
621
+
622
+ def build_env_lines(self, creds: Dict[str, Any], cred_number: int) -> List[str]:
623
+ """
624
+ Generate .env file lines for a Gemini CLI credential.
625
+
626
+ Includes tier and project_id from _proxy_metadata.
627
+ """
628
+ # Get base lines from parent class
629
+ lines = super().build_env_lines(creds, cred_number)
630
+
631
+ # Add Gemini-specific fields (tier and project_id)
632
+ metadata = creds.get("_proxy_metadata", {})
633
+ prefix = f"{self.ENV_PREFIX}_{cred_number}"
634
+
635
+ project_id = metadata.get("project_id", "")
636
+ tier = metadata.get("tier", "")
637
+
638
+ if project_id:
639
+ lines.append(f"{prefix}_PROJECT_ID={project_id}")
640
+ if tier:
641
+ lines.append(f"{prefix}_TIER={tier}")
642
+
643
+ return lines
src/rotator_library/providers/gemini_cli_provider.py CHANGED
@@ -11,6 +11,8 @@ from .provider_interface import ProviderInterface
11
  from .gemini_auth_base import GeminiAuthBase
12
  from .provider_cache import ProviderCache
13
  from ..model_definitions import ModelDefinitions
 
 
14
  import litellm
15
  from litellm.exceptions import RateLimitError
16
  from ..error_handler import extract_retry_after_from_body
@@ -21,8 +23,22 @@ from datetime import datetime
21
 
22
  lib_logger = logging.getLogger("rotator_library")
23
 
24
- LOGS_DIR = Path(__file__).resolve().parent.parent.parent.parent / "logs"
25
- GEMINI_CLI_LOGS_DIR = LOGS_DIR / "gemini_cli_logs"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  class _GeminiCliFileLogger:
@@ -38,7 +54,7 @@ class _GeminiCliFileLogger:
38
  # Sanitize model name for directory
39
  safe_model_name = model_name.replace("/", "_").replace(":", "_")
40
  self.log_dir = (
41
- GEMINI_CLI_LOGS_DIR / f"{timestamp}_{safe_model_name}_{request_id}"
42
  )
43
  try:
44
  self.log_dir.mkdir(parents=True, exist_ok=True)
@@ -102,12 +118,6 @@ HARDCODED_MODELS = [
102
  "gemini-3-pro-preview",
103
  ]
104
 
105
- # Cache directory for Gemini CLI
106
- CACHE_DIR = (
107
- Path(__file__).resolve().parent.parent.parent.parent / "cache" / "gemini_cli"
108
- )
109
- GEMINI3_SIGNATURE_CACHE_FILE = CACHE_DIR / "gemini3_signatures.json"
110
-
111
  # Gemini 3 tool fix system instruction (prevents hallucination)
112
  DEFAULT_GEMINI3_SYSTEM_INSTRUCTION = """<CRITICAL_TOOL_USAGE_INSTRUCTIONS>
113
  You are operating in a CUSTOM ENVIRONMENT where tool definitions COMPLETELY DIFFER from your training data.
@@ -173,6 +183,98 @@ FINISH_REASON_MAP = {
173
  }
174
 
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  def _env_bool(key: str, default: bool = False) -> bool:
177
  """Get boolean from environment variable."""
178
  return os.getenv(key, str(default).lower()).lower() in ("true", "1", "yes")
@@ -186,8 +288,8 @@ def _env_int(key: str, default: int) -> int:
186
  class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
187
  skip_cost_calculation = True
188
 
189
- # Balanced by default - Gemini CLI has short cooldowns (seconds, not hours)
190
- default_rotation_mode: str = "balanced"
191
 
192
  # =========================================================================
193
  # TIER CONFIGURATION
@@ -234,32 +336,156 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
234
  error: Exception, error_body: Optional[str] = None
235
  ) -> Optional[Dict[str, Any]]:
236
  """
237
- Parse Gemini CLI quota errors.
238
-
239
- Uses the same Google RPC format as Antigravity but typically has
240
- much shorter cooldown durations (seconds to minutes, not hours).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
  Args:
243
  error: The caught exception
244
  error_body: Optional raw response body string
245
 
246
  Returns:
247
- Same format as AntigravityProvider.parse_quota_error()
 
 
 
 
 
 
248
  """
249
- # Reuse the same parsing logic as Antigravity since both use Google RPC format
250
- from .antigravity_provider import AntigravityProvider
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
- return AntigravityProvider.parse_quota_error(error, error_body)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
  def __init__(self):
255
  super().__init__()
256
  self.model_definitions = ModelDefinitions()
257
- self.project_id_cache: Dict[
258
- str, str
259
- ] = {} # Cache project ID per credential path
260
- self.project_tier_cache: Dict[
261
- str, str
262
- ] = {} # Cache project tier per credential path
263
 
264
  # Gemini 3 configuration from environment
265
  memory_ttl = _env_int("GEMINI_CLI_SIGNATURE_CACHE_TTL", 3600)
@@ -267,7 +493,7 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
267
 
268
  # Initialize signature cache for Gemini 3 thoughtSignatures
269
  self._signature_cache = ProviderCache(
270
- GEMINI3_SIGNATURE_CACHE_FILE,
271
  memory_ttl,
272
  disk_ttl,
273
  env_prefix="GEMINI_CLI_SIGNATURE",
@@ -381,7 +607,7 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
381
 
382
  # Gemini 3 requires paid tier
383
  if model_name.startswith("gemini-3-"):
384
- return 1 # Only priority 1 (paid) credentials
385
 
386
  return None # All other models have no restrictions
387
 
@@ -391,9 +617,48 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
391
 
392
  This ensures all credential priorities are known before any API calls,
393
  preventing unknown credentials from getting priority 999.
 
 
 
394
  """
 
395
  await self._load_persisted_tiers(credential_paths)
396
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
  async def _load_persisted_tiers(
398
  self, credential_paths: List[str]
399
  ) -> Dict[str, str]:
@@ -451,6 +716,8 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
451
 
452
  return loaded
453
 
 
 
454
  # =========================================================================
455
  # MODEL UTILITIES
456
  # =========================================================================
@@ -466,520 +733,7 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
466
  return name[len(self._gemini3_tool_prefix) :]
467
  return name
468
 
469
- async def _discover_project_id(
470
- self, credential_path: str, access_token: str, litellm_params: Dict[str, Any]
471
- ) -> str:
472
- """
473
- Discovers the Google Cloud Project ID, with caching and onboarding for new accounts.
474
-
475
- This follows the official Gemini CLI discovery flow:
476
- 1. Check in-memory cache
477
- 2. Check configured project_id override (litellm_params or env var)
478
- 3. Check persisted project_id in credential file
479
- 4. Call loadCodeAssist to check if user is already known (has currentTier)
480
- - If currentTier exists AND cloudaicompanionProject returned: use server's project
481
- - If currentTier exists but NO cloudaicompanionProject: use configured project_id (paid tier requires this)
482
- - If no currentTier: user needs onboarding
483
- 5. Onboard user based on tier:
484
- - FREE tier: pass cloudaicompanionProject=None (server-managed)
485
- - PAID tier: pass cloudaicompanionProject=configured_project_id
486
- 6. Fallback to GCP Resource Manager project listing
487
- """
488
- lib_logger.debug(
489
- f"Starting project discovery for credential: {credential_path}"
490
- )
491
-
492
- # Check in-memory cache first
493
- if credential_path in self.project_id_cache:
494
- cached_project = self.project_id_cache[credential_path]
495
- lib_logger.debug(f"Using cached project ID: {cached_project}")
496
- return cached_project
497
-
498
- # Check for configured project ID override (from litellm_params or env var)
499
- # This is REQUIRED for paid tier users per the official CLI behavior
500
- configured_project_id = litellm_params.get("project_id")
501
- if configured_project_id:
502
- lib_logger.debug(
503
- f"Found configured project_id override: {configured_project_id}"
504
- )
505
-
506
- # Load credentials from file to check for persisted project_id and tier
507
- # Skip for env:// paths (environment-based credentials don't persist to files)
508
- credential_index = self._parse_env_credential_path(credential_path)
509
- if credential_index is None:
510
- # Only try to load from file if it's not an env:// path
511
- try:
512
- with open(credential_path, "r") as f:
513
- creds = json.load(f)
514
-
515
- metadata = creds.get("_proxy_metadata", {})
516
- persisted_project_id = metadata.get("project_id")
517
- persisted_tier = metadata.get("tier")
518
-
519
- if persisted_project_id:
520
- lib_logger.info(
521
- f"Loaded persisted project ID from credential file: {persisted_project_id}"
522
- )
523
- self.project_id_cache[credential_path] = persisted_project_id
524
-
525
- # Also load tier if available
526
- if persisted_tier:
527
- self.project_tier_cache[credential_path] = persisted_tier
528
- lib_logger.debug(f"Loaded persisted tier: {persisted_tier}")
529
-
530
- return persisted_project_id
531
- except (FileNotFoundError, json.JSONDecodeError, KeyError) as e:
532
- lib_logger.debug(f"Could not load persisted project ID from file: {e}")
533
-
534
- lib_logger.debug(
535
- "No cached or configured project ID found, initiating discovery..."
536
- )
537
- headers = {
538
- "Authorization": f"Bearer {access_token}",
539
- "Content-Type": "application/json",
540
- }
541
-
542
- discovered_project_id = None
543
- discovered_tier = None
544
-
545
- async with httpx.AsyncClient() as client:
546
- # 1. Try discovery endpoint with loadCodeAssist
547
- lib_logger.debug(
548
- "Attempting project discovery via Code Assist loadCodeAssist endpoint..."
549
- )
550
- try:
551
- # Build metadata - include duetProject only if we have a configured project
552
- core_client_metadata = {
553
- "ideType": "IDE_UNSPECIFIED",
554
- "platform": "PLATFORM_UNSPECIFIED",
555
- "pluginType": "GEMINI",
556
- }
557
- if configured_project_id:
558
- core_client_metadata["duetProject"] = configured_project_id
559
-
560
- # Build load request - pass configured_project_id if available, otherwise None
561
- load_request = {
562
- "cloudaicompanionProject": configured_project_id, # Can be None
563
- "metadata": core_client_metadata,
564
- }
565
-
566
- lib_logger.debug(
567
- f"Sending loadCodeAssist request with cloudaicompanionProject={configured_project_id}"
568
- )
569
- response = await client.post(
570
- f"{CODE_ASSIST_ENDPOINT}:loadCodeAssist",
571
- headers=headers,
572
- json=load_request,
573
- timeout=20,
574
- )
575
- response.raise_for_status()
576
- data = response.json()
577
-
578
- # Log full response for debugging
579
- lib_logger.debug(
580
- f"loadCodeAssist full response keys: {list(data.keys())}"
581
- )
582
-
583
- # Extract and log ALL tier information for debugging
584
- allowed_tiers = data.get("allowedTiers", [])
585
- current_tier = data.get("currentTier")
586
-
587
- lib_logger.debug(f"=== Tier Information ===")
588
- lib_logger.debug(f"currentTier: {current_tier}")
589
- lib_logger.debug(f"allowedTiers count: {len(allowed_tiers)}")
590
- for i, tier in enumerate(allowed_tiers):
591
- tier_id = tier.get("id", "unknown")
592
- is_default = tier.get("isDefault", False)
593
- user_defined = tier.get("userDefinedCloudaicompanionProject", False)
594
- lib_logger.debug(
595
- f" Tier {i + 1}: id={tier_id}, isDefault={is_default}, userDefinedProject={user_defined}"
596
- )
597
- lib_logger.debug(f"========================")
598
-
599
- # Determine the current tier ID
600
- current_tier_id = None
601
- if current_tier:
602
- current_tier_id = current_tier.get("id")
603
- lib_logger.debug(f"User has currentTier: {current_tier_id}")
604
-
605
- # Check if user is already known to server (has currentTier)
606
- if current_tier_id:
607
- # User is already onboarded - check for project from server
608
- server_project = data.get("cloudaicompanionProject")
609
-
610
- # Check if this tier requires user-defined project (paid tiers)
611
- requires_user_project = any(
612
- t.get("id") == current_tier_id
613
- and t.get("userDefinedCloudaicompanionProject", False)
614
- for t in allowed_tiers
615
- )
616
- is_free_tier = current_tier_id == "free-tier"
617
-
618
- if server_project:
619
- # Server returned a project - use it (server wins)
620
- # This is the normal case for FREE tier users
621
- project_id = server_project
622
- lib_logger.debug(f"Server returned project: {project_id}")
623
- elif configured_project_id:
624
- # No server project but we have configured one - use it
625
- # This is the PAID TIER case where server doesn't return a project
626
- project_id = configured_project_id
627
- lib_logger.debug(
628
- f"No server project, using configured: {project_id}"
629
- )
630
- elif is_free_tier:
631
- # Free tier user without server project - this shouldn't happen normally
632
- # but let's not fail, just proceed to onboarding
633
- lib_logger.debug(
634
- "Free tier user with currentTier but no project - will try onboarding"
635
- )
636
- project_id = None
637
- elif requires_user_project:
638
- # Paid tier requires a project ID to be set
639
- raise ValueError(
640
- f"Paid tier '{current_tier_id}' requires setting GEMINI_CLI_PROJECT_ID environment variable. "
641
- "See https://goo.gle/gemini-cli-auth-docs#workspace-gca"
642
- )
643
- else:
644
- # Unknown tier without project - proceed carefully
645
- lib_logger.warning(
646
- f"Tier '{current_tier_id}' has no project and none configured - will try onboarding"
647
- )
648
- project_id = None
649
-
650
- if project_id:
651
- # Cache tier info
652
- self.project_tier_cache[credential_path] = current_tier_id
653
- discovered_tier = current_tier_id
654
-
655
- # Log appropriately based on tier
656
- is_paid = current_tier_id and current_tier_id not in [
657
- "free-tier",
658
- "legacy-tier",
659
- "unknown",
660
- ]
661
- if is_paid:
662
- lib_logger.info(
663
- f"Using Gemini paid tier '{current_tier_id}' with project: {project_id}"
664
- )
665
- else:
666
- lib_logger.info(
667
- f"Discovered Gemini project ID via loadCodeAssist: {project_id}"
668
- )
669
-
670
- self.project_id_cache[credential_path] = project_id
671
- discovered_project_id = project_id
672
-
673
- # Persist to credential file
674
- await self._persist_project_metadata(
675
- credential_path, project_id, discovered_tier
676
- )
677
-
678
- return project_id
679
-
680
- # 2. User needs onboarding - no currentTier
681
- lib_logger.info(
682
- "No existing Gemini session found (no currentTier), attempting to onboard user..."
683
- )
684
-
685
- # Determine which tier to onboard with
686
- onboard_tier = None
687
- for tier in allowed_tiers:
688
- if tier.get("isDefault"):
689
- onboard_tier = tier
690
- break
691
-
692
- # Fallback to LEGACY tier if no default (requires user project)
693
- if not onboard_tier and allowed_tiers:
694
- # Look for legacy-tier as fallback
695
- for tier in allowed_tiers:
696
- if tier.get("id") == "legacy-tier":
697
- onboard_tier = tier
698
- break
699
- # If still no tier, use first available
700
- if not onboard_tier:
701
- onboard_tier = allowed_tiers[0]
702
-
703
- if not onboard_tier:
704
- raise ValueError("No onboarding tiers available from server")
705
-
706
- tier_id = onboard_tier.get("id", "free-tier")
707
- requires_user_project = onboard_tier.get(
708
- "userDefinedCloudaicompanionProject", False
709
- )
710
-
711
- lib_logger.debug(
712
- f"Onboarding with tier: {tier_id}, requiresUserProject: {requires_user_project}"
713
- )
714
-
715
- # Build onboard request based on tier type (following official CLI logic)
716
- # FREE tier: cloudaicompanionProject = None (server-managed)
717
- # PAID tier: cloudaicompanionProject = configured_project_id (user must provide)
718
- is_free_tier = tier_id == "free-tier"
719
-
720
- if is_free_tier:
721
- # Free tier uses server-managed project
722
- onboard_request = {
723
- "tierId": tier_id,
724
- "cloudaicompanionProject": None, # Server will create/manage
725
- "metadata": core_client_metadata,
726
- }
727
- lib_logger.debug(
728
- "Free tier onboarding: using server-managed project"
729
- )
730
- else:
731
- # Paid/legacy tier requires user-provided project
732
- if not configured_project_id and requires_user_project:
733
- raise ValueError(
734
- f"Tier '{tier_id}' requires setting GEMINI_CLI_PROJECT_ID environment variable. "
735
- "See https://goo.gle/gemini-cli-auth-docs#workspace-gca"
736
- )
737
- onboard_request = {
738
- "tierId": tier_id,
739
- "cloudaicompanionProject": configured_project_id,
740
- "metadata": {
741
- **core_client_metadata,
742
- "duetProject": configured_project_id,
743
- }
744
- if configured_project_id
745
- else core_client_metadata,
746
- }
747
- lib_logger.debug(
748
- f"Paid tier onboarding: using project {configured_project_id}"
749
- )
750
-
751
- lib_logger.debug("Initiating onboardUser request...")
752
- lro_response = await client.post(
753
- f"{CODE_ASSIST_ENDPOINT}:onboardUser",
754
- headers=headers,
755
- json=onboard_request,
756
- timeout=30,
757
- )
758
- lro_response.raise_for_status()
759
- lro_data = lro_response.json()
760
- lib_logger.debug(
761
- f"Initial onboarding response: done={lro_data.get('done')}"
762
- )
763
-
764
- for i in range(150): # Poll for up to 5 minutes (150 × 2s)
765
- if lro_data.get("done"):
766
- lib_logger.debug(
767
- f"Onboarding completed after {i} polling attempts"
768
- )
769
- break
770
- await asyncio.sleep(2)
771
- if (i + 1) % 15 == 0: # Log every 30 seconds
772
- lib_logger.info(
773
- f"Still waiting for onboarding completion... ({(i + 1) * 2}s elapsed)"
774
- )
775
- lib_logger.debug(
776
- f"Polling onboarding status... (Attempt {i + 1}/150)"
777
- )
778
- lro_response = await client.post(
779
- f"{CODE_ASSIST_ENDPOINT}:onboardUser",
780
- headers=headers,
781
- json=onboard_request,
782
- timeout=30,
783
- )
784
- lro_response.raise_for_status()
785
- lro_data = lro_response.json()
786
-
787
- if not lro_data.get("done"):
788
- lib_logger.error("Onboarding process timed out after 5 minutes")
789
- raise ValueError(
790
- "Onboarding process timed out after 5 minutes. Please try again or contact support."
791
- )
792
-
793
- # Extract project ID from LRO response
794
- # Note: onboardUser returns response.cloudaicompanionProject as an object with .id
795
- lro_response_data = lro_data.get("response", {})
796
- lro_project_obj = lro_response_data.get("cloudaicompanionProject", {})
797
- project_id = (
798
- lro_project_obj.get("id")
799
- if isinstance(lro_project_obj, dict)
800
- else None
801
- )
802
-
803
- # Fallback to configured project if LRO didn't return one
804
- if not project_id and configured_project_id:
805
- project_id = configured_project_id
806
- lib_logger.debug(
807
- f"LRO didn't return project, using configured: {project_id}"
808
- )
809
-
810
- if not project_id:
811
- lib_logger.error(
812
- "Onboarding completed but no project ID in response and none configured"
813
- )
814
- raise ValueError(
815
- "Onboarding completed, but no project ID was returned. "
816
- "For paid tiers, set GEMINI_CLI_PROJECT_ID environment variable."
817
- )
818
-
819
- lib_logger.debug(
820
- f"Successfully extracted project ID from onboarding response: {project_id}"
821
- )
822
-
823
- # Cache tier info
824
- self.project_tier_cache[credential_path] = tier_id
825
- discovered_tier = tier_id
826
- lib_logger.debug(f"Cached tier information: {tier_id}")
827
-
828
- # Log concise message for paid projects
829
- is_paid = tier_id and tier_id not in ["free-tier", "legacy-tier"]
830
- if is_paid:
831
- lib_logger.info(
832
- f"Using Gemini paid tier '{tier_id}' with project: {project_id}"
833
- )
834
- else:
835
- lib_logger.info(
836
- f"Successfully onboarded user and discovered project ID: {project_id}"
837
- )
838
-
839
- self.project_id_cache[credential_path] = project_id
840
- discovered_project_id = project_id
841
-
842
- # Persist to credential file
843
- await self._persist_project_metadata(
844
- credential_path, project_id, discovered_tier
845
- )
846
-
847
- return project_id
848
-
849
- except httpx.HTTPStatusError as e:
850
- error_body = ""
851
- try:
852
- error_body = e.response.text
853
- except Exception:
854
- pass
855
- if e.response.status_code == 403:
856
- lib_logger.error(
857
- f"Gemini Code Assist API access denied (403). Response: {error_body}"
858
- )
859
- lib_logger.error(
860
- "Possible causes: 1) cloudaicompanion.googleapis.com API not enabled, 2) Wrong project ID for paid tier, 3) Account lacks permissions"
861
- )
862
- elif e.response.status_code == 404:
863
- lib_logger.warning(
864
- f"Gemini Code Assist endpoint not found (404). Falling back to project listing."
865
- )
866
- elif e.response.status_code == 412:
867
- # Precondition Failed - often means wrong project for free tier onboarding
868
- lib_logger.error(
869
- f"Precondition failed (412): {error_body}. This may mean the project ID is incompatible with the selected tier."
870
- )
871
- else:
872
- lib_logger.warning(
873
- f"Gemini onboarding/discovery failed with status {e.response.status_code}: {error_body}. Falling back to project listing."
874
- )
875
- except httpx.RequestError as e:
876
- lib_logger.warning(
877
- f"Gemini onboarding/discovery network error: {e}. Falling back to project listing."
878
- )
879
-
880
- # 3. Fallback to listing all available GCP projects (last resort)
881
- lib_logger.debug(
882
- "Attempting to discover project via GCP Resource Manager API..."
883
- )
884
- try:
885
- async with httpx.AsyncClient() as client:
886
- lib_logger.debug(
887
- "Querying Cloud Resource Manager for available projects..."
888
- )
889
- response = await client.get(
890
- "https://cloudresourcemanager.googleapis.com/v1/projects",
891
- headers=headers,
892
- timeout=20,
893
- )
894
- response.raise_for_status()
895
- projects = response.json().get("projects", [])
896
- lib_logger.debug(f"Found {len(projects)} total projects")
897
- active_projects = [
898
- p for p in projects if p.get("lifecycleState") == "ACTIVE"
899
- ]
900
- lib_logger.debug(f"Found {len(active_projects)} active projects")
901
-
902
- if not projects:
903
- lib_logger.error(
904
- "No GCP projects found for this account. Please create a project in Google Cloud Console."
905
- )
906
- elif not active_projects:
907
- lib_logger.error(
908
- "No active GCP projects found. Please activate a project in Google Cloud Console."
909
- )
910
- else:
911
- project_id = active_projects[0]["projectId"]
912
- lib_logger.info(
913
- f"Discovered Gemini project ID from active projects list: {project_id}"
914
- )
915
- lib_logger.debug(
916
- f"Selected first active project: {project_id} (out of {len(active_projects)} active projects)"
917
- )
918
- self.project_id_cache[credential_path] = project_id
919
- discovered_project_id = project_id
920
-
921
- # [NEW] Persist to credential file (no tier info from resource manager)
922
- await self._persist_project_metadata(
923
- credential_path, project_id, None
924
- )
925
-
926
- return project_id
927
- except httpx.HTTPStatusError as e:
928
- if e.response.status_code == 403:
929
- lib_logger.error(
930
- "Failed to list GCP projects due to a 403 Forbidden error. The Cloud Resource Manager API may not be enabled, or your account lacks the 'resourcemanager.projects.list' permission."
931
- )
932
- else:
933
- lib_logger.error(
934
- f"Failed to list GCP projects with status {e.response.status_code}: {e}"
935
- )
936
- except httpx.RequestError as e:
937
- lib_logger.error(f"Network error while listing GCP projects: {e}")
938
-
939
- raise ValueError(
940
- "Could not auto-discover Gemini project ID. Possible causes:\n"
941
- " 1. The cloudaicompanion.googleapis.com API is not enabled (enable it in Google Cloud Console)\n"
942
- " 2. No active GCP projects exist for this account (create one in Google Cloud Console)\n"
943
- " 3. Account lacks necessary permissions\n"
944
- "To manually specify a project, set GEMINI_CLI_PROJECT_ID in your .env file."
945
- )
946
-
947
- async def _persist_project_metadata(
948
- self, credential_path: str, project_id: str, tier: Optional[str]
949
- ):
950
- """Persists project ID and tier to the credential file for faster future startups."""
951
- # Skip persistence for env:// paths (environment-based credentials)
952
- credential_index = self._parse_env_credential_path(credential_path)
953
- if credential_index is not None:
954
- lib_logger.debug(
955
- f"Skipping project metadata persistence for env:// credential path: {credential_path}"
956
- )
957
- return
958
-
959
- try:
960
- # Load current credentials
961
- with open(credential_path, "r") as f:
962
- creds = json.load(f)
963
-
964
- # Update metadata
965
- if "_proxy_metadata" not in creds:
966
- creds["_proxy_metadata"] = {}
967
-
968
- creds["_proxy_metadata"]["project_id"] = project_id
969
- if tier:
970
- creds["_proxy_metadata"]["tier"] = tier
971
-
972
- # Save back using the existing save method (handles atomic writes and permissions)
973
- await self._save_credentials(credential_path, creds)
974
-
975
- lib_logger.debug(
976
- f"Persisted project_id and tier to credential file: {credential_path}"
977
- )
978
- except Exception as e:
979
- lib_logger.warning(
980
- f"Failed to persist project metadata to credential file: {e}"
981
- )
982
- # Non-fatal - just means slower startup next time
983
 
984
  def _check_mixed_tier_warning(self):
985
  """Check if mixed free/paid tier credentials are loaded and emit warning."""
@@ -1166,7 +920,7 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
1166
  func_part["thoughtSignature"] = (
1167
  "skip_thought_signature_validator"
1168
  )
1169
- lib_logger.warning(
1170
  f"Missing thoughtSignature for first func call {tool_id}, using bypass"
1171
  )
1172
  # Subsequent parallel calls: no signature field at all
@@ -1178,23 +932,39 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
1178
  elif role == "tool":
1179
  tool_call_id = msg.get("tool_call_id")
1180
  function_name = tool_call_id_to_name.get(tool_call_id)
1181
- if function_name:
1182
- # Add prefix for Gemini 3
1183
- if is_gemini_3 and self._enable_gemini3_tool_fix:
1184
- function_name = f"{self._gemini3_tool_prefix}{function_name}"
1185
-
1186
- # Wrap the tool response in a 'result' object
1187
- response_content = {"result": content}
1188
- # Accumulate tool responses - they'll be combined into one user message
1189
- pending_tool_parts.append(
1190
- {
1191
- "functionResponse": {
1192
- "name": function_name,
1193
- "response": response_content,
1194
- "id": tool_call_id,
1195
- }
1196
- }
 
1197
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1198
  # Don't add parts here - tool responses are handled via pending_tool_parts
1199
  continue
1200
 
@@ -1210,6 +980,216 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
1210
 
1211
  return system_instruction, gemini_contents
1212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1213
  def _handle_reasoning_parameters(
1214
  self, payload: Dict[str, Any], model: str
1215
  ) -> Optional[Dict[str, Any]]:
@@ -1329,13 +1309,24 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
1329
  # Get current tool index from accumulator (default 0) and increment
1330
  current_tool_idx = accumulator.get("tool_idx", 0) if accumulator else 0
1331
 
 
 
 
 
 
 
 
 
 
 
 
1332
  tool_call = {
1333
  "index": current_tool_idx,
1334
  "id": tool_call_id,
1335
  "type": "function",
1336
  "function": {
1337
  "name": function_name,
1338
- "arguments": json.dumps(function_call.get("args", {})),
1339
  },
1340
  }
1341
 
@@ -1643,13 +1634,32 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
1643
  schema = self._gemini_cli_transform_schema(
1644
  new_function["parameters"]
1645
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
1646
  new_function["parametersJsonSchema"] = schema
1647
  del new_function["parameters"]
1648
  elif "parametersJsonSchema" not in new_function:
1649
- # Set default empty schema if neither exists
1650
  new_function["parametersJsonSchema"] = {
1651
  "type": "object",
1652
- "properties": {},
 
 
 
 
 
 
1653
  }
1654
 
1655
  # Gemini 3 specific transformations
@@ -1889,6 +1899,9 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
1889
  system_instruction, contents = self._transform_messages(
1890
  kwargs.get("messages", []), model_name
1891
  )
 
 
 
1892
  request_payload = {
1893
  "model": model_name,
1894
  "project": project_id,
@@ -1965,7 +1978,7 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
1965
  headers=final_headers,
1966
  json=request_payload,
1967
  params={"alt": "sse"},
1968
- timeout=600,
1969
  ) as response:
1970
  # Read and log error body before raise_for_status for better debugging
1971
  if response.status_code >= 400:
@@ -2176,6 +2189,8 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
2176
 
2177
  # Transform messages to Gemini format
2178
  system_instruction, contents = self._transform_messages(messages)
 
 
2179
 
2180
  # Build request payload
2181
  request_payload = {
 
11
  from .gemini_auth_base import GeminiAuthBase
12
  from .provider_cache import ProviderCache
13
  from ..model_definitions import ModelDefinitions
14
+ from ..timeout_config import TimeoutConfig
15
+ from ..utils.paths import get_logs_dir, get_cache_dir
16
  import litellm
17
  from litellm.exceptions import RateLimitError
18
  from ..error_handler import extract_retry_after_from_body
 
23
 
24
  lib_logger = logging.getLogger("rotator_library")
25
 
26
+
27
+ def _get_gemini_cli_logs_dir() -> Path:
28
+ """Get the Gemini CLI logs directory."""
29
+ logs_dir = get_logs_dir() / "gemini_cli_logs"
30
+ logs_dir.mkdir(parents=True, exist_ok=True)
31
+ return logs_dir
32
+
33
+
34
+ def _get_gemini_cli_cache_dir() -> Path:
35
+ """Get the Gemini CLI cache directory."""
36
+ return get_cache_dir(subdir="gemini_cli")
37
+
38
+
39
+ def _get_gemini3_signature_cache_file() -> Path:
40
+ """Get the Gemini 3 signature cache file path."""
41
+ return _get_gemini_cli_cache_dir() / "gemini3_signatures.json"
42
 
43
 
44
  class _GeminiCliFileLogger:
 
54
  # Sanitize model name for directory
55
  safe_model_name = model_name.replace("/", "_").replace(":", "_")
56
  self.log_dir = (
57
+ _get_gemini_cli_logs_dir() / f"{timestamp}_{safe_model_name}_{request_id}"
58
  )
59
  try:
60
  self.log_dir.mkdir(parents=True, exist_ok=True)
 
118
  "gemini-3-pro-preview",
119
  ]
120
 
 
 
 
 
 
 
121
  # Gemini 3 tool fix system instruction (prevents hallucination)
122
  DEFAULT_GEMINI3_SYSTEM_INSTRUCTION = """<CRITICAL_TOOL_USAGE_INSTRUCTIONS>
123
  You are operating in a CUSTOM ENVIRONMENT where tool definitions COMPLETELY DIFFER from your training data.
 
183
  }
184
 
185
 
186
+ def _recursively_parse_json_strings(obj: Any) -> Any:
187
+ """
188
+ Recursively parse JSON strings in nested data structures.
189
+
190
+ Gemini sometimes returns tool arguments with JSON-stringified values:
191
+ {"files": "[{...}]"} instead of {"files": [{...}]}.
192
+
193
+ Additionally handles:
194
+ - Malformed double-encoded JSON (extra trailing '}' or ']')
195
+ - Escaped string content (\n, \t, etc.)
196
+ """
197
+ if isinstance(obj, dict):
198
+ return {k: _recursively_parse_json_strings(v) for k, v in obj.items()}
199
+ elif isinstance(obj, list):
200
+ return [_recursively_parse_json_strings(item) for item in obj]
201
+ elif isinstance(obj, str):
202
+ stripped = obj.strip()
203
+
204
+ # Check if string contains control character escape sequences that need unescaping
205
+ # This handles cases where diff content has literal \n or \t instead of actual newlines/tabs
206
+ #
207
+ # IMPORTANT: We intentionally do NOT unescape strings containing \" or \\
208
+ # because these are typically intentional escapes in code/config content
209
+ # (e.g., JSON embedded in YAML: BOT_NAMES_JSON: '["mirrobot", ...]')
210
+ # Unescaping these would corrupt the content and cause issues like
211
+ # oldString and newString becoming identical when they should differ.
212
+ has_control_char_escapes = "\\n" in obj or "\\t" in obj
213
+ has_intentional_escapes = '\\"' in obj or "\\\\" in obj
214
+
215
+ if has_control_char_escapes and not has_intentional_escapes:
216
+ try:
217
+ # Use json.loads with quotes to properly unescape the string
218
+ # This converts \n -> newline, \t -> tab
219
+ unescaped = json.loads(f'"{obj}"')
220
+ # Log the fix with a snippet for debugging
221
+ snippet = obj[:80] + "..." if len(obj) > 80 else obj
222
+ lib_logger.debug(
223
+ f"[GeminiCli] Unescaped control chars in string: "
224
+ f"{len(obj) - len(unescaped)} chars changed. Snippet: {snippet!r}"
225
+ )
226
+ return unescaped
227
+ except (json.JSONDecodeError, ValueError):
228
+ # If unescaping fails, continue with original processing
229
+ pass
230
+
231
+ # Check if it looks like JSON (starts with { or [)
232
+ if stripped and stripped[0] in ("{", "["):
233
+ # Try standard parsing first
234
+ if (stripped.startswith("{") and stripped.endswith("}")) or (
235
+ stripped.startswith("[") and stripped.endswith("]")
236
+ ):
237
+ try:
238
+ parsed = json.loads(obj)
239
+ return _recursively_parse_json_strings(parsed)
240
+ except (json.JSONDecodeError, ValueError):
241
+ pass
242
+
243
+ # Handle malformed JSON: array that doesn't end with ]
244
+ # e.g., '[{"path": "..."}]}' instead of '[{"path": "..."}]'
245
+ if stripped.startswith("[") and not stripped.endswith("]"):
246
+ try:
247
+ # Find the last ] and truncate there
248
+ last_bracket = stripped.rfind("]")
249
+ if last_bracket > 0:
250
+ cleaned = stripped[: last_bracket + 1]
251
+ parsed = json.loads(cleaned)
252
+ lib_logger.warning(
253
+ f"[GeminiCli] Auto-corrected malformed JSON string: "
254
+ f"truncated {len(stripped) - len(cleaned)} extra chars"
255
+ )
256
+ return _recursively_parse_json_strings(parsed)
257
+ except (json.JSONDecodeError, ValueError):
258
+ pass
259
+
260
+ # Handle malformed JSON: object that doesn't end with }
261
+ if stripped.startswith("{") and not stripped.endswith("}"):
262
+ try:
263
+ # Find the last } and truncate there
264
+ last_brace = stripped.rfind("}")
265
+ if last_brace > 0:
266
+ cleaned = stripped[: last_brace + 1]
267
+ parsed = json.loads(cleaned)
268
+ lib_logger.warning(
269
+ f"[GeminiCli] Auto-corrected malformed JSON string: "
270
+ f"truncated {len(stripped) - len(cleaned)} extra chars"
271
+ )
272
+ return _recursively_parse_json_strings(parsed)
273
+ except (json.JSONDecodeError, ValueError):
274
+ pass
275
+ return obj
276
+
277
+
278
  def _env_bool(key: str, default: bool = False) -> bool:
279
  """Get boolean from environment variable."""
280
  return os.getenv(key, str(default).lower()).lower() in ("true", "1", "yes")
 
288
  class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
289
  skip_cost_calculation = True
290
 
291
+ # Sequential mode - stick with one credential until it gets a 429, then switch
292
+ default_rotation_mode: str = "sequential"
293
 
294
  # =========================================================================
295
  # TIER CONFIGURATION
 
336
  error: Exception, error_body: Optional[str] = None
337
  ) -> Optional[Dict[str, Any]]:
338
  """
339
+ Parse Gemini CLI rate limit/quota errors.
340
+
341
+ Handles the Gemini CLI error format which embeds reset time in the message:
342
+ "You have exhausted your capacity on this model. Your quota will reset after 2s."
343
+
344
+ Unlike Antigravity which uses structured RetryInfo/quotaResetDelay metadata,
345
+ Gemini CLI embeds the reset time in a human-readable message.
346
+
347
+ Example error format:
348
+ {
349
+ "error": {
350
+ "code": 429,
351
+ "message": "You have exhausted your capacity on this model. Your quota will reset after 2s.",
352
+ "status": "RESOURCE_EXHAUSTED",
353
+ "details": [
354
+ {
355
+ "@type": "type.googleapis.com/google.rpc.ErrorInfo",
356
+ "reason": "RATE_LIMIT_EXCEEDED",
357
+ "domain": "cloudcode-pa.googleapis.com",
358
+ "metadata": { "uiMessage": "true", "model": "gemini-3-pro-preview" }
359
+ }
360
+ ]
361
+ }
362
+ }
363
 
364
  Args:
365
  error: The caught exception
366
  error_body: Optional raw response body string
367
 
368
  Returns:
369
+ None if not a parseable quota error, otherwise:
370
+ {
371
+ "retry_after": int,
372
+ "reason": str | None,
373
+ "reset_timestamp": str | None,
374
+ "quota_reset_timestamp": float | None,
375
+ }
376
  """
377
+ import re as regex_module
378
+
379
+ # Get error body from exception if not provided
380
+ body = error_body
381
+ if not body:
382
+ if hasattr(error, "response") and hasattr(error.response, "text"):
383
+ try:
384
+ body = error.response.text
385
+ except Exception:
386
+ pass
387
+ if not body and hasattr(error, "body"):
388
+ body = str(error.body)
389
+ if not body and hasattr(error, "message"):
390
+ body = str(error.message)
391
+ if not body:
392
+ body = str(error)
393
+
394
+ if not body:
395
+ return None
396
+
397
+ result = {
398
+ "retry_after": None,
399
+ "reason": None,
400
+ "reset_timestamp": None,
401
+ "quota_reset_timestamp": None,
402
+ }
403
+
404
+ # 1. Try to extract retry time from human-readable message
405
+ # Pattern: "Your quota will reset after 2s." or "quota will reset after 156h14m36s"
406
+ retry_after = extract_retry_after_from_body(body)
407
+ if retry_after:
408
+ result["retry_after"] = retry_after
409
+
410
+ # 2. Try to parse JSON to get structured details (reason, any RetryInfo fallback)
411
+ try:
412
+ json_match = regex_module.search(r"\{[\s\S]*\}", body)
413
+ if json_match:
414
+ data = json.loads(json_match.group(0))
415
+ error_obj = data.get("error", data)
416
+ details = error_obj.get("details", [])
417
+
418
+ for detail in details:
419
+ detail_type = detail.get("@type", "")
420
+
421
+ # Extract reason from ErrorInfo
422
+ if "ErrorInfo" in detail_type:
423
+ if not result["reason"]:
424
+ result["reason"] = detail.get("reason")
425
+ # Check metadata for any additional timing info
426
+ metadata = detail.get("metadata", {})
427
+ quota_delay = metadata.get("quotaResetDelay")
428
+ if quota_delay and not result["retry_after"]:
429
+ parsed = GeminiCliProvider._parse_duration(quota_delay)
430
+ if parsed:
431
+ result["retry_after"] = parsed
432
+
433
+ # Check for RetryInfo (fallback, in case format changes)
434
+ if "RetryInfo" in detail_type and not result["retry_after"]:
435
+ retry_delay = detail.get("retryDelay")
436
+ if retry_delay:
437
+ parsed = GeminiCliProvider._parse_duration(retry_delay)
438
+ if parsed:
439
+ result["retry_after"] = parsed
440
+
441
+ except (json.JSONDecodeError, AttributeError, TypeError):
442
+ pass
443
 
444
+ # Return None if we couldn't extract retry_after
445
+ if not result["retry_after"]:
446
+ return None
447
+
448
+ return result
449
+
450
+ @staticmethod
451
+ def _parse_duration(duration_str: str) -> Optional[int]:
452
+ """
453
+ Parse duration strings like '2s', '156h14m36.73s', '515092.73s' to seconds.
454
+
455
+ Args:
456
+ duration_str: Duration string to parse
457
+
458
+ Returns:
459
+ Total seconds as integer, or None if parsing fails
460
+ """
461
+ import re as regex_module
462
+
463
+ if not duration_str:
464
+ return None
465
+
466
+ # Handle pure seconds format: "515092.730699158s" or "2s"
467
+ pure_seconds_match = regex_module.match(r"^([\d.]+)s$", duration_str)
468
+ if pure_seconds_match:
469
+ return int(float(pure_seconds_match.group(1)))
470
+
471
+ # Handle compound format: "143h4m52.730699158s"
472
+ total_seconds = 0
473
+ patterns = [
474
+ (r"(\d+)h", 3600), # hours
475
+ (r"(\d+)m", 60), # minutes
476
+ (r"([\d.]+)s", 1), # seconds
477
+ ]
478
+ for pattern, multiplier in patterns:
479
+ match = regex_module.search(pattern, duration_str)
480
+ if match:
481
+ total_seconds += float(match.group(1)) * multiplier
482
+
483
+ return int(total_seconds) if total_seconds > 0 else None
484
 
485
  def __init__(self):
486
  super().__init__()
487
  self.model_definitions = ModelDefinitions()
488
+ # NOTE: project_id_cache and project_tier_cache are inherited from GeminiAuthBase
 
 
 
 
 
489
 
490
  # Gemini 3 configuration from environment
491
  memory_ttl = _env_int("GEMINI_CLI_SIGNATURE_CACHE_TTL", 3600)
 
493
 
494
  # Initialize signature cache for Gemini 3 thoughtSignatures
495
  self._signature_cache = ProviderCache(
496
+ _get_gemini3_signature_cache_file(),
497
  memory_ttl,
498
  disk_ttl,
499
  env_prefix="GEMINI_CLI_SIGNATURE",
 
607
 
608
  # Gemini 3 requires paid tier
609
  if model_name.startswith("gemini-3-"):
610
+ return 2 # Only priority 2 (paid) credentials
611
 
612
  return None # All other models have no restrictions
613
 
 
617
 
618
  This ensures all credential priorities are known before any API calls,
619
  preventing unknown credentials from getting priority 999.
620
+
621
+ For credentials without persisted tier info (new or corrupted), performs
622
+ full discovery to ensure proper prioritization in sequential rotation mode.
623
  """
624
+ # Step 1: Load persisted tiers from files
625
  await self._load_persisted_tiers(credential_paths)
626
 
627
+ # Step 2: Identify credentials still missing tier info
628
+ credentials_needing_discovery = [
629
+ path
630
+ for path in credential_paths
631
+ if path not in self.project_tier_cache
632
+ and self._parse_env_credential_path(path) is None # Skip env:// paths
633
+ ]
634
+
635
+ if not credentials_needing_discovery:
636
+ return # All credentials have tier info
637
+
638
+ lib_logger.info(
639
+ f"GeminiCli: Discovering tier info for {len(credentials_needing_discovery)} credential(s)..."
640
+ )
641
+
642
+ # Step 3: Perform discovery for each missing credential (sequential to avoid rate limits)
643
+ for credential_path in credentials_needing_discovery:
644
+ try:
645
+ auth_header = await self.get_auth_header(credential_path)
646
+ access_token = auth_header["Authorization"].split(" ")[1]
647
+ await self._discover_project_id(
648
+ credential_path, access_token, litellm_params={}
649
+ )
650
+ discovered_tier = self.project_tier_cache.get(
651
+ credential_path, "unknown"
652
+ )
653
+ lib_logger.debug(
654
+ f"Discovered tier '{discovered_tier}' for {Path(credential_path).name}"
655
+ )
656
+ except Exception as e:
657
+ lib_logger.warning(
658
+ f"Failed to discover tier for {Path(credential_path).name}: {e}. "
659
+ f"Credential will use default priority."
660
+ )
661
+
662
  async def _load_persisted_tiers(
663
  self, credential_paths: List[str]
664
  ) -> Dict[str, str]:
 
716
 
717
  return loaded
718
 
719
+ # NOTE: _post_auth_discovery() is inherited from GeminiAuthBase
720
+
721
  # =========================================================================
722
  # MODEL UTILITIES
723
  # =========================================================================
 
733
  return name[len(self._gemini3_tool_prefix) :]
734
  return name
735
 
736
+ # NOTE: _discover_project_id() and _persist_project_metadata() are inherited from GeminiAuthBase
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
737
 
738
  def _check_mixed_tier_warning(self):
739
  """Check if mixed free/paid tier credentials are loaded and emit warning."""
 
920
  func_part["thoughtSignature"] = (
921
  "skip_thought_signature_validator"
922
  )
923
+ lib_logger.debug(
924
  f"Missing thoughtSignature for first func call {tool_id}, using bypass"
925
  )
926
  # Subsequent parallel calls: no signature field at all
 
932
  elif role == "tool":
933
  tool_call_id = msg.get("tool_call_id")
934
  function_name = tool_call_id_to_name.get(tool_call_id)
935
+
936
+ # Log warning if tool_call_id not found in mapping (can happen after context compaction)
937
+ if not function_name:
938
+ lib_logger.warning(
939
+ f"[ID Mismatch] Tool response has ID '{tool_call_id}' which was not found in tool_id_to_name map. "
940
+ f"Available IDs: {list(tool_call_id_to_name.keys())}. Using 'unknown_function' as fallback."
941
+ )
942
+ function_name = "unknown_function"
943
+
944
+ # Add prefix for Gemini 3
945
+ if is_gemini_3 and self._enable_gemini3_tool_fix:
946
+ function_name = f"{self._gemini3_tool_prefix}{function_name}"
947
+
948
+ # Try to parse content as JSON first, fall back to string
949
+ try:
950
+ parsed_content = (
951
+ json.loads(content) if isinstance(content, str) else content
952
  )
953
+ except (json.JSONDecodeError, TypeError):
954
+ parsed_content = content
955
+
956
+ # Wrap the tool response in a 'result' object
957
+ response_content = {"result": parsed_content}
958
+ # Accumulate tool responses - they'll be combined into one user message
959
+ pending_tool_parts.append(
960
+ {
961
+ "functionResponse": {
962
+ "name": function_name,
963
+ "response": response_content,
964
+ "id": tool_call_id,
965
+ }
966
+ }
967
+ )
968
  # Don't add parts here - tool responses are handled via pending_tool_parts
969
  continue
970
 
 
980
 
981
  return system_instruction, gemini_contents
982
 
983
+ def _fix_tool_response_grouping(
984
+ self, contents: List[Dict[str, Any]]
985
+ ) -> List[Dict[str, Any]]:
986
+ """
987
+ Group function calls with their responses for Gemini CLI compatibility.
988
+
989
+ Converts linear format (call, response, call, response)
990
+ to grouped format (model with calls, user with all responses).
991
+
992
+ IMPORTANT: Preserves ID-based pairing to prevent mismatches.
993
+ When IDs don't match, attempts recovery by:
994
+ 1. Matching by function name first
995
+ 2. Matching by order if names don't match
996
+ 3. Inserting placeholder responses if responses are missing
997
+ 4. Inserting responses at the CORRECT position (after their corresponding call)
998
+ """
999
+ new_contents = []
1000
+ # Each pending group tracks:
1001
+ # - ids: expected response IDs
1002
+ # - func_names: expected function names (for orphan matching)
1003
+ # - insert_after_idx: position in new_contents where model message was added
1004
+ pending_groups = []
1005
+ collected_responses = {} # Dict mapping ID -> response_part
1006
+
1007
+ for content in contents:
1008
+ role = content.get("role")
1009
+ parts = content.get("parts", [])
1010
+
1011
+ response_parts = [p for p in parts if "functionResponse" in p]
1012
+
1013
+ if response_parts:
1014
+ # Collect responses by ID (ignore duplicates - keep first occurrence)
1015
+ for resp in response_parts:
1016
+ resp_id = resp.get("functionResponse", {}).get("id", "")
1017
+ if resp_id:
1018
+ if resp_id in collected_responses:
1019
+ lib_logger.warning(
1020
+ f"[Grouping] Duplicate response ID detected: {resp_id}. "
1021
+ f"Ignoring duplicate - this may indicate malformed conversation history."
1022
+ )
1023
+ continue
1024
+ collected_responses[resp_id] = resp
1025
+
1026
+ # Try to satisfy pending groups (newest first)
1027
+ for i in range(len(pending_groups) - 1, -1, -1):
1028
+ group = pending_groups[i]
1029
+ group_ids = group["ids"]
1030
+
1031
+ # Check if we have ALL responses for this group
1032
+ if all(gid in collected_responses for gid in group_ids):
1033
+ # Extract responses in the same order as the function calls
1034
+ group_responses = [
1035
+ collected_responses.pop(gid) for gid in group_ids
1036
+ ]
1037
+ new_contents.append({"parts": group_responses, "role": "user"})
1038
+ pending_groups.pop(i)
1039
+ break
1040
+ continue
1041
+
1042
+ if role == "model":
1043
+ func_calls = [p for p in parts if "functionCall" in p]
1044
+ new_contents.append(content)
1045
+ if func_calls:
1046
+ call_ids = [
1047
+ fc.get("functionCall", {}).get("id", "") for fc in func_calls
1048
+ ]
1049
+ call_ids = [cid for cid in call_ids if cid] # Filter empty IDs
1050
+
1051
+ # Also extract function names for orphan matching
1052
+ func_names = [
1053
+ fc.get("functionCall", {}).get("name", "") for fc in func_calls
1054
+ ]
1055
+
1056
+ if call_ids:
1057
+ pending_groups.append(
1058
+ {
1059
+ "ids": call_ids,
1060
+ "func_names": func_names,
1061
+ "insert_after_idx": len(new_contents) - 1,
1062
+ }
1063
+ )
1064
+ else:
1065
+ new_contents.append(content)
1066
+
1067
+ # Handle remaining groups (shouldn't happen in well-formed conversations)
1068
+ # Attempt recovery by matching orphans to unsatisfied calls
1069
+ # Process in REVERSE order of insert_after_idx so insertions don't shift indices
1070
+ pending_groups.sort(key=lambda g: g["insert_after_idx"], reverse=True)
1071
+
1072
+ for group in pending_groups:
1073
+ group_ids = group["ids"]
1074
+ group_func_names = group.get("func_names", [])
1075
+ insert_idx = group["insert_after_idx"] + 1
1076
+ group_responses = []
1077
+
1078
+ lib_logger.debug(
1079
+ f"[Grouping Recovery] Processing unsatisfied group: "
1080
+ f"ids={group_ids}, names={group_func_names}, insert_at={insert_idx}"
1081
+ )
1082
+
1083
+ for i, expected_id in enumerate(group_ids):
1084
+ expected_name = group_func_names[i] if i < len(group_func_names) else ""
1085
+
1086
+ if expected_id in collected_responses:
1087
+ # Direct ID match
1088
+ group_responses.append(collected_responses.pop(expected_id))
1089
+ lib_logger.debug(
1090
+ f"[Grouping Recovery] Direct ID match for '{expected_id}'"
1091
+ )
1092
+ elif collected_responses:
1093
+ # Try to find orphan with matching function name first
1094
+ matched_orphan_id = None
1095
+
1096
+ # First pass: match by function name
1097
+ for orphan_id, orphan_resp in collected_responses.items():
1098
+ orphan_name = orphan_resp.get("functionResponse", {}).get(
1099
+ "name", ""
1100
+ )
1101
+ # Match if names are equal
1102
+ if orphan_name == expected_name:
1103
+ matched_orphan_id = orphan_id
1104
+ lib_logger.debug(
1105
+ f"[Grouping Recovery] Matched orphan '{orphan_id}' by name '{orphan_name}'"
1106
+ )
1107
+ break
1108
+
1109
+ # Second pass: if no name match, try "unknown_function" orphans
1110
+ if not matched_orphan_id:
1111
+ for orphan_id, orphan_resp in collected_responses.items():
1112
+ orphan_name = orphan_resp.get("functionResponse", {}).get(
1113
+ "name", ""
1114
+ )
1115
+ if orphan_name == "unknown_function":
1116
+ matched_orphan_id = orphan_id
1117
+ lib_logger.debug(
1118
+ f"[Grouping Recovery] Matched unknown_function orphan '{orphan_id}' "
1119
+ f"to expected '{expected_name}'"
1120
+ )
1121
+ break
1122
+
1123
+ # Third pass: if still no match, take first available (order-based)
1124
+ if not matched_orphan_id:
1125
+ matched_orphan_id = next(iter(collected_responses))
1126
+ lib_logger.debug(
1127
+ f"[Grouping Recovery] No name match, using first available orphan '{matched_orphan_id}'"
1128
+ )
1129
+
1130
+ if matched_orphan_id:
1131
+ orphan_resp = collected_responses.pop(matched_orphan_id)
1132
+
1133
+ # Fix the ID in the response to match the call
1134
+ old_id = orphan_resp["functionResponse"].get("id", "")
1135
+ orphan_resp["functionResponse"]["id"] = expected_id
1136
+
1137
+ # Fix the name if it was "unknown_function"
1138
+ if (
1139
+ orphan_resp["functionResponse"].get("name")
1140
+ == "unknown_function"
1141
+ and expected_name
1142
+ ):
1143
+ orphan_resp["functionResponse"]["name"] = expected_name
1144
+ lib_logger.info(
1145
+ f"[Grouping Recovery] Fixed function name from 'unknown_function' to '{expected_name}'"
1146
+ )
1147
+
1148
+ lib_logger.warning(
1149
+ f"[Grouping] Auto-repaired ID mismatch: mapped response '{old_id}' "
1150
+ f"to call '{expected_id}' (function: {expected_name})"
1151
+ )
1152
+ group_responses.append(orphan_resp)
1153
+ else:
1154
+ # No responses available - create placeholder
1155
+ placeholder_resp = {
1156
+ "functionResponse": {
1157
+ "name": expected_name or "unknown_function",
1158
+ "response": {
1159
+ "result": {
1160
+ "error": "Tool response was lost during context processing. "
1161
+ "This is a recovered placeholder.",
1162
+ "recovered": True,
1163
+ }
1164
+ },
1165
+ "id": expected_id,
1166
+ }
1167
+ }
1168
+ lib_logger.warning(
1169
+ f"[Grouping Recovery] Created placeholder response for missing tool: "
1170
+ f"id='{expected_id}', name='{expected_name}'"
1171
+ )
1172
+ group_responses.append(placeholder_resp)
1173
+
1174
+ if group_responses:
1175
+ # Insert at the correct position (right after the model message with the calls)
1176
+ new_contents.insert(
1177
+ insert_idx, {"parts": group_responses, "role": "user"}
1178
+ )
1179
+ lib_logger.info(
1180
+ f"[Grouping Recovery] Inserted {len(group_responses)} responses at position {insert_idx} "
1181
+ f"(expected {len(group_ids)})"
1182
+ )
1183
+
1184
+ # Warn about unmatched responses
1185
+ if collected_responses:
1186
+ lib_logger.warning(
1187
+ f"[Grouping] {len(collected_responses)} unmatched responses remaining: "
1188
+ f"ids={list(collected_responses.keys())}"
1189
+ )
1190
+
1191
+ return new_contents
1192
+
1193
  def _handle_reasoning_parameters(
1194
  self, payload: Dict[str, Any], model: str
1195
  ) -> Optional[Dict[str, Any]]:
 
1309
  # Get current tool index from accumulator (default 0) and increment
1310
  current_tool_idx = accumulator.get("tool_idx", 0) if accumulator else 0
1311
 
1312
+ # Get args, recursively parse any JSON strings, and strip _confirm if sole param
1313
+ raw_args = function_call.get("args", {})
1314
+ tool_args = _recursively_parse_json_strings(raw_args)
1315
+
1316
+ # Strip _confirm ONLY if it's the sole parameter
1317
+ # This ensures we only strip our injection, not legitimate user params
1318
+ if isinstance(tool_args, dict) and "_confirm" in tool_args:
1319
+ if len(tool_args) == 1:
1320
+ # _confirm is the only param - this was our injection
1321
+ tool_args.pop("_confirm")
1322
+
1323
  tool_call = {
1324
  "index": current_tool_idx,
1325
  "id": tool_call_id,
1326
  "type": "function",
1327
  "function": {
1328
  "name": function_name,
1329
+ "arguments": json.dumps(tool_args),
1330
  },
1331
  }
1332
 
 
1634
  schema = self._gemini_cli_transform_schema(
1635
  new_function["parameters"]
1636
  )
1637
+ # Workaround: Gemini fails to emit functionCall for tools
1638
+ # with empty properties {}. Inject a required confirmation param.
1639
+ # Using a required parameter forces the model to commit to
1640
+ # the tool call rather than just thinking about it.
1641
+ props = schema.get("properties", {})
1642
+ if not props:
1643
+ schema["properties"] = {
1644
+ "_confirm": {
1645
+ "type": "string",
1646
+ "description": "Enter 'yes' to proceed",
1647
+ }
1648
+ }
1649
+ schema["required"] = ["_confirm"]
1650
  new_function["parametersJsonSchema"] = schema
1651
  del new_function["parameters"]
1652
  elif "parametersJsonSchema" not in new_function:
1653
+ # Set default schema with required confirm param if neither exists
1654
  new_function["parametersJsonSchema"] = {
1655
  "type": "object",
1656
+ "properties": {
1657
+ "_confirm": {
1658
+ "type": "string",
1659
+ "description": "Enter 'yes' to proceed",
1660
+ }
1661
+ },
1662
+ "required": ["_confirm"],
1663
  }
1664
 
1665
  # Gemini 3 specific transformations
 
1899
  system_instruction, contents = self._transform_messages(
1900
  kwargs.get("messages", []), model_name
1901
  )
1902
+ # Fix tool response grouping (handles ID mismatches, missing responses)
1903
+ contents = self._fix_tool_response_grouping(contents)
1904
+
1905
  request_payload = {
1906
  "model": model_name,
1907
  "project": project_id,
 
1978
  headers=final_headers,
1979
  json=request_payload,
1980
  params={"alt": "sse"},
1981
+ timeout=TimeoutConfig.streaming(),
1982
  ) as response:
1983
  # Read and log error body before raise_for_status for better debugging
1984
  if response.status_code >= 400:
 
2189
 
2190
  # Transform messages to Gemini format
2191
  system_instruction, contents = self._transform_messages(messages)
2192
+ # Fix tool response grouping (handles ID mismatches, missing responses)
2193
+ contents = self._fix_tool_response_grouping(contents)
2194
 
2195
  # Build request payload
2196
  request_payload = {
src/rotator_library/providers/google_oauth_base.py CHANGED
@@ -1,16 +1,17 @@
1
  # src/rotator_library/providers/google_oauth_base.py
2
 
3
  import os
 
4
  import webbrowser
5
- from typing import Union, Optional
 
6
  import json
7
  import time
8
  import asyncio
9
  import logging
10
  from pathlib import Path
11
  from typing import Dict, Any
12
- import tempfile
13
- import shutil
14
 
15
  import httpx
16
  from rich.console import Console
@@ -20,12 +21,31 @@ from rich.markup import escape as rich_escape
20
 
21
  from ..utils.headless_detection import is_headless_environment
22
  from ..utils.reauth_coordinator import get_reauth_coordinator
 
23
 
24
  lib_logger = logging.getLogger("rotator_library")
25
 
26
  console = Console()
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  class GoogleOAuthBase:
30
  """
31
  Base class for Google OAuth2 authentication providers.
@@ -55,6 +75,25 @@ class GoogleOAuthBase:
55
  CALLBACK_PATH: str = "/oauth2callback"
56
  REFRESH_EXPIRY_BUFFER_SECONDS: int = 30 * 60 # 30 minutes
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def __init__(self):
59
  # Validate that subclass has set required attributes
60
  if self.CLIENT_ID is None:
@@ -83,19 +122,36 @@ class GoogleOAuthBase:
83
  str, float
84
  ] = {} # Track backoff timers (Unix timestamp)
85
 
86
- # [QUEUE SYSTEM] Sequential refresh processing
 
87
  self._refresh_queue: asyncio.Queue = asyncio.Queue()
88
- self._queued_credentials: set = set() # Track credentials already in queue
89
- # [FIX PR#34] Changed from set to dict mapping credential path to timestamp
90
- # This enables TTL-based stale entry cleanup as defense in depth
 
 
 
 
 
 
 
91
  self._unavailable_credentials: Dict[
92
  str, float
93
  ] = {} # Maps credential path -> timestamp when marked unavailable
94
- self._unavailable_ttl_seconds: int = 300 # 5 minutes TTL for stale entries
 
95
  self._queue_tracking_lock = asyncio.Lock() # Protects queue sets
96
- self._queue_processor_task: Optional[asyncio.Task] = (
97
- None # Background worker task
98
- )
 
 
 
 
 
 
 
 
99
 
100
  def _parse_env_credential_path(self, path: str) -> Optional[str]:
101
  """
@@ -228,17 +284,7 @@ class GoogleOAuthBase:
228
  f"Environment variables for {self.ENV_PREFIX} credential index {credential_index} not found"
229
  )
230
 
231
- # For file paths, first try loading from legacy env vars (for backwards compatibility)
232
- env_creds = self._load_from_env()
233
- if env_creds:
234
- lib_logger.info(
235
- f"Using {self.ENV_PREFIX} credentials from environment variables"
236
- )
237
- # Cache env-based credentials using the path as key
238
- self._credentials_cache[path] = env_creds
239
- return env_creds
240
-
241
- # Fall back to file-based loading
242
  try:
243
  lib_logger.debug(
244
  f"Loading {self.ENV_PREFIX} credentials from file: {path}"
@@ -251,6 +297,15 @@ class GoogleOAuthBase:
251
  self._credentials_cache[path] = creds
252
  return creds
253
  except FileNotFoundError:
 
 
 
 
 
 
 
 
 
254
  raise IOError(
255
  f"{self.ENV_PREFIX} OAuth credential file not found at '{path}'"
256
  )
@@ -258,70 +313,29 @@ class GoogleOAuthBase:
258
  raise IOError(
259
  f"Failed to load {self.ENV_PREFIX} OAuth credentials from '{path}': {e}"
260
  )
261
- except Exception as e:
262
- raise IOError(
263
- f"Failed to load {self.ENV_PREFIX} OAuth credentials from '{path}': {e}"
264
- )
265
 
266
  async def _save_credentials(self, path: str, creds: Dict[str, Any]):
 
 
 
 
267
  # Don't save to file if credentials were loaded from environment
268
  if creds.get("_proxy_metadata", {}).get("loaded_from_env"):
269
  lib_logger.debug("Credentials loaded from env, skipping file save")
270
- # Still update cache for in-memory consistency
271
- self._credentials_cache[path] = creds
272
  return
273
 
274
- # [ATOMIC WRITE] Use tempfile + move pattern to ensure atomic writes
275
- # This prevents credential corruption if the process is interrupted during write
276
- parent_dir = os.path.dirname(os.path.abspath(path))
277
- os.makedirs(parent_dir, exist_ok=True)
278
-
279
- tmp_fd = None
280
- tmp_path = None
281
- try:
282
- # Create temp file in same directory as target (ensures same filesystem)
283
- tmp_fd, tmp_path = tempfile.mkstemp(
284
- dir=parent_dir, prefix=".tmp_", suffix=".json", text=True
285
- )
286
-
287
- # Write JSON to temp file
288
- with os.fdopen(tmp_fd, "w") as f:
289
- json.dump(creds, f, indent=2)
290
- tmp_fd = None # fdopen closes the fd
291
-
292
- # Set secure permissions (0600 = owner read/write only)
293
- try:
294
- os.chmod(tmp_path, 0o600)
295
- except (OSError, AttributeError):
296
- # Windows may not support chmod, ignore
297
- pass
298
-
299
- # Atomic move (overwrites target if it exists)
300
- shutil.move(tmp_path, path)
301
- tmp_path = None # Successfully moved
302
-
303
- # Update cache AFTER successful file write (prevents cache/file inconsistency)
304
- self._credentials_cache[path] = creds
305
  lib_logger.debug(
306
- f"Saved updated {self.ENV_PREFIX} OAuth credentials to '{path}' (atomic write)."
307
  )
308
-
309
- except Exception as e:
310
- lib_logger.error(
311
- f"Failed to save updated {self.ENV_PREFIX} OAuth credentials to '{path}': {e}"
312
  )
313
- # Clean up temp file if it still exists
314
- if tmp_fd is not None:
315
- try:
316
- os.close(tmp_fd)
317
- except:
318
- pass
319
- if tmp_path and os.path.exists(tmp_path):
320
- try:
321
- os.unlink(tmp_path)
322
- except:
323
- pass
324
- raise
325
 
326
  def _is_token_expired(self, creds: Dict[str, Any]) -> bool:
327
  expiry = creds.get("token_expiry") # gcloud format
@@ -518,7 +532,7 @@ class GoogleOAuthBase:
518
  """Proactively refresh a credential by queueing it for refresh."""
519
  creds = await self._load_credentials(credential_path)
520
  if self._is_token_expired(creds):
521
- # Queue for refresh with needs_reauth=False (automated refresh)
522
  await self._queue_refresh(credential_path, force=False, needs_reauth=False)
523
 
524
  async def _get_lock(self, path: str) -> asyncio.Lock:
@@ -529,34 +543,69 @@ class GoogleOAuthBase:
529
  self._refresh_locks[path] = asyncio.Lock()
530
  return self._refresh_locks[path]
531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
  def is_credential_available(self, path: str) -> bool:
533
- """Check if a credential is available for rotation (not queued/refreshing).
 
 
 
 
534
 
535
- [FIX PR#34] Now includes TTL-based stale entry cleanup as defense in depth.
536
- If a credential has been unavailable for longer than _unavailable_ttl_seconds,
537
- it is automatically cleaned up and considered available.
 
 
 
 
538
  """
539
- if path not in self._unavailable_credentials:
540
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
541
 
542
- # [FIX PR#34] Check if the entry is stale (TTL expired)
543
- marked_time = self._unavailable_credentials.get(path)
544
- if marked_time is not None:
545
- now = time.time()
546
- if now - marked_time > self._unavailable_ttl_seconds:
547
- # Entry is stale - clean it up and return available
548
- lib_logger.warning(
549
- f"Credential '{Path(path).name}' was stuck in unavailable state for "
550
- f"{int(now - marked_time)}s (TTL: {self._unavailable_ttl_seconds}s). "
551
- f"Auto-cleaning stale entry."
 
552
  )
553
- # Note: This is a sync method, so we can't use async lock here.
554
- # However, pop from dict is thread-safe for single operations.
555
- # The _queue_tracking_lock protects concurrent modifications in async context.
556
- self._unavailable_credentials.pop(path, None)
557
- return True
558
 
559
- return False
560
 
561
  async def _ensure_queue_processor_running(self):
562
  """Lazily starts the queue processor if not already running."""
@@ -565,15 +614,27 @@ class GoogleOAuthBase:
565
  self._process_refresh_queue()
566
  )
567
 
 
 
 
 
 
 
 
568
  async def _queue_refresh(
569
  self, path: str, force: bool = False, needs_reauth: bool = False
570
  ):
571
- """Add a credential to the refresh queue if not already queued.
572
 
573
  Args:
574
  path: Credential file path
575
  force: Force refresh even if not expired
576
- needs_reauth: True if full re-authentication needed (bypasses backoff)
 
 
 
 
 
577
  """
578
  # IMPORTANT: Only check backoff for simple automated refreshes
579
  # Re-authentication (interactive OAuth) should BYPASS backoff since it needs user input
@@ -583,108 +644,226 @@ class GoogleOAuthBase:
583
  backoff_until = self._next_refresh_after[path]
584
  if now < backoff_until:
585
  # Credential is in backoff for automated refresh, do not queue
586
- remaining = int(backoff_until - now)
587
- lib_logger.debug(
588
- f"Skipping automated refresh for '{Path(path).name}' (in backoff for {remaining}s)"
589
- )
590
  return
591
 
592
  async with self._queue_tracking_lock:
593
  if path not in self._queued_credentials:
594
  self._queued_credentials.add(path)
595
- # [FIX PR#34] Store timestamp when marking unavailable (for TTL cleanup)
596
- self._unavailable_credentials[path] = time.time()
597
- lib_logger.debug(
598
- f"Marked '{Path(path).name}' as unavailable. "
599
- f"Total unavailable: {len(self._unavailable_credentials)}"
600
- )
601
- await self._refresh_queue.put((path, force, needs_reauth))
602
- await self._ensure_queue_processor_running()
 
 
 
 
 
 
 
 
 
 
603
 
604
  async def _process_refresh_queue(self):
605
- """Background worker that processes refresh requests sequentially."""
 
 
 
 
 
 
 
 
 
606
  while True:
607
  path = None
608
  try:
609
  # Wait for an item with timeout to allow graceful shutdown
610
  try:
611
- path, force, needs_reauth = await asyncio.wait_for(
612
  self._refresh_queue.get(), timeout=60.0
613
  )
614
  except asyncio.TimeoutError:
615
- # [FIX PR#34] Clean up any stale unavailable entries before exiting
616
- # If we're idle for 60s, no refreshes are in progress
617
  async with self._queue_tracking_lock:
618
- if self._unavailable_credentials:
619
- stale_count = len(self._unavailable_credentials)
620
- lib_logger.warning(
621
- f"Queue processor idle timeout. Cleaning {stale_count} "
622
- f"stale unavailable credentials: {list(self._unavailable_credentials.keys())}"
623
- )
624
- self._unavailable_credentials.clear()
625
  self._queue_processor_task = None
 
626
  return
627
 
628
  try:
629
- # Perform the actual refresh (still using per-credential lock)
630
- async with await self._get_lock(path):
631
- # Re-check if still expired (may have changed since queueing)
632
- creds = self._credentials_cache.get(path)
633
- if creds and not self._is_token_expired(creds):
634
- # No longer expired, mark as available
635
- async with self._queue_tracking_lock:
636
- self._unavailable_credentials.pop(path, None)
637
- lib_logger.debug(
638
- f"Credential '{Path(path).name}' no longer expired, marked available. "
639
- f"Remaining unavailable: {len(self._unavailable_credentials)}"
640
- )
641
- continue
 
 
 
 
 
642
 
643
- # Perform refresh
644
- if not creds:
645
- creds = await self._load_credentials(path)
646
- await self._refresh_token(path, creds, force=force)
647
-
648
- # SUCCESS: Mark as available again
649
- async with self._queue_tracking_lock:
650
- self._unavailable_credentials.pop(path, None)
651
- lib_logger.debug(
652
- f"Refresh SUCCESS for '{Path(path).name}', marked available. "
653
- f"Remaining unavailable: {len(self._unavailable_credentials)}"
 
 
 
 
 
 
654
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655
 
656
  finally:
657
- # [FIX PR#34] Remove from BOTH queued set AND unavailable credentials
658
- # This ensures cleanup happens in ALL exit paths (success, exception, etc.)
659
  async with self._queue_tracking_lock:
660
  self._queued_credentials.discard(path)
661
- # [FIX PR#34] Always clean up unavailable credentials in finally block
662
  self._unavailable_credentials.pop(path, None)
663
- lib_logger.debug(
664
- f"Finally cleanup for '{Path(path).name}'. "
665
- f"Remaining unavailable: {len(self._unavailable_credentials)}"
666
- )
667
- self._refresh_queue.task_done()
 
668
  except asyncio.CancelledError:
669
- # [FIX PR#34] Clean up the current credential before breaking
670
  if path:
671
  async with self._queue_tracking_lock:
 
672
  self._unavailable_credentials.pop(path, None)
673
- lib_logger.debug(
674
- f"CancelledError cleanup for '{Path(path).name}'. "
675
- f"Remaining unavailable: {len(self._unavailable_credentials)}"
676
- )
677
  break
678
  except Exception as e:
679
- lib_logger.error(f"Error in queue processor: {e}")
680
- # Even on error, mark as available (backoff will prevent immediate retry)
681
  if path:
682
  async with self._queue_tracking_lock:
 
683
  self._unavailable_credentials.pop(path, None)
684
- lib_logger.debug(
685
- f"Error cleanup for '{Path(path).name}': {e}. "
686
- f"Remaining unavailable: {len(self._unavailable_credentials)}"
687
- )
688
 
689
  async def _perform_interactive_oauth(
690
  self, path: str, creds: Dict[str, Any], display_name: str
@@ -744,14 +923,14 @@ class GoogleOAuthBase:
744
 
745
  try:
746
  server = await asyncio.start_server(
747
- handle_callback, "127.0.0.1", self.CALLBACK_PORT
748
  )
749
  from urllib.parse import urlencode
750
 
751
  auth_url = "https://accounts.google.com/o/oauth2/v2/auth?" + urlencode(
752
  {
753
  "client_id": self.CLIENT_ID,
754
- "redirect_uri": f"http://localhost:{self.CALLBACK_PORT}{self.CALLBACK_PATH}",
755
  "scope": " ".join(self.OAUTH_SCOPES),
756
  "access_type": "offline",
757
  "response_type": "code",
@@ -826,7 +1005,7 @@ class GoogleOAuthBase:
826
  "code": auth_code.strip(),
827
  "client_id": self.CLIENT_ID,
828
  "client_secret": self.CLIENT_SECRET,
829
- "redirect_uri": f"http://localhost:{self.CALLBACK_PORT}{self.CALLBACK_PATH}",
830
  "grant_type": "authorization_code",
831
  },
832
  )
@@ -864,6 +1043,18 @@ class GoogleOAuthBase:
864
  lib_logger.info(
865
  f"{self.ENV_PREFIX} OAuth initialized successfully for '{display_name}'."
866
  )
 
 
 
 
 
 
 
 
 
 
 
 
867
  return new_creds
868
 
869
  async def initialize_token(
@@ -940,10 +1131,51 @@ class GoogleOAuthBase:
940
  )
941
 
942
  async def get_auth_header(self, credential_path: str) -> Dict[str, str]:
943
- creds = await self._load_credentials(credential_path)
944
- if self._is_token_expired(creds):
945
- creds = await self._refresh_token(credential_path, creds)
946
- return {"Authorization": f"Bearer {creds['access_token']}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
947
 
948
  async def get_user_info(
949
  self, creds_or_path: Union[Dict[str, Any], str]
@@ -976,3 +1208,372 @@ class GoogleOAuthBase:
976
  if path:
977
  await self._save_credentials(path, creds)
978
  return {"email": user_info.get("email")}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # src/rotator_library/providers/google_oauth_base.py
2
 
3
  import os
4
+ import re
5
  import webbrowser
6
+ from dataclasses import dataclass, field
7
+ from typing import Union, Optional, List
8
  import json
9
  import time
10
  import asyncio
11
  import logging
12
  from pathlib import Path
13
  from typing import Dict, Any
14
+ from glob import glob
 
15
 
16
  import httpx
17
  from rich.console import Console
 
21
 
22
  from ..utils.headless_detection import is_headless_environment
23
  from ..utils.reauth_coordinator import get_reauth_coordinator
24
+ from ..utils.resilient_io import safe_write_json
25
 
26
  lib_logger = logging.getLogger("rotator_library")
27
 
28
  console = Console()
29
 
30
 
31
+ @dataclass
32
+ class CredentialSetupResult:
33
+ """
34
+ Standardized result structure for credential setup operations.
35
+
36
+ Used by all auth classes to return consistent setup results to the credential tool.
37
+ """
38
+
39
+ success: bool
40
+ file_path: Optional[str] = None
41
+ email: Optional[str] = None
42
+ tier: Optional[str] = None
43
+ project_id: Optional[str] = None
44
+ is_update: bool = False
45
+ error: Optional[str] = None
46
+ credentials: Optional[Dict[str, Any]] = field(default=None, repr=False)
47
+
48
+
49
  class GoogleOAuthBase:
50
  """
51
  Base class for Google OAuth2 authentication providers.
 
75
  CALLBACK_PATH: str = "/oauth2callback"
76
  REFRESH_EXPIRY_BUFFER_SECONDS: int = 30 * 60 # 30 minutes
77
 
78
+ @property
79
+ def callback_port(self) -> int:
80
+ """
81
+ Get the OAuth callback port, checking environment variable first.
82
+
83
+ Reads from {ENV_PREFIX}_OAUTH_PORT environment variable, falling back
84
+ to the class's CALLBACK_PORT default if not set.
85
+ """
86
+ env_var = f"{self.ENV_PREFIX}_OAUTH_PORT"
87
+ env_value = os.getenv(env_var)
88
+ if env_value:
89
+ try:
90
+ return int(env_value)
91
+ except ValueError:
92
+ lib_logger.warning(
93
+ f"Invalid {env_var} value: {env_value}, using default {self.CALLBACK_PORT}"
94
+ )
95
+ return self.CALLBACK_PORT
96
+
97
  def __init__(self):
98
  # Validate that subclass has set required attributes
99
  if self.CLIENT_ID is None:
 
122
  str, float
123
  ] = {} # Track backoff timers (Unix timestamp)
124
 
125
+ # [QUEUE SYSTEM] Sequential refresh processing with two separate queues
126
+ # Normal refresh queue: for proactive token refresh (old token still valid)
127
  self._refresh_queue: asyncio.Queue = asyncio.Queue()
128
+ self._queue_processor_task: Optional[asyncio.Task] = None
129
+
130
+ # Re-auth queue: for invalid refresh tokens (requires user interaction)
131
+ self._reauth_queue: asyncio.Queue = asyncio.Queue()
132
+ self._reauth_processor_task: Optional[asyncio.Task] = None
133
+
134
+ # Tracking sets/dicts
135
+ self._queued_credentials: set = set() # Track credentials in either queue
136
+ # Only credentials in re-auth queue are marked unavailable (not normal refresh)
137
+ # TTL cleanup is defense-in-depth for edge cases where re-auth processor crashes
138
  self._unavailable_credentials: Dict[
139
  str, float
140
  ] = {} # Maps credential path -> timestamp when marked unavailable
141
+ # TTL should exceed reauth timeout (300s) to avoid premature cleanup
142
+ self._unavailable_ttl_seconds: int = 360 # 6 minutes TTL for stale entries
143
  self._queue_tracking_lock = asyncio.Lock() # Protects queue sets
144
+
145
+ # Retry tracking for normal refresh queue
146
+ self._queue_retry_count: Dict[
147
+ str, int
148
+ ] = {} # Track retry attempts per credential
149
+
150
+ # Configuration constants
151
+ self._refresh_timeout_seconds: int = 15 # Max time for single refresh
152
+ self._refresh_interval_seconds: int = 30 # Delay between queue items
153
+ self._refresh_max_retries: int = 3 # Attempts before kicked out
154
+ self._reauth_timeout_seconds: int = 300 # Time for user to complete OAuth
155
 
156
  def _parse_env_credential_path(self, path: str) -> Optional[str]:
157
  """
 
284
  f"Environment variables for {self.ENV_PREFIX} credential index {credential_index} not found"
285
  )
286
 
287
+ # Try file-based loading first (preferred for explicit file paths)
 
 
 
 
 
 
 
 
 
 
288
  try:
289
  lib_logger.debug(
290
  f"Loading {self.ENV_PREFIX} credentials from file: {path}"
 
297
  self._credentials_cache[path] = creds
298
  return creds
299
  except FileNotFoundError:
300
+ # File not found - fall back to legacy env vars for backwards compatibility
301
+ # This handles the case where only env vars are set and file paths are placeholders
302
+ env_creds = self._load_from_env()
303
+ if env_creds:
304
+ lib_logger.info(
305
+ f"File '{path}' not found, using {self.ENV_PREFIX} credentials from environment variables"
306
+ )
307
+ self._credentials_cache[path] = env_creds
308
+ return env_creds
309
  raise IOError(
310
  f"{self.ENV_PREFIX} OAuth credential file not found at '{path}'"
311
  )
 
313
  raise IOError(
314
  f"Failed to load {self.ENV_PREFIX} OAuth credentials from '{path}': {e}"
315
  )
 
 
 
 
316
 
317
  async def _save_credentials(self, path: str, creds: Dict[str, Any]):
318
+ """Save credentials with in-memory fallback if disk unavailable."""
319
+ # Always update cache first (memory is reliable)
320
+ self._credentials_cache[path] = creds
321
+
322
  # Don't save to file if credentials were loaded from environment
323
  if creds.get("_proxy_metadata", {}).get("loaded_from_env"):
324
  lib_logger.debug("Credentials loaded from env, skipping file save")
 
 
325
  return
326
 
327
+ # Attempt disk write - if it fails, we still have the cache
328
+ # buffer_on_failure ensures data is retried periodically and saved on shutdown
329
+ if safe_write_json(
330
+ path, creds, lib_logger, secure_permissions=True, buffer_on_failure=True
331
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  lib_logger.debug(
333
+ f"Saved updated {self.ENV_PREFIX} OAuth credentials to '{path}'."
334
  )
335
+ else:
336
+ lib_logger.warning(
337
+ f"Credentials for {self.ENV_PREFIX} cached in memory only (buffered for retry)."
 
338
  )
 
 
 
 
 
 
 
 
 
 
 
 
339
 
340
  def _is_token_expired(self, creds: Dict[str, Any]) -> bool:
341
  expiry = creds.get("token_expiry") # gcloud format
 
532
  """Proactively refresh a credential by queueing it for refresh."""
533
  creds = await self._load_credentials(credential_path)
534
  if self._is_token_expired(creds):
535
+ # lib_logger.info(f"Proactive refresh triggered for '{Path(credential_path).name}'")
536
  await self._queue_refresh(credential_path, force=False, needs_reauth=False)
537
 
538
  async def _get_lock(self, path: str) -> asyncio.Lock:
 
543
  self._refresh_locks[path] = asyncio.Lock()
544
  return self._refresh_locks[path]
545
 
546
+ def _is_token_truly_expired(self, creds: Dict[str, Any]) -> bool:
547
+ """Check if token is TRULY expired (past actual expiry, not just threshold).
548
+
549
+ This is different from _is_token_expired() which uses a buffer for proactive refresh.
550
+ This method checks if the token is actually unusable.
551
+ """
552
+ expiry = creds.get("token_expiry") # gcloud format
553
+ if not expiry: # gemini-cli format
554
+ expiry_timestamp = creds.get("expiry_date", 0) / 1000
555
+ else:
556
+ expiry_timestamp = time.mktime(time.strptime(expiry, "%Y-%m-%dT%H:%M:%SZ"))
557
+ return expiry_timestamp < time.time()
558
+
559
  def is_credential_available(self, path: str) -> bool:
560
+ """Check if a credential is available for rotation.
561
+
562
+ Credentials are unavailable if:
563
+ 1. In re-auth queue (token is truly broken, requires user interaction)
564
+ 2. Token is TRULY expired (past actual expiry, not just threshold)
565
 
566
+ Note: Credentials in normal refresh queue are still available because
567
+ the old token is valid until actual expiry.
568
+
569
+ TTL cleanup (defense-in-depth): If a credential has been in the re-auth
570
+ queue longer than _unavailable_ttl_seconds without being processed, it's
571
+ cleaned up. This should only happen if the re-auth processor crashes or
572
+ is cancelled without proper cleanup.
573
  """
574
+ # Check if in re-auth queue (truly unavailable)
575
+ if path in self._unavailable_credentials:
576
+ marked_time = self._unavailable_credentials.get(path)
577
+ if marked_time is not None:
578
+ now = time.time()
579
+ if now - marked_time > self._unavailable_ttl_seconds:
580
+ # Entry is stale - clean it up and return available
581
+ # This is a defense-in-depth for edge cases where re-auth
582
+ # processor crashed or was cancelled without cleanup
583
+ lib_logger.warning(
584
+ f"Credential '{Path(path).name}' stuck in re-auth queue for "
585
+ f"{int(now - marked_time)}s (TTL: {self._unavailable_ttl_seconds}s). "
586
+ f"Re-auth processor may have crashed. Auto-cleaning stale entry."
587
+ )
588
+ # Clean up both tracking structures for consistency
589
+ self._unavailable_credentials.pop(path, None)
590
+ self._queued_credentials.discard(path)
591
+ else:
592
+ return False # Still in re-auth, not available
593
 
594
+ # Check if token is TRULY expired (not just threshold-expired)
595
+ creds = self._credentials_cache.get(path)
596
+ if creds and self._is_token_truly_expired(creds):
597
+ # Token is actually expired - should not be used
598
+ # Queue for refresh if not already queued
599
+ if path not in self._queued_credentials:
600
+ # lib_logger.debug(
601
+ # f"Credential '{Path(path).name}' is truly expired, queueing for refresh"
602
+ # )
603
+ asyncio.create_task(
604
+ self._queue_refresh(path, force=True, needs_reauth=False)
605
  )
606
+ return False
 
 
 
 
607
 
608
+ return True
609
 
610
  async def _ensure_queue_processor_running(self):
611
  """Lazily starts the queue processor if not already running."""
 
614
  self._process_refresh_queue()
615
  )
616
 
617
+ async def _ensure_reauth_processor_running(self):
618
+ """Lazily starts the re-auth queue processor if not already running."""
619
+ if self._reauth_processor_task is None or self._reauth_processor_task.done():
620
+ self._reauth_processor_task = asyncio.create_task(
621
+ self._process_reauth_queue()
622
+ )
623
+
624
  async def _queue_refresh(
625
  self, path: str, force: bool = False, needs_reauth: bool = False
626
  ):
627
+ """Add a credential to the appropriate refresh queue if not already queued.
628
 
629
  Args:
630
  path: Credential file path
631
  force: Force refresh even if not expired
632
+ needs_reauth: True if full re-authentication needed (routes to re-auth queue)
633
+
634
+ Queue routing:
635
+ - needs_reauth=True: Goes to re-auth queue, marks as unavailable
636
+ - needs_reauth=False: Goes to normal refresh queue, does NOT mark unavailable
637
+ (old token is still valid until actual expiry)
638
  """
639
  # IMPORTANT: Only check backoff for simple automated refreshes
640
  # Re-authentication (interactive OAuth) should BYPASS backoff since it needs user input
 
644
  backoff_until = self._next_refresh_after[path]
645
  if now < backoff_until:
646
  # Credential is in backoff for automated refresh, do not queue
647
+ # remaining = int(backoff_until - now)
648
+ # lib_logger.debug(
649
+ # f"Skipping automated refresh for '{Path(path).name}' (in backoff for {remaining}s)"
650
+ # )
651
  return
652
 
653
  async with self._queue_tracking_lock:
654
  if path not in self._queued_credentials:
655
  self._queued_credentials.add(path)
656
+
657
+ if needs_reauth:
658
+ # Re-auth queue: mark as unavailable (token is truly broken)
659
+ self._unavailable_credentials[path] = time.time()
660
+ # lib_logger.debug(
661
+ # f"Queued '{Path(path).name}' for RE-AUTH (marked unavailable). "
662
+ # f"Total unavailable: {len(self._unavailable_credentials)}"
663
+ # )
664
+ await self._reauth_queue.put(path)
665
+ await self._ensure_reauth_processor_running()
666
+ else:
667
+ # Normal refresh queue: do NOT mark unavailable (old token still valid)
668
+ # lib_logger.debug(
669
+ # f"Queued '{Path(path).name}' for refresh (still available). "
670
+ # f"Queue size: {self._refresh_queue.qsize() + 1}"
671
+ # )
672
+ await self._refresh_queue.put((path, force))
673
+ await self._ensure_queue_processor_running()
674
 
675
  async def _process_refresh_queue(self):
676
+ """Background worker that processes normal refresh requests sequentially.
677
+
678
+ Key behaviors:
679
+ - 15s timeout per refresh operation
680
+ - 30s delay between processing credentials (prevents thundering herd)
681
+ - On failure: back of queue, max 3 retries before kicked
682
+ - If 401/403 detected: routes to re-auth queue
683
+ - Does NOT mark credentials unavailable (old token still valid)
684
+ """
685
+ # lib_logger.info("Refresh queue processor started")
686
  while True:
687
  path = None
688
  try:
689
  # Wait for an item with timeout to allow graceful shutdown
690
  try:
691
+ path, force = await asyncio.wait_for(
692
  self._refresh_queue.get(), timeout=60.0
693
  )
694
  except asyncio.TimeoutError:
695
+ # Queue is empty and idle for 60s - clean up and exit
 
696
  async with self._queue_tracking_lock:
697
+ # Clear any stale retry counts
698
+ self._queue_retry_count.clear()
 
 
 
 
 
699
  self._queue_processor_task = None
700
+ # lib_logger.debug("Refresh queue processor idle, shutting down")
701
  return
702
 
703
  try:
704
+ # Quick check if still expired (optimization to avoid unnecessary refresh)
705
+ creds = self._credentials_cache.get(path)
706
+ if creds and not self._is_token_expired(creds):
707
+ # No longer expired, skip refresh
708
+ # lib_logger.debug(
709
+ # f"Credential '{Path(path).name}' no longer expired, skipping refresh"
710
+ # )
711
+ # Clear retry count on skip (not a failure)
712
+ self._queue_retry_count.pop(path, None)
713
+ continue
714
+
715
+ # Perform refresh with timeout
716
+ if not creds:
717
+ creds = await self._load_credentials(path)
718
+
719
+ try:
720
+ async with asyncio.timeout(self._refresh_timeout_seconds):
721
+ await self._refresh_token(path, creds, force=force)
722
 
723
+ # SUCCESS: Clear retry count
724
+ self._queue_retry_count.pop(path, None)
725
+ # lib_logger.info(f"Refresh SUCCESS for '{Path(path).name}'")
726
+
727
+ except asyncio.TimeoutError:
728
+ lib_logger.warning(
729
+ f"Refresh timeout ({self._refresh_timeout_seconds}s) for '{Path(path).name}'"
730
+ )
731
+ await self._handle_refresh_failure(path, force, "timeout")
732
+
733
+ except httpx.HTTPStatusError as e:
734
+ status_code = e.response.status_code
735
+ if status_code in (401, 403):
736
+ # Invalid refresh token - route to re-auth queue
737
+ lib_logger.warning(
738
+ f"Refresh token invalid for '{Path(path).name}' (HTTP {status_code}). "
739
+ f"Routing to re-auth queue."
740
  )
741
+ self._queue_retry_count.pop(path, None) # Clear retry count
742
+ async with self._queue_tracking_lock:
743
+ self._queued_credentials.discard(
744
+ path
745
+ ) # Remove from queued
746
+ await self._queue_refresh(
747
+ path, force=True, needs_reauth=True
748
+ )
749
+ else:
750
+ await self._handle_refresh_failure(
751
+ path, force, f"HTTP {status_code}"
752
+ )
753
+
754
+ except Exception as e:
755
+ await self._handle_refresh_failure(path, force, str(e))
756
+
757
+ finally:
758
+ # Remove from queued set (unless re-queued by failure handler)
759
+ async with self._queue_tracking_lock:
760
+ # Only discard if not re-queued (check if still in queue set from retry)
761
+ if (
762
+ path in self._queued_credentials
763
+ and self._queue_retry_count.get(path, 0) == 0
764
+ ):
765
+ self._queued_credentials.discard(path)
766
+ self._refresh_queue.task_done()
767
+
768
+ # Wait between credentials to spread load
769
+ await asyncio.sleep(self._refresh_interval_seconds)
770
+
771
+ except asyncio.CancelledError:
772
+ # lib_logger.debug("Refresh queue processor cancelled")
773
+ break
774
+ except Exception as e:
775
+ lib_logger.error(f"Error in refresh queue processor: {e}")
776
+ if path:
777
+ async with self._queue_tracking_lock:
778
+ self._queued_credentials.discard(path)
779
+
780
+ async def _handle_refresh_failure(self, path: str, force: bool, error: str):
781
+ """Handle a refresh failure with back-of-line retry logic.
782
+
783
+ - Increments retry count
784
+ - If under max retries: re-adds to END of queue
785
+ - If at max retries: kicks credential out (retried next BackgroundRefresher cycle)
786
+ """
787
+ retry_count = self._queue_retry_count.get(path, 0) + 1
788
+ self._queue_retry_count[path] = retry_count
789
+
790
+ if retry_count >= self._refresh_max_retries:
791
+ # Kicked out until next BackgroundRefresher cycle
792
+ lib_logger.error(
793
+ f"Max retries ({self._refresh_max_retries}) reached for '{Path(path).name}' "
794
+ f"(last error: {error}). Will retry next refresh cycle."
795
+ )
796
+ self._queue_retry_count.pop(path, None)
797
+ async with self._queue_tracking_lock:
798
+ self._queued_credentials.discard(path)
799
+ return
800
+
801
+ # Re-add to END of queue for retry
802
+ lib_logger.warning(
803
+ f"Refresh failed for '{Path(path).name}' ({error}). "
804
+ f"Retry {retry_count}/{self._refresh_max_retries}, back of queue."
805
+ )
806
+ # Keep in queued_credentials set, add back to queue
807
+ await self._refresh_queue.put((path, force))
808
+
809
+ async def _process_reauth_queue(self):
810
+ """Background worker that processes re-auth requests.
811
+
812
+ Key behaviors:
813
+ - Credentials ARE marked unavailable (token is truly broken)
814
+ - Uses ReauthCoordinator for interactive OAuth
815
+ - No automatic retry (requires user action)
816
+ - Cleans up unavailable status when done
817
+ """
818
+ # lib_logger.info("Re-auth queue processor started")
819
+ while True:
820
+ path = None
821
+ try:
822
+ # Wait for an item with timeout to allow graceful shutdown
823
+ try:
824
+ path = await asyncio.wait_for(
825
+ self._reauth_queue.get(), timeout=60.0
826
+ )
827
+ except asyncio.TimeoutError:
828
+ # Queue is empty and idle for 60s - exit
829
+ self._reauth_processor_task = None
830
+ # lib_logger.debug("Re-auth queue processor idle, shutting down")
831
+ return
832
+
833
+ try:
834
+ lib_logger.info(f"Starting re-auth for '{Path(path).name}'...")
835
+ await self.initialize_token(path)
836
+ lib_logger.info(f"Re-auth SUCCESS for '{Path(path).name}'")
837
+
838
+ except Exception as e:
839
+ lib_logger.error(f"Re-auth FAILED for '{Path(path).name}': {e}")
840
+ # No automatic retry for re-auth (requires user action)
841
 
842
  finally:
843
+ # Always clean up
 
844
  async with self._queue_tracking_lock:
845
  self._queued_credentials.discard(path)
 
846
  self._unavailable_credentials.pop(path, None)
847
+ # lib_logger.debug(
848
+ # f"Re-auth cleanup for '{Path(path).name}'. "
849
+ # f"Remaining unavailable: {len(self._unavailable_credentials)}"
850
+ # )
851
+ self._reauth_queue.task_done()
852
+
853
  except asyncio.CancelledError:
854
+ # Clean up current credential before breaking
855
  if path:
856
  async with self._queue_tracking_lock:
857
+ self._queued_credentials.discard(path)
858
  self._unavailable_credentials.pop(path, None)
859
+ # lib_logger.debug("Re-auth queue processor cancelled")
 
 
 
860
  break
861
  except Exception as e:
862
+ lib_logger.error(f"Error in re-auth queue processor: {e}")
 
863
  if path:
864
  async with self._queue_tracking_lock:
865
+ self._queued_credentials.discard(path)
866
  self._unavailable_credentials.pop(path, None)
 
 
 
 
867
 
868
  async def _perform_interactive_oauth(
869
  self, path: str, creds: Dict[str, Any], display_name: str
 
923
 
924
  try:
925
  server = await asyncio.start_server(
926
+ handle_callback, "127.0.0.1", self.callback_port
927
  )
928
  from urllib.parse import urlencode
929
 
930
  auth_url = "https://accounts.google.com/o/oauth2/v2/auth?" + urlencode(
931
  {
932
  "client_id": self.CLIENT_ID,
933
+ "redirect_uri": f"http://localhost:{self.callback_port}{self.CALLBACK_PATH}",
934
  "scope": " ".join(self.OAUTH_SCOPES),
935
  "access_type": "offline",
936
  "response_type": "code",
 
1005
  "code": auth_code.strip(),
1006
  "client_id": self.CLIENT_ID,
1007
  "client_secret": self.CLIENT_SECRET,
1008
+ "redirect_uri": f"http://localhost:{self.callback_port}{self.CALLBACK_PATH}",
1009
  "grant_type": "authorization_code",
1010
  },
1011
  )
 
1043
  lib_logger.info(
1044
  f"{self.ENV_PREFIX} OAuth initialized successfully for '{display_name}'."
1045
  )
1046
+
1047
+ # Perform post-auth discovery (tier, project, etc.) while we have a fresh token
1048
+ if path:
1049
+ try:
1050
+ await self._post_auth_discovery(path, new_creds["access_token"])
1051
+ except Exception as e:
1052
+ # Don't fail auth if discovery fails - it can be retried on first request
1053
+ lib_logger.warning(
1054
+ f"Post-auth discovery failed for '{display_name}': {e}. "
1055
+ "Tier/project will be discovered on first request."
1056
+ )
1057
+
1058
  return new_creds
1059
 
1060
  async def initialize_token(
 
1131
  )
1132
 
1133
  async def get_auth_header(self, credential_path: str) -> Dict[str, str]:
1134
+ """Get auth header with graceful degradation if refresh fails."""
1135
+ try:
1136
+ creds = await self._load_credentials(credential_path)
1137
+ if self._is_token_expired(creds):
1138
+ try:
1139
+ creds = await self._refresh_token(credential_path, creds)
1140
+ except Exception as e:
1141
+ # Check if we have a cached token that might still work
1142
+ cached = self._credentials_cache.get(credential_path)
1143
+ if cached and cached.get("access_token"):
1144
+ lib_logger.warning(
1145
+ f"Token refresh failed for {Path(credential_path).name}: {e}. "
1146
+ "Using cached token (may be expired)."
1147
+ )
1148
+ creds = cached
1149
+ else:
1150
+ raise
1151
+ return {"Authorization": f"Bearer {creds['access_token']}"}
1152
+ except Exception as e:
1153
+ # Check if any cached credential exists as last resort
1154
+ cached = self._credentials_cache.get(credential_path)
1155
+ if cached and cached.get("access_token"):
1156
+ lib_logger.error(
1157
+ f"Credential load failed for {credential_path}: {e}. "
1158
+ "Using stale cached token as last resort."
1159
+ )
1160
+ return {"Authorization": f"Bearer {cached['access_token']}"}
1161
+ raise
1162
+
1163
+ async def _post_auth_discovery(
1164
+ self, credential_path: str, access_token: str
1165
+ ) -> None:
1166
+ """
1167
+ Hook for subclasses to perform post-authentication discovery.
1168
+
1169
+ Called after successful OAuth authentication (both initial and re-auth).
1170
+ Subclasses can override this to discover and cache tier/project information
1171
+ during the authentication flow rather than waiting for the first API request.
1172
+
1173
+ Args:
1174
+ credential_path: Path to the credential file
1175
+ access_token: The newly obtained access token
1176
+ """
1177
+ # Default implementation does nothing - subclasses can override
1178
+ pass
1179
 
1180
  async def get_user_info(
1181
  self, creds_or_path: Union[Dict[str, Any], str]
 
1208
  if path:
1209
  await self._save_credentials(path, creds)
1210
  return {"email": user_info.get("email")}
1211
+
1212
+ # =========================================================================
1213
+ # CREDENTIAL MANAGEMENT METHODS
1214
+ # =========================================================================
1215
+
1216
+ def _get_provider_file_prefix(self) -> str:
1217
+ """
1218
+ Get the file name prefix for this provider's credential files.
1219
+
1220
+ Override in subclasses if the prefix differs from ENV_PREFIX.
1221
+ Default: lowercase ENV_PREFIX with underscores (e.g., "gemini_cli")
1222
+ """
1223
+ return self.ENV_PREFIX.lower()
1224
+
1225
+ def _get_oauth_base_dir(self) -> Path:
1226
+ """
1227
+ Get the base directory for OAuth credential files.
1228
+
1229
+ Can be overridden to customize credential storage location.
1230
+ """
1231
+ return Path.cwd() / "oauth_creds"
1232
+
1233
+ def _find_existing_credential_by_email(
1234
+ self, email: str, base_dir: Optional[Path] = None
1235
+ ) -> Optional[Path]:
1236
+ """
1237
+ Find an existing credential file for the given email.
1238
+
1239
+ Args:
1240
+ email: Email address to search for
1241
+ base_dir: Directory to search in (defaults to oauth_creds)
1242
+
1243
+ Returns:
1244
+ Path to existing credential file, or None if not found
1245
+ """
1246
+ if base_dir is None:
1247
+ base_dir = self._get_oauth_base_dir()
1248
+
1249
+ prefix = self._get_provider_file_prefix()
1250
+ pattern = str(base_dir / f"{prefix}_oauth_*.json")
1251
+
1252
+ for cred_file in glob(pattern):
1253
+ try:
1254
+ with open(cred_file, "r") as f:
1255
+ creds = json.load(f)
1256
+ existing_email = creds.get("_proxy_metadata", {}).get("email")
1257
+ if existing_email == email:
1258
+ return Path(cred_file)
1259
+ except (json.JSONDecodeError, IOError) as e:
1260
+ lib_logger.debug(f"Could not read credential file {cred_file}: {e}")
1261
+ continue
1262
+
1263
+ return None
1264
+
1265
+ def _get_next_credential_number(self, base_dir: Optional[Path] = None) -> int:
1266
+ """
1267
+ Get the next available credential number for new credential files.
1268
+
1269
+ Args:
1270
+ base_dir: Directory to scan (defaults to oauth_creds)
1271
+
1272
+ Returns:
1273
+ Next available credential number (1, 2, 3, etc.)
1274
+ """
1275
+ if base_dir is None:
1276
+ base_dir = self._get_oauth_base_dir()
1277
+
1278
+ prefix = self._get_provider_file_prefix()
1279
+ pattern = str(base_dir / f"{prefix}_oauth_*.json")
1280
+
1281
+ existing_numbers = []
1282
+ for cred_file in glob(pattern):
1283
+ match = re.search(r"_oauth_(\d+)\.json$", cred_file)
1284
+ if match:
1285
+ existing_numbers.append(int(match.group(1)))
1286
+
1287
+ if not existing_numbers:
1288
+ return 1
1289
+ return max(existing_numbers) + 1
1290
+
1291
+ def _build_credential_path(
1292
+ self, base_dir: Optional[Path] = None, number: Optional[int] = None
1293
+ ) -> Path:
1294
+ """
1295
+ Build a path for a new credential file.
1296
+
1297
+ Args:
1298
+ base_dir: Directory for credential files (defaults to oauth_creds)
1299
+ number: Credential number (auto-determined if None)
1300
+
1301
+ Returns:
1302
+ Path for the new credential file
1303
+ """
1304
+ if base_dir is None:
1305
+ base_dir = self._get_oauth_base_dir()
1306
+
1307
+ if number is None:
1308
+ number = self._get_next_credential_number(base_dir)
1309
+
1310
+ prefix = self._get_provider_file_prefix()
1311
+ filename = f"{prefix}_oauth_{number}.json"
1312
+ return base_dir / filename
1313
+
1314
+ async def setup_credential(
1315
+ self, base_dir: Optional[Path] = None
1316
+ ) -> CredentialSetupResult:
1317
+ """
1318
+ Complete credential setup flow: OAuth -> save -> discovery.
1319
+
1320
+ This is the main entry point for setting up new credentials.
1321
+ Handles the entire lifecycle:
1322
+ 1. Perform OAuth authentication
1323
+ 2. Get user info (email) for deduplication
1324
+ 3. Find existing credential or create new file path
1325
+ 4. Save credentials to file
1326
+ 5. Perform post-auth discovery (tier/project for Google OAuth)
1327
+
1328
+ Args:
1329
+ base_dir: Directory for credential files (defaults to oauth_creds)
1330
+
1331
+ Returns:
1332
+ CredentialSetupResult with status and details
1333
+ """
1334
+ if base_dir is None:
1335
+ base_dir = self._get_oauth_base_dir()
1336
+
1337
+ # Ensure directory exists
1338
+ base_dir.mkdir(exist_ok=True)
1339
+
1340
+ try:
1341
+ # Step 1: Perform OAuth authentication (returns credentials dict)
1342
+ temp_creds = {
1343
+ "_proxy_metadata": {"display_name": f"new {self.ENV_PREFIX} credential"}
1344
+ }
1345
+ new_creds = await self.initialize_token(temp_creds)
1346
+
1347
+ # Step 2: Get user info for deduplication
1348
+ user_info = await self.get_user_info(new_creds)
1349
+ email = user_info.get("email")
1350
+
1351
+ if not email:
1352
+ return CredentialSetupResult(
1353
+ success=False, error="Could not retrieve email from OAuth response"
1354
+ )
1355
+
1356
+ # Step 3: Check for existing credential with same email
1357
+ existing_path = self._find_existing_credential_by_email(email, base_dir)
1358
+ is_update = existing_path is not None
1359
+
1360
+ if is_update:
1361
+ file_path = existing_path
1362
+ lib_logger.info(
1363
+ f"Found existing credential for {email}, updating {file_path.name}"
1364
+ )
1365
+ else:
1366
+ file_path = self._build_credential_path(base_dir)
1367
+ lib_logger.info(
1368
+ f"Creating new credential for {email} at {file_path.name}"
1369
+ )
1370
+
1371
+ # Step 4: Save credentials to file
1372
+ await self._save_credentials(str(file_path), new_creds)
1373
+
1374
+ # Step 5: Perform post-auth discovery (tier, project_id)
1375
+ # This is already called in _perform_interactive_oauth, but we call it again
1376
+ # in case credentials were loaded from existing token
1377
+ tier = None
1378
+ project_id = None
1379
+ try:
1380
+ await self._post_auth_discovery(
1381
+ str(file_path), new_creds["access_token"]
1382
+ )
1383
+ # Reload credentials to get discovered metadata
1384
+ with open(file_path, "r") as f:
1385
+ updated_creds = json.load(f)
1386
+ tier = updated_creds.get("_proxy_metadata", {}).get("tier")
1387
+ project_id = updated_creds.get("_proxy_metadata", {}).get("project_id")
1388
+ new_creds = updated_creds
1389
+ except Exception as e:
1390
+ lib_logger.warning(
1391
+ f"Post-auth discovery failed: {e}. Tier/project will be discovered on first request."
1392
+ )
1393
+
1394
+ return CredentialSetupResult(
1395
+ success=True,
1396
+ file_path=str(file_path),
1397
+ email=email,
1398
+ tier=tier,
1399
+ project_id=project_id,
1400
+ is_update=is_update,
1401
+ credentials=new_creds,
1402
+ )
1403
+
1404
+ except Exception as e:
1405
+ lib_logger.error(f"Credential setup failed: {e}")
1406
+ return CredentialSetupResult(success=False, error=str(e))
1407
+
1408
+ def build_env_lines(self, creds: Dict[str, Any], cred_number: int) -> List[str]:
1409
+ """
1410
+ Generate .env file lines for a credential.
1411
+
1412
+ Subclasses should override to include provider-specific fields
1413
+ (e.g., tier, project_id for Google OAuth providers).
1414
+
1415
+ Args:
1416
+ creds: Credential dictionary loaded from JSON
1417
+ cred_number: Credential number (1, 2, 3, etc.)
1418
+
1419
+ Returns:
1420
+ List of .env file lines
1421
+ """
1422
+ email = creds.get("_proxy_metadata", {}).get("email", "unknown")
1423
+ prefix = f"{self.ENV_PREFIX}_{cred_number}"
1424
+
1425
+ lines = [
1426
+ f"# {self.ENV_PREFIX} Credential #{cred_number} for: {email}",
1427
+ f"# Exported from: {self._get_provider_file_prefix()}_oauth_{cred_number}.json",
1428
+ f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
1429
+ "#",
1430
+ "# To combine multiple credentials into one .env file, copy these lines",
1431
+ "# and ensure each credential has a unique number (1, 2, 3, etc.)",
1432
+ "",
1433
+ f"{prefix}_ACCESS_TOKEN={creds.get('access_token', '')}",
1434
+ f"{prefix}_REFRESH_TOKEN={creds.get('refresh_token', '')}",
1435
+ f"{prefix}_SCOPE={creds.get('scope', '')}",
1436
+ f"{prefix}_TOKEN_TYPE={creds.get('token_type', 'Bearer')}",
1437
+ f"{prefix}_ID_TOKEN={creds.get('id_token', '')}",
1438
+ f"{prefix}_EXPIRY_DATE={creds.get('expiry_date', 0)}",
1439
+ f"{prefix}_CLIENT_ID={creds.get('client_id', '')}",
1440
+ f"{prefix}_CLIENT_SECRET={creds.get('client_secret', '')}",
1441
+ f"{prefix}_TOKEN_URI={creds.get('token_uri', 'https://oauth2.googleapis.com/token')}",
1442
+ f"{prefix}_UNIVERSE_DOMAIN={creds.get('universe_domain', 'googleapis.com')}",
1443
+ f"{prefix}_EMAIL={email}",
1444
+ ]
1445
+
1446
+ return lines
1447
+
1448
+ def export_credential_to_env(
1449
+ self, credential_path: str, output_dir: Optional[Path] = None
1450
+ ) -> Optional[str]:
1451
+ """
1452
+ Export a credential file to .env format.
1453
+
1454
+ Args:
1455
+ credential_path: Path to the credential JSON file
1456
+ output_dir: Directory for output .env file (defaults to same as credential)
1457
+
1458
+ Returns:
1459
+ Path to the exported .env file, or None on error
1460
+ """
1461
+ try:
1462
+ cred_path = Path(credential_path)
1463
+
1464
+ # Load credential
1465
+ with open(cred_path, "r") as f:
1466
+ creds = json.load(f)
1467
+
1468
+ # Extract metadata
1469
+ email = creds.get("_proxy_metadata", {}).get("email", "unknown")
1470
+
1471
+ # Get credential number from filename
1472
+ match = re.search(r"_oauth_(\d+)\.json$", cred_path.name)
1473
+ cred_number = int(match.group(1)) if match else 1
1474
+
1475
+ # Build output path
1476
+ if output_dir is None:
1477
+ output_dir = cred_path.parent
1478
+
1479
+ safe_email = email.replace("@", "_at_").replace(".", "_")
1480
+ prefix = self._get_provider_file_prefix()
1481
+ env_filename = f"{prefix}_{cred_number}_{safe_email}.env"
1482
+ env_path = output_dir / env_filename
1483
+
1484
+ # Build and write content
1485
+ env_lines = self.build_env_lines(creds, cred_number)
1486
+ with open(env_path, "w") as f:
1487
+ f.write("\n".join(env_lines))
1488
+
1489
+ lib_logger.info(f"Exported credential to {env_path}")
1490
+ return str(env_path)
1491
+
1492
+ except Exception as e:
1493
+ lib_logger.error(f"Failed to export credential: {e}")
1494
+ return None
1495
+
1496
+ def list_credentials(self, base_dir: Optional[Path] = None) -> List[Dict[str, Any]]:
1497
+ """
1498
+ List all credential files for this provider.
1499
+
1500
+ Args:
1501
+ base_dir: Directory to search (defaults to oauth_creds)
1502
+
1503
+ Returns:
1504
+ List of dicts with credential info:
1505
+ - file_path: Path to credential file
1506
+ - email: User email
1507
+ - tier: Tier info (if available)
1508
+ - project_id: Project ID (if available)
1509
+ - number: Credential number
1510
+ """
1511
+ if base_dir is None:
1512
+ base_dir = self._get_oauth_base_dir()
1513
+
1514
+ prefix = self._get_provider_file_prefix()
1515
+ pattern = str(base_dir / f"{prefix}_oauth_*.json")
1516
+
1517
+ credentials = []
1518
+ for cred_file in sorted(glob(pattern)):
1519
+ try:
1520
+ with open(cred_file, "r") as f:
1521
+ creds = json.load(f)
1522
+
1523
+ metadata = creds.get("_proxy_metadata", {})
1524
+
1525
+ # Extract number from filename
1526
+ match = re.search(r"_oauth_(\d+)\.json$", cred_file)
1527
+ number = int(match.group(1)) if match else 0
1528
+
1529
+ credentials.append(
1530
+ {
1531
+ "file_path": cred_file,
1532
+ "email": metadata.get("email", "unknown"),
1533
+ "tier": metadata.get("tier"),
1534
+ "project_id": metadata.get("project_id"),
1535
+ "number": number,
1536
+ }
1537
+ )
1538
+ except Exception as e:
1539
+ lib_logger.debug(f"Could not read credential file {cred_file}: {e}")
1540
+ continue
1541
+
1542
+ return credentials
1543
+
1544
+ def delete_credential(self, credential_path: str) -> bool:
1545
+ """
1546
+ Delete a credential file.
1547
+
1548
+ Args:
1549
+ credential_path: Path to the credential file
1550
+
1551
+ Returns:
1552
+ True if deleted successfully, False otherwise
1553
+ """
1554
+ try:
1555
+ cred_path = Path(credential_path)
1556
+
1557
+ # Validate that it's one of our credential files
1558
+ prefix = self._get_provider_file_prefix()
1559
+ if not cred_path.name.startswith(f"{prefix}_oauth_"):
1560
+ lib_logger.error(
1561
+ f"File {cred_path.name} does not appear to be a {self.ENV_PREFIX} credential"
1562
+ )
1563
+ return False
1564
+
1565
+ if not cred_path.exists():
1566
+ lib_logger.warning(f"Credential file does not exist: {credential_path}")
1567
+ return False
1568
+
1569
+ # Remove from cache if present
1570
+ self._credentials_cache.pop(credential_path, None)
1571
+
1572
+ # Delete the file
1573
+ cred_path.unlink()
1574
+ lib_logger.info(f"Deleted credential file: {credential_path}")
1575
+ return True
1576
+
1577
+ except Exception as e:
1578
+ lib_logger.error(f"Failed to delete credential: {e}")
1579
+ return False
src/rotator_library/providers/iflow_auth_base.py CHANGED
@@ -9,11 +9,12 @@ import logging
9
  import webbrowser
10
  import socket
11
  import os
 
 
12
  from pathlib import Path
13
- from typing import Dict, Any, Tuple, Union, Optional
 
14
  from urllib.parse import urlencode, parse_qs, urlparse
15
- import tempfile
16
- import shutil
17
 
18
  import httpx
19
  from aiohttp import web
@@ -24,6 +25,7 @@ from rich.text import Text
24
  from rich.markup import escape as rich_escape
25
  from ..utils.headless_detection import is_headless_environment
26
  from ..utils.reauth_coordinator import get_reauth_coordinator
 
27
 
28
  lib_logger = logging.getLogger("rotator_library")
29
 
@@ -40,6 +42,39 @@ IFLOW_CLIENT_SECRET = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW"
40
  # Local callback server port
41
  CALLBACK_PORT = 11451
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  # Refresh tokens 24 hours before expiry
44
  REFRESH_EXPIRY_BUFFER_SECONDS = 24 * 60 * 60
45
 
@@ -171,19 +206,36 @@ class IFlowAuthBase:
171
  str, float
172
  ] = {} # Track backoff timers (Unix timestamp)
173
 
174
- # [QUEUE SYSTEM] Sequential refresh processing
 
175
  self._refresh_queue: asyncio.Queue = asyncio.Queue()
176
- self._queued_credentials: set = set() # Track credentials already in queue
177
- # [FIX PR#34] Changed from set to dict mapping credential path to timestamp
178
- # This enables TTL-based stale entry cleanup as defense in depth
 
 
 
 
 
 
 
179
  self._unavailable_credentials: Dict[
180
  str, float
181
  ] = {} # Maps credential path -> timestamp when marked unavailable
182
- self._unavailable_ttl_seconds: int = 300 # 5 minutes TTL for stale entries
 
183
  self._queue_tracking_lock = asyncio.Lock() # Protects queue sets
184
- self._queue_processor_task: Optional[asyncio.Task] = (
185
- None # Background worker task
186
- )
 
 
 
 
 
 
 
 
187
 
188
  def _parse_env_credential_path(self, path: str) -> Optional[str]:
189
  """
@@ -305,76 +357,40 @@ class IFlowAuthBase:
305
  f"Environment variables for iFlow credential index {credential_index} not found"
306
  )
307
 
308
- # For file paths, try loading from legacy env vars first
309
- env_creds = self._load_from_env()
310
- if env_creds:
311
- lib_logger.info("Using iFlow credentials from environment variables")
312
- self._credentials_cache[path] = env_creds
313
- return env_creds
314
-
315
- # Fall back to file-based loading
316
- return await self._read_creds_from_file(path)
 
 
 
 
317
 
318
  async def _save_credentials(self, path: str, creds: Dict[str, Any]):
319
- """Saves credentials to cache and file using atomic writes."""
 
 
 
320
  # Don't save to file if credentials were loaded from environment
321
  if creds.get("_proxy_metadata", {}).get("loaded_from_env"):
322
  lib_logger.debug("Credentials loaded from env, skipping file save")
323
- # Still update cache for in-memory consistency
324
- self._credentials_cache[path] = creds
325
  return
326
 
327
- # [ATOMIC WRITE] Use tempfile + move pattern to ensure atomic writes
328
- # This prevents credential corruption if the process is interrupted during write
329
- parent_dir = os.path.dirname(os.path.abspath(path))
330
- os.makedirs(parent_dir, exist_ok=True)
331
-
332
- tmp_fd = None
333
- tmp_path = None
334
- try:
335
- # Create temp file in same directory as target (ensures same filesystem)
336
- tmp_fd, tmp_path = tempfile.mkstemp(
337
- dir=parent_dir, prefix=".tmp_", suffix=".json", text=True
338
- )
339
-
340
- # Write JSON to temp file
341
- with os.fdopen(tmp_fd, "w") as f:
342
- json.dump(creds, f, indent=2)
343
- tmp_fd = None # fdopen closes the fd
344
-
345
- # Set secure permissions (0600 = owner read/write only)
346
- try:
347
- os.chmod(tmp_path, 0o600)
348
- except (OSError, AttributeError):
349
- # Windows may not support chmod, ignore
350
- pass
351
-
352
- # Atomic move (overwrites target if it exists)
353
- shutil.move(tmp_path, path)
354
- tmp_path = None # Successfully moved
355
-
356
- # Update cache AFTER successful file write
357
- self._credentials_cache[path] = creds
358
- lib_logger.debug(
359
- f"Saved updated iFlow OAuth credentials to '{path}' (atomic write)."
360
- )
361
-
362
- except Exception as e:
363
- lib_logger.error(
364
- f"Failed to save updated iFlow OAuth credentials to '{path}': {e}"
365
  )
366
- # Clean up temp file if it still exists
367
- if tmp_fd is not None:
368
- try:
369
- os.close(tmp_fd)
370
- except:
371
- pass
372
- if tmp_path and os.path.exists(tmp_path):
373
- try:
374
- os.unlink(tmp_path)
375
- except:
376
- pass
377
- raise
378
 
379
  def _is_token_expired(self, creds: Dict[str, Any]) -> bool:
380
  """Checks if the token is expired (with buffer for proactive refresh)."""
@@ -399,6 +415,29 @@ class IFlowAuthBase:
399
 
400
  return expiry_timestamp < time.time() + REFRESH_EXPIRY_BUFFER_SECONDS
401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
  async def _fetch_user_info(self, access_token: str) -> Dict[str, Any]:
403
  """
404
  Fetches user info (including API key) from iFlow API.
@@ -553,6 +592,26 @@ class IFlowAuthBase:
553
  )
554
  response.raise_for_status()
555
  new_token_data = response.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
556
  break # Success
557
 
558
  except httpx.HTTPStatusError as e:
@@ -654,6 +713,16 @@ class IFlowAuthBase:
654
  # Update tokens
655
  access_token = new_token_data.get("access_token")
656
  if not access_token:
 
 
 
 
 
 
 
 
 
 
657
  raise ValueError("Missing access_token in refresh response")
658
 
659
  creds_from_file["access_token"] = access_token
@@ -749,7 +818,7 @@ class IFlowAuthBase:
749
  Proactively refreshes tokens if they're close to expiry.
750
  Only applies to OAuth credentials (file paths or env:// paths). Direct API keys are skipped.
751
  """
752
- lib_logger.debug(f"proactively_refresh called for: {credential_identifier}")
753
 
754
  # Try to load credentials - this will fail for direct API keys
755
  # and succeed for OAuth credentials (file paths or env:// paths)
@@ -757,21 +826,21 @@ class IFlowAuthBase:
757
  creds = await self._load_credentials(credential_identifier)
758
  except IOError as e:
759
  # Not a valid credential path (likely a direct API key string)
760
- lib_logger.debug(
761
- f"Skipping refresh for '{credential_identifier}' - not an OAuth credential: {e}"
762
- )
763
  return
764
 
765
  is_expired = self._is_token_expired(creds)
766
- lib_logger.debug(
767
- f"Token expired check for '{Path(credential_identifier).name}': {is_expired}"
768
- )
769
 
770
  if is_expired:
771
- lib_logger.debug(
772
- f"Queueing refresh for '{Path(credential_identifier).name}'"
773
- )
774
- # Queue for refresh with needs_reauth=False (automated refresh)
775
  await self._queue_refresh(
776
  credential_identifier, force=False, needs_reauth=False
777
  )
@@ -785,30 +854,55 @@ class IFlowAuthBase:
785
  return self._refresh_locks[path]
786
 
787
  def is_credential_available(self, path: str) -> bool:
788
- """Check if a credential is available for rotation (not queued/refreshing).
789
 
790
- [FIX PR#34] Now includes TTL-based stale entry cleanup as defense in depth.
791
- If a credential has been unavailable for longer than _unavailable_ttl_seconds,
792
- it is automatically cleaned up and considered available.
 
 
 
 
 
 
 
 
793
  """
794
- if path not in self._unavailable_credentials:
795
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
796
 
797
- # [FIX PR#34] Check if the entry is stale (TTL expired)
798
- marked_time = self._unavailable_credentials.get(path)
799
- if marked_time is not None:
800
- now = time.time()
801
- if now - marked_time > self._unavailable_ttl_seconds:
802
- # Entry is stale - clean it up and return available
803
- lib_logger.warning(
804
- f"Credential '{Path(path).name}' was stuck in unavailable state for "
805
- f"{int(now - marked_time)}s (TTL: {self._unavailable_ttl_seconds}s). "
806
- f"Auto-cleaning stale entry."
 
807
  )
808
- self._unavailable_credentials.pop(path, None)
809
- return True
810
 
811
- return False
812
 
813
  async def _ensure_queue_processor_running(self):
814
  """Lazily starts the queue processor if not already running."""
@@ -817,15 +911,27 @@ class IFlowAuthBase:
817
  self._process_refresh_queue()
818
  )
819
 
 
 
 
 
 
 
 
820
  async def _queue_refresh(
821
  self, path: str, force: bool = False, needs_reauth: bool = False
822
  ):
823
- """Add a credential to the refresh queue if not already queued.
824
 
825
  Args:
826
  path: Credential file path
827
  force: Force refresh even if not expired
828
- needs_reauth: True if full re-authentication needed (bypasses backoff)
 
 
 
 
 
829
  """
830
  # IMPORTANT: Only check backoff for simple automated refreshes
831
  # Re-authentication (interactive OAuth) should BYPASS backoff since it needs user input
@@ -835,114 +941,223 @@ class IFlowAuthBase:
835
  backoff_until = self._next_refresh_after[path]
836
  if now < backoff_until:
837
  # Credential is in backoff for automated refresh, do not queue
838
- remaining = int(backoff_until - now)
839
- lib_logger.debug(
840
- f"Skipping automated refresh for '{Path(path).name}' (in backoff for {remaining}s)"
841
- )
842
  return
843
 
844
  async with self._queue_tracking_lock:
845
  if path not in self._queued_credentials:
846
  self._queued_credentials.add(path)
847
- # [FIX PR#34] Store timestamp when marking unavailable (for TTL cleanup)
848
- self._unavailable_credentials[path] = time.time()
849
- lib_logger.debug(
850
- f"Marked '{Path(path).name}' as unavailable. "
851
- f"Total unavailable: {len(self._unavailable_credentials)}"
852
- )
853
- await self._refresh_queue.put((path, force, needs_reauth))
854
- await self._ensure_queue_processor_running()
 
 
 
 
 
 
 
 
 
 
855
 
856
  async def _process_refresh_queue(self):
857
- """Background worker that processes refresh requests sequentially."""
 
 
 
 
 
 
 
 
 
858
  while True:
859
  path = None
860
  try:
861
  # Wait for an item with timeout to allow graceful shutdown
862
  try:
863
- path, force, needs_reauth = await asyncio.wait_for(
864
  self._refresh_queue.get(), timeout=60.0
865
  )
866
  except asyncio.TimeoutError:
867
- # [FIX PR#34] Clean up any stale unavailable entries before exiting
868
- # If we're idle for 60s, no refreshes are in progress
869
  async with self._queue_tracking_lock:
870
- if self._unavailable_credentials:
871
- stale_count = len(self._unavailable_credentials)
872
- lib_logger.warning(
873
- f"Queue processor idle timeout. Cleaning {stale_count} "
874
- f"stale unavailable credentials: {list(self._unavailable_credentials.keys())}"
875
- )
876
- self._unavailable_credentials.clear()
877
- # [FIX BUG#6] Also clear queued credentials to prevent stuck state
878
- if self._queued_credentials:
879
- lib_logger.debug(
880
- f"Clearing {len(self._queued_credentials)} queued credentials on timeout"
881
- )
882
- self._queued_credentials.clear()
883
  self._queue_processor_task = None
 
884
  return
885
 
886
  try:
887
- # Perform the actual refresh (still using per-credential lock)
888
- async with await self._get_lock(path):
889
- # Re-check if still expired (may have changed since queueing)
890
- creds = self._credentials_cache.get(path)
891
- if creds and not self._is_token_expired(creds):
892
- # No longer expired, mark as available
893
- async with self._queue_tracking_lock:
894
- self._unavailable_credentials.pop(path, None)
895
- lib_logger.debug(
896
- f"Credential '{Path(path).name}' no longer expired, marked available. "
897
- f"Remaining unavailable: {len(self._unavailable_credentials)}"
898
- )
899
- continue
 
 
900
 
901
- # Perform refresh
902
- if not creds:
903
- creds = await self._load_credentials(path)
904
- await self._refresh_token(path, force=force)
905
 
906
- # SUCCESS: Mark as available again
907
- async with self._queue_tracking_lock:
908
- self._unavailable_credentials.pop(path, None)
909
- lib_logger.debug(
910
- f"Refresh SUCCESS for '{Path(path).name}', marked available. "
911
- f"Remaining unavailable: {len(self._unavailable_credentials)}"
 
 
 
 
 
 
 
912
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
913
 
914
  finally:
915
- # [FIX PR#34] Remove from BOTH queued set AND unavailable credentials
916
- # This ensures cleanup happens in ALL exit paths (success, exception, etc.)
917
  async with self._queue_tracking_lock:
918
  self._queued_credentials.discard(path)
919
- # [FIX PR#34] Always clean up unavailable credentials in finally block
920
  self._unavailable_credentials.pop(path, None)
921
- lib_logger.debug(
922
- f"Finally cleanup for '{Path(path).name}'. "
923
- f"Remaining unavailable: {len(self._unavailable_credentials)}"
924
- )
925
- self._refresh_queue.task_done()
 
926
  except asyncio.CancelledError:
927
- # [FIX PR#34] Clean up the current credential before breaking
928
  if path:
929
  async with self._queue_tracking_lock:
 
930
  self._unavailable_credentials.pop(path, None)
931
- lib_logger.debug(
932
- f"CancelledError cleanup for '{Path(path).name}'. "
933
- f"Remaining unavailable: {len(self._unavailable_credentials)}"
934
- )
935
  break
936
  except Exception as e:
937
- lib_logger.error(f"Error in queue processor: {e}")
938
- # Even on error, mark as available (backoff will prevent immediate retry)
939
  if path:
940
  async with self._queue_tracking_lock:
 
941
  self._unavailable_credentials.pop(path, None)
942
- lib_logger.debug(
943
- f"Error cleanup for '{Path(path).name}': {e}. "
944
- f"Remaining unavailable: {len(self._unavailable_credentials)}"
945
- )
946
 
947
  async def _perform_interactive_oauth(
948
  self, path: str, creds: Dict[str, Any], display_name: str
@@ -968,7 +1183,8 @@ class IFlowAuthBase:
968
  state = secrets.token_urlsafe(32)
969
 
970
  # Build authorization URL
971
- redirect_uri = f"http://localhost:{CALLBACK_PORT}/oauth2callback"
 
972
  auth_params = {
973
  "loginMethod": "phone",
974
  "type": "phone",
@@ -979,7 +1195,7 @@ class IFlowAuthBase:
979
  auth_url = f"{IFLOW_OAUTH_AUTHORIZE_ENDPOINT}?{urlencode(auth_params)}"
980
 
981
  # Start OAuth callback server
982
- callback_server = OAuthCallbackServer(port=CALLBACK_PORT)
983
  try:
984
  await callback_server.start(expected_state=state)
985
 
@@ -1182,3 +1398,261 @@ class IFlowAuthBase:
1182
  except Exception as e:
1183
  lib_logger.error(f"Failed to get iFlow user info from credentials: {e}")
1184
  return {"email": None}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  import webbrowser
10
  import socket
11
  import os
12
+ import re
13
+ from dataclasses import dataclass, field
14
  from pathlib import Path
15
+ from glob import glob
16
+ from typing import Dict, Any, Tuple, Union, Optional, List
17
  from urllib.parse import urlencode, parse_qs, urlparse
 
 
18
 
19
  import httpx
20
  from aiohttp import web
 
25
  from rich.markup import escape as rich_escape
26
  from ..utils.headless_detection import is_headless_environment
27
  from ..utils.reauth_coordinator import get_reauth_coordinator
28
+ from ..utils.resilient_io import safe_write_json
29
 
30
  lib_logger = logging.getLogger("rotator_library")
31
 
 
42
  # Local callback server port
43
  CALLBACK_PORT = 11451
44
 
45
+
46
+ @dataclass
47
+ class IFlowCredentialSetupResult:
48
+ """
49
+ Standardized result structure for iFlow credential setup operations.
50
+ """
51
+
52
+ success: bool
53
+ file_path: Optional[str] = None
54
+ email: Optional[str] = None
55
+ is_update: bool = False
56
+ error: Optional[str] = None
57
+ credentials: Optional[Dict[str, Any]] = field(default=None, repr=False)
58
+
59
+
60
+ def get_callback_port() -> int:
61
+ """
62
+ Get the OAuth callback port, checking environment variable first.
63
+
64
+ Reads from IFLOW_OAUTH_PORT environment variable, falling back
65
+ to the default CALLBACK_PORT if not set.
66
+ """
67
+ env_value = os.getenv("IFLOW_OAUTH_PORT")
68
+ if env_value:
69
+ try:
70
+ return int(env_value)
71
+ except ValueError:
72
+ logging.getLogger("rotator_library").warning(
73
+ f"Invalid IFLOW_OAUTH_PORT value: {env_value}, using default {CALLBACK_PORT}"
74
+ )
75
+ return CALLBACK_PORT
76
+
77
+
78
  # Refresh tokens 24 hours before expiry
79
  REFRESH_EXPIRY_BUFFER_SECONDS = 24 * 60 * 60
80
 
 
206
  str, float
207
  ] = {} # Track backoff timers (Unix timestamp)
208
 
209
+ # [QUEUE SYSTEM] Sequential refresh processing with two separate queues
210
+ # Normal refresh queue: for proactive token refresh (old token still valid)
211
  self._refresh_queue: asyncio.Queue = asyncio.Queue()
212
+ self._queue_processor_task: Optional[asyncio.Task] = None
213
+
214
+ # Re-auth queue: for invalid refresh tokens (requires user interaction)
215
+ self._reauth_queue: asyncio.Queue = asyncio.Queue()
216
+ self._reauth_processor_task: Optional[asyncio.Task] = None
217
+
218
+ # Tracking sets/dicts
219
+ self._queued_credentials: set = set() # Track credentials in either queue
220
+ # Only credentials in re-auth queue are marked unavailable (not normal refresh)
221
+ # TTL cleanup is defense-in-depth for edge cases where re-auth processor crashes
222
  self._unavailable_credentials: Dict[
223
  str, float
224
  ] = {} # Maps credential path -> timestamp when marked unavailable
225
+ # TTL should exceed reauth timeout (300s) to avoid premature cleanup
226
+ self._unavailable_ttl_seconds: int = 360 # 6 minutes TTL for stale entries
227
  self._queue_tracking_lock = asyncio.Lock() # Protects queue sets
228
+
229
+ # Retry tracking for normal refresh queue
230
+ self._queue_retry_count: Dict[
231
+ str, int
232
+ ] = {} # Track retry attempts per credential
233
+
234
+ # Configuration constants
235
+ self._refresh_timeout_seconds: int = 15 # Max time for single refresh
236
+ self._refresh_interval_seconds: int = 30 # Delay between queue items
237
+ self._refresh_max_retries: int = 3 # Attempts before kicked out
238
+ self._reauth_timeout_seconds: int = 300 # Time for user to complete OAuth
239
 
240
  def _parse_env_credential_path(self, path: str) -> Optional[str]:
241
  """
 
357
  f"Environment variables for iFlow credential index {credential_index} not found"
358
  )
359
 
360
+ # Try file-based loading first (preferred for explicit file paths)
361
+ try:
362
+ return await self._read_creds_from_file(path)
363
+ except IOError:
364
+ # File not found - fall back to legacy env vars for backwards compatibility
365
+ env_creds = self._load_from_env()
366
+ if env_creds:
367
+ lib_logger.info(
368
+ f"File '{path}' not found, using iFlow credentials from environment variables"
369
+ )
370
+ self._credentials_cache[path] = env_creds
371
+ return env_creds
372
+ raise # Re-raise the original file not found error
373
 
374
  async def _save_credentials(self, path: str, creds: Dict[str, Any]):
375
+ """Save credentials with in-memory fallback if disk unavailable."""
376
+ # Always update cache first (memory is reliable)
377
+ self._credentials_cache[path] = creds
378
+
379
  # Don't save to file if credentials were loaded from environment
380
  if creds.get("_proxy_metadata", {}).get("loaded_from_env"):
381
  lib_logger.debug("Credentials loaded from env, skipping file save")
 
 
382
  return
383
 
384
+ # Attempt disk write - if it fails, we still have the cache
385
+ # buffer_on_failure ensures data is retried periodically and saved on shutdown
386
+ if safe_write_json(
387
+ path, creds, lib_logger, secure_permissions=True, buffer_on_failure=True
388
+ ):
389
+ lib_logger.debug(f"Saved updated iFlow OAuth credentials to '{path}'.")
390
+ else:
391
+ lib_logger.warning(
392
+ "iFlow credentials cached in memory only (buffered for retry)."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  )
 
 
 
 
 
 
 
 
 
 
 
 
394
 
395
  def _is_token_expired(self, creds: Dict[str, Any]) -> bool:
396
  """Checks if the token is expired (with buffer for proactive refresh)."""
 
415
 
416
  return expiry_timestamp < time.time() + REFRESH_EXPIRY_BUFFER_SECONDS
417
 
418
+ def _is_token_truly_expired(self, creds: Dict[str, Any]) -> bool:
419
+ """Check if token is TRULY expired (past actual expiry, not just threshold).
420
+
421
+ This is different from _is_token_expired() which uses a buffer for proactive refresh.
422
+ This method checks if the token is actually unusable.
423
+ """
424
+ expiry_str = creds.get("expiry_date")
425
+ if not expiry_str:
426
+ return True
427
+
428
+ try:
429
+ from datetime import datetime
430
+
431
+ expiry_dt = datetime.fromisoformat(expiry_str.replace("Z", "+00:00"))
432
+ expiry_timestamp = expiry_dt.timestamp()
433
+ except (ValueError, AttributeError):
434
+ try:
435
+ expiry_timestamp = float(expiry_str)
436
+ except (ValueError, TypeError):
437
+ return True
438
+
439
+ return expiry_timestamp < time.time()
440
+
441
  async def _fetch_user_info(self, access_token: str) -> Dict[str, Any]:
442
  """
443
  Fetches user info (including API key) from iFlow API.
 
592
  )
593
  response.raise_for_status()
594
  new_token_data = response.json()
595
+
596
+ # [FIX] Handle wrapped response format: {success: bool, data: {...}}
597
+ # iFlow API may return tokens nested inside a 'data' key
598
+ if (
599
+ isinstance(new_token_data, dict)
600
+ and "data" in new_token_data
601
+ ):
602
+ lib_logger.debug(
603
+ f"iFlow refresh response wrapped in 'data' key, extracting..."
604
+ )
605
+ # Check for error in wrapped response
606
+ if not new_token_data.get("success", True):
607
+ error_msg = new_token_data.get(
608
+ "message", "Unknown error"
609
+ )
610
+ raise ValueError(
611
+ f"iFlow token refresh failed: {error_msg}"
612
+ )
613
+ new_token_data = new_token_data.get("data", {})
614
+
615
  break # Success
616
 
617
  except httpx.HTTPStatusError as e:
 
713
  # Update tokens
714
  access_token = new_token_data.get("access_token")
715
  if not access_token:
716
+ # Log response keys for debugging
717
+ response_keys = (
718
+ list(new_token_data.keys())
719
+ if isinstance(new_token_data, dict)
720
+ else type(new_token_data).__name__
721
+ )
722
+ lib_logger.error(
723
+ f"Missing access_token in refresh response for '{Path(path).name}'. "
724
+ f"Response keys: {response_keys}"
725
+ )
726
  raise ValueError("Missing access_token in refresh response")
727
 
728
  creds_from_file["access_token"] = access_token
 
818
  Proactively refreshes tokens if they're close to expiry.
819
  Only applies to OAuth credentials (file paths or env:// paths). Direct API keys are skipped.
820
  """
821
+ # lib_logger.debug(f"proactively_refresh called for: {credential_identifier}")
822
 
823
  # Try to load credentials - this will fail for direct API keys
824
  # and succeed for OAuth credentials (file paths or env:// paths)
 
826
  creds = await self._load_credentials(credential_identifier)
827
  except IOError as e:
828
  # Not a valid credential path (likely a direct API key string)
829
+ # lib_logger.debug(
830
+ # f"Skipping refresh for '{credential_identifier}' - not an OAuth credential: {e}"
831
+ # )
832
  return
833
 
834
  is_expired = self._is_token_expired(creds)
835
+ # lib_logger.debug(
836
+ # f"Token expired check for '{Path(credential_identifier).name}': {is_expired}"
837
+ # )
838
 
839
  if is_expired:
840
+ # lib_logger.debug(
841
+ # f"Queueing refresh for '{Path(credential_identifier).name}'"
842
+ # )
843
+ # lib_logger.info(f"Proactive refresh triggered for '{Path(credential_identifier).name}'")
844
  await self._queue_refresh(
845
  credential_identifier, force=False, needs_reauth=False
846
  )
 
854
  return self._refresh_locks[path]
855
 
856
  def is_credential_available(self, path: str) -> bool:
857
+ """Check if a credential is available for rotation.
858
 
859
+ Credentials are unavailable if:
860
+ 1. In re-auth queue (token is truly broken, requires user interaction)
861
+ 2. Token is TRULY expired (past actual expiry, not just threshold)
862
+
863
+ Note: Credentials in normal refresh queue are still available because
864
+ the old token is valid until actual expiry.
865
+
866
+ TTL cleanup (defense-in-depth): If a credential has been in the re-auth
867
+ queue longer than _unavailable_ttl_seconds without being processed, it's
868
+ cleaned up. This should only happen if the re-auth processor crashes or
869
+ is cancelled without proper cleanup.
870
  """
871
+ # Check if in re-auth queue (truly unavailable)
872
+ if path in self._unavailable_credentials:
873
+ marked_time = self._unavailable_credentials.get(path)
874
+ if marked_time is not None:
875
+ now = time.time()
876
+ if now - marked_time > self._unavailable_ttl_seconds:
877
+ # Entry is stale - clean it up and return available
878
+ # This is a defense-in-depth for edge cases where re-auth
879
+ # processor crashed or was cancelled without cleanup
880
+ lib_logger.warning(
881
+ f"Credential '{Path(path).name}' stuck in re-auth queue for "
882
+ f"{int(now - marked_time)}s (TTL: {self._unavailable_ttl_seconds}s). "
883
+ f"Re-auth processor may have crashed. Auto-cleaning stale entry."
884
+ )
885
+ # Clean up both tracking structures for consistency
886
+ self._unavailable_credentials.pop(path, None)
887
+ self._queued_credentials.discard(path)
888
+ else:
889
+ return False # Still in re-auth, not available
890
 
891
+ # Check if token is TRULY expired (not just threshold-expired)
892
+ creds = self._credentials_cache.get(path)
893
+ if creds and self._is_token_truly_expired(creds):
894
+ # Token is actually expired - should not be used
895
+ # Queue for refresh if not already queued
896
+ if path not in self._queued_credentials:
897
+ # lib_logger.debug(
898
+ # f"Credential '{Path(path).name}' is truly expired, queueing for refresh"
899
+ # )
900
+ asyncio.create_task(
901
+ self._queue_refresh(path, force=True, needs_reauth=False)
902
  )
903
+ return False
 
904
 
905
+ return True
906
 
907
  async def _ensure_queue_processor_running(self):
908
  """Lazily starts the queue processor if not already running."""
 
911
  self._process_refresh_queue()
912
  )
913
 
914
+ async def _ensure_reauth_processor_running(self):
915
+ """Lazily starts the re-auth queue processor if not already running."""
916
+ if self._reauth_processor_task is None or self._reauth_processor_task.done():
917
+ self._reauth_processor_task = asyncio.create_task(
918
+ self._process_reauth_queue()
919
+ )
920
+
921
  async def _queue_refresh(
922
  self, path: str, force: bool = False, needs_reauth: bool = False
923
  ):
924
+ """Add a credential to the appropriate refresh queue if not already queued.
925
 
926
  Args:
927
  path: Credential file path
928
  force: Force refresh even if not expired
929
+ needs_reauth: True if full re-authentication needed (routes to re-auth queue)
930
+
931
+ Queue routing:
932
+ - needs_reauth=True: Goes to re-auth queue, marks as unavailable
933
+ - needs_reauth=False: Goes to normal refresh queue, does NOT mark unavailable
934
+ (old token is still valid until actual expiry)
935
  """
936
  # IMPORTANT: Only check backoff for simple automated refreshes
937
  # Re-authentication (interactive OAuth) should BYPASS backoff since it needs user input
 
941
  backoff_until = self._next_refresh_after[path]
942
  if now < backoff_until:
943
  # Credential is in backoff for automated refresh, do not queue
944
+ # remaining = int(backoff_until - now)
945
+ # lib_logger.debug(
946
+ # f"Skipping automated refresh for '{Path(path).name}' (in backoff for {remaining}s)"
947
+ # )
948
  return
949
 
950
  async with self._queue_tracking_lock:
951
  if path not in self._queued_credentials:
952
  self._queued_credentials.add(path)
953
+
954
+ if needs_reauth:
955
+ # Re-auth queue: mark as unavailable (token is truly broken)
956
+ self._unavailable_credentials[path] = time.time()
957
+ # lib_logger.debug(
958
+ # f"Queued '{Path(path).name}' for RE-AUTH (marked unavailable). "
959
+ # f"Total unavailable: {len(self._unavailable_credentials)}"
960
+ # )
961
+ await self._reauth_queue.put(path)
962
+ await self._ensure_reauth_processor_running()
963
+ else:
964
+ # Normal refresh queue: do NOT mark unavailable (old token still valid)
965
+ # lib_logger.debug(
966
+ # f"Queued '{Path(path).name}' for refresh (still available). "
967
+ # f"Queue size: {self._refresh_queue.qsize() + 1}"
968
+ # )
969
+ await self._refresh_queue.put((path, force))
970
+ await self._ensure_queue_processor_running()
971
 
972
  async def _process_refresh_queue(self):
973
+ """Background worker that processes normal refresh requests sequentially.
974
+
975
+ Key behaviors:
976
+ - 15s timeout per refresh operation
977
+ - 30s delay between processing credentials (prevents thundering herd)
978
+ - On failure: back of queue, max 3 retries before kicked
979
+ - If 401/403 detected: routes to re-auth queue
980
+ - Does NOT mark credentials unavailable (old token still valid)
981
+ """
982
+ # lib_logger.info("Refresh queue processor started")
983
  while True:
984
  path = None
985
  try:
986
  # Wait for an item with timeout to allow graceful shutdown
987
  try:
988
+ path, force = await asyncio.wait_for(
989
  self._refresh_queue.get(), timeout=60.0
990
  )
991
  except asyncio.TimeoutError:
992
+ # Queue is empty and idle for 60s - clean up and exit
 
993
  async with self._queue_tracking_lock:
994
+ # Clear any stale retry counts
995
+ self._queue_retry_count.clear()
 
 
 
 
 
 
 
 
 
 
 
996
  self._queue_processor_task = None
997
+ # lib_logger.debug("Refresh queue processor idle, shutting down")
998
  return
999
 
1000
  try:
1001
+ # Quick check if still expired (optimization to avoid unnecessary refresh)
1002
+ creds = self._credentials_cache.get(path)
1003
+ if creds and not self._is_token_expired(creds):
1004
+ # No longer expired, skip refresh
1005
+ # lib_logger.debug(
1006
+ # f"Credential '{Path(path).name}' no longer expired, skipping refresh"
1007
+ # )
1008
+ # Clear retry count on skip (not a failure)
1009
+ self._queue_retry_count.pop(path, None)
1010
+ continue
1011
+
1012
+ # Perform refresh with timeout
1013
+ try:
1014
+ async with asyncio.timeout(self._refresh_timeout_seconds):
1015
+ await self._refresh_token(path, force=force)
1016
 
1017
+ # SUCCESS: Clear retry count
1018
+ self._queue_retry_count.pop(path, None)
1019
+ # lib_logger.info(f"Refresh SUCCESS for '{Path(path).name}'")
 
1020
 
1021
+ except asyncio.TimeoutError:
1022
+ lib_logger.warning(
1023
+ f"Refresh timeout ({self._refresh_timeout_seconds}s) for '{Path(path).name}'"
1024
+ )
1025
+ await self._handle_refresh_failure(path, force, "timeout")
1026
+
1027
+ except httpx.HTTPStatusError as e:
1028
+ status_code = e.response.status_code
1029
+ if status_code in (401, 403):
1030
+ # Invalid refresh token - route to re-auth queue
1031
+ lib_logger.warning(
1032
+ f"Refresh token invalid for '{Path(path).name}' (HTTP {status_code}). "
1033
+ f"Routing to re-auth queue."
1034
  )
1035
+ self._queue_retry_count.pop(path, None) # Clear retry count
1036
+ async with self._queue_tracking_lock:
1037
+ self._queued_credentials.discard(
1038
+ path
1039
+ ) # Remove from queued
1040
+ await self._queue_refresh(
1041
+ path, force=True, needs_reauth=True
1042
+ )
1043
+ else:
1044
+ await self._handle_refresh_failure(
1045
+ path, force, f"HTTP {status_code}"
1046
+ )
1047
+
1048
+ except Exception as e:
1049
+ await self._handle_refresh_failure(path, force, str(e))
1050
+
1051
+ finally:
1052
+ # Remove from queued set (unless re-queued by failure handler)
1053
+ async with self._queue_tracking_lock:
1054
+ # Only discard if not re-queued (check if still in queue set from retry)
1055
+ if (
1056
+ path in self._queued_credentials
1057
+ and self._queue_retry_count.get(path, 0) == 0
1058
+ ):
1059
+ self._queued_credentials.discard(path)
1060
+ self._refresh_queue.task_done()
1061
+
1062
+ # Wait between credentials to spread load
1063
+ await asyncio.sleep(self._refresh_interval_seconds)
1064
+
1065
+ except asyncio.CancelledError:
1066
+ # lib_logger.debug("Refresh queue processor cancelled")
1067
+ break
1068
+ except Exception as e:
1069
+ lib_logger.error(f"Error in refresh queue processor: {e}")
1070
+ if path:
1071
+ async with self._queue_tracking_lock:
1072
+ self._queued_credentials.discard(path)
1073
+
1074
+ async def _handle_refresh_failure(self, path: str, force: bool, error: str):
1075
+ """Handle a refresh failure with back-of-line retry logic.
1076
+
1077
+ - Increments retry count
1078
+ - If under max retries: re-adds to END of queue
1079
+ - If at max retries: kicks credential out (retried next BackgroundRefresher cycle)
1080
+ """
1081
+ retry_count = self._queue_retry_count.get(path, 0) + 1
1082
+ self._queue_retry_count[path] = retry_count
1083
+
1084
+ if retry_count >= self._refresh_max_retries:
1085
+ # Kicked out until next BackgroundRefresher cycle
1086
+ lib_logger.error(
1087
+ f"Max retries ({self._refresh_max_retries}) reached for '{Path(path).name}' "
1088
+ f"(last error: {error}). Will retry next refresh cycle."
1089
+ )
1090
+ self._queue_retry_count.pop(path, None)
1091
+ async with self._queue_tracking_lock:
1092
+ self._queued_credentials.discard(path)
1093
+ return
1094
+
1095
+ # Re-add to END of queue for retry
1096
+ lib_logger.warning(
1097
+ f"Refresh failed for '{Path(path).name}' ({error}). "
1098
+ f"Retry {retry_count}/{self._refresh_max_retries}, back of queue."
1099
+ )
1100
+ # Keep in queued_credentials set, add back to queue
1101
+ await self._refresh_queue.put((path, force))
1102
+
1103
+ async def _process_reauth_queue(self):
1104
+ """Background worker that processes re-auth requests.
1105
+
1106
+ Key behaviors:
1107
+ - Credentials ARE marked unavailable (token is truly broken)
1108
+ - Uses ReauthCoordinator for interactive OAuth
1109
+ - No automatic retry (requires user action)
1110
+ - Cleans up unavailable status when done
1111
+ """
1112
+ # lib_logger.info("Re-auth queue processor started")
1113
+ while True:
1114
+ path = None
1115
+ try:
1116
+ # Wait for an item with timeout to allow graceful shutdown
1117
+ try:
1118
+ path = await asyncio.wait_for(
1119
+ self._reauth_queue.get(), timeout=60.0
1120
+ )
1121
+ except asyncio.TimeoutError:
1122
+ # Queue is empty and idle for 60s - exit
1123
+ self._reauth_processor_task = None
1124
+ # lib_logger.debug("Re-auth queue processor idle, shutting down")
1125
+ return
1126
+
1127
+ try:
1128
+ lib_logger.info(f"Starting re-auth for '{Path(path).name}'...")
1129
+ await self.initialize_token(path)
1130
+ lib_logger.info(f"Re-auth SUCCESS for '{Path(path).name}'")
1131
+
1132
+ except Exception as e:
1133
+ lib_logger.error(f"Re-auth FAILED for '{Path(path).name}': {e}")
1134
+ # No automatic retry for re-auth (requires user action)
1135
 
1136
  finally:
1137
+ # Always clean up
 
1138
  async with self._queue_tracking_lock:
1139
  self._queued_credentials.discard(path)
 
1140
  self._unavailable_credentials.pop(path, None)
1141
+ # lib_logger.debug(
1142
+ # f"Re-auth cleanup for '{Path(path).name}'. "
1143
+ # f"Remaining unavailable: {len(self._unavailable_credentials)}"
1144
+ # )
1145
+ self._reauth_queue.task_done()
1146
+
1147
  except asyncio.CancelledError:
1148
+ # Clean up current credential before breaking
1149
  if path:
1150
  async with self._queue_tracking_lock:
1151
+ self._queued_credentials.discard(path)
1152
  self._unavailable_credentials.pop(path, None)
1153
+ # lib_logger.debug("Re-auth queue processor cancelled")
 
 
 
1154
  break
1155
  except Exception as e:
1156
+ lib_logger.error(f"Error in re-auth queue processor: {e}")
 
1157
  if path:
1158
  async with self._queue_tracking_lock:
1159
+ self._queued_credentials.discard(path)
1160
  self._unavailable_credentials.pop(path, None)
 
 
 
 
1161
 
1162
  async def _perform_interactive_oauth(
1163
  self, path: str, creds: Dict[str, Any], display_name: str
 
1183
  state = secrets.token_urlsafe(32)
1184
 
1185
  # Build authorization URL
1186
+ callback_port = get_callback_port()
1187
+ redirect_uri = f"http://localhost:{callback_port}/oauth2callback"
1188
  auth_params = {
1189
  "loginMethod": "phone",
1190
  "type": "phone",
 
1195
  auth_url = f"{IFLOW_OAUTH_AUTHORIZE_ENDPOINT}?{urlencode(auth_params)}"
1196
 
1197
  # Start OAuth callback server
1198
+ callback_server = OAuthCallbackServer(port=callback_port)
1199
  try:
1200
  await callback_server.start(expected_state=state)
1201
 
 
1398
  except Exception as e:
1399
  lib_logger.error(f"Failed to get iFlow user info from credentials: {e}")
1400
  return {"email": None}
1401
+
1402
+ # =========================================================================
1403
+ # CREDENTIAL MANAGEMENT METHODS
1404
+ # =========================================================================
1405
+
1406
+ def _get_provider_file_prefix(self) -> str:
1407
+ """Return the file prefix for iFlow credentials."""
1408
+ return "iflow"
1409
+
1410
+ def _get_oauth_base_dir(self) -> Path:
1411
+ """Get the base directory for OAuth credential files."""
1412
+ return Path.cwd() / "oauth_creds"
1413
+
1414
+ def _find_existing_credential_by_email(
1415
+ self, email: str, base_dir: Optional[Path] = None
1416
+ ) -> Optional[Path]:
1417
+ """Find an existing credential file for the given email."""
1418
+ if base_dir is None:
1419
+ base_dir = self._get_oauth_base_dir()
1420
+
1421
+ prefix = self._get_provider_file_prefix()
1422
+ pattern = str(base_dir / f"{prefix}_oauth_*.json")
1423
+
1424
+ for cred_file in glob(pattern):
1425
+ try:
1426
+ with open(cred_file, "r") as f:
1427
+ creds = json.load(f)
1428
+ existing_email = creds.get("email") or creds.get(
1429
+ "_proxy_metadata", {}
1430
+ ).get("email")
1431
+ if existing_email == email:
1432
+ return Path(cred_file)
1433
+ except (json.JSONDecodeError, IOError) as e:
1434
+ lib_logger.debug(f"Could not read credential file {cred_file}: {e}")
1435
+ continue
1436
+
1437
+ return None
1438
+
1439
+ def _get_next_credential_number(self, base_dir: Optional[Path] = None) -> int:
1440
+ """Get the next available credential number."""
1441
+ if base_dir is None:
1442
+ base_dir = self._get_oauth_base_dir()
1443
+
1444
+ prefix = self._get_provider_file_prefix()
1445
+ pattern = str(base_dir / f"{prefix}_oauth_*.json")
1446
+
1447
+ existing_numbers = []
1448
+ for cred_file in glob(pattern):
1449
+ match = re.search(r"_oauth_(\d+)\.json$", cred_file)
1450
+ if match:
1451
+ existing_numbers.append(int(match.group(1)))
1452
+
1453
+ if not existing_numbers:
1454
+ return 1
1455
+ return max(existing_numbers) + 1
1456
+
1457
+ def _build_credential_path(
1458
+ self, base_dir: Optional[Path] = None, number: Optional[int] = None
1459
+ ) -> Path:
1460
+ """Build a path for a new credential file."""
1461
+ if base_dir is None:
1462
+ base_dir = self._get_oauth_base_dir()
1463
+
1464
+ if number is None:
1465
+ number = self._get_next_credential_number(base_dir)
1466
+
1467
+ prefix = self._get_provider_file_prefix()
1468
+ filename = f"{prefix}_oauth_{number}.json"
1469
+ return base_dir / filename
1470
+
1471
+ async def setup_credential(
1472
+ self, base_dir: Optional[Path] = None
1473
+ ) -> IFlowCredentialSetupResult:
1474
+ """
1475
+ Complete credential setup flow: OAuth -> save.
1476
+
1477
+ This is the main entry point for setting up new credentials.
1478
+ """
1479
+ if base_dir is None:
1480
+ base_dir = self._get_oauth_base_dir()
1481
+
1482
+ # Ensure directory exists
1483
+ base_dir.mkdir(exist_ok=True)
1484
+
1485
+ try:
1486
+ # Step 1: Perform OAuth authentication
1487
+ temp_creds = {"_proxy_metadata": {"display_name": "new iFlow credential"}}
1488
+ new_creds = await self.initialize_token(temp_creds)
1489
+
1490
+ # Step 2: Get user info for deduplication
1491
+ email = new_creds.get("email") or new_creds.get("_proxy_metadata", {}).get(
1492
+ "email"
1493
+ )
1494
+
1495
+ if not email:
1496
+ return IFlowCredentialSetupResult(
1497
+ success=False, error="Could not retrieve email from OAuth response"
1498
+ )
1499
+
1500
+ # Step 3: Check for existing credential with same email
1501
+ existing_path = self._find_existing_credential_by_email(email, base_dir)
1502
+ is_update = existing_path is not None
1503
+
1504
+ if is_update:
1505
+ file_path = existing_path
1506
+ lib_logger.info(
1507
+ f"Found existing credential for {email}, updating {file_path.name}"
1508
+ )
1509
+ else:
1510
+ file_path = self._build_credential_path(base_dir)
1511
+ lib_logger.info(
1512
+ f"Creating new credential for {email} at {file_path.name}"
1513
+ )
1514
+
1515
+ # Step 4: Save credentials to file
1516
+ await self._save_credentials(str(file_path), new_creds)
1517
+
1518
+ return IFlowCredentialSetupResult(
1519
+ success=True,
1520
+ file_path=str(file_path),
1521
+ email=email,
1522
+ is_update=is_update,
1523
+ credentials=new_creds,
1524
+ )
1525
+
1526
+ except Exception as e:
1527
+ lib_logger.error(f"Credential setup failed: {e}")
1528
+ return IFlowCredentialSetupResult(success=False, error=str(e))
1529
+
1530
+ def build_env_lines(self, creds: Dict[str, Any], cred_number: int) -> List[str]:
1531
+ """Generate .env file lines for an iFlow credential."""
1532
+ email = creds.get("email") or creds.get("_proxy_metadata", {}).get(
1533
+ "email", "unknown"
1534
+ )
1535
+ prefix = f"IFLOW_{cred_number}"
1536
+
1537
+ lines = [
1538
+ f"# IFLOW Credential #{cred_number} for: {email}",
1539
+ f"# Exported from: iflow_oauth_{cred_number}.json",
1540
+ f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
1541
+ "#",
1542
+ "# To combine multiple credentials into one .env file, copy these lines",
1543
+ "# and ensure each credential has a unique number (1, 2, 3, etc.)",
1544
+ "",
1545
+ f"{prefix}_ACCESS_TOKEN={creds.get('access_token', '')}",
1546
+ f"{prefix}_REFRESH_TOKEN={creds.get('refresh_token', '')}",
1547
+ f"{prefix}_API_KEY={creds.get('api_key', '')}",
1548
+ f"{prefix}_EXPIRY_DATE={creds.get('expiry_date', '')}",
1549
+ f"{prefix}_EMAIL={email}",
1550
+ f"{prefix}_TOKEN_TYPE={creds.get('token_type', 'Bearer')}",
1551
+ f"{prefix}_SCOPE={creds.get('scope', 'read write')}",
1552
+ ]
1553
+
1554
+ return lines
1555
+
1556
+ def export_credential_to_env(
1557
+ self, credential_path: str, output_dir: Optional[Path] = None
1558
+ ) -> Optional[str]:
1559
+ """Export a credential file to .env format."""
1560
+ try:
1561
+ cred_path = Path(credential_path)
1562
+
1563
+ # Load credential
1564
+ with open(cred_path, "r") as f:
1565
+ creds = json.load(f)
1566
+
1567
+ # Extract metadata
1568
+ email = creds.get("email") or creds.get("_proxy_metadata", {}).get(
1569
+ "email", "unknown"
1570
+ )
1571
+
1572
+ # Get credential number from filename
1573
+ match = re.search(r"_oauth_(\d+)\.json$", cred_path.name)
1574
+ cred_number = int(match.group(1)) if match else 1
1575
+
1576
+ # Build output path
1577
+ if output_dir is None:
1578
+ output_dir = cred_path.parent
1579
+
1580
+ safe_email = email.replace("@", "_at_").replace(".", "_")
1581
+ env_filename = f"iflow_{cred_number}_{safe_email}.env"
1582
+ env_path = output_dir / env_filename
1583
+
1584
+ # Build and write content
1585
+ env_lines = self.build_env_lines(creds, cred_number)
1586
+ with open(env_path, "w") as f:
1587
+ f.write("\n".join(env_lines))
1588
+
1589
+ lib_logger.info(f"Exported credential to {env_path}")
1590
+ return str(env_path)
1591
+
1592
+ except Exception as e:
1593
+ lib_logger.error(f"Failed to export credential: {e}")
1594
+ return None
1595
+
1596
+ def list_credentials(self, base_dir: Optional[Path] = None) -> List[Dict[str, Any]]:
1597
+ """List all iFlow credential files."""
1598
+ if base_dir is None:
1599
+ base_dir = self._get_oauth_base_dir()
1600
+
1601
+ prefix = self._get_provider_file_prefix()
1602
+ pattern = str(base_dir / f"{prefix}_oauth_*.json")
1603
+
1604
+ credentials = []
1605
+ for cred_file in sorted(glob(pattern)):
1606
+ try:
1607
+ with open(cred_file, "r") as f:
1608
+ creds = json.load(f)
1609
+
1610
+ email = creds.get("email") or creds.get("_proxy_metadata", {}).get(
1611
+ "email", "unknown"
1612
+ )
1613
+
1614
+ # Extract number from filename
1615
+ match = re.search(r"_oauth_(\d+)\.json$", cred_file)
1616
+ number = int(match.group(1)) if match else 0
1617
+
1618
+ credentials.append(
1619
+ {
1620
+ "file_path": cred_file,
1621
+ "email": email,
1622
+ "number": number,
1623
+ }
1624
+ )
1625
+ except Exception as e:
1626
+ lib_logger.debug(f"Could not read credential file {cred_file}: {e}")
1627
+ continue
1628
+
1629
+ return credentials
1630
+
1631
+ def delete_credential(self, credential_path: str) -> bool:
1632
+ """Delete a credential file."""
1633
+ try:
1634
+ cred_path = Path(credential_path)
1635
+
1636
+ # Validate that it's one of our credential files
1637
+ prefix = self._get_provider_file_prefix()
1638
+ if not cred_path.name.startswith(f"{prefix}_oauth_"):
1639
+ lib_logger.error(
1640
+ f"File {cred_path.name} does not appear to be an iFlow credential"
1641
+ )
1642
+ return False
1643
+
1644
+ if not cred_path.exists():
1645
+ lib_logger.warning(f"Credential file does not exist: {credential_path}")
1646
+ return False
1647
+
1648
+ # Remove from cache if present
1649
+ self._credentials_cache.pop(credential_path, None)
1650
+
1651
+ # Delete the file
1652
+ cred_path.unlink()
1653
+ lib_logger.info(f"Deleted credential file: {credential_path}")
1654
+ return True
1655
+
1656
+ except Exception as e:
1657
+ lib_logger.error(f"Failed to delete credential: {e}")
1658
+ return False
src/rotator_library/providers/iflow_provider.py CHANGED
@@ -10,19 +10,27 @@ from typing import Union, AsyncGenerator, List, Dict, Any
10
  from .provider_interface import ProviderInterface
11
  from .iflow_auth_base import IFlowAuthBase
12
  from ..model_definitions import ModelDefinitions
 
 
13
  import litellm
14
  from litellm.exceptions import RateLimitError, AuthenticationError
15
  from pathlib import Path
16
  import uuid
17
  from datetime import datetime
18
 
19
- lib_logger = logging.getLogger('rotator_library')
 
 
 
 
 
 
 
20
 
21
- LOGS_DIR = Path(__file__).resolve().parent.parent.parent.parent / "logs"
22
- IFLOW_LOGS_DIR = LOGS_DIR / "iflow_logs"
23
 
24
  class _IFlowFileLogger:
25
  """A simple file logger for a single iFlow transaction."""
 
26
  def __init__(self, model_name: str, enabled: bool = True):
27
  self.enabled = enabled
28
  if not self.enabled:
@@ -31,8 +39,10 @@ class _IFlowFileLogger:
31
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
32
  request_id = str(uuid.uuid4())
33
  # Sanitize model name for directory
34
- safe_model_name = model_name.replace('/', '_').replace(':', '_')
35
- self.log_dir = IFLOW_LOGS_DIR / f"{timestamp}_{safe_model_name}_{request_id}"
 
 
36
  try:
37
  self.log_dir.mkdir(parents=True, exist_ok=True)
38
  except Exception as e:
@@ -41,16 +51,20 @@ class _IFlowFileLogger:
41
 
42
  def log_request(self, payload: Dict[str, Any]):
43
  """Logs the request payload sent to iFlow."""
44
- if not self.enabled: return
 
45
  try:
46
- with open(self.log_dir / "request_payload.json", "w", encoding="utf-8") as f:
 
 
47
  json.dump(payload, f, indent=2, ensure_ascii=False)
48
  except Exception as e:
49
  lib_logger.error(f"_IFlowFileLogger: Failed to write request: {e}")
50
 
51
  def log_response_chunk(self, chunk: str):
52
  """Logs a raw chunk from the iFlow response stream."""
53
- if not self.enabled: return
 
54
  try:
55
  with open(self.log_dir / "response_stream.log", "a", encoding="utf-8") as f:
56
  f.write(chunk + "\n")
@@ -59,7 +73,8 @@ class _IFlowFileLogger:
59
 
60
  def log_error(self, error_message: str):
61
  """Logs an error message."""
62
- if not self.enabled: return
 
63
  try:
64
  with open(self.log_dir / "error.log", "a", encoding="utf-8") as f:
65
  f.write(f"[{datetime.utcnow().isoformat()}] {error_message}\n")
@@ -68,13 +83,15 @@ class _IFlowFileLogger:
68
 
69
  def log_final_response(self, response_data: Dict[str, Any]):
70
  """Logs the final, reassembled response."""
71
- if not self.enabled: return
 
72
  try:
73
  with open(self.log_dir / "final_response.json", "w", encoding="utf-8") as f:
74
  json.dump(response_data, f, indent=2, ensure_ascii=False)
75
  except Exception as e:
76
  lib_logger.error(f"_IFlowFileLogger: Failed to write final response: {e}")
77
 
 
78
  # Model list can be expanded as iFlow supports more models
79
  HARDCODED_MODELS = [
80
  "glm-4.6",
@@ -90,14 +107,25 @@ HARDCODED_MODELS = [
90
  "deepseek-v3",
91
  "qwen3-vl-plus",
92
  "qwen3-235b-a22b-instruct",
93
- "qwen3-235b"
94
  ]
95
 
96
  # OpenAI-compatible parameters supported by iFlow API
97
  SUPPORTED_PARAMS = {
98
- 'model', 'messages', 'temperature', 'top_p', 'max_tokens',
99
- 'stream', 'tools', 'tool_choice', 'presence_penalty',
100
- 'frequency_penalty', 'n', 'stop', 'seed', 'response_format'
 
 
 
 
 
 
 
 
 
 
 
101
  }
102
 
103
 
@@ -106,6 +134,7 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
106
  iFlow provider using OAuth authentication with local callback server.
107
  API requests use the derived API key (NOT OAuth access_token).
108
  """
 
109
  skip_cost_calculation = True
110
 
111
  def __init__(self):
@@ -128,7 +157,9 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
128
  Validates OAuth credentials if applicable.
129
  """
130
  models = []
131
- env_var_ids = set() # Track IDs from env vars to prevent hardcoded/dynamic duplicates
 
 
132
 
133
  def extract_model_id(item) -> str:
134
  """Extract model ID from various formats (dict, string with/without provider prefix)."""
@@ -154,7 +185,9 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
154
  # Track the ID to prevent hardcoded/dynamic duplicates
155
  if model_id:
156
  env_var_ids.add(model_id)
157
- lib_logger.info(f"Loaded {len(static_models)} static models for iflow from environment variables")
 
 
158
 
159
  # Source 2: Add hardcoded models (only if ID not already in env vars)
160
  for model_id in HARDCODED_MODELS:
@@ -172,14 +205,17 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
172
  models_url = f"{api_base.rstrip('/')}/models"
173
 
174
  response = await client.get(
175
- models_url,
176
- headers={"Authorization": f"Bearer {api_key}"}
177
  )
178
  response.raise_for_status()
179
 
180
  dynamic_data = response.json()
181
  # Handle both {data: [...]} and direct [...] formats
182
- model_list = dynamic_data.get("data", dynamic_data) if isinstance(dynamic_data, dict) else dynamic_data
 
 
 
 
183
 
184
  dynamic_count = 0
185
  for model in model_list:
@@ -190,7 +226,9 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
190
  dynamic_count += 1
191
 
192
  if dynamic_count > 0:
193
- lib_logger.debug(f"Discovered {dynamic_count} additional models for iflow from API")
 
 
194
 
195
  except Exception as e:
196
  # Silently ignore dynamic discovery errors
@@ -255,7 +293,7 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
255
  payload = {k: v for k, v in kwargs.items() if k in SUPPORTED_PARAMS}
256
 
257
  # Always force streaming for internal processing
258
- payload['stream'] = True
259
 
260
  # NOTE: iFlow API does not support stream_options parameter
261
  # Unlike other providers, we don't include it to avoid HTTP 406 errors
@@ -264,16 +302,22 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
264
  if "tools" in payload and payload["tools"]:
265
  payload["tools"] = self._clean_tool_schemas(payload["tools"])
266
  lib_logger.debug(f"Cleaned {len(payload['tools'])} tool schemas")
267
- elif "tools" in payload and isinstance(payload["tools"], list) and len(payload["tools"]) == 0:
 
 
 
 
268
  # Inject dummy tool for empty arrays to prevent streaming issues (similar to Qwen's behavior)
269
- payload["tools"] = [{
270
- "type": "function",
271
- "function": {
272
- "name": "noop",
273
- "description": "Placeholder tool to stabilise streaming",
274
- "parameters": {"type": "object"}
 
 
275
  }
276
- }]
277
  lib_logger.debug("Injected placeholder tool for empty tools array")
278
 
279
  return payload
@@ -282,7 +326,7 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
282
  """
283
  Converts a raw iFlow SSE chunk to an OpenAI-compatible chunk.
284
  Since iFlow is OpenAI-compatible, minimal conversion is needed.
285
-
286
  CRITICAL FIX: Handle chunks with BOTH usage and choices (final chunk)
287
  without early return to ensure finish_reason is properly processed.
288
  """
@@ -302,32 +346,36 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
302
  "model": model_id,
303
  "object": "chat.completion.chunk",
304
  "id": chunk.get("id", f"chatcmpl-iflow-{time.time()}"),
305
- "created": chunk.get("created", int(time.time()))
306
  }
307
  # Then yield the usage chunk
308
  yield {
309
- "choices": [], "model": model_id, "object": "chat.completion.chunk",
 
 
310
  "id": chunk.get("id", f"chatcmpl-iflow-{time.time()}"),
311
  "created": chunk.get("created", int(time.time())),
312
  "usage": {
313
  "prompt_tokens": usage_data.get("prompt_tokens", 0),
314
  "completion_tokens": usage_data.get("completion_tokens", 0),
315
  "total_tokens": usage_data.get("total_tokens", 0),
316
- }
317
  }
318
  return
319
 
320
  # Handle usage-only chunks
321
  if usage_data:
322
  yield {
323
- "choices": [], "model": model_id, "object": "chat.completion.chunk",
 
 
324
  "id": chunk.get("id", f"chatcmpl-iflow-{time.time()}"),
325
  "created": chunk.get("created", int(time.time())),
326
  "usage": {
327
  "prompt_tokens": usage_data.get("prompt_tokens", 0),
328
  "completion_tokens": usage_data.get("completion_tokens", 0),
329
  "total_tokens": usage_data.get("total_tokens", 0),
330
- }
331
  }
332
  return
333
 
@@ -339,13 +387,15 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
339
  "model": model_id,
340
  "object": "chat.completion.chunk",
341
  "id": chunk.get("id", f"chatcmpl-iflow-{time.time()}"),
342
- "created": chunk.get("created", int(time.time()))
343
  }
344
 
345
- def _stream_to_completion_response(self, chunks: List[litellm.ModelResponse]) -> litellm.ModelResponse:
 
 
346
  """
347
  Manually reassembles streaming chunks into a complete response.
348
-
349
  Key improvements:
350
  - Determines finish_reason based on accumulated state (tool_calls vs stop)
351
  - Properly initializes tool_calls with type field
@@ -358,14 +408,16 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
358
  final_message = {"role": "assistant"}
359
  aggregated_tool_calls = {}
360
  usage_data = None
361
- chunk_finish_reason = None # Track finish_reason from chunks (but we'll override)
 
 
362
 
363
  # Get the first chunk for basic response metadata
364
  first_chunk = chunks[0]
365
 
366
  # Process each chunk to aggregate content
367
  for chunk in chunks:
368
- if not hasattr(chunk, 'choices') or not chunk.choices:
369
  continue
370
 
371
  choice = chunk.choices[0]
@@ -389,25 +441,48 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
389
  index = tc_chunk.get("index", 0)
390
  if index not in aggregated_tool_calls:
391
  # Initialize with type field for OpenAI compatibility
392
- aggregated_tool_calls[index] = {"type": "function", "function": {"name": "", "arguments": ""}}
 
 
 
393
  if "id" in tc_chunk:
394
  aggregated_tool_calls[index]["id"] = tc_chunk["id"]
395
  if "type" in tc_chunk:
396
  aggregated_tool_calls[index]["type"] = tc_chunk["type"]
397
  if "function" in tc_chunk:
398
- if "name" in tc_chunk["function"] and tc_chunk["function"]["name"] is not None:
399
- aggregated_tool_calls[index]["function"]["name"] += tc_chunk["function"]["name"]
400
- if "arguments" in tc_chunk["function"] and tc_chunk["function"]["arguments"] is not None:
401
- aggregated_tool_calls[index]["function"]["arguments"] += tc_chunk["function"]["arguments"]
 
 
 
 
 
 
 
 
 
 
402
 
403
  # Aggregate function calls (legacy format)
404
  if "function_call" in delta and delta["function_call"] is not None:
405
  if "function_call" not in final_message:
406
  final_message["function_call"] = {"name": "", "arguments": ""}
407
- if "name" in delta["function_call"] and delta["function_call"]["name"] is not None:
408
- final_message["function_call"]["name"] += delta["function_call"]["name"]
409
- if "arguments" in delta["function_call"] and delta["function_call"]["arguments"] is not None:
410
- final_message["function_call"]["arguments"] += delta["function_call"]["arguments"]
 
 
 
 
 
 
 
 
 
 
411
 
412
  # Track finish_reason from chunks (for reference only)
413
  if choice.get("finish_reason"):
@@ -415,7 +490,7 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
415
 
416
  # Handle usage data from the last chunk that has it
417
  for chunk in reversed(chunks):
418
- if hasattr(chunk, 'usage') and chunk.usage:
419
  usage_data = chunk.usage
420
  break
421
 
@@ -441,7 +516,7 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
441
  final_choice = {
442
  "index": 0,
443
  "message": final_message,
444
- "finish_reason": finish_reason
445
  }
446
 
447
  # Create the final ModelResponse
@@ -451,21 +526,20 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
451
  "created": first_chunk.created,
452
  "model": first_chunk.model,
453
  "choices": [final_choice],
454
- "usage": usage_data
455
  }
456
 
457
  return litellm.ModelResponse(**final_response_data)
458
 
459
- async def acompletion(self, client: httpx.AsyncClient, **kwargs) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]:
 
 
460
  credential_path = kwargs.pop("credential_identifier")
461
  enable_request_logging = kwargs.pop("enable_request_logging", False)
462
  model = kwargs["model"]
463
 
464
  # Create dedicated file logger for this request
465
- file_logger = _IFlowFileLogger(
466
- model_name=model,
467
- enabled=enable_request_logging
468
- )
469
 
470
  async def make_request():
471
  """Prepares and makes the actual API call."""
@@ -473,8 +547,8 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
473
  api_base, api_key = await self.get_api_details(credential_path)
474
 
475
  # Strip provider prefix from model name (e.g., "iflow/Qwen3-Coder-Plus" -> "Qwen3-Coder-Plus")
476
- model_name = model.split('/')[-1]
477
- kwargs_with_stripped_model = {**kwargs, 'model': model_name}
478
 
479
  # Build clean payload with only supported parameters
480
  payload = self._build_request_payload(**kwargs_with_stripped_model)
@@ -483,7 +557,7 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
483
  "Authorization": f"Bearer {api_key}", # Uses api_key from user info
484
  "Content-Type": "application/json",
485
  "Accept": "text/event-stream",
486
- "User-Agent": "iFlow-Cli"
487
  }
488
 
489
  url = f"{api_base.rstrip('/')}/chat/completions"
@@ -492,7 +566,13 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
492
  file_logger.log_request(payload)
493
  lib_logger.debug(f"iFlow Request URL: {url}")
494
 
495
- return client.stream("POST", url, headers=headers, json=payload, timeout=600)
 
 
 
 
 
 
496
 
497
  async def stream_handler(response_stream, attempt=1):
498
  """Handles the streaming response and converts chunks."""
@@ -501,11 +581,17 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
501
  # Check for HTTP errors before processing stream
502
  if response.status_code >= 400:
503
  error_text = await response.aread()
504
- error_text = error_text.decode('utf-8') if isinstance(error_text, bytes) else error_text
 
 
 
 
505
 
506
  # Handle 401: Force token refresh and retry once
507
  if response.status_code == 401 and attempt == 1:
508
- lib_logger.warning("iFlow returned 401. Forcing token refresh and retrying once.")
 
 
509
  await self._refresh_token(credential_path, force=True)
510
  retry_stream = await make_request()
511
  async for chunk in stream_handler(retry_stream, attempt=2):
@@ -513,50 +599,61 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
513
  return
514
 
515
  # Handle 429: Rate limit
516
- elif response.status_code == 429 or "slow_down" in error_text.lower():
 
 
 
517
  raise RateLimitError(
518
  f"iFlow rate limit exceeded: {error_text}",
519
  llm_provider="iflow",
520
  model=model,
521
- response=response
522
  )
523
 
524
  # Handle other errors
525
  else:
526
- error_msg = f"iFlow HTTP {response.status_code} error: {error_text}"
 
 
527
  file_logger.log_error(error_msg)
528
  raise httpx.HTTPStatusError(
529
  f"HTTP {response.status_code}: {error_text}",
530
  request=response.request,
531
- response=response
532
  )
533
 
534
  # Process successful streaming response
535
  async for line in response.aiter_lines():
536
  file_logger.log_response_chunk(line)
537
-
538
  # CRITICAL FIX: Handle both "data:" (no space) and "data: " (with space)
539
- if line.startswith('data:'):
540
  # Extract data after "data:" prefix, handling both formats
541
- if line.startswith('data: '):
542
  data_str = line[6:] # Skip "data: "
543
  else:
544
  data_str = line[5:] # Skip "data:"
545
-
546
  if data_str.strip() == "[DONE]":
547
  break
548
  try:
549
  chunk = json.loads(data_str)
550
- for openai_chunk in self._convert_chunk_to_openai(chunk, model):
 
 
551
  yield litellm.ModelResponse(**openai_chunk)
552
  except json.JSONDecodeError:
553
- lib_logger.warning(f"Could not decode JSON from iFlow: {line}")
 
 
554
 
555
  except httpx.HTTPStatusError:
556
  raise # Re-raise HTTP errors we already handled
557
  except Exception as e:
558
  file_logger.log_error(f"Error during iFlow stream processing: {e}")
559
- lib_logger.error(f"Error during iFlow stream processing: {e}", exc_info=True)
 
 
560
  raise
561
 
562
  async def logging_stream_wrapper():
@@ -574,7 +671,9 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
574
  if kwargs.get("stream"):
575
  return logging_stream_wrapper()
576
  else:
 
577
  async def non_stream_wrapper():
578
  chunks = [chunk async for chunk in logging_stream_wrapper()]
579
  return self._stream_to_completion_response(chunks)
 
580
  return await non_stream_wrapper()
 
10
  from .provider_interface import ProviderInterface
11
  from .iflow_auth_base import IFlowAuthBase
12
  from ..model_definitions import ModelDefinitions
13
+ from ..timeout_config import TimeoutConfig
14
+ from ..utils.paths import get_logs_dir
15
  import litellm
16
  from litellm.exceptions import RateLimitError, AuthenticationError
17
  from pathlib import Path
18
  import uuid
19
  from datetime import datetime
20
 
21
+ lib_logger = logging.getLogger("rotator_library")
22
+
23
+
24
+ def _get_iflow_logs_dir() -> Path:
25
+ """Get the iFlow logs directory."""
26
+ logs_dir = get_logs_dir() / "iflow_logs"
27
+ logs_dir.mkdir(parents=True, exist_ok=True)
28
+ return logs_dir
29
 
 
 
30
 
31
  class _IFlowFileLogger:
32
  """A simple file logger for a single iFlow transaction."""
33
+
34
  def __init__(self, model_name: str, enabled: bool = True):
35
  self.enabled = enabled
36
  if not self.enabled:
 
39
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
40
  request_id = str(uuid.uuid4())
41
  # Sanitize model name for directory
42
+ safe_model_name = model_name.replace("/", "_").replace(":", "_")
43
+ self.log_dir = (
44
+ _get_iflow_logs_dir() / f"{timestamp}_{safe_model_name}_{request_id}"
45
+ )
46
  try:
47
  self.log_dir.mkdir(parents=True, exist_ok=True)
48
  except Exception as e:
 
51
 
52
  def log_request(self, payload: Dict[str, Any]):
53
  """Logs the request payload sent to iFlow."""
54
+ if not self.enabled:
55
+ return
56
  try:
57
+ with open(
58
+ self.log_dir / "request_payload.json", "w", encoding="utf-8"
59
+ ) as f:
60
  json.dump(payload, f, indent=2, ensure_ascii=False)
61
  except Exception as e:
62
  lib_logger.error(f"_IFlowFileLogger: Failed to write request: {e}")
63
 
64
  def log_response_chunk(self, chunk: str):
65
  """Logs a raw chunk from the iFlow response stream."""
66
+ if not self.enabled:
67
+ return
68
  try:
69
  with open(self.log_dir / "response_stream.log", "a", encoding="utf-8") as f:
70
  f.write(chunk + "\n")
 
73
 
74
  def log_error(self, error_message: str):
75
  """Logs an error message."""
76
+ if not self.enabled:
77
+ return
78
  try:
79
  with open(self.log_dir / "error.log", "a", encoding="utf-8") as f:
80
  f.write(f"[{datetime.utcnow().isoformat()}] {error_message}\n")
 
83
 
84
  def log_final_response(self, response_data: Dict[str, Any]):
85
  """Logs the final, reassembled response."""
86
+ if not self.enabled:
87
+ return
88
  try:
89
  with open(self.log_dir / "final_response.json", "w", encoding="utf-8") as f:
90
  json.dump(response_data, f, indent=2, ensure_ascii=False)
91
  except Exception as e:
92
  lib_logger.error(f"_IFlowFileLogger: Failed to write final response: {e}")
93
 
94
+
95
  # Model list can be expanded as iFlow supports more models
96
  HARDCODED_MODELS = [
97
  "glm-4.6",
 
107
  "deepseek-v3",
108
  "qwen3-vl-plus",
109
  "qwen3-235b-a22b-instruct",
110
+ "qwen3-235b",
111
  ]
112
 
113
  # OpenAI-compatible parameters supported by iFlow API
114
  SUPPORTED_PARAMS = {
115
+ "model",
116
+ "messages",
117
+ "temperature",
118
+ "top_p",
119
+ "max_tokens",
120
+ "stream",
121
+ "tools",
122
+ "tool_choice",
123
+ "presence_penalty",
124
+ "frequency_penalty",
125
+ "n",
126
+ "stop",
127
+ "seed",
128
+ "response_format",
129
  }
130
 
131
 
 
134
  iFlow provider using OAuth authentication with local callback server.
135
  API requests use the derived API key (NOT OAuth access_token).
136
  """
137
+
138
  skip_cost_calculation = True
139
 
140
  def __init__(self):
 
157
  Validates OAuth credentials if applicable.
158
  """
159
  models = []
160
+ env_var_ids = (
161
+ set()
162
+ ) # Track IDs from env vars to prevent hardcoded/dynamic duplicates
163
 
164
  def extract_model_id(item) -> str:
165
  """Extract model ID from various formats (dict, string with/without provider prefix)."""
 
185
  # Track the ID to prevent hardcoded/dynamic duplicates
186
  if model_id:
187
  env_var_ids.add(model_id)
188
+ lib_logger.info(
189
+ f"Loaded {len(static_models)} static models for iflow from environment variables"
190
+ )
191
 
192
  # Source 2: Add hardcoded models (only if ID not already in env vars)
193
  for model_id in HARDCODED_MODELS:
 
205
  models_url = f"{api_base.rstrip('/')}/models"
206
 
207
  response = await client.get(
208
+ models_url, headers={"Authorization": f"Bearer {api_key}"}
 
209
  )
210
  response.raise_for_status()
211
 
212
  dynamic_data = response.json()
213
  # Handle both {data: [...]} and direct [...] formats
214
+ model_list = (
215
+ dynamic_data.get("data", dynamic_data)
216
+ if isinstance(dynamic_data, dict)
217
+ else dynamic_data
218
+ )
219
 
220
  dynamic_count = 0
221
  for model in model_list:
 
226
  dynamic_count += 1
227
 
228
  if dynamic_count > 0:
229
+ lib_logger.debug(
230
+ f"Discovered {dynamic_count} additional models for iflow from API"
231
+ )
232
 
233
  except Exception as e:
234
  # Silently ignore dynamic discovery errors
 
293
  payload = {k: v for k, v in kwargs.items() if k in SUPPORTED_PARAMS}
294
 
295
  # Always force streaming for internal processing
296
+ payload["stream"] = True
297
 
298
  # NOTE: iFlow API does not support stream_options parameter
299
  # Unlike other providers, we don't include it to avoid HTTP 406 errors
 
302
  if "tools" in payload and payload["tools"]:
303
  payload["tools"] = self._clean_tool_schemas(payload["tools"])
304
  lib_logger.debug(f"Cleaned {len(payload['tools'])} tool schemas")
305
+ elif (
306
+ "tools" in payload
307
+ and isinstance(payload["tools"], list)
308
+ and len(payload["tools"]) == 0
309
+ ):
310
  # Inject dummy tool for empty arrays to prevent streaming issues (similar to Qwen's behavior)
311
+ payload["tools"] = [
312
+ {
313
+ "type": "function",
314
+ "function": {
315
+ "name": "noop",
316
+ "description": "Placeholder tool to stabilise streaming",
317
+ "parameters": {"type": "object"},
318
+ },
319
  }
320
+ ]
321
  lib_logger.debug("Injected placeholder tool for empty tools array")
322
 
323
  return payload
 
326
  """
327
  Converts a raw iFlow SSE chunk to an OpenAI-compatible chunk.
328
  Since iFlow is OpenAI-compatible, minimal conversion is needed.
329
+
330
  CRITICAL FIX: Handle chunks with BOTH usage and choices (final chunk)
331
  without early return to ensure finish_reason is properly processed.
332
  """
 
346
  "model": model_id,
347
  "object": "chat.completion.chunk",
348
  "id": chunk.get("id", f"chatcmpl-iflow-{time.time()}"),
349
+ "created": chunk.get("created", int(time.time())),
350
  }
351
  # Then yield the usage chunk
352
  yield {
353
+ "choices": [],
354
+ "model": model_id,
355
+ "object": "chat.completion.chunk",
356
  "id": chunk.get("id", f"chatcmpl-iflow-{time.time()}"),
357
  "created": chunk.get("created", int(time.time())),
358
  "usage": {
359
  "prompt_tokens": usage_data.get("prompt_tokens", 0),
360
  "completion_tokens": usage_data.get("completion_tokens", 0),
361
  "total_tokens": usage_data.get("total_tokens", 0),
362
+ },
363
  }
364
  return
365
 
366
  # Handle usage-only chunks
367
  if usage_data:
368
  yield {
369
+ "choices": [],
370
+ "model": model_id,
371
+ "object": "chat.completion.chunk",
372
  "id": chunk.get("id", f"chatcmpl-iflow-{time.time()}"),
373
  "created": chunk.get("created", int(time.time())),
374
  "usage": {
375
  "prompt_tokens": usage_data.get("prompt_tokens", 0),
376
  "completion_tokens": usage_data.get("completion_tokens", 0),
377
  "total_tokens": usage_data.get("total_tokens", 0),
378
+ },
379
  }
380
  return
381
 
 
387
  "model": model_id,
388
  "object": "chat.completion.chunk",
389
  "id": chunk.get("id", f"chatcmpl-iflow-{time.time()}"),
390
+ "created": chunk.get("created", int(time.time())),
391
  }
392
 
393
+ def _stream_to_completion_response(
394
+ self, chunks: List[litellm.ModelResponse]
395
+ ) -> litellm.ModelResponse:
396
  """
397
  Manually reassembles streaming chunks into a complete response.
398
+
399
  Key improvements:
400
  - Determines finish_reason based on accumulated state (tool_calls vs stop)
401
  - Properly initializes tool_calls with type field
 
408
  final_message = {"role": "assistant"}
409
  aggregated_tool_calls = {}
410
  usage_data = None
411
+ chunk_finish_reason = (
412
+ None # Track finish_reason from chunks (but we'll override)
413
+ )
414
 
415
  # Get the first chunk for basic response metadata
416
  first_chunk = chunks[0]
417
 
418
  # Process each chunk to aggregate content
419
  for chunk in chunks:
420
+ if not hasattr(chunk, "choices") or not chunk.choices:
421
  continue
422
 
423
  choice = chunk.choices[0]
 
441
  index = tc_chunk.get("index", 0)
442
  if index not in aggregated_tool_calls:
443
  # Initialize with type field for OpenAI compatibility
444
+ aggregated_tool_calls[index] = {
445
+ "type": "function",
446
+ "function": {"name": "", "arguments": ""},
447
+ }
448
  if "id" in tc_chunk:
449
  aggregated_tool_calls[index]["id"] = tc_chunk["id"]
450
  if "type" in tc_chunk:
451
  aggregated_tool_calls[index]["type"] = tc_chunk["type"]
452
  if "function" in tc_chunk:
453
+ if (
454
+ "name" in tc_chunk["function"]
455
+ and tc_chunk["function"]["name"] is not None
456
+ ):
457
+ aggregated_tool_calls[index]["function"]["name"] += (
458
+ tc_chunk["function"]["name"]
459
+ )
460
+ if (
461
+ "arguments" in tc_chunk["function"]
462
+ and tc_chunk["function"]["arguments"] is not None
463
+ ):
464
+ aggregated_tool_calls[index]["function"]["arguments"] += (
465
+ tc_chunk["function"]["arguments"]
466
+ )
467
 
468
  # Aggregate function calls (legacy format)
469
  if "function_call" in delta and delta["function_call"] is not None:
470
  if "function_call" not in final_message:
471
  final_message["function_call"] = {"name": "", "arguments": ""}
472
+ if (
473
+ "name" in delta["function_call"]
474
+ and delta["function_call"]["name"] is not None
475
+ ):
476
+ final_message["function_call"]["name"] += delta["function_call"][
477
+ "name"
478
+ ]
479
+ if (
480
+ "arguments" in delta["function_call"]
481
+ and delta["function_call"]["arguments"] is not None
482
+ ):
483
+ final_message["function_call"]["arguments"] += delta[
484
+ "function_call"
485
+ ]["arguments"]
486
 
487
  # Track finish_reason from chunks (for reference only)
488
  if choice.get("finish_reason"):
 
490
 
491
  # Handle usage data from the last chunk that has it
492
  for chunk in reversed(chunks):
493
+ if hasattr(chunk, "usage") and chunk.usage:
494
  usage_data = chunk.usage
495
  break
496
 
 
516
  final_choice = {
517
  "index": 0,
518
  "message": final_message,
519
+ "finish_reason": finish_reason,
520
  }
521
 
522
  # Create the final ModelResponse
 
526
  "created": first_chunk.created,
527
  "model": first_chunk.model,
528
  "choices": [final_choice],
529
+ "usage": usage_data,
530
  }
531
 
532
  return litellm.ModelResponse(**final_response_data)
533
 
534
+ async def acompletion(
535
+ self, client: httpx.AsyncClient, **kwargs
536
+ ) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]:
537
  credential_path = kwargs.pop("credential_identifier")
538
  enable_request_logging = kwargs.pop("enable_request_logging", False)
539
  model = kwargs["model"]
540
 
541
  # Create dedicated file logger for this request
542
+ file_logger = _IFlowFileLogger(model_name=model, enabled=enable_request_logging)
 
 
 
543
 
544
  async def make_request():
545
  """Prepares and makes the actual API call."""
 
547
  api_base, api_key = await self.get_api_details(credential_path)
548
 
549
  # Strip provider prefix from model name (e.g., "iflow/Qwen3-Coder-Plus" -> "Qwen3-Coder-Plus")
550
+ model_name = model.split("/")[-1]
551
+ kwargs_with_stripped_model = {**kwargs, "model": model_name}
552
 
553
  # Build clean payload with only supported parameters
554
  payload = self._build_request_payload(**kwargs_with_stripped_model)
 
557
  "Authorization": f"Bearer {api_key}", # Uses api_key from user info
558
  "Content-Type": "application/json",
559
  "Accept": "text/event-stream",
560
+ "User-Agent": "iFlow-Cli",
561
  }
562
 
563
  url = f"{api_base.rstrip('/')}/chat/completions"
 
566
  file_logger.log_request(payload)
567
  lib_logger.debug(f"iFlow Request URL: {url}")
568
 
569
+ return client.stream(
570
+ "POST",
571
+ url,
572
+ headers=headers,
573
+ json=payload,
574
+ timeout=TimeoutConfig.streaming(),
575
+ )
576
 
577
  async def stream_handler(response_stream, attempt=1):
578
  """Handles the streaming response and converts chunks."""
 
581
  # Check for HTTP errors before processing stream
582
  if response.status_code >= 400:
583
  error_text = await response.aread()
584
+ error_text = (
585
+ error_text.decode("utf-8")
586
+ if isinstance(error_text, bytes)
587
+ else error_text
588
+ )
589
 
590
  # Handle 401: Force token refresh and retry once
591
  if response.status_code == 401 and attempt == 1:
592
+ lib_logger.warning(
593
+ "iFlow returned 401. Forcing token refresh and retrying once."
594
+ )
595
  await self._refresh_token(credential_path, force=True)
596
  retry_stream = await make_request()
597
  async for chunk in stream_handler(retry_stream, attempt=2):
 
599
  return
600
 
601
  # Handle 429: Rate limit
602
+ elif (
603
+ response.status_code == 429
604
+ or "slow_down" in error_text.lower()
605
+ ):
606
  raise RateLimitError(
607
  f"iFlow rate limit exceeded: {error_text}",
608
  llm_provider="iflow",
609
  model=model,
610
+ response=response,
611
  )
612
 
613
  # Handle other errors
614
  else:
615
+ error_msg = (
616
+ f"iFlow HTTP {response.status_code} error: {error_text}"
617
+ )
618
  file_logger.log_error(error_msg)
619
  raise httpx.HTTPStatusError(
620
  f"HTTP {response.status_code}: {error_text}",
621
  request=response.request,
622
+ response=response,
623
  )
624
 
625
  # Process successful streaming response
626
  async for line in response.aiter_lines():
627
  file_logger.log_response_chunk(line)
628
+
629
  # CRITICAL FIX: Handle both "data:" (no space) and "data: " (with space)
630
+ if line.startswith("data:"):
631
  # Extract data after "data:" prefix, handling both formats
632
+ if line.startswith("data: "):
633
  data_str = line[6:] # Skip "data: "
634
  else:
635
  data_str = line[5:] # Skip "data:"
636
+
637
  if data_str.strip() == "[DONE]":
638
  break
639
  try:
640
  chunk = json.loads(data_str)
641
+ for openai_chunk in self._convert_chunk_to_openai(
642
+ chunk, model
643
+ ):
644
  yield litellm.ModelResponse(**openai_chunk)
645
  except json.JSONDecodeError:
646
+ lib_logger.warning(
647
+ f"Could not decode JSON from iFlow: {line}"
648
+ )
649
 
650
  except httpx.HTTPStatusError:
651
  raise # Re-raise HTTP errors we already handled
652
  except Exception as e:
653
  file_logger.log_error(f"Error during iFlow stream processing: {e}")
654
+ lib_logger.error(
655
+ f"Error during iFlow stream processing: {e}", exc_info=True
656
+ )
657
  raise
658
 
659
  async def logging_stream_wrapper():
 
671
  if kwargs.get("stream"):
672
  return logging_stream_wrapper()
673
  else:
674
+
675
  async def non_stream_wrapper():
676
  chunks = [chunk async for chunk in logging_stream_wrapper()]
677
  return self._stream_to_completion_response(chunks)
678
+
679
  return await non_stream_wrapper()
src/rotator_library/providers/provider_cache.py CHANGED
@@ -20,19 +20,20 @@ import asyncio
20
  import json
21
  import logging
22
  import os
23
- import shutil
24
- import tempfile
25
  import time
26
  from pathlib import Path
27
  from typing import Any, Dict, Optional, Tuple
28
 
29
- lib_logger = logging.getLogger('rotator_library')
 
 
30
 
31
 
32
  # =============================================================================
33
  # UTILITY FUNCTIONS
34
  # =============================================================================
35
 
 
36
  def _env_bool(key: str, default: bool = False) -> bool:
37
  """Get boolean from environment variable."""
38
  return os.getenv(key, str(default).lower()).lower() in ("true", "1", "yes")
@@ -47,18 +48,19 @@ def _env_int(key: str, default: int) -> int:
47
  # PROVIDER CACHE CLASS
48
  # =============================================================================
49
 
 
50
  class ProviderCache:
51
  """
52
  Server-side cache for provider conversation state preservation.
53
-
54
  A generic, modular cache supporting any key-value data that providers need
55
  to persist across requests. Features:
56
-
57
  - Dual-TTL system: configurable memory TTL, longer disk TTL
58
  - Async disk persistence with batched writes
59
  - Background cleanup task for expired entries
60
  - Statistics tracking (hits, misses, writes)
61
-
62
  Args:
63
  cache_file: Path to disk cache file
64
  memory_ttl_seconds: In-memory entry lifetime (default: 1 hour)
@@ -67,13 +69,13 @@ class ProviderCache:
67
  write_interval: Seconds between background disk writes (default: 60)
68
  cleanup_interval: Seconds between expired entry cleanup (default: 30 min)
69
  env_prefix: Environment variable prefix for configuration overrides
70
-
71
  Environment Variables (with default prefix "PROVIDER_CACHE"):
72
  {PREFIX}_ENABLE: Enable/disable disk persistence
73
  {PREFIX}_WRITE_INTERVAL: Background write interval in seconds
74
  {PREFIX}_CLEANUP_INTERVAL: Cleanup interval in seconds
75
  """
76
-
77
  def __init__(
78
  self,
79
  cache_file: Path,
@@ -82,7 +84,7 @@ class ProviderCache:
82
  enable_disk: Optional[bool] = None,
83
  write_interval: Optional[int] = None,
84
  cleanup_interval: Optional[int] = None,
85
- env_prefix: str = "PROVIDER_CACHE"
86
  ):
87
  # In-memory cache: {cache_key: (data, timestamp)}
88
  self._cache: Dict[str, Tuple[str, float]] = {}
@@ -90,25 +92,42 @@ class ProviderCache:
90
  self._disk_ttl = disk_ttl_seconds
91
  self._lock = asyncio.Lock()
92
  self._disk_lock = asyncio.Lock()
93
-
94
  # Disk persistence configuration
95
  self._cache_file = cache_file
96
- self._enable_disk = enable_disk if enable_disk is not None else _env_bool(f"{env_prefix}_ENABLE", True)
 
 
 
 
97
  self._dirty = False
98
- self._write_interval = write_interval or _env_int(f"{env_prefix}_WRITE_INTERVAL", 60)
99
- self._cleanup_interval = cleanup_interval or _env_int(f"{env_prefix}_CLEANUP_INTERVAL", 1800)
100
-
 
 
 
 
101
  # Background tasks
102
  self._writer_task: Optional[asyncio.Task] = None
103
  self._cleanup_task: Optional[asyncio.Task] = None
104
  self._running = False
105
-
106
  # Statistics
107
- self._stats = {"memory_hits": 0, "disk_hits": 0, "misses": 0, "writes": 0}
108
-
 
 
 
 
 
 
 
 
 
109
  # Metadata about this cache instance
110
  self._cache_name = cache_file.stem if cache_file else "unnamed"
111
-
112
  if self._enable_disk:
113
  lib_logger.debug(
114
  f"ProviderCache[{self._cache_name}]: Disk enabled "
@@ -117,123 +136,120 @@ class ProviderCache:
117
  asyncio.create_task(self._async_init())
118
  else:
119
  lib_logger.debug(f"ProviderCache[{self._cache_name}]: Memory-only mode")
120
-
121
  # =========================================================================
122
  # INITIALIZATION
123
  # =========================================================================
124
-
125
  async def _async_init(self) -> None:
126
  """Async initialization: load from disk and start background tasks."""
127
  try:
128
  await self._load_from_disk()
129
  await self._start_background_tasks()
130
  except Exception as e:
131
- lib_logger.error(f"ProviderCache[{self._cache_name}] async init failed: {e}")
132
-
 
 
133
  async def _load_from_disk(self) -> None:
134
  """Load cache from disk file with TTL validation."""
135
  if not self._enable_disk or not self._cache_file.exists():
136
  return
137
-
138
  try:
139
  async with self._disk_lock:
140
- with open(self._cache_file, 'r', encoding='utf-8') as f:
141
  data = json.load(f)
142
-
143
  if data.get("version") != "1.0":
144
- lib_logger.warning(f"ProviderCache[{self._cache_name}]: Version mismatch, starting fresh")
 
 
145
  return
146
-
147
  now = time.time()
148
  entries = data.get("entries", {})
149
  loaded = expired = 0
150
-
151
  for cache_key, entry in entries.items():
152
  age = now - entry.get("timestamp", 0)
153
  if age <= self._disk_ttl:
154
- value = entry.get("value", entry.get("signature", "")) # Support both formats
 
 
155
  if value:
156
  self._cache[cache_key] = (value, entry["timestamp"])
157
  loaded += 1
158
  else:
159
  expired += 1
160
-
161
  lib_logger.debug(
162
  f"ProviderCache[{self._cache_name}]: Loaded {loaded} entries ({expired} expired)"
163
  )
164
  except json.JSONDecodeError as e:
165
- lib_logger.warning(f"ProviderCache[{self._cache_name}]: File corrupted: {e}")
 
 
166
  except Exception as e:
167
  lib_logger.error(f"ProviderCache[{self._cache_name}]: Load failed: {e}")
168
-
169
  # =========================================================================
170
  # DISK PERSISTENCE
171
  # =========================================================================
172
-
173
- async def _save_to_disk(self) -> None:
174
- """Persist cache to disk using atomic write."""
 
 
 
 
175
  if not self._enable_disk:
176
- return
177
-
178
- try:
179
- async with self._disk_lock:
180
- self._cache_file.parent.mkdir(parents=True, exist_ok=True)
181
-
182
- cache_data = {
183
- "version": "1.0",
184
- "memory_ttl_seconds": self._memory_ttl,
185
- "disk_ttl_seconds": self._disk_ttl,
186
- "entries": {
187
- key: {"value": val, "timestamp": ts}
188
- for key, (val, ts) in self._cache.items()
189
- },
190
- "statistics": {
191
- "total_entries": len(self._cache),
192
- "last_write": time.time(),
193
- **self._stats
194
- }
195
- }
196
-
197
- # Atomic write using temp file
198
- parent_dir = self._cache_file.parent
199
- tmp_fd, tmp_path = tempfile.mkstemp(dir=parent_dir, prefix='.tmp_', suffix='.json')
200
-
201
- try:
202
- with os.fdopen(tmp_fd, 'w', encoding='utf-8') as f:
203
- json.dump(cache_data, f, indent=2)
204
-
205
- # Set restrictive permissions (if supported)
206
- try:
207
- os.chmod(tmp_path, 0o600)
208
- except (OSError, AttributeError):
209
- pass
210
-
211
- shutil.move(tmp_path, self._cache_file)
212
- self._stats["writes"] += 1
213
- lib_logger.debug(
214
- f"ProviderCache[{self._cache_name}]: Saved {len(self._cache)} entries"
215
- )
216
- except Exception:
217
- if tmp_path and os.path.exists(tmp_path):
218
- os.unlink(tmp_path)
219
- raise
220
- except Exception as e:
221
- lib_logger.error(f"ProviderCache[{self._cache_name}]: Disk save failed: {e}")
222
-
223
  # =========================================================================
224
  # BACKGROUND TASKS
225
  # =========================================================================
226
-
227
  async def _start_background_tasks(self) -> None:
228
  """Start background writer and cleanup tasks."""
229
  if not self._enable_disk or self._running:
230
  return
231
-
232
  self._running = True
233
  self._writer_task = asyncio.create_task(self._writer_loop())
234
  self._cleanup_task = asyncio.create_task(self._cleanup_loop())
235
  lib_logger.debug(f"ProviderCache[{self._cache_name}]: Started background tasks")
236
-
237
  async def _writer_loop(self) -> None:
238
  """Background task: periodically flush dirty cache to disk."""
239
  try:
@@ -241,13 +257,17 @@ class ProviderCache:
241
  await asyncio.sleep(self._write_interval)
242
  if self._dirty:
243
  try:
244
- await self._save_to_disk()
245
- self._dirty = False
 
 
246
  except Exception as e:
247
- lib_logger.error(f"ProviderCache[{self._cache_name}]: Writer error: {e}")
 
 
248
  except asyncio.CancelledError:
249
  pass
250
-
251
  async def _cleanup_loop(self) -> None:
252
  """Background task: periodically clean up expired entries."""
253
  try:
@@ -256,12 +276,14 @@ class ProviderCache:
256
  await self._cleanup_expired()
257
  except asyncio.CancelledError:
258
  pass
259
-
260
  async def _cleanup_expired(self) -> None:
261
  """Remove expired entries from memory cache."""
262
  async with self._lock:
263
  now = time.time()
264
- expired = [k for k, (_, ts) in self._cache.items() if now - ts > self._memory_ttl]
 
 
265
  for k in expired:
266
  del self._cache[k]
267
  if expired:
@@ -269,42 +291,42 @@ class ProviderCache:
269
  lib_logger.debug(
270
  f"ProviderCache[{self._cache_name}]: Cleaned {len(expired)} expired entries"
271
  )
272
-
273
  # =========================================================================
274
  # CORE OPERATIONS
275
  # =========================================================================
276
-
277
  def store(self, key: str, value: str) -> None:
278
  """
279
  Store a value synchronously (schedules async storage).
280
-
281
  Args:
282
  key: Cache key
283
  value: Value to store (typically JSON-serialized data)
284
  """
285
  asyncio.create_task(self._async_store(key, value))
286
-
287
  async def _async_store(self, key: str, value: str) -> None:
288
  """Async implementation of store."""
289
  async with self._lock:
290
  self._cache[key] = (value, time.time())
291
  self._dirty = True
292
-
293
  async def store_async(self, key: str, value: str) -> None:
294
  """
295
  Store a value asynchronously (awaitable).
296
-
297
  Use this when you need to ensure the value is stored before continuing.
298
  """
299
  await self._async_store(key, value)
300
-
301
  def retrieve(self, key: str) -> Optional[str]:
302
  """
303
  Retrieve a value by key (synchronous, with optional async disk fallback).
304
-
305
  Args:
306
  key: Cache key
307
-
308
  Returns:
309
  Cached value if found and not expired, None otherwise
310
  """
@@ -316,17 +338,17 @@ class ProviderCache:
316
  else:
317
  del self._cache[key]
318
  self._dirty = True
319
-
320
  self._stats["misses"] += 1
321
  if self._enable_disk:
322
  # Schedule async disk lookup for next time
323
  asyncio.create_task(self._check_disk_fallback(key))
324
  return None
325
-
326
  async def retrieve_async(self, key: str) -> Optional[str]:
327
  """
328
  Retrieve a value asynchronously (checks disk if not in memory).
329
-
330
  Use this when you can await and need guaranteed disk fallback.
331
  """
332
  # Check memory first
@@ -340,24 +362,24 @@ class ProviderCache:
340
  if key in self._cache:
341
  del self._cache[key]
342
  self._dirty = True
343
-
344
  # Check disk
345
  if self._enable_disk:
346
  return await self._disk_retrieve(key)
347
-
348
  self._stats["misses"] += 1
349
  return None
350
-
351
  async def _check_disk_fallback(self, key: str) -> None:
352
  """Check disk for key and load into memory if found (background)."""
353
  try:
354
  if not self._cache_file.exists():
355
  return
356
-
357
  async with self._disk_lock:
358
- with open(self._cache_file, 'r', encoding='utf-8') as f:
359
  data = json.load(f)
360
-
361
  entries = data.get("entries", {})
362
  if key in entries:
363
  entry = entries[key]
@@ -372,19 +394,21 @@ class ProviderCache:
372
  f"ProviderCache[{self._cache_name}]: Loaded {key} from disk"
373
  )
374
  except Exception as e:
375
- lib_logger.debug(f"ProviderCache[{self._cache_name}]: Disk fallback failed: {e}")
376
-
 
 
377
  async def _disk_retrieve(self, key: str) -> Optional[str]:
378
  """Direct disk retrieval with loading into memory."""
379
  try:
380
  if not self._cache_file.exists():
381
  self._stats["misses"] += 1
382
  return None
383
-
384
  async with self._disk_lock:
385
- with open(self._cache_file, 'r', encoding='utf-8') as f:
386
  data = json.load(f)
387
-
388
  entries = data.get("entries", {})
389
  if key in entries:
390
  entry = entries[key]
@@ -396,34 +420,37 @@ class ProviderCache:
396
  self._cache[key] = (value, ts)
397
  self._stats["disk_hits"] += 1
398
  return value
399
-
400
  self._stats["misses"] += 1
401
  return None
402
  except Exception as e:
403
- lib_logger.debug(f"ProviderCache[{self._cache_name}]: Disk retrieve failed: {e}")
 
 
404
  self._stats["misses"] += 1
405
  return None
406
-
407
  # =========================================================================
408
  # UTILITY METHODS
409
  # =========================================================================
410
-
411
  def contains(self, key: str) -> bool:
412
  """Check if key exists in memory cache (without updating stats)."""
413
  if key in self._cache:
414
  _, timestamp = self._cache[key]
415
  return time.time() - timestamp <= self._memory_ttl
416
  return False
417
-
418
  def get_stats(self) -> Dict[str, Any]:
419
- """Get cache statistics."""
420
  return {
421
  **self._stats,
422
  "memory_entries": len(self._cache),
423
  "dirty": self._dirty,
424
- "disk_enabled": self._enable_disk
 
425
  }
426
-
427
  async def clear(self) -> None:
428
  """Clear all cached data."""
429
  async with self._lock:
@@ -431,12 +458,12 @@ class ProviderCache:
431
  self._dirty = True
432
  if self._enable_disk:
433
  await self._save_to_disk()
434
-
435
  async def shutdown(self) -> None:
436
  """Graceful shutdown: flush pending writes and stop background tasks."""
437
  lib_logger.info(f"ProviderCache[{self._cache_name}]: Shutting down...")
438
  self._running = False
439
-
440
  # Cancel background tasks
441
  for task in (self._writer_task, self._cleanup_task):
442
  if task:
@@ -445,11 +472,11 @@ class ProviderCache:
445
  await task
446
  except asyncio.CancelledError:
447
  pass
448
-
449
  # Final save
450
  if self._dirty and self._enable_disk:
451
  await self._save_to_disk()
452
-
453
  lib_logger.info(
454
  f"ProviderCache[{self._cache_name}]: Shutdown complete "
455
  f"(stats: mem_hits={self._stats['memory_hits']}, "
@@ -461,38 +488,39 @@ class ProviderCache:
461
  # CONVENIENCE FACTORY
462
  # =============================================================================
463
 
 
464
  def create_provider_cache(
465
  name: str,
466
  cache_dir: Optional[Path] = None,
467
  memory_ttl_seconds: int = 3600,
468
  disk_ttl_seconds: int = 86400,
469
- env_prefix: Optional[str] = None
470
  ) -> ProviderCache:
471
  """
472
  Factory function to create a provider cache with sensible defaults.
473
-
474
  Args:
475
  name: Cache name (used as filename and for logging)
476
  cache_dir: Directory for cache file (default: project_root/cache/provider_name)
477
  memory_ttl_seconds: In-memory TTL
478
  disk_ttl_seconds: Disk TTL
479
  env_prefix: Environment variable prefix (default: derived from name)
480
-
481
  Returns:
482
  Configured ProviderCache instance
483
  """
484
  if cache_dir is None:
485
  cache_dir = Path(__file__).resolve().parent.parent.parent.parent / "cache"
486
-
487
  cache_file = cache_dir / f"{name}.json"
488
-
489
  if env_prefix is None:
490
  # Convert name to env prefix: "gemini3_signatures" -> "GEMINI3_SIGNATURES_CACHE"
491
  env_prefix = f"{name.upper().replace('-', '_')}_CACHE"
492
-
493
  return ProviderCache(
494
  cache_file=cache_file,
495
  memory_ttl_seconds=memory_ttl_seconds,
496
  disk_ttl_seconds=disk_ttl_seconds,
497
- env_prefix=env_prefix
498
  )
 
20
  import json
21
  import logging
22
  import os
 
 
23
  import time
24
  from pathlib import Path
25
  from typing import Any, Dict, Optional, Tuple
26
 
27
+ from ..utils.resilient_io import safe_write_json
28
+
29
+ lib_logger = logging.getLogger("rotator_library")
30
 
31
 
32
  # =============================================================================
33
  # UTILITY FUNCTIONS
34
  # =============================================================================
35
 
36
+
37
  def _env_bool(key: str, default: bool = False) -> bool:
38
  """Get boolean from environment variable."""
39
  return os.getenv(key, str(default).lower()).lower() in ("true", "1", "yes")
 
48
  # PROVIDER CACHE CLASS
49
  # =============================================================================
50
 
51
+
52
  class ProviderCache:
53
  """
54
  Server-side cache for provider conversation state preservation.
55
+
56
  A generic, modular cache supporting any key-value data that providers need
57
  to persist across requests. Features:
58
+
59
  - Dual-TTL system: configurable memory TTL, longer disk TTL
60
  - Async disk persistence with batched writes
61
  - Background cleanup task for expired entries
62
  - Statistics tracking (hits, misses, writes)
63
+
64
  Args:
65
  cache_file: Path to disk cache file
66
  memory_ttl_seconds: In-memory entry lifetime (default: 1 hour)
 
69
  write_interval: Seconds between background disk writes (default: 60)
70
  cleanup_interval: Seconds between expired entry cleanup (default: 30 min)
71
  env_prefix: Environment variable prefix for configuration overrides
72
+
73
  Environment Variables (with default prefix "PROVIDER_CACHE"):
74
  {PREFIX}_ENABLE: Enable/disable disk persistence
75
  {PREFIX}_WRITE_INTERVAL: Background write interval in seconds
76
  {PREFIX}_CLEANUP_INTERVAL: Cleanup interval in seconds
77
  """
78
+
79
  def __init__(
80
  self,
81
  cache_file: Path,
 
84
  enable_disk: Optional[bool] = None,
85
  write_interval: Optional[int] = None,
86
  cleanup_interval: Optional[int] = None,
87
+ env_prefix: str = "PROVIDER_CACHE",
88
  ):
89
  # In-memory cache: {cache_key: (data, timestamp)}
90
  self._cache: Dict[str, Tuple[str, float]] = {}
 
92
  self._disk_ttl = disk_ttl_seconds
93
  self._lock = asyncio.Lock()
94
  self._disk_lock = asyncio.Lock()
95
+
96
  # Disk persistence configuration
97
  self._cache_file = cache_file
98
+ self._enable_disk = (
99
+ enable_disk
100
+ if enable_disk is not None
101
+ else _env_bool(f"{env_prefix}_ENABLE", True)
102
+ )
103
  self._dirty = False
104
+ self._write_interval = write_interval or _env_int(
105
+ f"{env_prefix}_WRITE_INTERVAL", 60
106
+ )
107
+ self._cleanup_interval = cleanup_interval or _env_int(
108
+ f"{env_prefix}_CLEANUP_INTERVAL", 1800
109
+ )
110
+
111
  # Background tasks
112
  self._writer_task: Optional[asyncio.Task] = None
113
  self._cleanup_task: Optional[asyncio.Task] = None
114
  self._running = False
115
+
116
  # Statistics
117
+ self._stats = {
118
+ "memory_hits": 0,
119
+ "disk_hits": 0,
120
+ "misses": 0,
121
+ "writes": 0,
122
+ "disk_errors": 0,
123
+ }
124
+
125
+ # Track disk health for monitoring
126
+ self._disk_available = True
127
+
128
  # Metadata about this cache instance
129
  self._cache_name = cache_file.stem if cache_file else "unnamed"
130
+
131
  if self._enable_disk:
132
  lib_logger.debug(
133
  f"ProviderCache[{self._cache_name}]: Disk enabled "
 
136
  asyncio.create_task(self._async_init())
137
  else:
138
  lib_logger.debug(f"ProviderCache[{self._cache_name}]: Memory-only mode")
139
+
140
  # =========================================================================
141
  # INITIALIZATION
142
  # =========================================================================
143
+
144
  async def _async_init(self) -> None:
145
  """Async initialization: load from disk and start background tasks."""
146
  try:
147
  await self._load_from_disk()
148
  await self._start_background_tasks()
149
  except Exception as e:
150
+ lib_logger.error(
151
+ f"ProviderCache[{self._cache_name}] async init failed: {e}"
152
+ )
153
+
154
  async def _load_from_disk(self) -> None:
155
  """Load cache from disk file with TTL validation."""
156
  if not self._enable_disk or not self._cache_file.exists():
157
  return
158
+
159
  try:
160
  async with self._disk_lock:
161
+ with open(self._cache_file, "r", encoding="utf-8") as f:
162
  data = json.load(f)
163
+
164
  if data.get("version") != "1.0":
165
+ lib_logger.warning(
166
+ f"ProviderCache[{self._cache_name}]: Version mismatch, starting fresh"
167
+ )
168
  return
169
+
170
  now = time.time()
171
  entries = data.get("entries", {})
172
  loaded = expired = 0
173
+
174
  for cache_key, entry in entries.items():
175
  age = now - entry.get("timestamp", 0)
176
  if age <= self._disk_ttl:
177
+ value = entry.get(
178
+ "value", entry.get("signature", "")
179
+ ) # Support both formats
180
  if value:
181
  self._cache[cache_key] = (value, entry["timestamp"])
182
  loaded += 1
183
  else:
184
  expired += 1
185
+
186
  lib_logger.debug(
187
  f"ProviderCache[{self._cache_name}]: Loaded {loaded} entries ({expired} expired)"
188
  )
189
  except json.JSONDecodeError as e:
190
+ lib_logger.warning(
191
+ f"ProviderCache[{self._cache_name}]: File corrupted: {e}"
192
+ )
193
  except Exception as e:
194
  lib_logger.error(f"ProviderCache[{self._cache_name}]: Load failed: {e}")
195
+
196
  # =========================================================================
197
  # DISK PERSISTENCE
198
  # =========================================================================
199
+
200
+ async def _save_to_disk(self) -> bool:
201
+ """Persist cache to disk using atomic write with health tracking.
202
+
203
+ Returns:
204
+ True if write succeeded, False otherwise.
205
+ """
206
  if not self._enable_disk:
207
+ return True # Not an error if disk is disabled
208
+
209
+ async with self._disk_lock:
210
+ cache_data = {
211
+ "version": "1.0",
212
+ "memory_ttl_seconds": self._memory_ttl,
213
+ "disk_ttl_seconds": self._disk_ttl,
214
+ "entries": {
215
+ key: {"value": val, "timestamp": ts}
216
+ for key, (val, ts) in self._cache.items()
217
+ },
218
+ "statistics": {
219
+ "total_entries": len(self._cache),
220
+ "last_write": time.time(),
221
+ **self._stats,
222
+ },
223
+ }
224
+
225
+ if safe_write_json(
226
+ self._cache_file, cache_data, lib_logger, secure_permissions=True
227
+ ):
228
+ self._stats["writes"] += 1
229
+ self._disk_available = True
230
+ lib_logger.debug(
231
+ f"ProviderCache[{self._cache_name}]: Saved {len(self._cache)} entries"
232
+ )
233
+ return True
234
+ else:
235
+ self._stats["disk_errors"] += 1
236
+ self._disk_available = False
237
+ return False
238
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  # =========================================================================
240
  # BACKGROUND TASKS
241
  # =========================================================================
242
+
243
  async def _start_background_tasks(self) -> None:
244
  """Start background writer and cleanup tasks."""
245
  if not self._enable_disk or self._running:
246
  return
247
+
248
  self._running = True
249
  self._writer_task = asyncio.create_task(self._writer_loop())
250
  self._cleanup_task = asyncio.create_task(self._cleanup_loop())
251
  lib_logger.debug(f"ProviderCache[{self._cache_name}]: Started background tasks")
252
+
253
  async def _writer_loop(self) -> None:
254
  """Background task: periodically flush dirty cache to disk."""
255
  try:
 
257
  await asyncio.sleep(self._write_interval)
258
  if self._dirty:
259
  try:
260
+ success = await self._save_to_disk()
261
+ if success:
262
+ self._dirty = False
263
+ # If save failed, _dirty remains True so we retry next interval
264
  except Exception as e:
265
+ lib_logger.error(
266
+ f"ProviderCache[{self._cache_name}]: Writer error: {e}"
267
+ )
268
  except asyncio.CancelledError:
269
  pass
270
+
271
  async def _cleanup_loop(self) -> None:
272
  """Background task: periodically clean up expired entries."""
273
  try:
 
276
  await self._cleanup_expired()
277
  except asyncio.CancelledError:
278
  pass
279
+
280
  async def _cleanup_expired(self) -> None:
281
  """Remove expired entries from memory cache."""
282
  async with self._lock:
283
  now = time.time()
284
+ expired = [
285
+ k for k, (_, ts) in self._cache.items() if now - ts > self._memory_ttl
286
+ ]
287
  for k in expired:
288
  del self._cache[k]
289
  if expired:
 
291
  lib_logger.debug(
292
  f"ProviderCache[{self._cache_name}]: Cleaned {len(expired)} expired entries"
293
  )
294
+
295
  # =========================================================================
296
  # CORE OPERATIONS
297
  # =========================================================================
298
+
299
  def store(self, key: str, value: str) -> None:
300
  """
301
  Store a value synchronously (schedules async storage).
302
+
303
  Args:
304
  key: Cache key
305
  value: Value to store (typically JSON-serialized data)
306
  """
307
  asyncio.create_task(self._async_store(key, value))
308
+
309
  async def _async_store(self, key: str, value: str) -> None:
310
  """Async implementation of store."""
311
  async with self._lock:
312
  self._cache[key] = (value, time.time())
313
  self._dirty = True
314
+
315
  async def store_async(self, key: str, value: str) -> None:
316
  """
317
  Store a value asynchronously (awaitable).
318
+
319
  Use this when you need to ensure the value is stored before continuing.
320
  """
321
  await self._async_store(key, value)
322
+
323
  def retrieve(self, key: str) -> Optional[str]:
324
  """
325
  Retrieve a value by key (synchronous, with optional async disk fallback).
326
+
327
  Args:
328
  key: Cache key
329
+
330
  Returns:
331
  Cached value if found and not expired, None otherwise
332
  """
 
338
  else:
339
  del self._cache[key]
340
  self._dirty = True
341
+
342
  self._stats["misses"] += 1
343
  if self._enable_disk:
344
  # Schedule async disk lookup for next time
345
  asyncio.create_task(self._check_disk_fallback(key))
346
  return None
347
+
348
  async def retrieve_async(self, key: str) -> Optional[str]:
349
  """
350
  Retrieve a value asynchronously (checks disk if not in memory).
351
+
352
  Use this when you can await and need guaranteed disk fallback.
353
  """
354
  # Check memory first
 
362
  if key in self._cache:
363
  del self._cache[key]
364
  self._dirty = True
365
+
366
  # Check disk
367
  if self._enable_disk:
368
  return await self._disk_retrieve(key)
369
+
370
  self._stats["misses"] += 1
371
  return None
372
+
373
  async def _check_disk_fallback(self, key: str) -> None:
374
  """Check disk for key and load into memory if found (background)."""
375
  try:
376
  if not self._cache_file.exists():
377
  return
378
+
379
  async with self._disk_lock:
380
+ with open(self._cache_file, "r", encoding="utf-8") as f:
381
  data = json.load(f)
382
+
383
  entries = data.get("entries", {})
384
  if key in entries:
385
  entry = entries[key]
 
394
  f"ProviderCache[{self._cache_name}]: Loaded {key} from disk"
395
  )
396
  except Exception as e:
397
+ lib_logger.debug(
398
+ f"ProviderCache[{self._cache_name}]: Disk fallback failed: {e}"
399
+ )
400
+
401
  async def _disk_retrieve(self, key: str) -> Optional[str]:
402
  """Direct disk retrieval with loading into memory."""
403
  try:
404
  if not self._cache_file.exists():
405
  self._stats["misses"] += 1
406
  return None
407
+
408
  async with self._disk_lock:
409
+ with open(self._cache_file, "r", encoding="utf-8") as f:
410
  data = json.load(f)
411
+
412
  entries = data.get("entries", {})
413
  if key in entries:
414
  entry = entries[key]
 
420
  self._cache[key] = (value, ts)
421
  self._stats["disk_hits"] += 1
422
  return value
423
+
424
  self._stats["misses"] += 1
425
  return None
426
  except Exception as e:
427
+ lib_logger.debug(
428
+ f"ProviderCache[{self._cache_name}]: Disk retrieve failed: {e}"
429
+ )
430
  self._stats["misses"] += 1
431
  return None
432
+
433
  # =========================================================================
434
  # UTILITY METHODS
435
  # =========================================================================
436
+
437
  def contains(self, key: str) -> bool:
438
  """Check if key exists in memory cache (without updating stats)."""
439
  if key in self._cache:
440
  _, timestamp = self._cache[key]
441
  return time.time() - timestamp <= self._memory_ttl
442
  return False
443
+
444
  def get_stats(self) -> Dict[str, Any]:
445
+ """Get cache statistics including disk health."""
446
  return {
447
  **self._stats,
448
  "memory_entries": len(self._cache),
449
  "dirty": self._dirty,
450
+ "disk_enabled": self._enable_disk,
451
+ "disk_available": self._disk_available,
452
  }
453
+
454
  async def clear(self) -> None:
455
  """Clear all cached data."""
456
  async with self._lock:
 
458
  self._dirty = True
459
  if self._enable_disk:
460
  await self._save_to_disk()
461
+
462
  async def shutdown(self) -> None:
463
  """Graceful shutdown: flush pending writes and stop background tasks."""
464
  lib_logger.info(f"ProviderCache[{self._cache_name}]: Shutting down...")
465
  self._running = False
466
+
467
  # Cancel background tasks
468
  for task in (self._writer_task, self._cleanup_task):
469
  if task:
 
472
  await task
473
  except asyncio.CancelledError:
474
  pass
475
+
476
  # Final save
477
  if self._dirty and self._enable_disk:
478
  await self._save_to_disk()
479
+
480
  lib_logger.info(
481
  f"ProviderCache[{self._cache_name}]: Shutdown complete "
482
  f"(stats: mem_hits={self._stats['memory_hits']}, "
 
488
  # CONVENIENCE FACTORY
489
  # =============================================================================
490
 
491
+
492
  def create_provider_cache(
493
  name: str,
494
  cache_dir: Optional[Path] = None,
495
  memory_ttl_seconds: int = 3600,
496
  disk_ttl_seconds: int = 86400,
497
+ env_prefix: Optional[str] = None,
498
  ) -> ProviderCache:
499
  """
500
  Factory function to create a provider cache with sensible defaults.
501
+
502
  Args:
503
  name: Cache name (used as filename and for logging)
504
  cache_dir: Directory for cache file (default: project_root/cache/provider_name)
505
  memory_ttl_seconds: In-memory TTL
506
  disk_ttl_seconds: Disk TTL
507
  env_prefix: Environment variable prefix (default: derived from name)
508
+
509
  Returns:
510
  Configured ProviderCache instance
511
  """
512
  if cache_dir is None:
513
  cache_dir = Path(__file__).resolve().parent.parent.parent.parent / "cache"
514
+
515
  cache_file = cache_dir / f"{name}.json"
516
+
517
  if env_prefix is None:
518
  # Convert name to env prefix: "gemini3_signatures" -> "GEMINI3_SIGNATURES_CACHE"
519
  env_prefix = f"{name.upper().replace('-', '_')}_CACHE"
520
+
521
  return ProviderCache(
522
  cache_file=cache_file,
523
  memory_ttl_seconds=memory_ttl_seconds,
524
  disk_ttl_seconds=disk_ttl_seconds,
525
+ env_prefix=env_prefix,
526
  )
src/rotator_library/providers/qwen_auth_base.py CHANGED
@@ -9,10 +9,11 @@ import asyncio
9
  import logging
10
  import webbrowser
11
  import os
 
 
12
  from pathlib import Path
13
- from typing import Dict, Any, Tuple, Union, Optional
14
- import tempfile
15
- import shutil
16
 
17
  import httpx
18
  from rich.console import Console
@@ -23,6 +24,7 @@ from rich.markup import escape as rich_escape
23
 
24
  from ..utils.headless_detection import is_headless_environment
25
  from ..utils.reauth_coordinator import get_reauth_coordinator
 
26
 
27
  lib_logger = logging.getLogger("rotator_library")
28
 
@@ -36,6 +38,20 @@ REFRESH_EXPIRY_BUFFER_SECONDS = 3 * 60 * 60 # 3 hours buffer before expiry
36
  console = Console()
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  class QwenAuthBase:
40
  def __init__(self):
41
  self._credentials_cache: Dict[str, Dict[str, Any]] = {}
@@ -51,19 +67,36 @@ class QwenAuthBase:
51
  str, float
52
  ] = {} # Track backoff timers (Unix timestamp)
53
 
54
- # [QUEUE SYSTEM] Sequential refresh processing
 
55
  self._refresh_queue: asyncio.Queue = asyncio.Queue()
56
- self._queued_credentials: set = set() # Track credentials already in queue
57
- # [FIX PR#34] Changed from set to dict mapping credential path to timestamp
58
- # This enables TTL-based stale entry cleanup as defense in depth
 
 
 
 
 
 
 
59
  self._unavailable_credentials: Dict[
60
  str, float
61
  ] = {} # Maps credential path -> timestamp when marked unavailable
62
- self._unavailable_ttl_seconds: int = 300 # 5 minutes TTL for stale entries
 
63
  self._queue_tracking_lock = asyncio.Lock() # Protects queue sets
64
- self._queue_processor_task: Optional[asyncio.Task] = (
65
- None # Background worker task
66
- )
 
 
 
 
 
 
 
 
67
 
68
  def _parse_env_credential_path(self, path: str) -> Optional[str]:
69
  """
@@ -188,81 +221,54 @@ class QwenAuthBase:
188
  f"Environment variables for Qwen Code credential index {credential_index} not found"
189
  )
190
 
191
- # For file paths, try loading from legacy env vars first
192
- env_creds = self._load_from_env()
193
- if env_creds:
194
- lib_logger.info(
195
- "Using Qwen Code credentials from environment variables"
196
- )
197
- self._credentials_cache[path] = env_creds
198
- return env_creds
199
-
200
- # Fall back to file-based loading
201
- return await self._read_creds_from_file(path)
 
 
202
 
203
  async def _save_credentials(self, path: str, creds: Dict[str, Any]):
 
 
 
 
204
  # Don't save to file if credentials were loaded from environment
205
  if creds.get("_proxy_metadata", {}).get("loaded_from_env"):
206
  lib_logger.debug("Credentials loaded from env, skipping file save")
207
- # Still update cache for in-memory consistency
208
- self._credentials_cache[path] = creds
209
  return
210
 
211
- # [ATOMIC WRITE] Use tempfile + move pattern to ensure atomic writes
212
- parent_dir = os.path.dirname(os.path.abspath(path))
213
- os.makedirs(parent_dir, exist_ok=True)
214
-
215
- tmp_fd = None
216
- tmp_path = None
217
- try:
218
- # Create temp file in same directory as target (ensures same filesystem)
219
- tmp_fd, tmp_path = tempfile.mkstemp(
220
- dir=parent_dir, prefix=".tmp_", suffix=".json", text=True
221
- )
222
-
223
- # Write JSON to temp file
224
- with os.fdopen(tmp_fd, "w") as f:
225
- json.dump(creds, f, indent=2)
226
- tmp_fd = None # fdopen closes the fd
227
-
228
- # Set secure permissions (0600 = owner read/write only)
229
- try:
230
- os.chmod(tmp_path, 0o600)
231
- except (OSError, AttributeError):
232
- # Windows may not support chmod, ignore
233
- pass
234
-
235
- # Atomic move (overwrites target if it exists)
236
- shutil.move(tmp_path, path)
237
- tmp_path = None # Successfully moved
238
-
239
- # Update cache AFTER successful file write
240
- self._credentials_cache[path] = creds
241
- lib_logger.debug(
242
- f"Saved updated Qwen OAuth credentials to '{path}' (atomic write)."
243
- )
244
-
245
- except Exception as e:
246
- lib_logger.error(
247
- f"Failed to save updated Qwen OAuth credentials to '{path}': {e}"
248
  )
249
- # Clean up temp file if it still exists
250
- if tmp_fd is not None:
251
- try:
252
- os.close(tmp_fd)
253
- except:
254
- pass
255
- if tmp_path and os.path.exists(tmp_path):
256
- try:
257
- os.unlink(tmp_path)
258
- except:
259
- pass
260
- raise
261
 
262
  def _is_token_expired(self, creds: Dict[str, Any]) -> bool:
263
  expiry_timestamp = creds.get("expiry_date", 0) / 1000
264
  return expiry_timestamp < time.time() + REFRESH_EXPIRY_BUFFER_SECONDS
265
 
 
 
 
 
 
 
 
 
 
266
  async def _refresh_token(self, path: str, force: bool = False) -> Dict[str, Any]:
267
  async with await self._get_lock(path):
268
  cached_creds = self._credentials_cache.get(path)
@@ -476,7 +482,7 @@ class QwenAuthBase:
476
  Proactively refreshes tokens if they're close to expiry.
477
  Only applies to OAuth credentials (file paths or env:// paths). Direct API keys are skipped.
478
  """
479
- lib_logger.debug(f"proactively_refresh called for: {credential_identifier}")
480
 
481
  # Try to load credentials - this will fail for direct API keys
482
  # and succeed for OAuth credentials (file paths or env:// paths)
@@ -484,21 +490,21 @@ class QwenAuthBase:
484
  creds = await self._load_credentials(credential_identifier)
485
  except IOError as e:
486
  # Not a valid credential path (likely a direct API key string)
487
- lib_logger.debug(
488
- f"Skipping refresh for '{credential_identifier}' - not an OAuth credential: {e}"
489
- )
490
  return
491
 
492
  is_expired = self._is_token_expired(creds)
493
- lib_logger.debug(
494
- f"Token expired check for '{Path(credential_identifier).name}': {is_expired}"
495
- )
496
 
497
  if is_expired:
498
- lib_logger.debug(
499
- f"Queueing refresh for '{Path(credential_identifier).name}'"
500
- )
501
- # Queue for refresh with needs_reauth=False (automated refresh)
502
  await self._queue_refresh(
503
  credential_identifier, force=False, needs_reauth=False
504
  )
@@ -511,30 +517,55 @@ class QwenAuthBase:
511
  return self._refresh_locks[path]
512
 
513
  def is_credential_available(self, path: str) -> bool:
514
- """Check if a credential is available for rotation (not queued/refreshing).
 
 
 
 
 
 
 
515
 
516
- [FIX PR#34] Now includes TTL-based stale entry cleanup as defense in depth.
517
- If a credential has been unavailable for longer than _unavailable_ttl_seconds,
518
- it is automatically cleaned up and considered available.
 
519
  """
520
- if path not in self._unavailable_credentials:
521
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
 
523
- # [FIX PR#34] Check if the entry is stale (TTL expired)
524
- marked_time = self._unavailable_credentials.get(path)
525
- if marked_time is not None:
526
- now = time.time()
527
- if now - marked_time > self._unavailable_ttl_seconds:
528
- # Entry is stale - clean it up and return available
529
- lib_logger.warning(
530
- f"Credential '{Path(path).name}' was stuck in unavailable state for "
531
- f"{int(now - marked_time)}s (TTL: {self._unavailable_ttl_seconds}s). "
532
- f"Auto-cleaning stale entry."
 
533
  )
534
- self._unavailable_credentials.pop(path, None)
535
- return True
536
 
537
- return False
538
 
539
  async def _ensure_queue_processor_running(self):
540
  """Lazily starts the queue processor if not already running."""
@@ -543,15 +574,27 @@ class QwenAuthBase:
543
  self._process_refresh_queue()
544
  )
545
 
 
 
 
 
 
 
 
546
  async def _queue_refresh(
547
  self, path: str, force: bool = False, needs_reauth: bool = False
548
  ):
549
- """Add a credential to the refresh queue if not already queued.
550
 
551
  Args:
552
  path: Credential file path
553
  force: Force refresh even if not expired
554
- needs_reauth: True if full re-authentication needed (bypasses backoff)
 
 
 
 
 
555
  """
556
  # IMPORTANT: Only check backoff for simple automated refreshes
557
  # Re-authentication (interactive OAuth) should BYPASS backoff since it needs user input
@@ -561,114 +604,223 @@ class QwenAuthBase:
561
  backoff_until = self._next_refresh_after[path]
562
  if now < backoff_until:
563
  # Credential is in backoff for automated refresh, do not queue
564
- remaining = int(backoff_until - now)
565
- lib_logger.debug(
566
- f"Skipping automated refresh for '{Path(path).name}' (in backoff for {remaining}s)"
567
- )
568
  return
569
 
570
  async with self._queue_tracking_lock:
571
  if path not in self._queued_credentials:
572
  self._queued_credentials.add(path)
573
- # [FIX PR#34] Store timestamp when marking unavailable (for TTL cleanup)
574
- self._unavailable_credentials[path] = time.time()
575
- lib_logger.debug(
576
- f"Marked '{Path(path).name}' as unavailable. "
577
- f"Total unavailable: {len(self._unavailable_credentials)}"
578
- )
579
- await self._refresh_queue.put((path, force, needs_reauth))
580
- await self._ensure_queue_processor_running()
 
 
 
 
 
 
 
 
 
 
581
 
582
  async def _process_refresh_queue(self):
583
- """Background worker that processes refresh requests sequentially."""
 
 
 
 
 
 
 
 
 
584
  while True:
585
  path = None
586
  try:
587
  # Wait for an item with timeout to allow graceful shutdown
588
  try:
589
- path, force, needs_reauth = await asyncio.wait_for(
590
  self._refresh_queue.get(), timeout=60.0
591
  )
592
  except asyncio.TimeoutError:
593
- # [FIX PR#34] Clean up any stale unavailable entries before exiting
594
- # If we're idle for 60s, no refreshes are in progress
595
  async with self._queue_tracking_lock:
596
- if self._unavailable_credentials:
597
- stale_count = len(self._unavailable_credentials)
598
- lib_logger.warning(
599
- f"Queue processor idle timeout. Cleaning {stale_count} "
600
- f"stale unavailable credentials: {list(self._unavailable_credentials.keys())}"
601
- )
602
- self._unavailable_credentials.clear()
603
- # [FIX BUG#6] Also clear queued credentials to prevent stuck state
604
- if self._queued_credentials:
605
- lib_logger.debug(
606
- f"Clearing {len(self._queued_credentials)} queued credentials on timeout"
607
- )
608
- self._queued_credentials.clear()
609
  self._queue_processor_task = None
 
610
  return
611
 
612
  try:
613
- # Perform the actual refresh (still using per-credential lock)
614
- async with await self._get_lock(path):
615
- # Re-check if still expired (may have changed since queueing)
616
- creds = self._credentials_cache.get(path)
617
- if creds and not self._is_token_expired(creds):
618
- # No longer expired, mark as available
619
- async with self._queue_tracking_lock:
620
- self._unavailable_credentials.pop(path, None)
621
- lib_logger.debug(
622
- f"Credential '{Path(path).name}' no longer expired, marked available. "
623
- f"Remaining unavailable: {len(self._unavailable_credentials)}"
624
- )
625
- continue
 
 
626
 
627
- # Perform refresh
628
- if not creds:
629
- creds = await self._load_credentials(path)
630
- await self._refresh_token(path, force=force)
631
 
632
- # SUCCESS: Mark as available again
633
- async with self._queue_tracking_lock:
634
- self._unavailable_credentials.pop(path, None)
635
- lib_logger.debug(
636
- f"Refresh SUCCESS for '{Path(path).name}', marked available. "
637
- f"Remaining unavailable: {len(self._unavailable_credentials)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
638
  )
639
 
 
 
 
640
  finally:
641
- # [FIX PR#34] Remove from BOTH queued set AND unavailable credentials
642
- # This ensures cleanup happens in ALL exit paths (success, exception, etc.)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643
  async with self._queue_tracking_lock:
644
  self._queued_credentials.discard(path)
645
- # [FIX PR#34] Always clean up unavailable credentials in finally block
646
  self._unavailable_credentials.pop(path, None)
647
- lib_logger.debug(
648
- f"Finally cleanup for '{Path(path).name}'. "
649
- f"Remaining unavailable: {len(self._unavailable_credentials)}"
650
- )
651
- self._refresh_queue.task_done()
 
652
  except asyncio.CancelledError:
653
- # [FIX PR#34] Clean up the current credential before breaking
654
  if path:
655
  async with self._queue_tracking_lock:
 
656
  self._unavailable_credentials.pop(path, None)
657
- lib_logger.debug(
658
- f"CancelledError cleanup for '{Path(path).name}'. "
659
- f"Remaining unavailable: {len(self._unavailable_credentials)}"
660
- )
661
  break
662
  except Exception as e:
663
- lib_logger.error(f"Error in queue processor: {e}")
664
- # Even on error, mark as available (backoff will prevent immediate retry)
665
  if path:
666
  async with self._queue_tracking_lock:
 
667
  self._unavailable_credentials.pop(path, None)
668
- lib_logger.debug(
669
- f"Error cleanup for '{Path(path).name}': {e}. "
670
- f"Remaining unavailable: {len(self._unavailable_credentials)}"
671
- )
672
 
673
  async def _perform_interactive_oauth(
674
  self, path: str, creds: Dict[str, Any], display_name: str
@@ -965,3 +1117,251 @@ class QwenAuthBase:
965
  except Exception as e:
966
  lib_logger.error(f"Failed to get Qwen user info from credentials: {e}")
967
  return {"email": None}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  import logging
10
  import webbrowser
11
  import os
12
+ import re
13
+ from dataclasses import dataclass, field
14
  from pathlib import Path
15
+ from glob import glob
16
+ from typing import Dict, Any, Tuple, Union, Optional, List
 
17
 
18
  import httpx
19
  from rich.console import Console
 
24
 
25
  from ..utils.headless_detection import is_headless_environment
26
  from ..utils.reauth_coordinator import get_reauth_coordinator
27
+ from ..utils.resilient_io import safe_write_json
28
 
29
  lib_logger = logging.getLogger("rotator_library")
30
 
 
38
  console = Console()
39
 
40
 
41
+ @dataclass
42
+ class QwenCredentialSetupResult:
43
+ """
44
+ Standardized result structure for Qwen credential setup operations.
45
+ """
46
+
47
+ success: bool
48
+ file_path: Optional[str] = None
49
+ email: Optional[str] = None
50
+ is_update: bool = False
51
+ error: Optional[str] = None
52
+ credentials: Optional[Dict[str, Any]] = field(default=None, repr=False)
53
+
54
+
55
  class QwenAuthBase:
56
  def __init__(self):
57
  self._credentials_cache: Dict[str, Dict[str, Any]] = {}
 
67
  str, float
68
  ] = {} # Track backoff timers (Unix timestamp)
69
 
70
+ # [QUEUE SYSTEM] Sequential refresh processing with two separate queues
71
+ # Normal refresh queue: for proactive token refresh (old token still valid)
72
  self._refresh_queue: asyncio.Queue = asyncio.Queue()
73
+ self._queue_processor_task: Optional[asyncio.Task] = None
74
+
75
+ # Re-auth queue: for invalid refresh tokens (requires user interaction)
76
+ self._reauth_queue: asyncio.Queue = asyncio.Queue()
77
+ self._reauth_processor_task: Optional[asyncio.Task] = None
78
+
79
+ # Tracking sets/dicts
80
+ self._queued_credentials: set = set() # Track credentials in either queue
81
+ # Only credentials in re-auth queue are marked unavailable (not normal refresh)
82
+ # TTL cleanup is defense-in-depth for edge cases where re-auth processor crashes
83
  self._unavailable_credentials: Dict[
84
  str, float
85
  ] = {} # Maps credential path -> timestamp when marked unavailable
86
+ # TTL should exceed reauth timeout (300s) to avoid premature cleanup
87
+ self._unavailable_ttl_seconds: int = 360 # 6 minutes TTL for stale entries
88
  self._queue_tracking_lock = asyncio.Lock() # Protects queue sets
89
+
90
+ # Retry tracking for normal refresh queue
91
+ self._queue_retry_count: Dict[
92
+ str, int
93
+ ] = {} # Track retry attempts per credential
94
+
95
+ # Configuration constants
96
+ self._refresh_timeout_seconds: int = 15 # Max time for single refresh
97
+ self._refresh_interval_seconds: int = 30 # Delay between queue items
98
+ self._refresh_max_retries: int = 3 # Attempts before kicked out
99
+ self._reauth_timeout_seconds: int = 300 # Time for user to complete OAuth
100
 
101
  def _parse_env_credential_path(self, path: str) -> Optional[str]:
102
  """
 
221
  f"Environment variables for Qwen Code credential index {credential_index} not found"
222
  )
223
 
224
+ # Try file-based loading first (preferred for explicit file paths)
225
+ try:
226
+ return await self._read_creds_from_file(path)
227
+ except IOError:
228
+ # File not found - fall back to legacy env vars for backwards compatibility
229
+ env_creds = self._load_from_env()
230
+ if env_creds:
231
+ lib_logger.info(
232
+ f"File '{path}' not found, using Qwen Code credentials from environment variables"
233
+ )
234
+ self._credentials_cache[path] = env_creds
235
+ return env_creds
236
+ raise # Re-raise the original file not found error
237
 
238
  async def _save_credentials(self, path: str, creds: Dict[str, Any]):
239
+ """Save credentials with in-memory fallback if disk unavailable."""
240
+ # Always update cache first (memory is reliable)
241
+ self._credentials_cache[path] = creds
242
+
243
  # Don't save to file if credentials were loaded from environment
244
  if creds.get("_proxy_metadata", {}).get("loaded_from_env"):
245
  lib_logger.debug("Credentials loaded from env, skipping file save")
 
 
246
  return
247
 
248
+ # Attempt disk write - if it fails, we still have the cache
249
+ # buffer_on_failure ensures data is retried periodically and saved on shutdown
250
+ if safe_write_json(
251
+ path, creds, lib_logger, secure_permissions=True, buffer_on_failure=True
252
+ ):
253
+ lib_logger.debug(f"Saved updated Qwen OAuth credentials to '{path}'.")
254
+ else:
255
+ lib_logger.warning(
256
+ "Qwen credentials cached in memory only (buffered for retry)."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  )
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
  def _is_token_expired(self, creds: Dict[str, Any]) -> bool:
260
  expiry_timestamp = creds.get("expiry_date", 0) / 1000
261
  return expiry_timestamp < time.time() + REFRESH_EXPIRY_BUFFER_SECONDS
262
 
263
+ def _is_token_truly_expired(self, creds: Dict[str, Any]) -> bool:
264
+ """Check if token is TRULY expired (past actual expiry, not just threshold).
265
+
266
+ This is different from _is_token_expired() which uses a buffer for proactive refresh.
267
+ This method checks if the token is actually unusable.
268
+ """
269
+ expiry_timestamp = creds.get("expiry_date", 0) / 1000
270
+ return expiry_timestamp < time.time()
271
+
272
  async def _refresh_token(self, path: str, force: bool = False) -> Dict[str, Any]:
273
  async with await self._get_lock(path):
274
  cached_creds = self._credentials_cache.get(path)
 
482
  Proactively refreshes tokens if they're close to expiry.
483
  Only applies to OAuth credentials (file paths or env:// paths). Direct API keys are skipped.
484
  """
485
+ # lib_logger.debug(f"proactively_refresh called for: {credential_identifier}")
486
 
487
  # Try to load credentials - this will fail for direct API keys
488
  # and succeed for OAuth credentials (file paths or env:// paths)
 
490
  creds = await self._load_credentials(credential_identifier)
491
  except IOError as e:
492
  # Not a valid credential path (likely a direct API key string)
493
+ # lib_logger.debug(
494
+ # f"Skipping refresh for '{credential_identifier}' - not an OAuth credential: {e}"
495
+ # )
496
  return
497
 
498
  is_expired = self._is_token_expired(creds)
499
+ # lib_logger.debug(
500
+ # f"Token expired check for '{Path(credential_identifier).name}': {is_expired}"
501
+ # )
502
 
503
  if is_expired:
504
+ # lib_logger.debug(
505
+ # f"Queueing refresh for '{Path(credential_identifier).name}'"
506
+ # )
507
+ # lib_logger.info(f"Proactive refresh triggered for '{Path(credential_identifier).name}'")
508
  await self._queue_refresh(
509
  credential_identifier, force=False, needs_reauth=False
510
  )
 
517
  return self._refresh_locks[path]
518
 
519
  def is_credential_available(self, path: str) -> bool:
520
+ """Check if a credential is available for rotation.
521
+
522
+ Credentials are unavailable if:
523
+ 1. In re-auth queue (token is truly broken, requires user interaction)
524
+ 2. Token is TRULY expired (past actual expiry, not just threshold)
525
+
526
+ Note: Credentials in normal refresh queue are still available because
527
+ the old token is valid until actual expiry.
528
 
529
+ TTL cleanup (defense-in-depth): If a credential has been in the re-auth
530
+ queue longer than _unavailable_ttl_seconds without being processed, it's
531
+ cleaned up. This should only happen if the re-auth processor crashes or
532
+ is cancelled without proper cleanup.
533
  """
534
+ # Check if in re-auth queue (truly unavailable)
535
+ if path in self._unavailable_credentials:
536
+ marked_time = self._unavailable_credentials.get(path)
537
+ if marked_time is not None:
538
+ now = time.time()
539
+ if now - marked_time > self._unavailable_ttl_seconds:
540
+ # Entry is stale - clean it up and return available
541
+ # This is a defense-in-depth for edge cases where re-auth
542
+ # processor crashed or was cancelled without cleanup
543
+ lib_logger.warning(
544
+ f"Credential '{Path(path).name}' stuck in re-auth queue for "
545
+ f"{int(now - marked_time)}s (TTL: {self._unavailable_ttl_seconds}s). "
546
+ f"Re-auth processor may have crashed. Auto-cleaning stale entry."
547
+ )
548
+ # Clean up both tracking structures for consistency
549
+ self._unavailable_credentials.pop(path, None)
550
+ self._queued_credentials.discard(path)
551
+ else:
552
+ return False # Still in re-auth, not available
553
 
554
+ # Check if token is TRULY expired (not just threshold-expired)
555
+ creds = self._credentials_cache.get(path)
556
+ if creds and self._is_token_truly_expired(creds):
557
+ # Token is actually expired - should not be used
558
+ # Queue for refresh if not already queued
559
+ if path not in self._queued_credentials:
560
+ # lib_logger.debug(
561
+ # f"Credential '{Path(path).name}' is truly expired, queueing for refresh"
562
+ # )
563
+ asyncio.create_task(
564
+ self._queue_refresh(path, force=True, needs_reauth=False)
565
  )
566
+ return False
 
567
 
568
+ return True
569
 
570
  async def _ensure_queue_processor_running(self):
571
  """Lazily starts the queue processor if not already running."""
 
574
  self._process_refresh_queue()
575
  )
576
 
577
+ async def _ensure_reauth_processor_running(self):
578
+ """Lazily starts the re-auth queue processor if not already running."""
579
+ if self._reauth_processor_task is None or self._reauth_processor_task.done():
580
+ self._reauth_processor_task = asyncio.create_task(
581
+ self._process_reauth_queue()
582
+ )
583
+
584
  async def _queue_refresh(
585
  self, path: str, force: bool = False, needs_reauth: bool = False
586
  ):
587
+ """Add a credential to the appropriate refresh queue if not already queued.
588
 
589
  Args:
590
  path: Credential file path
591
  force: Force refresh even if not expired
592
+ needs_reauth: True if full re-authentication needed (routes to re-auth queue)
593
+
594
+ Queue routing:
595
+ - needs_reauth=True: Goes to re-auth queue, marks as unavailable
596
+ - needs_reauth=False: Goes to normal refresh queue, does NOT mark unavailable
597
+ (old token is still valid until actual expiry)
598
  """
599
  # IMPORTANT: Only check backoff for simple automated refreshes
600
  # Re-authentication (interactive OAuth) should BYPASS backoff since it needs user input
 
604
  backoff_until = self._next_refresh_after[path]
605
  if now < backoff_until:
606
  # Credential is in backoff for automated refresh, do not queue
607
+ # remaining = int(backoff_until - now)
608
+ # lib_logger.debug(
609
+ # f"Skipping automated refresh for '{Path(path).name}' (in backoff for {remaining}s)"
610
+ # )
611
  return
612
 
613
  async with self._queue_tracking_lock:
614
  if path not in self._queued_credentials:
615
  self._queued_credentials.add(path)
616
+
617
+ if needs_reauth:
618
+ # Re-auth queue: mark as unavailable (token is truly broken)
619
+ self._unavailable_credentials[path] = time.time()
620
+ # lib_logger.debug(
621
+ # f"Queued '{Path(path).name}' for RE-AUTH (marked unavailable). "
622
+ # f"Total unavailable: {len(self._unavailable_credentials)}"
623
+ # )
624
+ await self._reauth_queue.put(path)
625
+ await self._ensure_reauth_processor_running()
626
+ else:
627
+ # Normal refresh queue: do NOT mark unavailable (old token still valid)
628
+ # lib_logger.debug(
629
+ # f"Queued '{Path(path).name}' for refresh (still available). "
630
+ # f"Queue size: {self._refresh_queue.qsize() + 1}"
631
+ # )
632
+ await self._refresh_queue.put((path, force))
633
+ await self._ensure_queue_processor_running()
634
 
635
  async def _process_refresh_queue(self):
636
+ """Background worker that processes normal refresh requests sequentially.
637
+
638
+ Key behaviors:
639
+ - 15s timeout per refresh operation
640
+ - 30s delay between processing credentials (prevents thundering herd)
641
+ - On failure: back of queue, max 3 retries before kicked
642
+ - If 401/403 detected: routes to re-auth queue
643
+ - Does NOT mark credentials unavailable (old token still valid)
644
+ """
645
+ # lib_logger.info("Refresh queue processor started")
646
  while True:
647
  path = None
648
  try:
649
  # Wait for an item with timeout to allow graceful shutdown
650
  try:
651
+ path, force = await asyncio.wait_for(
652
  self._refresh_queue.get(), timeout=60.0
653
  )
654
  except asyncio.TimeoutError:
655
+ # Queue is empty and idle for 60s - clean up and exit
 
656
  async with self._queue_tracking_lock:
657
+ # Clear any stale retry counts
658
+ self._queue_retry_count.clear()
 
 
 
 
 
 
 
 
 
 
 
659
  self._queue_processor_task = None
660
+ # lib_logger.debug("Refresh queue processor idle, shutting down")
661
  return
662
 
663
  try:
664
+ # Quick check if still expired (optimization to avoid unnecessary refresh)
665
+ creds = self._credentials_cache.get(path)
666
+ if creds and not self._is_token_expired(creds):
667
+ # No longer expired, skip refresh
668
+ # lib_logger.debug(
669
+ # f"Credential '{Path(path).name}' no longer expired, skipping refresh"
670
+ # )
671
+ # Clear retry count on skip (not a failure)
672
+ self._queue_retry_count.pop(path, None)
673
+ continue
674
+
675
+ # Perform refresh with timeout
676
+ try:
677
+ async with asyncio.timeout(self._refresh_timeout_seconds):
678
+ await self._refresh_token(path, force=force)
679
 
680
+ # SUCCESS: Clear retry count
681
+ self._queue_retry_count.pop(path, None)
682
+ # lib_logger.info(f"Refresh SUCCESS for '{Path(path).name}'")
 
683
 
684
+ except asyncio.TimeoutError:
685
+ lib_logger.warning(
686
+ f"Refresh timeout ({self._refresh_timeout_seconds}s) for '{Path(path).name}'"
687
+ )
688
+ await self._handle_refresh_failure(path, force, "timeout")
689
+
690
+ except httpx.HTTPStatusError as e:
691
+ status_code = e.response.status_code
692
+ if status_code in (401, 403):
693
+ # Invalid refresh token - route to re-auth queue
694
+ lib_logger.warning(
695
+ f"Refresh token invalid for '{Path(path).name}' (HTTP {status_code}). "
696
+ f"Routing to re-auth queue."
697
+ )
698
+ self._queue_retry_count.pop(path, None) # Clear retry count
699
+ async with self._queue_tracking_lock:
700
+ self._queued_credentials.discard(
701
+ path
702
+ ) # Remove from queued
703
+ await self._queue_refresh(
704
+ path, force=True, needs_reauth=True
705
+ )
706
+ else:
707
+ await self._handle_refresh_failure(
708
+ path, force, f"HTTP {status_code}"
709
  )
710
 
711
+ except Exception as e:
712
+ await self._handle_refresh_failure(path, force, str(e))
713
+
714
  finally:
715
+ # Remove from queued set (unless re-queued by failure handler)
716
+ async with self._queue_tracking_lock:
717
+ # Only discard if not re-queued (check if still in queue set from retry)
718
+ if (
719
+ path in self._queued_credentials
720
+ and self._queue_retry_count.get(path, 0) == 0
721
+ ):
722
+ self._queued_credentials.discard(path)
723
+ self._refresh_queue.task_done()
724
+
725
+ # Wait between credentials to spread load
726
+ await asyncio.sleep(self._refresh_interval_seconds)
727
+
728
+ except asyncio.CancelledError:
729
+ # lib_logger.debug("Refresh queue processor cancelled")
730
+ break
731
+ except Exception as e:
732
+ lib_logger.error(f"Error in refresh queue processor: {e}")
733
+ if path:
734
+ async with self._queue_tracking_lock:
735
+ self._queued_credentials.discard(path)
736
+
737
+ async def _handle_refresh_failure(self, path: str, force: bool, error: str):
738
+ """Handle a refresh failure with back-of-line retry logic.
739
+
740
+ - Increments retry count
741
+ - If under max retries: re-adds to END of queue
742
+ - If at max retries: kicks credential out (retried next BackgroundRefresher cycle)
743
+ """
744
+ retry_count = self._queue_retry_count.get(path, 0) + 1
745
+ self._queue_retry_count[path] = retry_count
746
+
747
+ if retry_count >= self._refresh_max_retries:
748
+ # Kicked out until next BackgroundRefresher cycle
749
+ lib_logger.error(
750
+ f"Max retries ({self._refresh_max_retries}) reached for '{Path(path).name}' "
751
+ f"(last error: {error}). Will retry next refresh cycle."
752
+ )
753
+ self._queue_retry_count.pop(path, None)
754
+ async with self._queue_tracking_lock:
755
+ self._queued_credentials.discard(path)
756
+ return
757
+
758
+ # Re-add to END of queue for retry
759
+ lib_logger.warning(
760
+ f"Refresh failed for '{Path(path).name}' ({error}). "
761
+ f"Retry {retry_count}/{self._refresh_max_retries}, back of queue."
762
+ )
763
+ # Keep in queued_credentials set, add back to queue
764
+ await self._refresh_queue.put((path, force))
765
+
766
+ async def _process_reauth_queue(self):
767
+ """Background worker that processes re-auth requests.
768
+
769
+ Key behaviors:
770
+ - Credentials ARE marked unavailable (token is truly broken)
771
+ - Uses ReauthCoordinator for interactive OAuth
772
+ - No automatic retry (requires user action)
773
+ - Cleans up unavailable status when done
774
+ """
775
+ # lib_logger.info("Re-auth queue processor started")
776
+ while True:
777
+ path = None
778
+ try:
779
+ # Wait for an item with timeout to allow graceful shutdown
780
+ try:
781
+ path = await asyncio.wait_for(
782
+ self._reauth_queue.get(), timeout=60.0
783
+ )
784
+ except asyncio.TimeoutError:
785
+ # Queue is empty and idle for 60s - exit
786
+ self._reauth_processor_task = None
787
+ # lib_logger.debug("Re-auth queue processor idle, shutting down")
788
+ return
789
+
790
+ try:
791
+ lib_logger.info(f"Starting re-auth for '{Path(path).name}'...")
792
+ await self.initialize_token(path)
793
+ lib_logger.info(f"Re-auth SUCCESS for '{Path(path).name}'")
794
+
795
+ except Exception as e:
796
+ lib_logger.error(f"Re-auth FAILED for '{Path(path).name}': {e}")
797
+ # No automatic retry for re-auth (requires user action)
798
+
799
+ finally:
800
+ # Always clean up
801
  async with self._queue_tracking_lock:
802
  self._queued_credentials.discard(path)
 
803
  self._unavailable_credentials.pop(path, None)
804
+ # lib_logger.debug(
805
+ # f"Re-auth cleanup for '{Path(path).name}'. "
806
+ # f"Remaining unavailable: {len(self._unavailable_credentials)}"
807
+ # )
808
+ self._reauth_queue.task_done()
809
+
810
  except asyncio.CancelledError:
811
+ # Clean up current credential before breaking
812
  if path:
813
  async with self._queue_tracking_lock:
814
+ self._queued_credentials.discard(path)
815
  self._unavailable_credentials.pop(path, None)
816
+ # lib_logger.debug("Re-auth queue processor cancelled")
 
 
 
817
  break
818
  except Exception as e:
819
+ lib_logger.error(f"Error in re-auth queue processor: {e}")
 
820
  if path:
821
  async with self._queue_tracking_lock:
822
+ self._queued_credentials.discard(path)
823
  self._unavailable_credentials.pop(path, None)
 
 
 
 
824
 
825
  async def _perform_interactive_oauth(
826
  self, path: str, creds: Dict[str, Any], display_name: str
 
1117
  except Exception as e:
1118
  lib_logger.error(f"Failed to get Qwen user info from credentials: {e}")
1119
  return {"email": None}
1120
+
1121
+ # =========================================================================
1122
+ # CREDENTIAL MANAGEMENT METHODS
1123
+ # =========================================================================
1124
+
1125
+ def _get_provider_file_prefix(self) -> str:
1126
+ """Return the file prefix for Qwen credentials."""
1127
+ return "qwen_code"
1128
+
1129
+ def _get_oauth_base_dir(self) -> Path:
1130
+ """Get the base directory for OAuth credential files."""
1131
+ return Path.cwd() / "oauth_creds"
1132
+
1133
+ def _find_existing_credential_by_email(
1134
+ self, email: str, base_dir: Optional[Path] = None
1135
+ ) -> Optional[Path]:
1136
+ """Find an existing credential file for the given email."""
1137
+ if base_dir is None:
1138
+ base_dir = self._get_oauth_base_dir()
1139
+
1140
+ prefix = self._get_provider_file_prefix()
1141
+ pattern = str(base_dir / f"{prefix}_oauth_*.json")
1142
+
1143
+ for cred_file in glob(pattern):
1144
+ try:
1145
+ with open(cred_file, "r") as f:
1146
+ creds = json.load(f)
1147
+ existing_email = creds.get("_proxy_metadata", {}).get("email")
1148
+ if existing_email == email:
1149
+ return Path(cred_file)
1150
+ except (json.JSONDecodeError, IOError) as e:
1151
+ lib_logger.debug(f"Could not read credential file {cred_file}: {e}")
1152
+ continue
1153
+
1154
+ return None
1155
+
1156
+ def _get_next_credential_number(self, base_dir: Optional[Path] = None) -> int:
1157
+ """Get the next available credential number."""
1158
+ if base_dir is None:
1159
+ base_dir = self._get_oauth_base_dir()
1160
+
1161
+ prefix = self._get_provider_file_prefix()
1162
+ pattern = str(base_dir / f"{prefix}_oauth_*.json")
1163
+
1164
+ existing_numbers = []
1165
+ for cred_file in glob(pattern):
1166
+ match = re.search(r"_oauth_(\d+)\.json$", cred_file)
1167
+ if match:
1168
+ existing_numbers.append(int(match.group(1)))
1169
+
1170
+ if not existing_numbers:
1171
+ return 1
1172
+ return max(existing_numbers) + 1
1173
+
1174
+ def _build_credential_path(
1175
+ self, base_dir: Optional[Path] = None, number: Optional[int] = None
1176
+ ) -> Path:
1177
+ """Build a path for a new credential file."""
1178
+ if base_dir is None:
1179
+ base_dir = self._get_oauth_base_dir()
1180
+
1181
+ if number is None:
1182
+ number = self._get_next_credential_number(base_dir)
1183
+
1184
+ prefix = self._get_provider_file_prefix()
1185
+ filename = f"{prefix}_oauth_{number}.json"
1186
+ return base_dir / filename
1187
+
1188
+ async def setup_credential(
1189
+ self, base_dir: Optional[Path] = None
1190
+ ) -> QwenCredentialSetupResult:
1191
+ """
1192
+ Complete credential setup flow: OAuth -> save.
1193
+
1194
+ This is the main entry point for setting up new credentials.
1195
+ """
1196
+ if base_dir is None:
1197
+ base_dir = self._get_oauth_base_dir()
1198
+
1199
+ # Ensure directory exists
1200
+ base_dir.mkdir(exist_ok=True)
1201
+
1202
+ try:
1203
+ # Step 1: Perform OAuth authentication
1204
+ temp_creds = {
1205
+ "_proxy_metadata": {"display_name": "new Qwen Code credential"}
1206
+ }
1207
+ new_creds = await self.initialize_token(temp_creds)
1208
+
1209
+ # Step 2: Get user info for deduplication
1210
+ email = new_creds.get("_proxy_metadata", {}).get("email")
1211
+
1212
+ if not email:
1213
+ return QwenCredentialSetupResult(
1214
+ success=False, error="Could not retrieve email from OAuth response"
1215
+ )
1216
+
1217
+ # Step 3: Check for existing credential with same email
1218
+ existing_path = self._find_existing_credential_by_email(email, base_dir)
1219
+ is_update = existing_path is not None
1220
+
1221
+ if is_update:
1222
+ file_path = existing_path
1223
+ lib_logger.info(
1224
+ f"Found existing credential for {email}, updating {file_path.name}"
1225
+ )
1226
+ else:
1227
+ file_path = self._build_credential_path(base_dir)
1228
+ lib_logger.info(
1229
+ f"Creating new credential for {email} at {file_path.name}"
1230
+ )
1231
+
1232
+ # Step 4: Save credentials to file
1233
+ await self._save_credentials(str(file_path), new_creds)
1234
+
1235
+ return QwenCredentialSetupResult(
1236
+ success=True,
1237
+ file_path=str(file_path),
1238
+ email=email,
1239
+ is_update=is_update,
1240
+ credentials=new_creds,
1241
+ )
1242
+
1243
+ except Exception as e:
1244
+ lib_logger.error(f"Credential setup failed: {e}")
1245
+ return QwenCredentialSetupResult(success=False, error=str(e))
1246
+
1247
+ def build_env_lines(self, creds: Dict[str, Any], cred_number: int) -> List[str]:
1248
+ """Generate .env file lines for a Qwen credential."""
1249
+ email = creds.get("_proxy_metadata", {}).get("email", "unknown")
1250
+ prefix = f"QWEN_CODE_{cred_number}"
1251
+
1252
+ lines = [
1253
+ f"# QWEN_CODE Credential #{cred_number} for: {email}",
1254
+ f"# Exported from: qwen_code_oauth_{cred_number}.json",
1255
+ f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
1256
+ "#",
1257
+ "# To combine multiple credentials into one .env file, copy these lines",
1258
+ "# and ensure each credential has a unique number (1, 2, 3, etc.)",
1259
+ "",
1260
+ f"{prefix}_ACCESS_TOKEN={creds.get('access_token', '')}",
1261
+ f"{prefix}_REFRESH_TOKEN={creds.get('refresh_token', '')}",
1262
+ f"{prefix}_EXPIRY_DATE={creds.get('expiry_date', 0)}",
1263
+ f"{prefix}_RESOURCE_URL={creds.get('resource_url', 'https://portal.qwen.ai/v1')}",
1264
+ f"{prefix}_EMAIL={email}",
1265
+ ]
1266
+
1267
+ return lines
1268
+
1269
+ def export_credential_to_env(
1270
+ self, credential_path: str, output_dir: Optional[Path] = None
1271
+ ) -> Optional[str]:
1272
+ """Export a credential file to .env format."""
1273
+ try:
1274
+ cred_path = Path(credential_path)
1275
+
1276
+ # Load credential
1277
+ with open(cred_path, "r") as f:
1278
+ creds = json.load(f)
1279
+
1280
+ # Extract metadata
1281
+ email = creds.get("_proxy_metadata", {}).get("email", "unknown")
1282
+
1283
+ # Get credential number from filename
1284
+ match = re.search(r"_oauth_(\d+)\.json$", cred_path.name)
1285
+ cred_number = int(match.group(1)) if match else 1
1286
+
1287
+ # Build output path
1288
+ if output_dir is None:
1289
+ output_dir = cred_path.parent
1290
+
1291
+ safe_email = email.replace("@", "_at_").replace(".", "_")
1292
+ env_filename = f"qwen_code_{cred_number}_{safe_email}.env"
1293
+ env_path = output_dir / env_filename
1294
+
1295
+ # Build and write content
1296
+ env_lines = self.build_env_lines(creds, cred_number)
1297
+ with open(env_path, "w") as f:
1298
+ f.write("\n".join(env_lines))
1299
+
1300
+ lib_logger.info(f"Exported credential to {env_path}")
1301
+ return str(env_path)
1302
+
1303
+ except Exception as e:
1304
+ lib_logger.error(f"Failed to export credential: {e}")
1305
+ return None
1306
+
1307
+ def list_credentials(self, base_dir: Optional[Path] = None) -> List[Dict[str, Any]]:
1308
+ """List all Qwen credential files."""
1309
+ if base_dir is None:
1310
+ base_dir = self._get_oauth_base_dir()
1311
+
1312
+ prefix = self._get_provider_file_prefix()
1313
+ pattern = str(base_dir / f"{prefix}_oauth_*.json")
1314
+
1315
+ credentials = []
1316
+ for cred_file in sorted(glob(pattern)):
1317
+ try:
1318
+ with open(cred_file, "r") as f:
1319
+ creds = json.load(f)
1320
+
1321
+ metadata = creds.get("_proxy_metadata", {})
1322
+
1323
+ # Extract number from filename
1324
+ match = re.search(r"_oauth_(\d+)\.json$", cred_file)
1325
+ number = int(match.group(1)) if match else 0
1326
+
1327
+ credentials.append(
1328
+ {
1329
+ "file_path": cred_file,
1330
+ "email": metadata.get("email", "unknown"),
1331
+ "number": number,
1332
+ }
1333
+ )
1334
+ except Exception as e:
1335
+ lib_logger.debug(f"Could not read credential file {cred_file}: {e}")
1336
+ continue
1337
+
1338
+ return credentials
1339
+
1340
+ def delete_credential(self, credential_path: str) -> bool:
1341
+ """Delete a credential file."""
1342
+ try:
1343
+ cred_path = Path(credential_path)
1344
+
1345
+ # Validate that it's one of our credential files
1346
+ prefix = self._get_provider_file_prefix()
1347
+ if not cred_path.name.startswith(f"{prefix}_oauth_"):
1348
+ lib_logger.error(
1349
+ f"File {cred_path.name} does not appear to be a Qwen Code credential"
1350
+ )
1351
+ return False
1352
+
1353
+ if not cred_path.exists():
1354
+ lib_logger.warning(f"Credential file does not exist: {credential_path}")
1355
+ return False
1356
+
1357
+ # Remove from cache if present
1358
+ self._credentials_cache.pop(credential_path, None)
1359
+
1360
+ # Delete the file
1361
+ cred_path.unlink()
1362
+ lib_logger.info(f"Deleted credential file: {credential_path}")
1363
+ return True
1364
+
1365
+ except Exception as e:
1366
+ lib_logger.error(f"Failed to delete credential: {e}")
1367
+ return False
src/rotator_library/providers/qwen_code_provider.py CHANGED
@@ -10,19 +10,27 @@ from typing import Union, AsyncGenerator, List, Dict, Any
10
  from .provider_interface import ProviderInterface
11
  from .qwen_auth_base import QwenAuthBase
12
  from ..model_definitions import ModelDefinitions
 
 
13
  import litellm
14
  from litellm.exceptions import RateLimitError, AuthenticationError
15
  from pathlib import Path
16
  import uuid
17
  from datetime import datetime
18
 
19
- lib_logger = logging.getLogger('rotator_library')
 
 
 
 
 
 
 
20
 
21
- LOGS_DIR = Path(__file__).resolve().parent.parent.parent.parent / "logs"
22
- QWEN_CODE_LOGS_DIR = LOGS_DIR / "qwen_code_logs"
23
 
24
  class _QwenCodeFileLogger:
25
  """A simple file logger for a single Qwen Code transaction."""
 
26
  def __init__(self, model_name: str, enabled: bool = True):
27
  self.enabled = enabled
28
  if not self.enabled:
@@ -31,8 +39,10 @@ class _QwenCodeFileLogger:
31
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
32
  request_id = str(uuid.uuid4())
33
  # Sanitize model name for directory
34
- safe_model_name = model_name.replace('/', '_').replace(':', '_')
35
- self.log_dir = QWEN_CODE_LOGS_DIR / f"{timestamp}_{safe_model_name}_{request_id}"
 
 
36
  try:
37
  self.log_dir.mkdir(parents=True, exist_ok=True)
38
  except Exception as e:
@@ -41,25 +51,32 @@ class _QwenCodeFileLogger:
41
 
42
  def log_request(self, payload: Dict[str, Any]):
43
  """Logs the request payload sent to Qwen Code."""
44
- if not self.enabled: return
 
45
  try:
46
- with open(self.log_dir / "request_payload.json", "w", encoding="utf-8") as f:
 
 
47
  json.dump(payload, f, indent=2, ensure_ascii=False)
48
  except Exception as e:
49
  lib_logger.error(f"_QwenCodeFileLogger: Failed to write request: {e}")
50
 
51
  def log_response_chunk(self, chunk: str):
52
  """Logs a raw chunk from the Qwen Code response stream."""
53
- if not self.enabled: return
 
54
  try:
55
  with open(self.log_dir / "response_stream.log", "a", encoding="utf-8") as f:
56
  f.write(chunk + "\n")
57
  except Exception as e:
58
- lib_logger.error(f"_QwenCodeFileLogger: Failed to write response chunk: {e}")
 
 
59
 
60
  def log_error(self, error_message: str):
61
  """Logs an error message."""
62
- if not self.enabled: return
 
63
  try:
64
  with open(self.log_dir / "error.log", "a", encoding="utf-8") as f:
65
  f.write(f"[{datetime.utcnow().isoformat()}] {error_message}\n")
@@ -68,28 +85,41 @@ class _QwenCodeFileLogger:
68
 
69
  def log_final_response(self, response_data: Dict[str, Any]):
70
  """Logs the final, reassembled response."""
71
- if not self.enabled: return
 
72
  try:
73
  with open(self.log_dir / "final_response.json", "w", encoding="utf-8") as f:
74
  json.dump(response_data, f, indent=2, ensure_ascii=False)
75
  except Exception as e:
76
- lib_logger.error(f"_QwenCodeFileLogger: Failed to write final response: {e}")
 
 
 
77
 
78
- HARDCODED_MODELS = [
79
- "qwen3-coder-plus",
80
- "qwen3-coder-flash"
81
- ]
82
 
83
  # OpenAI-compatible parameters supported by Qwen Code API
84
  SUPPORTED_PARAMS = {
85
- 'model', 'messages', 'temperature', 'top_p', 'max_tokens',
86
- 'stream', 'tools', 'tool_choice', 'presence_penalty',
87
- 'frequency_penalty', 'n', 'stop', 'seed', 'response_format'
 
 
 
 
 
 
 
 
 
 
 
88
  }
89
 
 
90
  class QwenCodeProvider(QwenAuthBase, ProviderInterface):
91
  skip_cost_calculation = True
92
- REASONING_START_MARKER = 'THINK||'
93
 
94
  def __init__(self):
95
  super().__init__()
@@ -111,7 +141,9 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
111
  Validates OAuth credentials if applicable.
112
  """
113
  models = []
114
- env_var_ids = set() # Track IDs from env vars to prevent hardcoded/dynamic duplicates
 
 
115
 
116
  def extract_model_id(item) -> str:
117
  """Extract model ID from various formats (dict, string with/without provider prefix)."""
@@ -137,7 +169,9 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
137
  # Track the ID to prevent hardcoded/dynamic duplicates
138
  if model_id:
139
  env_var_ids.add(model_id)
140
- lib_logger.info(f"Loaded {len(static_models)} static models for qwen_code from environment variables")
 
 
141
 
142
  # Source 2: Add hardcoded models (only if ID not already in env vars)
143
  for model_id in HARDCODED_MODELS:
@@ -155,14 +189,17 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
155
  models_url = f"{api_base.rstrip('/')}/v1/models"
156
 
157
  response = await client.get(
158
- models_url,
159
- headers={"Authorization": f"Bearer {access_token}"}
160
  )
161
  response.raise_for_status()
162
 
163
  dynamic_data = response.json()
164
  # Handle both {data: [...]} and direct [...] formats
165
- model_list = dynamic_data.get("data", dynamic_data) if isinstance(dynamic_data, dict) else dynamic_data
 
 
 
 
166
 
167
  dynamic_count = 0
168
  for model in model_list:
@@ -173,7 +210,9 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
173
  dynamic_count += 1
174
 
175
  if dynamic_count > 0:
176
- lib_logger.debug(f"Discovered {dynamic_count} additional models for qwen_code from API")
 
 
177
 
178
  except Exception as e:
179
  # Silently ignore dynamic discovery errors
@@ -238,10 +277,10 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
238
  payload = {k: v for k, v in kwargs.items() if k in SUPPORTED_PARAMS}
239
 
240
  # Always force streaming for internal processing
241
- payload['stream'] = True
242
 
243
  # Always include usage data in stream
244
- payload['stream_options'] = {"include_usage": True}
245
 
246
  # Handle tool schema cleaning
247
  if "tools" in payload and payload["tools"]:
@@ -250,22 +289,26 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
250
  elif not payload.get("tools"):
251
  # Per Qwen Code API bug (see: https://github.com/qianwen-team/flash-dance/issues/2),
252
  # injecting a dummy tool prevents stream corruption when no tools are provided
253
- payload["tools"] = [{
254
- "type": "function",
255
- "function": {
256
- "name": "do_not_call_me",
257
- "description": "Do not call this tool.",
258
- "parameters": {"type": "object", "properties": {}}
 
 
259
  }
260
- }]
261
- lib_logger.debug("Injected dummy tool to prevent Qwen API stream corruption")
 
 
262
 
263
  return payload
264
 
265
  def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str):
266
  """
267
  Converts a raw Qwen SSE chunk to an OpenAI-compatible chunk.
268
-
269
  CRITICAL FIX: Handle chunks with BOTH usage and choices (final chunk)
270
  without early return to ensure finish_reason is properly processed.
271
  """
@@ -287,32 +330,42 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
287
 
288
  # Yield the choice chunk first (contains finish_reason)
289
  yield {
290
- "choices": [{"index": 0, "delta": delta, "finish_reason": finish_reason}],
291
- "model": model_id, "object": "chat.completion.chunk",
292
- "id": chunk_id, "created": chunk_created
 
 
 
 
293
  }
294
  # Then yield the usage chunk
295
  yield {
296
- "choices": [], "model": model_id, "object": "chat.completion.chunk",
297
- "id": chunk_id, "created": chunk_created,
 
 
 
298
  "usage": {
299
  "prompt_tokens": usage_data.get("prompt_tokens", 0),
300
  "completion_tokens": usage_data.get("completion_tokens", 0),
301
  "total_tokens": usage_data.get("total_tokens", 0),
302
- }
303
  }
304
  return
305
 
306
  # Handle usage-only chunks
307
  if usage_data:
308
  yield {
309
- "choices": [], "model": model_id, "object": "chat.completion.chunk",
310
- "id": chunk_id, "created": chunk_created,
 
 
 
311
  "usage": {
312
  "prompt_tokens": usage_data.get("prompt_tokens", 0),
313
  "completion_tokens": usage_data.get("completion_tokens", 0),
314
  "total_tokens": usage_data.get("total_tokens", 0),
315
- }
316
  }
317
  return
318
 
@@ -327,35 +380,52 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
327
  # Handle <think> tags for reasoning content
328
  content = delta.get("content")
329
  if content and ("<think>" in content or "</think>" in content):
330
- parts = content.replace("<think>", f"||{self.REASONING_START_MARKER}").replace("</think>", f"||/{self.REASONING_START_MARKER}").split("||")
 
 
 
 
331
  for part in parts:
332
- if not part: continue
333
-
 
334
  new_delta = {}
335
  if part.startswith(self.REASONING_START_MARKER):
336
- new_delta['reasoning_content'] = part.replace(self.REASONING_START_MARKER, "")
 
 
337
  elif part.startswith(f"/{self.REASONING_START_MARKER}"):
338
  continue
339
  else:
340
- new_delta['content'] = part
341
-
342
  yield {
343
- "choices": [{"index": 0, "delta": new_delta, "finish_reason": None}],
344
- "model": model_id, "object": "chat.completion.chunk",
345
- "id": chunk_id, "created": chunk_created
 
 
 
 
346
  }
347
  else:
348
  # Standard content chunk
349
  yield {
350
- "choices": [{"index": 0, "delta": delta, "finish_reason": finish_reason}],
351
- "model": model_id, "object": "chat.completion.chunk",
352
- "id": chunk_id, "created": chunk_created
 
 
 
 
353
  }
354
 
355
- def _stream_to_completion_response(self, chunks: List[litellm.ModelResponse]) -> litellm.ModelResponse:
 
 
356
  """
357
  Manually reassembles streaming chunks into a complete response.
358
-
359
  Key improvements:
360
  - Determines finish_reason based on accumulated state (tool_calls vs stop)
361
  - Properly initializes tool_calls with type field
@@ -368,14 +438,16 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
368
  final_message = {"role": "assistant"}
369
  aggregated_tool_calls = {}
370
  usage_data = None
371
- chunk_finish_reason = None # Track finish_reason from chunks (but we'll override)
 
 
372
 
373
  # Get the first chunk for basic response metadata
374
  first_chunk = chunks[0]
375
 
376
  # Process each chunk to aggregate content
377
  for chunk in chunks:
378
- if not hasattr(chunk, 'choices') or not chunk.choices:
379
  continue
380
 
381
  choice = chunk.choices[0]
@@ -399,25 +471,48 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
399
  index = tc_chunk.get("index", 0)
400
  if index not in aggregated_tool_calls:
401
  # Initialize with type field for OpenAI compatibility
402
- aggregated_tool_calls[index] = {"type": "function", "function": {"name": "", "arguments": ""}}
 
 
 
403
  if "id" in tc_chunk:
404
  aggregated_tool_calls[index]["id"] = tc_chunk["id"]
405
  if "type" in tc_chunk:
406
  aggregated_tool_calls[index]["type"] = tc_chunk["type"]
407
  if "function" in tc_chunk:
408
- if "name" in tc_chunk["function"] and tc_chunk["function"]["name"] is not None:
409
- aggregated_tool_calls[index]["function"]["name"] += tc_chunk["function"]["name"]
410
- if "arguments" in tc_chunk["function"] and tc_chunk["function"]["arguments"] is not None:
411
- aggregated_tool_calls[index]["function"]["arguments"] += tc_chunk["function"]["arguments"]
 
 
 
 
 
 
 
 
 
 
412
 
413
  # Aggregate function calls (legacy format)
414
  if "function_call" in delta and delta["function_call"] is not None:
415
  if "function_call" not in final_message:
416
  final_message["function_call"] = {"name": "", "arguments": ""}
417
- if "name" in delta["function_call"] and delta["function_call"]["name"] is not None:
418
- final_message["function_call"]["name"] += delta["function_call"]["name"]
419
- if "arguments" in delta["function_call"] and delta["function_call"]["arguments"] is not None:
420
- final_message["function_call"]["arguments"] += delta["function_call"]["arguments"]
 
 
 
 
 
 
 
 
 
 
421
 
422
  # Track finish_reason from chunks (for reference only)
423
  if choice.get("finish_reason"):
@@ -425,7 +520,7 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
425
 
426
  # Handle usage data from the last chunk that has it
427
  for chunk in reversed(chunks):
428
- if hasattr(chunk, 'usage') and chunk.usage:
429
  usage_data = chunk.usage
430
  break
431
 
@@ -451,7 +546,7 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
451
  final_choice = {
452
  "index": 0,
453
  "message": final_message,
454
- "finish_reason": finish_reason
455
  }
456
 
457
  # Create the final ModelResponse
@@ -461,20 +556,21 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
461
  "created": first_chunk.created,
462
  "model": first_chunk.model,
463
  "choices": [final_choice],
464
- "usage": usage_data
465
  }
466
 
467
  return litellm.ModelResponse(**final_response_data)
468
 
469
- async def acompletion(self, client: httpx.AsyncClient, **kwargs) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]:
 
 
470
  credential_path = kwargs.pop("credential_identifier")
471
  enable_request_logging = kwargs.pop("enable_request_logging", False)
472
  model = kwargs["model"]
473
 
474
  # Create dedicated file logger for this request
475
  file_logger = _QwenCodeFileLogger(
476
- model_name=model,
477
- enabled=enable_request_logging
478
  )
479
 
480
  async def make_request():
@@ -482,8 +578,8 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
482
  api_base, access_token = await self.get_api_details(credential_path)
483
 
484
  # Strip provider prefix from model name (e.g., "qwen_code/qwen3-coder-plus" -> "qwen3-coder-plus")
485
- model_name = model.split('/')[-1]
486
- kwargs_with_stripped_model = {**kwargs, 'model': model_name}
487
 
488
  # Build clean payload with only supported parameters
489
  payload = self._build_request_payload(**kwargs_with_stripped_model)
@@ -503,7 +599,13 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
503
  file_logger.log_request(payload)
504
  lib_logger.debug(f"Qwen Code Request URL: {url}")
505
 
506
- return client.stream("POST", url, headers=headers, json=payload, timeout=600)
 
 
 
 
 
 
507
 
508
  async def stream_handler(response_stream, attempt=1):
509
  """Handles the streaming response and converts chunks."""
@@ -512,11 +614,17 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
512
  # Check for HTTP errors before processing stream
513
  if response.status_code >= 400:
514
  error_text = await response.aread()
515
- error_text = error_text.decode('utf-8') if isinstance(error_text, bytes) else error_text
 
 
 
 
516
 
517
  # Handle 401: Force token refresh and retry once
518
  if response.status_code == 401 and attempt == 1:
519
- lib_logger.warning("Qwen Code returned 401. Forcing token refresh and retrying once.")
 
 
520
  await self._refresh_token(credential_path, force=True)
521
  retry_stream = await make_request()
522
  async for chunk in stream_handler(retry_stream, attempt=2):
@@ -524,12 +632,15 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
524
  return
525
 
526
  # Handle 429: Rate limit
527
- elif response.status_code == 429 or "slow_down" in error_text.lower():
 
 
 
528
  raise RateLimitError(
529
  f"Qwen Code rate limit exceeded: {error_text}",
530
  llm_provider="qwen_code",
531
  model=model,
532
- response=response
533
  )
534
 
535
  # Handle other errors
@@ -539,28 +650,34 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
539
  raise httpx.HTTPStatusError(
540
  f"HTTP {response.status_code}: {error_text}",
541
  request=response.request,
542
- response=response
543
  )
544
 
545
  # Process successful streaming response
546
  async for line in response.aiter_lines():
547
  file_logger.log_response_chunk(line)
548
- if line.startswith('data: '):
549
  data_str = line[6:]
550
  if data_str == "[DONE]":
551
  break
552
  try:
553
  chunk = json.loads(data_str)
554
- for openai_chunk in self._convert_chunk_to_openai(chunk, model):
 
 
555
  yield litellm.ModelResponse(**openai_chunk)
556
  except json.JSONDecodeError:
557
- lib_logger.warning(f"Could not decode JSON from Qwen Code: {line}")
 
 
558
 
559
  except httpx.HTTPStatusError:
560
  raise # Re-raise HTTP errors we already handled
561
  except Exception as e:
562
  file_logger.log_error(f"Error during Qwen Code stream processing: {e}")
563
- lib_logger.error(f"Error during Qwen Code stream processing: {e}", exc_info=True)
 
 
564
  raise
565
 
566
  async def logging_stream_wrapper():
@@ -578,7 +695,9 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
578
  if kwargs.get("stream"):
579
  return logging_stream_wrapper()
580
  else:
 
581
  async def non_stream_wrapper():
582
  chunks = [chunk async for chunk in logging_stream_wrapper()]
583
  return self._stream_to_completion_response(chunks)
584
- return await non_stream_wrapper()
 
 
10
  from .provider_interface import ProviderInterface
11
  from .qwen_auth_base import QwenAuthBase
12
  from ..model_definitions import ModelDefinitions
13
+ from ..timeout_config import TimeoutConfig
14
+ from ..utils.paths import get_logs_dir
15
  import litellm
16
  from litellm.exceptions import RateLimitError, AuthenticationError
17
  from pathlib import Path
18
  import uuid
19
  from datetime import datetime
20
 
21
+ lib_logger = logging.getLogger("rotator_library")
22
+
23
+
24
+ def _get_qwen_code_logs_dir() -> Path:
25
+ """Get the Qwen Code logs directory."""
26
+ logs_dir = get_logs_dir() / "qwen_code_logs"
27
+ logs_dir.mkdir(parents=True, exist_ok=True)
28
+ return logs_dir
29
 
 
 
30
 
31
  class _QwenCodeFileLogger:
32
  """A simple file logger for a single Qwen Code transaction."""
33
+
34
  def __init__(self, model_name: str, enabled: bool = True):
35
  self.enabled = enabled
36
  if not self.enabled:
 
39
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
40
  request_id = str(uuid.uuid4())
41
  # Sanitize model name for directory
42
+ safe_model_name = model_name.replace("/", "_").replace(":", "_")
43
+ self.log_dir = (
44
+ _get_qwen_code_logs_dir() / f"{timestamp}_{safe_model_name}_{request_id}"
45
+ )
46
  try:
47
  self.log_dir.mkdir(parents=True, exist_ok=True)
48
  except Exception as e:
 
51
 
52
  def log_request(self, payload: Dict[str, Any]):
53
  """Logs the request payload sent to Qwen Code."""
54
+ if not self.enabled:
55
+ return
56
  try:
57
+ with open(
58
+ self.log_dir / "request_payload.json", "w", encoding="utf-8"
59
+ ) as f:
60
  json.dump(payload, f, indent=2, ensure_ascii=False)
61
  except Exception as e:
62
  lib_logger.error(f"_QwenCodeFileLogger: Failed to write request: {e}")
63
 
64
  def log_response_chunk(self, chunk: str):
65
  """Logs a raw chunk from the Qwen Code response stream."""
66
+ if not self.enabled:
67
+ return
68
  try:
69
  with open(self.log_dir / "response_stream.log", "a", encoding="utf-8") as f:
70
  f.write(chunk + "\n")
71
  except Exception as e:
72
+ lib_logger.error(
73
+ f"_QwenCodeFileLogger: Failed to write response chunk: {e}"
74
+ )
75
 
76
  def log_error(self, error_message: str):
77
  """Logs an error message."""
78
+ if not self.enabled:
79
+ return
80
  try:
81
  with open(self.log_dir / "error.log", "a", encoding="utf-8") as f:
82
  f.write(f"[{datetime.utcnow().isoformat()}] {error_message}\n")
 
85
 
86
  def log_final_response(self, response_data: Dict[str, Any]):
87
  """Logs the final, reassembled response."""
88
+ if not self.enabled:
89
+ return
90
  try:
91
  with open(self.log_dir / "final_response.json", "w", encoding="utf-8") as f:
92
  json.dump(response_data, f, indent=2, ensure_ascii=False)
93
  except Exception as e:
94
+ lib_logger.error(
95
+ f"_QwenCodeFileLogger: Failed to write final response: {e}"
96
+ )
97
+
98
 
99
+ HARDCODED_MODELS = ["qwen3-coder-plus", "qwen3-coder-flash"]
 
 
 
100
 
101
  # OpenAI-compatible parameters supported by Qwen Code API
102
  SUPPORTED_PARAMS = {
103
+ "model",
104
+ "messages",
105
+ "temperature",
106
+ "top_p",
107
+ "max_tokens",
108
+ "stream",
109
+ "tools",
110
+ "tool_choice",
111
+ "presence_penalty",
112
+ "frequency_penalty",
113
+ "n",
114
+ "stop",
115
+ "seed",
116
+ "response_format",
117
  }
118
 
119
+
120
  class QwenCodeProvider(QwenAuthBase, ProviderInterface):
121
  skip_cost_calculation = True
122
+ REASONING_START_MARKER = "THINK||"
123
 
124
  def __init__(self):
125
  super().__init__()
 
141
  Validates OAuth credentials if applicable.
142
  """
143
  models = []
144
+ env_var_ids = (
145
+ set()
146
+ ) # Track IDs from env vars to prevent hardcoded/dynamic duplicates
147
 
148
  def extract_model_id(item) -> str:
149
  """Extract model ID from various formats (dict, string with/without provider prefix)."""
 
169
  # Track the ID to prevent hardcoded/dynamic duplicates
170
  if model_id:
171
  env_var_ids.add(model_id)
172
+ lib_logger.info(
173
+ f"Loaded {len(static_models)} static models for qwen_code from environment variables"
174
+ )
175
 
176
  # Source 2: Add hardcoded models (only if ID not already in env vars)
177
  for model_id in HARDCODED_MODELS:
 
189
  models_url = f"{api_base.rstrip('/')}/v1/models"
190
 
191
  response = await client.get(
192
+ models_url, headers={"Authorization": f"Bearer {access_token}"}
 
193
  )
194
  response.raise_for_status()
195
 
196
  dynamic_data = response.json()
197
  # Handle both {data: [...]} and direct [...] formats
198
+ model_list = (
199
+ dynamic_data.get("data", dynamic_data)
200
+ if isinstance(dynamic_data, dict)
201
+ else dynamic_data
202
+ )
203
 
204
  dynamic_count = 0
205
  for model in model_list:
 
210
  dynamic_count += 1
211
 
212
  if dynamic_count > 0:
213
+ lib_logger.debug(
214
+ f"Discovered {dynamic_count} additional models for qwen_code from API"
215
+ )
216
 
217
  except Exception as e:
218
  # Silently ignore dynamic discovery errors
 
277
  payload = {k: v for k, v in kwargs.items() if k in SUPPORTED_PARAMS}
278
 
279
  # Always force streaming for internal processing
280
+ payload["stream"] = True
281
 
282
  # Always include usage data in stream
283
+ payload["stream_options"] = {"include_usage": True}
284
 
285
  # Handle tool schema cleaning
286
  if "tools" in payload and payload["tools"]:
 
289
  elif not payload.get("tools"):
290
  # Per Qwen Code API bug (see: https://github.com/qianwen-team/flash-dance/issues/2),
291
  # injecting a dummy tool prevents stream corruption when no tools are provided
292
+ payload["tools"] = [
293
+ {
294
+ "type": "function",
295
+ "function": {
296
+ "name": "do_not_call_me",
297
+ "description": "Do not call this tool.",
298
+ "parameters": {"type": "object", "properties": {}},
299
+ },
300
  }
301
+ ]
302
+ lib_logger.debug(
303
+ "Injected dummy tool to prevent Qwen API stream corruption"
304
+ )
305
 
306
  return payload
307
 
308
  def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str):
309
  """
310
  Converts a raw Qwen SSE chunk to an OpenAI-compatible chunk.
311
+
312
  CRITICAL FIX: Handle chunks with BOTH usage and choices (final chunk)
313
  without early return to ensure finish_reason is properly processed.
314
  """
 
330
 
331
  # Yield the choice chunk first (contains finish_reason)
332
  yield {
333
+ "choices": [
334
+ {"index": 0, "delta": delta, "finish_reason": finish_reason}
335
+ ],
336
+ "model": model_id,
337
+ "object": "chat.completion.chunk",
338
+ "id": chunk_id,
339
+ "created": chunk_created,
340
  }
341
  # Then yield the usage chunk
342
  yield {
343
+ "choices": [],
344
+ "model": model_id,
345
+ "object": "chat.completion.chunk",
346
+ "id": chunk_id,
347
+ "created": chunk_created,
348
  "usage": {
349
  "prompt_tokens": usage_data.get("prompt_tokens", 0),
350
  "completion_tokens": usage_data.get("completion_tokens", 0),
351
  "total_tokens": usage_data.get("total_tokens", 0),
352
+ },
353
  }
354
  return
355
 
356
  # Handle usage-only chunks
357
  if usage_data:
358
  yield {
359
+ "choices": [],
360
+ "model": model_id,
361
+ "object": "chat.completion.chunk",
362
+ "id": chunk_id,
363
+ "created": chunk_created,
364
  "usage": {
365
  "prompt_tokens": usage_data.get("prompt_tokens", 0),
366
  "completion_tokens": usage_data.get("completion_tokens", 0),
367
  "total_tokens": usage_data.get("total_tokens", 0),
368
+ },
369
  }
370
  return
371
 
 
380
  # Handle <think> tags for reasoning content
381
  content = delta.get("content")
382
  if content and ("<think>" in content or "</think>" in content):
383
+ parts = (
384
+ content.replace("<think>", f"||{self.REASONING_START_MARKER}")
385
+ .replace("</think>", f"||/{self.REASONING_START_MARKER}")
386
+ .split("||")
387
+ )
388
  for part in parts:
389
+ if not part:
390
+ continue
391
+
392
  new_delta = {}
393
  if part.startswith(self.REASONING_START_MARKER):
394
+ new_delta["reasoning_content"] = part.replace(
395
+ self.REASONING_START_MARKER, ""
396
+ )
397
  elif part.startswith(f"/{self.REASONING_START_MARKER}"):
398
  continue
399
  else:
400
+ new_delta["content"] = part
401
+
402
  yield {
403
+ "choices": [
404
+ {"index": 0, "delta": new_delta, "finish_reason": None}
405
+ ],
406
+ "model": model_id,
407
+ "object": "chat.completion.chunk",
408
+ "id": chunk_id,
409
+ "created": chunk_created,
410
  }
411
  else:
412
  # Standard content chunk
413
  yield {
414
+ "choices": [
415
+ {"index": 0, "delta": delta, "finish_reason": finish_reason}
416
+ ],
417
+ "model": model_id,
418
+ "object": "chat.completion.chunk",
419
+ "id": chunk_id,
420
+ "created": chunk_created,
421
  }
422
 
423
+ def _stream_to_completion_response(
424
+ self, chunks: List[litellm.ModelResponse]
425
+ ) -> litellm.ModelResponse:
426
  """
427
  Manually reassembles streaming chunks into a complete response.
428
+
429
  Key improvements:
430
  - Determines finish_reason based on accumulated state (tool_calls vs stop)
431
  - Properly initializes tool_calls with type field
 
438
  final_message = {"role": "assistant"}
439
  aggregated_tool_calls = {}
440
  usage_data = None
441
+ chunk_finish_reason = (
442
+ None # Track finish_reason from chunks (but we'll override)
443
+ )
444
 
445
  # Get the first chunk for basic response metadata
446
  first_chunk = chunks[0]
447
 
448
  # Process each chunk to aggregate content
449
  for chunk in chunks:
450
+ if not hasattr(chunk, "choices") or not chunk.choices:
451
  continue
452
 
453
  choice = chunk.choices[0]
 
471
  index = tc_chunk.get("index", 0)
472
  if index not in aggregated_tool_calls:
473
  # Initialize with type field for OpenAI compatibility
474
+ aggregated_tool_calls[index] = {
475
+ "type": "function",
476
+ "function": {"name": "", "arguments": ""},
477
+ }
478
  if "id" in tc_chunk:
479
  aggregated_tool_calls[index]["id"] = tc_chunk["id"]
480
  if "type" in tc_chunk:
481
  aggregated_tool_calls[index]["type"] = tc_chunk["type"]
482
  if "function" in tc_chunk:
483
+ if (
484
+ "name" in tc_chunk["function"]
485
+ and tc_chunk["function"]["name"] is not None
486
+ ):
487
+ aggregated_tool_calls[index]["function"]["name"] += (
488
+ tc_chunk["function"]["name"]
489
+ )
490
+ if (
491
+ "arguments" in tc_chunk["function"]
492
+ and tc_chunk["function"]["arguments"] is not None
493
+ ):
494
+ aggregated_tool_calls[index]["function"]["arguments"] += (
495
+ tc_chunk["function"]["arguments"]
496
+ )
497
 
498
  # Aggregate function calls (legacy format)
499
  if "function_call" in delta and delta["function_call"] is not None:
500
  if "function_call" not in final_message:
501
  final_message["function_call"] = {"name": "", "arguments": ""}
502
+ if (
503
+ "name" in delta["function_call"]
504
+ and delta["function_call"]["name"] is not None
505
+ ):
506
+ final_message["function_call"]["name"] += delta["function_call"][
507
+ "name"
508
+ ]
509
+ if (
510
+ "arguments" in delta["function_call"]
511
+ and delta["function_call"]["arguments"] is not None
512
+ ):
513
+ final_message["function_call"]["arguments"] += delta[
514
+ "function_call"
515
+ ]["arguments"]
516
 
517
  # Track finish_reason from chunks (for reference only)
518
  if choice.get("finish_reason"):
 
520
 
521
  # Handle usage data from the last chunk that has it
522
  for chunk in reversed(chunks):
523
+ if hasattr(chunk, "usage") and chunk.usage:
524
  usage_data = chunk.usage
525
  break
526
 
 
546
  final_choice = {
547
  "index": 0,
548
  "message": final_message,
549
+ "finish_reason": finish_reason,
550
  }
551
 
552
  # Create the final ModelResponse
 
556
  "created": first_chunk.created,
557
  "model": first_chunk.model,
558
  "choices": [final_choice],
559
+ "usage": usage_data,
560
  }
561
 
562
  return litellm.ModelResponse(**final_response_data)
563
 
564
+ async def acompletion(
565
+ self, client: httpx.AsyncClient, **kwargs
566
+ ) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]:
567
  credential_path = kwargs.pop("credential_identifier")
568
  enable_request_logging = kwargs.pop("enable_request_logging", False)
569
  model = kwargs["model"]
570
 
571
  # Create dedicated file logger for this request
572
  file_logger = _QwenCodeFileLogger(
573
+ model_name=model, enabled=enable_request_logging
 
574
  )
575
 
576
  async def make_request():
 
578
  api_base, access_token = await self.get_api_details(credential_path)
579
 
580
  # Strip provider prefix from model name (e.g., "qwen_code/qwen3-coder-plus" -> "qwen3-coder-plus")
581
+ model_name = model.split("/")[-1]
582
+ kwargs_with_stripped_model = {**kwargs, "model": model_name}
583
 
584
  # Build clean payload with only supported parameters
585
  payload = self._build_request_payload(**kwargs_with_stripped_model)
 
599
  file_logger.log_request(payload)
600
  lib_logger.debug(f"Qwen Code Request URL: {url}")
601
 
602
+ return client.stream(
603
+ "POST",
604
+ url,
605
+ headers=headers,
606
+ json=payload,
607
+ timeout=TimeoutConfig.streaming(),
608
+ )
609
 
610
  async def stream_handler(response_stream, attempt=1):
611
  """Handles the streaming response and converts chunks."""
 
614
  # Check for HTTP errors before processing stream
615
  if response.status_code >= 400:
616
  error_text = await response.aread()
617
+ error_text = (
618
+ error_text.decode("utf-8")
619
+ if isinstance(error_text, bytes)
620
+ else error_text
621
+ )
622
 
623
  # Handle 401: Force token refresh and retry once
624
  if response.status_code == 401 and attempt == 1:
625
+ lib_logger.warning(
626
+ "Qwen Code returned 401. Forcing token refresh and retrying once."
627
+ )
628
  await self._refresh_token(credential_path, force=True)
629
  retry_stream = await make_request()
630
  async for chunk in stream_handler(retry_stream, attempt=2):
 
632
  return
633
 
634
  # Handle 429: Rate limit
635
+ elif (
636
+ response.status_code == 429
637
+ or "slow_down" in error_text.lower()
638
+ ):
639
  raise RateLimitError(
640
  f"Qwen Code rate limit exceeded: {error_text}",
641
  llm_provider="qwen_code",
642
  model=model,
643
+ response=response,
644
  )
645
 
646
  # Handle other errors
 
650
  raise httpx.HTTPStatusError(
651
  f"HTTP {response.status_code}: {error_text}",
652
  request=response.request,
653
+ response=response,
654
  )
655
 
656
  # Process successful streaming response
657
  async for line in response.aiter_lines():
658
  file_logger.log_response_chunk(line)
659
+ if line.startswith("data: "):
660
  data_str = line[6:]
661
  if data_str == "[DONE]":
662
  break
663
  try:
664
  chunk = json.loads(data_str)
665
+ for openai_chunk in self._convert_chunk_to_openai(
666
+ chunk, model
667
+ ):
668
  yield litellm.ModelResponse(**openai_chunk)
669
  except json.JSONDecodeError:
670
+ lib_logger.warning(
671
+ f"Could not decode JSON from Qwen Code: {line}"
672
+ )
673
 
674
  except httpx.HTTPStatusError:
675
  raise # Re-raise HTTP errors we already handled
676
  except Exception as e:
677
  file_logger.log_error(f"Error during Qwen Code stream processing: {e}")
678
+ lib_logger.error(
679
+ f"Error during Qwen Code stream processing: {e}", exc_info=True
680
+ )
681
  raise
682
 
683
  async def logging_stream_wrapper():
 
695
  if kwargs.get("stream"):
696
  return logging_stream_wrapper()
697
  else:
698
+
699
  async def non_stream_wrapper():
700
  chunks = [chunk async for chunk in logging_stream_wrapper()]
701
  return self._stream_to_completion_response(chunks)
702
+
703
+ return await non_stream_wrapper()
src/rotator_library/timeout_config.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/rotator_library/timeout_config.py
2
+ """
3
+ Centralized timeout configuration for HTTP requests.
4
+
5
+ All values can be overridden via environment variables:
6
+ TIMEOUT_CONNECT - Connection establishment timeout (default: 30s)
7
+ TIMEOUT_WRITE - Request body send timeout (default: 30s)
8
+ TIMEOUT_POOL - Connection pool acquisition timeout (default: 60s)
9
+ TIMEOUT_READ_STREAMING - Read timeout between chunks for streaming (default: 180s / 3 min)
10
+ TIMEOUT_READ_NON_STREAMING - Read timeout for non-streaming responses (default: 600s / 10 min)
11
+ """
12
+
13
+ import os
14
+ import logging
15
+ import httpx
16
+
17
+ lib_logger = logging.getLogger("rotator_library")
18
+
19
+
20
+ class TimeoutConfig:
21
+ """
22
+ Centralized timeout configuration for HTTP requests.
23
+
24
+ All values can be overridden via environment variables.
25
+ """
26
+
27
+ # Default values (in seconds)
28
+ _CONNECT = 30.0
29
+ _WRITE = 30.0
30
+ _POOL = 60.0
31
+ _READ_STREAMING = 180.0 # 3 minutes between chunks
32
+ _READ_NON_STREAMING = 600.0 # 10 minutes for full response
33
+
34
+ @classmethod
35
+ def _get_env_float(cls, key: str, default: float) -> float:
36
+ """Get a float value from environment variable, or return default."""
37
+ value = os.environ.get(key)
38
+ if value is not None:
39
+ try:
40
+ return float(value)
41
+ except ValueError:
42
+ lib_logger.warning(
43
+ f"Invalid value for {key}: {value}. Using default: {default}"
44
+ )
45
+ return default
46
+
47
+ @classmethod
48
+ def connect(cls) -> float:
49
+ """Connection establishment timeout."""
50
+ return cls._get_env_float("TIMEOUT_CONNECT", cls._CONNECT)
51
+
52
+ @classmethod
53
+ def write(cls) -> float:
54
+ """Request body send timeout."""
55
+ return cls._get_env_float("TIMEOUT_WRITE", cls._WRITE)
56
+
57
+ @classmethod
58
+ def pool(cls) -> float:
59
+ """Connection pool acquisition timeout."""
60
+ return cls._get_env_float("TIMEOUT_POOL", cls._POOL)
61
+
62
+ @classmethod
63
+ def read_streaming(cls) -> float:
64
+ """Read timeout between chunks for streaming requests."""
65
+ return cls._get_env_float("TIMEOUT_READ_STREAMING", cls._READ_STREAMING)
66
+
67
+ @classmethod
68
+ def read_non_streaming(cls) -> float:
69
+ """Read timeout for non-streaming responses."""
70
+ return cls._get_env_float("TIMEOUT_READ_NON_STREAMING", cls._READ_NON_STREAMING)
71
+
72
+ @classmethod
73
+ def streaming(cls) -> httpx.Timeout:
74
+ """
75
+ Timeout configuration for streaming LLM requests.
76
+
77
+ Uses a shorter read timeout (default 3 min) since we expect
78
+ periodic chunks. If no data arrives for this duration, the
79
+ connection is considered stalled.
80
+ """
81
+ return httpx.Timeout(
82
+ connect=cls.connect(),
83
+ read=cls.read_streaming(),
84
+ write=cls.write(),
85
+ pool=cls.pool(),
86
+ )
87
+
88
+ @classmethod
89
+ def non_streaming(cls) -> httpx.Timeout:
90
+ """
91
+ Timeout configuration for non-streaming LLM requests.
92
+
93
+ Uses a longer read timeout (default 10 min) since the server
94
+ may take significant time to generate the complete response
95
+ before sending anything back.
96
+ """
97
+ return httpx.Timeout(
98
+ connect=cls.connect(),
99
+ read=cls.read_non_streaming(),
100
+ write=cls.write(),
101
+ pool=cls.pool(),
102
+ )
src/rotator_library/usage_manager.py CHANGED
@@ -5,12 +5,15 @@ import logging
5
  import asyncio
6
  import random
7
  from datetime import date, datetime, timezone, time as dt_time
8
- from typing import Any, Dict, List, Optional, Set, Tuple
 
9
  import aiofiles
10
  import litellm
11
 
12
  from .error_handler import ClassifiedError, NoAvailableKeysError, mask_credential
13
  from .providers import PROVIDER_PLUGINS
 
 
14
 
15
  lib_logger = logging.getLogger("rotator_library")
16
  lib_logger.propagate = False
@@ -50,7 +53,7 @@ class UsageManager:
50
 
51
  def __init__(
52
  self,
53
- file_path: str = "key_usage.json",
54
  daily_reset_time_utc: Optional[str] = "03:00",
55
  rotation_tolerance: float = 0.0,
56
  provider_rotation_modes: Optional[Dict[str, str]] = None,
@@ -65,7 +68,8 @@ class UsageManager:
65
  Initialize the UsageManager.
66
 
67
  Args:
68
- file_path: Path to the usage data JSON file
 
69
  daily_reset_time_utc: Time in UTC when daily stats should reset (HH:MM format)
70
  rotation_tolerance: Tolerance for weighted random credential rotation.
71
  - 0.0: Deterministic, least-used credential always selected
@@ -85,7 +89,14 @@ class UsageManager:
85
  Used in sequential mode when priority not in priority_multipliers.
86
  Example: {"antigravity": 2}
87
  """
88
- self.file_path = file_path
 
 
 
 
 
 
 
89
  self.rotation_tolerance = rotation_tolerance
90
  self.provider_rotation_modes = provider_rotation_modes or {}
91
  self.provider_plugins = provider_plugins or PROVIDER_PLUGINS
@@ -103,6 +114,9 @@ class UsageManager:
103
  self._timeout_lock = asyncio.Lock()
104
  self._claimed_on_timeout: Set[str] = set()
105
 
 
 
 
106
  if daily_reset_time_utc:
107
  hour, minute = map(int, daily_reset_time_utc.split(":"))
108
  self.daily_reset_time_utc = dt_time(
@@ -540,27 +554,40 @@ class UsageManager:
540
  self._initialized.set()
541
 
542
  async def _load_usage(self):
543
- """Loads usage data from the JSON file asynchronously."""
544
  async with self._data_lock:
545
  if not os.path.exists(self.file_path):
546
  self._usage_data = {}
547
  return
 
548
  try:
549
  async with aiofiles.open(self.file_path, "r") as f:
550
  content = await f.read()
551
- self._usage_data = json.loads(content)
552
- except (json.JSONDecodeError, IOError, FileNotFoundError):
 
 
 
 
 
 
 
 
 
 
 
553
  self._usage_data = {}
554
 
555
  async def _save_usage(self):
556
- """Saves the current usage data to the JSON file asynchronously."""
557
  if self._usage_data is None:
558
  return
 
559
  async with self._data_lock:
560
  # Add human-readable timestamp fields before saving
561
  self._add_readable_timestamps(self._usage_data)
562
- async with aiofiles.open(self.file_path, "w") as f:
563
- await f.write(json.dumps(self._usage_data, indent=2))
564
 
565
  async def _reset_daily_stats_if_needed(self):
566
  """
 
5
  import asyncio
6
  import random
7
  from datetime import date, datetime, timezone, time as dt_time
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Optional, Set, Tuple, Union
10
  import aiofiles
11
  import litellm
12
 
13
  from .error_handler import ClassifiedError, NoAvailableKeysError, mask_credential
14
  from .providers import PROVIDER_PLUGINS
15
+ from .utils.resilient_io import ResilientStateWriter
16
+ from .utils.paths import get_data_file
17
 
18
  lib_logger = logging.getLogger("rotator_library")
19
  lib_logger.propagate = False
 
53
 
54
  def __init__(
55
  self,
56
+ file_path: Optional[Union[str, Path]] = None,
57
  daily_reset_time_utc: Optional[str] = "03:00",
58
  rotation_tolerance: float = 0.0,
59
  provider_rotation_modes: Optional[Dict[str, str]] = None,
 
68
  Initialize the UsageManager.
69
 
70
  Args:
71
+ file_path: Path to the usage data JSON file. If None, uses get_data_file("key_usage.json").
72
+ Can be absolute Path, relative Path, or string.
73
  daily_reset_time_utc: Time in UTC when daily stats should reset (HH:MM format)
74
  rotation_tolerance: Tolerance for weighted random credential rotation.
75
  - 0.0: Deterministic, least-used credential always selected
 
89
  Used in sequential mode when priority not in priority_multipliers.
90
  Example: {"antigravity": 2}
91
  """
92
+ # Resolve file_path - use default if not provided
93
+ if file_path is None:
94
+ self.file_path = str(get_data_file("key_usage.json"))
95
+ elif isinstance(file_path, Path):
96
+ self.file_path = str(file_path)
97
+ else:
98
+ # String path - could be relative or absolute
99
+ self.file_path = file_path
100
  self.rotation_tolerance = rotation_tolerance
101
  self.provider_rotation_modes = provider_rotation_modes or {}
102
  self.provider_plugins = provider_plugins or PROVIDER_PLUGINS
 
114
  self._timeout_lock = asyncio.Lock()
115
  self._claimed_on_timeout: Set[str] = set()
116
 
117
+ # Resilient writer for usage data persistence
118
+ self._state_writer = ResilientStateWriter(file_path, lib_logger)
119
+
120
  if daily_reset_time_utc:
121
  hour, minute = map(int, daily_reset_time_utc.split(":"))
122
  self.daily_reset_time_utc = dt_time(
 
554
  self._initialized.set()
555
 
556
  async def _load_usage(self):
557
+ """Loads usage data from the JSON file asynchronously with resilience."""
558
  async with self._data_lock:
559
  if not os.path.exists(self.file_path):
560
  self._usage_data = {}
561
  return
562
+
563
  try:
564
  async with aiofiles.open(self.file_path, "r") as f:
565
  content = await f.read()
566
+ self._usage_data = json.loads(content) if content.strip() else {}
567
+ except FileNotFoundError:
568
+ # File deleted between exists check and open
569
+ self._usage_data = {}
570
+ except json.JSONDecodeError as e:
571
+ lib_logger.warning(
572
+ f"Corrupted usage file {self.file_path}: {e}. Starting fresh."
573
+ )
574
+ self._usage_data = {}
575
+ except (OSError, PermissionError, IOError) as e:
576
+ lib_logger.warning(
577
+ f"Cannot read usage file {self.file_path}: {e}. Using empty state."
578
+ )
579
  self._usage_data = {}
580
 
581
  async def _save_usage(self):
582
+ """Saves the current usage data using the resilient state writer."""
583
  if self._usage_data is None:
584
  return
585
+
586
  async with self._data_lock:
587
  # Add human-readable timestamp fields before saving
588
  self._add_readable_timestamps(self._usage_data)
589
+ # Hand off to resilient writer - handles retries and disk failures
590
+ self._state_writer.write(self._usage_data)
591
 
592
  async def _reset_daily_stats_if_needed(self):
593
  """
src/rotator_library/utils/__init__.py CHANGED
@@ -1,6 +1,34 @@
1
  # src/rotator_library/utils/__init__.py
2
 
3
  from .headless_detection import is_headless_environment
 
 
 
 
 
 
 
4
  from .reauth_coordinator import get_reauth_coordinator, ReauthCoordinator
 
 
 
 
 
 
 
5
 
6
- __all__ = ["is_headless_environment", "get_reauth_coordinator", "ReauthCoordinator"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # src/rotator_library/utils/__init__.py
2
 
3
  from .headless_detection import is_headless_environment
4
+ from .paths import (
5
+ get_default_root,
6
+ get_logs_dir,
7
+ get_cache_dir,
8
+ get_oauth_dir,
9
+ get_data_file,
10
+ )
11
  from .reauth_coordinator import get_reauth_coordinator, ReauthCoordinator
12
+ from .resilient_io import (
13
+ BufferedWriteRegistry,
14
+ ResilientStateWriter,
15
+ safe_write_json,
16
+ safe_log_write,
17
+ safe_mkdir,
18
+ )
19
 
20
+ __all__ = [
21
+ "is_headless_environment",
22
+ "get_default_root",
23
+ "get_logs_dir",
24
+ "get_cache_dir",
25
+ "get_oauth_dir",
26
+ "get_data_file",
27
+ "get_reauth_coordinator",
28
+ "ReauthCoordinator",
29
+ "BufferedWriteRegistry",
30
+ "ResilientStateWriter",
31
+ "safe_write_json",
32
+ "safe_log_write",
33
+ "safe_mkdir",
34
+ ]
src/rotator_library/utils/paths.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/rotator_library/utils/paths.py
2
+ """
3
+ Centralized path management for the rotator library.
4
+
5
+ Supports two runtime modes:
6
+ 1. PyInstaller EXE -> files in the directory containing the executable
7
+ 2. Script/Library -> files in the current working directory (overridable)
8
+
9
+ Library users can override by passing `data_dir` to RotatingClient.
10
+ """
11
+
12
+ import sys
13
+ from pathlib import Path
14
+ from typing import Optional, Union
15
+
16
+
17
+ def get_default_root() -> Path:
18
+ """
19
+ Get the default root directory for data files.
20
+
21
+ - EXE mode (PyInstaller): directory containing the executable
22
+ - Otherwise: current working directory
23
+
24
+ Returns:
25
+ Path to the root directory
26
+ """
27
+ if getattr(sys, "frozen", False):
28
+ # Running as PyInstaller bundle - use executable's directory
29
+ return Path(sys.executable).parent
30
+ # Running as script or library - use current working directory
31
+ return Path.cwd()
32
+
33
+
34
+ def get_logs_dir(root: Optional[Union[Path, str]] = None) -> Path:
35
+ """
36
+ Get the logs directory, creating it if needed.
37
+
38
+ Args:
39
+ root: Optional root directory. If None, uses get_default_root().
40
+
41
+ Returns:
42
+ Path to the logs directory
43
+ """
44
+ base = Path(root) if root else get_default_root()
45
+ logs_dir = base / "logs"
46
+ logs_dir.mkdir(exist_ok=True)
47
+ return logs_dir
48
+
49
+
50
+ def get_cache_dir(
51
+ root: Optional[Union[Path, str]] = None, subdir: Optional[str] = None
52
+ ) -> Path:
53
+ """
54
+ Get the cache directory, optionally with a subdirectory.
55
+
56
+ Args:
57
+ root: Optional root directory. If None, uses get_default_root().
58
+ subdir: Optional subdirectory name (e.g., "gemini_cli", "antigravity")
59
+
60
+ Returns:
61
+ Path to the cache directory (or subdirectory)
62
+ """
63
+ base = Path(root) if root else get_default_root()
64
+ cache_dir = base / "cache"
65
+ if subdir:
66
+ cache_dir = cache_dir / subdir
67
+ cache_dir.mkdir(parents=True, exist_ok=True)
68
+ return cache_dir
69
+
70
+
71
+ def get_oauth_dir(root: Optional[Union[Path, str]] = None) -> Path:
72
+ """
73
+ Get the OAuth credentials directory, creating it if needed.
74
+
75
+ Args:
76
+ root: Optional root directory. If None, uses get_default_root().
77
+
78
+ Returns:
79
+ Path to the oauth_creds directory
80
+ """
81
+ base = Path(root) if root else get_default_root()
82
+ oauth_dir = base / "oauth_creds"
83
+ oauth_dir.mkdir(exist_ok=True)
84
+ return oauth_dir
85
+
86
+
87
+ def get_data_file(filename: str, root: Optional[Union[Path, str]] = None) -> Path:
88
+ """
89
+ Get the path to a data file in the root directory.
90
+
91
+ Args:
92
+ filename: Name of the file (e.g., "key_usage.json", ".env")
93
+ root: Optional root directory. If None, uses get_default_root().
94
+
95
+ Returns:
96
+ Path to the file (does not create the file)
97
+ """
98
+ base = Path(root) if root else get_default_root()
99
+ return base / filename
src/rotator_library/utils/resilient_io.py ADDED
@@ -0,0 +1,665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/rotator_library/utils/resilient_io.py
2
+ """
3
+ Resilient I/O utilities for handling file operations gracefully.
4
+
5
+ Provides three main patterns:
6
+ 1. BufferedWriteRegistry - Global singleton for buffered writes with periodic
7
+ retry and shutdown flush. Ensures data is saved on app exit (Ctrl+C).
8
+ 2. ResilientStateWriter - For stateful files (usage.json) that should be
9
+ buffered in memory and retried on disk failure.
10
+ 3. safe_write_json (with buffer_on_failure) - For critical files (auth tokens)
11
+ that should be buffered and retried if write fails.
12
+ 4. safe_log_write - For logs that can be dropped on failure.
13
+ """
14
+
15
+ import atexit
16
+ import json
17
+ import os
18
+ import shutil
19
+ import tempfile
20
+ import threading
21
+ import time
22
+ import logging
23
+ from pathlib import Path
24
+ from typing import Any, Callable, Dict, Optional, Tuple, Union
25
+
26
+
27
+ # =============================================================================
28
+ # BUFFERED WRITE REGISTRY (SINGLETON)
29
+ # =============================================================================
30
+
31
+
32
+ class BufferedWriteRegistry:
33
+ """
34
+ Global singleton registry for buffered writes with periodic retry and shutdown flush.
35
+
36
+ This ensures that critical data (auth tokens, usage stats) is saved even if
37
+ disk writes fail temporarily. On app exit (including Ctrl+C), all pending
38
+ writes are flushed.
39
+
40
+ Features:
41
+ - Per-file buffering: each file path has its own pending write
42
+ - Periodic retries: background thread retries failed writes every N seconds
43
+ - Shutdown flush: atexit hook ensures final write attempt on app exit
44
+ - Thread-safe: safe for concurrent access from multiple threads
45
+
46
+ Usage:
47
+ # Get the singleton instance
48
+ registry = BufferedWriteRegistry.get_instance()
49
+
50
+ # Register a pending write (usually called by safe_write_json on failure)
51
+ registry.register_pending(path, data, serializer_fn, options)
52
+
53
+ # Manual flush (optional - atexit handles this automatically)
54
+ results = registry.flush_all()
55
+ """
56
+
57
+ _instance: Optional["BufferedWriteRegistry"] = None
58
+ _instance_lock = threading.Lock()
59
+
60
+ def __init__(self, retry_interval: float = 30.0):
61
+ """
62
+ Initialize the registry. Use get_instance() instead of direct construction.
63
+
64
+ Args:
65
+ retry_interval: Seconds between retry attempts (default: 30)
66
+ """
67
+ self._pending: Dict[str, Tuple[Any, Callable[[Any], str], Dict[str, Any]]] = {}
68
+ self._retry_interval = retry_interval
69
+ self._lock = threading.Lock()
70
+ self._running = False
71
+ self._retry_thread: Optional[threading.Thread] = None
72
+ self._logger = logging.getLogger("rotator_library.resilient_io")
73
+
74
+ # Start background retry thread
75
+ self._start_retry_thread()
76
+
77
+ # Register atexit handler for shutdown flush
78
+ atexit.register(self._atexit_handler)
79
+
80
+ @classmethod
81
+ def get_instance(cls, retry_interval: float = 30.0) -> "BufferedWriteRegistry":
82
+ """
83
+ Get or create the singleton instance.
84
+
85
+ Args:
86
+ retry_interval: Seconds between retry attempts (only used on first call)
87
+
88
+ Returns:
89
+ The singleton BufferedWriteRegistry instance
90
+ """
91
+ if cls._instance is None:
92
+ with cls._instance_lock:
93
+ if cls._instance is None:
94
+ cls._instance = cls(retry_interval)
95
+ return cls._instance
96
+
97
+ def _start_retry_thread(self) -> None:
98
+ """Start the background retry thread."""
99
+ if self._running:
100
+ return
101
+
102
+ self._running = True
103
+ self._retry_thread = threading.Thread(
104
+ target=self._retry_loop,
105
+ name="BufferedWriteRegistry-Retry",
106
+ daemon=True, # Daemon so it doesn't block app exit
107
+ )
108
+ self._retry_thread.start()
109
+
110
+ def _retry_loop(self) -> None:
111
+ """Background thread: periodically retry pending writes."""
112
+ while self._running:
113
+ time.sleep(self._retry_interval)
114
+ if not self._running:
115
+ break
116
+ self._retry_pending()
117
+
118
+ def _retry_pending(self) -> None:
119
+ """Attempt to write all pending files."""
120
+ with self._lock:
121
+ if not self._pending:
122
+ return
123
+
124
+ # Copy paths to avoid modifying dict during iteration
125
+ paths = list(self._pending.keys())
126
+
127
+ for path_str in paths:
128
+ self._try_write(path_str, remove_on_success=True)
129
+
130
+ def register_pending(
131
+ self,
132
+ path: Union[str, Path],
133
+ data: Any,
134
+ serializer: Callable[[Any], str],
135
+ options: Optional[Dict[str, Any]] = None,
136
+ ) -> None:
137
+ """
138
+ Register a pending write for later retry.
139
+
140
+ If a write is already pending for this path, it is replaced with the new data
141
+ (we always want to write the latest state).
142
+
143
+ Args:
144
+ path: File path to write to
145
+ data: Data to serialize and write
146
+ serializer: Function to serialize data to string
147
+ options: Additional options (e.g., secure_permissions)
148
+ """
149
+ path_str = str(Path(path).resolve())
150
+ with self._lock:
151
+ self._pending[path_str] = (data, serializer, options or {})
152
+ self._logger.debug(f"Registered pending write for {Path(path).name}")
153
+
154
+ def unregister(self, path: Union[str, Path]) -> None:
155
+ """
156
+ Remove a pending write (called when write succeeds elsewhere).
157
+
158
+ Args:
159
+ path: File path to remove from pending
160
+ """
161
+ path_str = str(Path(path).resolve())
162
+ with self._lock:
163
+ self._pending.pop(path_str, None)
164
+
165
+ def _try_write(self, path_str: str, remove_on_success: bool = True) -> bool:
166
+ """
167
+ Attempt to write a pending file.
168
+
169
+ Args:
170
+ path_str: Resolved path string
171
+ remove_on_success: Remove from pending if successful
172
+
173
+ Returns:
174
+ True if write succeeded, False otherwise
175
+ """
176
+ with self._lock:
177
+ if path_str not in self._pending:
178
+ return True
179
+ data, serializer, options = self._pending[path_str]
180
+
181
+ path = Path(path_str)
182
+ try:
183
+ # Ensure directory exists
184
+ path.parent.mkdir(parents=True, exist_ok=True)
185
+
186
+ # Serialize data
187
+ content = serializer(data)
188
+
189
+ # Atomic write
190
+ tmp_fd = None
191
+ tmp_path = None
192
+ try:
193
+ tmp_fd, tmp_path = tempfile.mkstemp(
194
+ dir=path.parent, prefix=".tmp_", suffix=".json", text=True
195
+ )
196
+ with os.fdopen(tmp_fd, "w", encoding="utf-8") as f:
197
+ f.write(content)
198
+ tmp_fd = None
199
+
200
+ # Set secure permissions if requested
201
+ if options.get("secure_permissions"):
202
+ try:
203
+ os.chmod(tmp_path, 0o600)
204
+ except (OSError, AttributeError):
205
+ pass
206
+
207
+ shutil.move(tmp_path, path)
208
+ tmp_path = None
209
+
210
+ finally:
211
+ if tmp_fd is not None:
212
+ try:
213
+ os.close(tmp_fd)
214
+ except OSError:
215
+ pass
216
+ if tmp_path and os.path.exists(tmp_path):
217
+ try:
218
+ os.unlink(tmp_path)
219
+ except OSError:
220
+ pass
221
+
222
+ # Success - remove from pending
223
+ if remove_on_success:
224
+ with self._lock:
225
+ self._pending.pop(path_str, None)
226
+
227
+ self._logger.debug(f"Retry succeeded for {path.name}")
228
+ return True
229
+
230
+ except (OSError, PermissionError, IOError) as e:
231
+ self._logger.debug(f"Retry failed for {path.name}: {e}")
232
+ return False
233
+
234
+ def flush_all(self) -> Dict[str, bool]:
235
+ """
236
+ Attempt to write all pending files immediately.
237
+
238
+ Returns:
239
+ Dict mapping file paths to success status
240
+ """
241
+ with self._lock:
242
+ paths = list(self._pending.keys())
243
+
244
+ results = {}
245
+ for path_str in paths:
246
+ results[path_str] = self._try_write(path_str, remove_on_success=True)
247
+
248
+ return results
249
+
250
+ def _atexit_handler(self) -> None:
251
+ """Called on app exit to flush pending writes."""
252
+ self._running = False
253
+
254
+ with self._lock:
255
+ pending_count = len(self._pending)
256
+
257
+ if pending_count == 0:
258
+ return
259
+
260
+ self._logger.info(f"Flushing {pending_count} pending write(s) on shutdown...")
261
+ results = self.flush_all()
262
+
263
+ succeeded = sum(1 for v in results.values() if v)
264
+ failed = pending_count - succeeded
265
+
266
+ if failed > 0:
267
+ self._logger.warning(
268
+ f"Shutdown flush: {succeeded} succeeded, {failed} failed"
269
+ )
270
+ for path_str, success in results.items():
271
+ if not success:
272
+ self._logger.warning(f" Failed to save: {Path(path_str).name}")
273
+ else:
274
+ self._logger.info(f"Shutdown flush: all {succeeded} write(s) succeeded")
275
+
276
+ def get_pending_count(self) -> int:
277
+ """Get the number of pending writes."""
278
+ with self._lock:
279
+ return len(self._pending)
280
+
281
+ def get_pending_paths(self) -> list:
282
+ """Get list of paths with pending writes (for monitoring)."""
283
+ with self._lock:
284
+ return [Path(p).name for p in self._pending.keys()]
285
+
286
+ def shutdown(self) -> Dict[str, bool]:
287
+ """
288
+ Manually trigger shutdown: stop retry thread and flush all pending writes.
289
+
290
+ Returns:
291
+ Dict mapping file paths to success status
292
+ """
293
+ self._running = False
294
+ if self._retry_thread and self._retry_thread.is_alive():
295
+ self._retry_thread.join(timeout=1.0)
296
+ return self.flush_all()
297
+
298
+
299
+ # =============================================================================
300
+ # RESILIENT STATE WRITER
301
+ # =============================================================================
302
+
303
+
304
+ class ResilientStateWriter:
305
+ """
306
+ Manages resilient writes for stateful files (usage stats, credentials, cache).
307
+
308
+ Design:
309
+ - Caller hands off data via write() - always succeeds (memory update)
310
+ - Attempts disk write immediately
311
+ - If disk fails, retries periodically in background
312
+ - On recovery, writes full current state (not just new data)
313
+
314
+ Thread-safe for use in async contexts with sync file I/O.
315
+
316
+ Usage:
317
+ writer = ResilientStateWriter("data.json", logger)
318
+ writer.write({"key": "value"}) # Always succeeds
319
+ # ... later ...
320
+ if not writer.is_healthy:
321
+ logger.warning("Disk writes failing, data in memory only")
322
+ """
323
+
324
+ def __init__(
325
+ self,
326
+ path: Union[str, Path],
327
+ logger: logging.Logger,
328
+ retry_interval: float = 30.0,
329
+ serializer: Optional[Callable[[Any], str]] = None,
330
+ ):
331
+ """
332
+ Initialize the resilient writer.
333
+
334
+ Args:
335
+ path: File path to write to
336
+ logger: Logger for warnings/errors
337
+ retry_interval: Seconds between retry attempts when disk is unhealthy
338
+ serializer: Custom serializer function (defaults to JSON with indent=2)
339
+ """
340
+ self.path = Path(path)
341
+ self.logger = logger
342
+ self.retry_interval = retry_interval
343
+ self._serializer = serializer or (lambda d: json.dumps(d, indent=2))
344
+
345
+ self._current_state: Optional[Any] = None
346
+ self._disk_healthy = True
347
+ self._last_attempt: float = 0
348
+ self._last_success: Optional[float] = None
349
+ self._failure_count = 0
350
+ self._lock = threading.Lock()
351
+
352
+ def write(self, data: Any) -> bool:
353
+ """
354
+ Update state and attempt disk write.
355
+
356
+ Always updates in-memory state (guaranteed to succeed).
357
+ Attempts disk write - if disk is unhealthy, respects retry_interval
358
+ before attempting again to avoid flooding with failed writes.
359
+
360
+ Args:
361
+ data: Data to persist (must be serializable)
362
+
363
+ Returns:
364
+ True if disk write succeeded, False if failed (data still in memory)
365
+ """
366
+ with self._lock:
367
+ self._current_state = data
368
+
369
+ # If disk is unhealthy, only retry after retry_interval has passed
370
+ if not self._disk_healthy:
371
+ now = time.time()
372
+ if now - self._last_attempt < self.retry_interval:
373
+ # Too soon to retry, data is safe in memory
374
+ return False
375
+
376
+ return self._try_disk_write()
377
+
378
+ def retry_if_needed(self) -> bool:
379
+ """
380
+ Retry disk write if unhealthy and retry interval has passed.
381
+
382
+ Call this periodically (e.g., on each save attempt) to recover
383
+ from transient disk failures.
384
+
385
+ Returns:
386
+ True if healthy (no retry needed or retry succeeded)
387
+ """
388
+ with self._lock:
389
+ if self._disk_healthy:
390
+ return True
391
+
392
+ if self._current_state is None:
393
+ return True
394
+
395
+ now = time.time()
396
+ if now - self._last_attempt < self.retry_interval:
397
+ return False
398
+
399
+ return self._try_disk_write()
400
+
401
+ def _try_disk_write(self) -> bool:
402
+ """
403
+ Attempt atomic write to disk. Updates health status.
404
+
405
+ Uses tempfile + move pattern for atomic writes on POSIX systems.
406
+ On Windows, uses direct write (still safe for our use case).
407
+
408
+ Also registers/unregisters with BufferedWriteRegistry for shutdown flush.
409
+ """
410
+ if self._current_state is None:
411
+ return True
412
+
413
+ self._last_attempt = time.time()
414
+
415
+ try:
416
+ # Ensure directory exists
417
+ self.path.parent.mkdir(parents=True, exist_ok=True)
418
+
419
+ # Serialize data
420
+ content = self._serializer(self._current_state)
421
+
422
+ # Atomic write: write to temp file, then move
423
+ tmp_fd = None
424
+ tmp_path = None
425
+ try:
426
+ tmp_fd, tmp_path = tempfile.mkstemp(
427
+ dir=self.path.parent, prefix=".tmp_", suffix=".json", text=True
428
+ )
429
+
430
+ with os.fdopen(tmp_fd, "w", encoding="utf-8") as f:
431
+ f.write(content)
432
+ tmp_fd = None # fdopen closes the fd
433
+
434
+ # Atomic move
435
+ shutil.move(tmp_path, self.path)
436
+ tmp_path = None
437
+
438
+ finally:
439
+ # Cleanup on failure
440
+ if tmp_fd is not None:
441
+ try:
442
+ os.close(tmp_fd)
443
+ except OSError:
444
+ pass
445
+ if tmp_path and os.path.exists(tmp_path):
446
+ try:
447
+ os.unlink(tmp_path)
448
+ except OSError:
449
+ pass
450
+
451
+ # Success - update health and unregister from shutdown flush
452
+ self._disk_healthy = True
453
+ self._last_success = time.time()
454
+ self._failure_count = 0
455
+ BufferedWriteRegistry.get_instance().unregister(self.path)
456
+ return True
457
+
458
+ except (OSError, PermissionError, IOError) as e:
459
+ self._disk_healthy = False
460
+ self._failure_count += 1
461
+
462
+ # Register with BufferedWriteRegistry for shutdown flush
463
+ registry = BufferedWriteRegistry.get_instance()
464
+ registry.register_pending(
465
+ self.path,
466
+ self._current_state,
467
+ self._serializer,
468
+ {}, # No special options for ResilientStateWriter
469
+ )
470
+
471
+ # Log warning (rate-limited to avoid flooding)
472
+ if self._failure_count == 1 or self._failure_count % 10 == 0:
473
+ self.logger.warning(
474
+ f"Failed to write {self.path.name}: {e}. "
475
+ f"Data retained in memory (failure #{self._failure_count})."
476
+ )
477
+ return False
478
+
479
+ @property
480
+ def is_healthy(self) -> bool:
481
+ """Check if disk writes are currently working."""
482
+ return self._disk_healthy
483
+
484
+ @property
485
+ def current_state(self) -> Optional[Any]:
486
+ """Get the current in-memory state (for inspection/debugging)."""
487
+ return self._current_state
488
+
489
+ def get_health_info(self) -> Dict[str, Any]:
490
+ """
491
+ Get detailed health information for monitoring.
492
+
493
+ Returns dict with:
494
+ - healthy: bool
495
+ - failure_count: int
496
+ - last_success: Optional[float] (timestamp)
497
+ - last_attempt: float (timestamp)
498
+ - path: str
499
+ """
500
+ return {
501
+ "healthy": self._disk_healthy,
502
+ "failure_count": self._failure_count,
503
+ "last_success": self._last_success,
504
+ "last_attempt": self._last_attempt,
505
+ "path": str(self.path),
506
+ }
507
+
508
+
509
+ def safe_write_json(
510
+ path: Union[str, Path],
511
+ data: Dict[str, Any],
512
+ logger: logging.Logger,
513
+ atomic: bool = True,
514
+ indent: int = 2,
515
+ ensure_ascii: bool = True,
516
+ secure_permissions: bool = False,
517
+ buffer_on_failure: bool = False,
518
+ ) -> bool:
519
+ """
520
+ Write JSON data to file with error handling and optional buffering.
521
+
522
+ When buffer_on_failure is True, failed writes are registered with the
523
+ BufferedWriteRegistry for periodic retry and shutdown flush. This ensures
524
+ critical data (like auth tokens) is eventually saved.
525
+
526
+ Args:
527
+ path: File path to write to
528
+ data: JSON-serializable data
529
+ logger: Logger for warnings
530
+ atomic: Use atomic write pattern (tempfile + move)
531
+ indent: JSON indentation level (default: 2)
532
+ ensure_ascii: Escape non-ASCII characters (default: True)
533
+ secure_permissions: Set file permissions to 0o600 (default: False)
534
+ buffer_on_failure: Register with BufferedWriteRegistry on failure (default: False)
535
+
536
+ Returns:
537
+ True on success, False on failure (never raises)
538
+ """
539
+ path = Path(path)
540
+
541
+ # Create serializer function that matches the requested formatting
542
+ def serializer(d: Any) -> str:
543
+ return json.dumps(d, indent=indent, ensure_ascii=ensure_ascii)
544
+
545
+ try:
546
+ path.parent.mkdir(parents=True, exist_ok=True)
547
+ content = serializer(data)
548
+
549
+ if atomic:
550
+ tmp_fd = None
551
+ tmp_path = None
552
+ try:
553
+ tmp_fd, tmp_path = tempfile.mkstemp(
554
+ dir=path.parent, prefix=".tmp_", suffix=".json", text=True
555
+ )
556
+ with os.fdopen(tmp_fd, "w", encoding="utf-8") as f:
557
+ f.write(content)
558
+ tmp_fd = None
559
+
560
+ # Set secure permissions if requested (before move for security)
561
+ if secure_permissions:
562
+ try:
563
+ os.chmod(tmp_path, 0o600)
564
+ except (OSError, AttributeError):
565
+ # Windows may not support chmod, ignore
566
+ pass
567
+
568
+ shutil.move(tmp_path, path)
569
+ tmp_path = None
570
+ finally:
571
+ if tmp_fd is not None:
572
+ try:
573
+ os.close(tmp_fd)
574
+ except OSError:
575
+ pass
576
+ if tmp_path and os.path.exists(tmp_path):
577
+ try:
578
+ os.unlink(tmp_path)
579
+ except OSError:
580
+ pass
581
+ else:
582
+ with open(path, "w", encoding="utf-8") as f:
583
+ f.write(content)
584
+
585
+ # Set secure permissions if requested
586
+ if secure_permissions:
587
+ try:
588
+ os.chmod(path, 0o600)
589
+ except (OSError, AttributeError):
590
+ pass
591
+
592
+ # Success - remove from pending if it was there
593
+ if buffer_on_failure:
594
+ BufferedWriteRegistry.get_instance().unregister(path)
595
+
596
+ return True
597
+
598
+ except (OSError, PermissionError, IOError, TypeError, ValueError) as e:
599
+ logger.warning(f"Failed to write JSON to {path}: {e}")
600
+
601
+ # Register for retry if buffering is enabled
602
+ if buffer_on_failure:
603
+ registry = BufferedWriteRegistry.get_instance()
604
+ registry.register_pending(
605
+ path,
606
+ data,
607
+ serializer,
608
+ {"secure_permissions": secure_permissions},
609
+ )
610
+ logger.debug(f"Buffered {path.name} for retry on next interval or shutdown")
611
+
612
+ return False
613
+
614
+
615
+ def safe_log_write(
616
+ path: Union[str, Path],
617
+ content: str,
618
+ logger: logging.Logger,
619
+ mode: str = "a",
620
+ ) -> bool:
621
+ """
622
+ Write content to log file with error handling. No buffering or retry.
623
+
624
+ Suitable for log files where occasional loss is acceptable.
625
+ Creates parent directories if needed.
626
+
627
+ Args:
628
+ path: File path to write to
629
+ content: String content to write
630
+ logger: Logger for warnings
631
+ mode: File mode ('a' for append, 'w' for overwrite)
632
+
633
+ Returns:
634
+ True on success, False on failure (never raises)
635
+ """
636
+ path = Path(path)
637
+
638
+ try:
639
+ path.parent.mkdir(parents=True, exist_ok=True)
640
+ with open(path, mode, encoding="utf-8") as f:
641
+ f.write(content)
642
+ return True
643
+
644
+ except (OSError, PermissionError, IOError) as e:
645
+ logger.warning(f"Failed to write log to {path}: {e}")
646
+ return False
647
+
648
+
649
+ def safe_mkdir(path: Union[str, Path], logger: logging.Logger) -> bool:
650
+ """
651
+ Create directory with error handling.
652
+
653
+ Args:
654
+ path: Directory path to create
655
+ logger: Logger for warnings
656
+
657
+ Returns:
658
+ True on success (or already exists), False on failure
659
+ """
660
+ try:
661
+ Path(path).mkdir(parents=True, exist_ok=True)
662
+ return True
663
+ except (OSError, PermissionError) as e:
664
+ logger.warning(f"Failed to create directory {path}: {e}")
665
+ return False