Spaces:
Paused
Paused
Merge branch 'dev' into main
Browse files- DOCUMENTATION.md +312 -3
- README.md +590 -582
- src/proxy_app/detailed_logger.py +60 -30
- src/proxy_app/launcher_tui.py +65 -34
- src/proxy_app/main.py +23 -15
- src/proxy_app/settings_tool.py +707 -84
- src/rotator_library/client.py +27 -5
- src/rotator_library/credential_manager.py +70 -34
- src/rotator_library/credential_tool.py +795 -769
- src/rotator_library/failure_logger.py +97 -26
- src/rotator_library/providers/antigravity_auth_base.py +620 -3
- src/rotator_library/providers/antigravity_provider.py +170 -557
- src/rotator_library/providers/gemini_auth_base.py +626 -4
- src/rotator_library/providers/gemini_cli_provider.py +577 -562
- src/rotator_library/providers/google_oauth_base.py +775 -174
- src/rotator_library/providers/iflow_auth_base.py +652 -178
- src/rotator_library/providers/iflow_provider.py +173 -74
- src/rotator_library/providers/provider_cache.py +161 -133
- src/rotator_library/providers/qwen_auth_base.py +576 -176
- src/rotator_library/providers/qwen_code_provider.py +209 -90
- src/rotator_library/timeout_config.py +102 -0
- src/rotator_library/usage_manager.py +37 -10
- src/rotator_library/utils/__init__.py +29 -1
- src/rotator_library/utils/paths.py +99 -0
- src/rotator_library/utils/resilient_io.py +665 -0
DOCUMENTATION.md
CHANGED
|
@@ -856,6 +856,142 @@ class AntigravityAuthBase(GoogleOAuthBase):
|
|
| 856 |
- Headless environment detection
|
| 857 |
- Sequential refresh queue processing
|
| 858 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 859 |
---
|
| 860 |
|
| 861 |
|
|
@@ -877,8 +1013,8 @@ The `GeminiCliProvider` is the most complex implementation, mimicking the Google
|
|
| 877 |
|
| 878 |
#### Authentication (`gemini_auth_base.py`)
|
| 879 |
|
| 880 |
-
* **Device Flow**: Uses a standard OAuth 2.0 flow. The `credential_tool` spins up a local web server (`localhost:8085`) to capture the callback from Google's auth page.
|
| 881 |
-
* **Token Lifecycle**:
|
| 882 |
* **Proactive Refresh**: Tokens are refreshed 5 minutes before expiry.
|
| 883 |
* **Atomic Writes**: Credential files are updated using a temp-file-and-move strategy to prevent corruption during writes.
|
| 884 |
* **Revocation Handling**: If a `400` or `401` occurs during refresh, the token is marked as revoked, preventing infinite retry loops.
|
|
@@ -907,7 +1043,7 @@ The provider employs a sophisticated, cached discovery mechanism to find a valid
|
|
| 907 |
### 3.3. iFlow (`iflow_provider.py`)
|
| 908 |
|
| 909 |
* **Hybrid Auth**: Uses a custom OAuth flow (Authorization Code) to obtain an `access_token`. However, the *actual* API calls use a separate `apiKey` that is retrieved from the user's profile (`/api/oauth/getUserInfo`) using the access token.
|
| 910 |
-
* **Callback Server**: The auth flow spins up a local server
|
| 911 |
* **Token Management**: Automatically refreshes the OAuth token and re-fetches the API key if needed.
|
| 912 |
* **Schema Cleaning**: Similar to Qwen, it aggressively sanitizes tool schemas to prevent 400 errors.
|
| 913 |
* **Dedicated Logging**: Implements `_IFlowFileLogger` to capture raw chunks for debugging proprietary API behaviors.
|
|
@@ -935,4 +1071,177 @@ To facilitate robust debugging, the proxy includes a comprehensive transaction l
|
|
| 935 |
|
| 936 |
This level of detail allows developers to trace exactly why a request failed or why a specific key was rotated.
|
| 937 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 938 |
|
|
|
|
| 856 |
- Headless environment detection
|
| 857 |
- Sequential refresh queue processing
|
| 858 |
|
| 859 |
+
#### OAuth Callback Port Configuration
|
| 860 |
+
|
| 861 |
+
Each OAuth provider uses a local callback server during authentication. The callback port can be customized via environment variables to avoid conflicts with other services.
|
| 862 |
+
|
| 863 |
+
**Default Ports:**
|
| 864 |
+
|
| 865 |
+
| Provider | Default Port | Environment Variable |
|
| 866 |
+
|----------|-------------|---------------------|
|
| 867 |
+
| Gemini CLI | 8085 | `GEMINI_CLI_OAUTH_PORT` |
|
| 868 |
+
| Antigravity | 51121 | `ANTIGRAVITY_OAUTH_PORT` |
|
| 869 |
+
| iFlow | 11451 | `IFLOW_OAUTH_PORT` |
|
| 870 |
+
|
| 871 |
+
**Configuration Methods:**
|
| 872 |
+
|
| 873 |
+
1. **Via TUI Settings Menu:**
|
| 874 |
+
- Main Menu → `4. View Provider & Advanced Settings` → `1. Launch Settings Tool`
|
| 875 |
+
- Select the provider (Gemini CLI, Antigravity, or iFlow)
|
| 876 |
+
- Modify the `*_OAUTH_PORT` setting
|
| 877 |
+
- Use "Reset to Default" to restore the original port
|
| 878 |
+
|
| 879 |
+
2. **Via `.env` file:**
|
| 880 |
+
```env
|
| 881 |
+
# Custom OAuth callback ports (optional)
|
| 882 |
+
GEMINI_CLI_OAUTH_PORT=8085
|
| 883 |
+
ANTIGRAVITY_OAUTH_PORT=51121
|
| 884 |
+
IFLOW_OAUTH_PORT=11451
|
| 885 |
+
```
|
| 886 |
+
|
| 887 |
+
**When to Change Ports:**
|
| 888 |
+
|
| 889 |
+
- If the default port conflicts with another service on your system
|
| 890 |
+
- If running multiple proxy instances on the same machine
|
| 891 |
+
- If firewall rules require specific port ranges
|
| 892 |
+
|
| 893 |
+
**Note:** Port changes take effect on the next OAuth authentication attempt. Existing tokens are not affected.
|
| 894 |
+
|
| 895 |
+
---
|
| 896 |
+
|
| 897 |
+
### 2.14. HTTP Timeout Configuration (`timeout_config.py`)
|
| 898 |
+
|
| 899 |
+
Centralized timeout configuration for all HTTP requests to LLM providers.
|
| 900 |
+
|
| 901 |
+
#### Purpose
|
| 902 |
+
|
| 903 |
+
The `TimeoutConfig` class provides fine-grained control over HTTP timeouts for streaming and non-streaming LLM requests. This addresses the common issue of proxy hangs when upstream providers stall during connection establishment or response generation.
|
| 904 |
+
|
| 905 |
+
#### Timeout Types Explained
|
| 906 |
+
|
| 907 |
+
| Timeout | Description |
|
| 908 |
+
|---------|-------------|
|
| 909 |
+
| **connect** | Maximum time to establish a TCP/TLS connection to the upstream server |
|
| 910 |
+
| **read** | Maximum time to wait between receiving data chunks (resets on each chunk for streaming) |
|
| 911 |
+
| **write** | Maximum time to wait while sending the request body |
|
| 912 |
+
| **pool** | Maximum time to wait for a connection from the connection pool |
|
| 913 |
+
|
| 914 |
+
#### Default Values
|
| 915 |
+
|
| 916 |
+
| Setting | Streaming | Non-Streaming | Rationale |
|
| 917 |
+
|---------|-----------|---------------|-----------|
|
| 918 |
+
| **connect** | 30s | 30s | Fast fail if server is unreachable |
|
| 919 |
+
| **read** | 180s (3 min) | 600s (10 min) | Streaming expects periodic chunks; non-streaming may wait for full generation |
|
| 920 |
+
| **write** | 30s | 30s | Request bodies are typically small |
|
| 921 |
+
| **pool** | 60s | 60s | Reasonable wait for connection pool |
|
| 922 |
+
|
| 923 |
+
#### Environment Variable Overrides
|
| 924 |
+
|
| 925 |
+
All timeout values can be customized via environment variables:
|
| 926 |
+
|
| 927 |
+
```env
|
| 928 |
+
# Connection establishment timeout (seconds)
|
| 929 |
+
TIMEOUT_CONNECT=30
|
| 930 |
+
|
| 931 |
+
# Request body send timeout (seconds)
|
| 932 |
+
TIMEOUT_WRITE=30
|
| 933 |
+
|
| 934 |
+
# Connection pool acquisition timeout (seconds)
|
| 935 |
+
TIMEOUT_POOL=60
|
| 936 |
+
|
| 937 |
+
# Read timeout between chunks for streaming requests (seconds)
|
| 938 |
+
# If no data arrives for this duration, the connection is considered stalled
|
| 939 |
+
TIMEOUT_READ_STREAMING=180
|
| 940 |
+
|
| 941 |
+
# Read timeout for non-streaming responses (seconds)
|
| 942 |
+
# Longer to accommodate models that take time to generate full responses
|
| 943 |
+
TIMEOUT_READ_NON_STREAMING=600
|
| 944 |
+
```
|
| 945 |
+
|
| 946 |
+
#### Streaming vs Non-Streaming Behavior
|
| 947 |
+
|
| 948 |
+
**Streaming Requests** (`TimeoutConfig.streaming()`):
|
| 949 |
+
- Uses shorter read timeout (default 3 minutes)
|
| 950 |
+
- Timer resets every time a chunk arrives
|
| 951 |
+
- If no data for 3 minutes → connection considered dead → failover to next credential
|
| 952 |
+
- Appropriate for chat completions where tokens should arrive periodically
|
| 953 |
+
|
| 954 |
+
**Non-Streaming Requests** (`TimeoutConfig.non_streaming()`):
|
| 955 |
+
- Uses longer read timeout (default 10 minutes)
|
| 956 |
+
- Server may take significant time to generate the complete response before sending anything
|
| 957 |
+
- Complex reasoning tasks or large outputs may legitimately take several minutes
|
| 958 |
+
- Only used by Antigravity provider's `_handle_non_streaming()` method
|
| 959 |
+
|
| 960 |
+
#### Provider Usage
|
| 961 |
+
|
| 962 |
+
The following providers use `TimeoutConfig`:
|
| 963 |
+
|
| 964 |
+
| Provider | Method | Timeout Type |
|
| 965 |
+
|----------|--------|--------------|
|
| 966 |
+
| `antigravity_provider.py` | `_handle_non_streaming()` | `non_streaming()` |
|
| 967 |
+
| `antigravity_provider.py` | `_handle_streaming()` | `streaming()` |
|
| 968 |
+
| `gemini_cli_provider.py` | `acompletion()` | `streaming()` |
|
| 969 |
+
| `iflow_provider.py` | `acompletion()` | `streaming()` |
|
| 970 |
+
| `qwen_code_provider.py` | `acompletion()` | `streaming()` |
|
| 971 |
+
|
| 972 |
+
**Note:** iFlow, Qwen Code, and Gemini CLI providers always use streaming internally (even for non-streaming requests), aggregating chunks into a complete response. Only Antigravity has a true non-streaming path.
|
| 973 |
+
|
| 974 |
+
#### Tuning Recommendations
|
| 975 |
+
|
| 976 |
+
| Use Case | Recommendation |
|
| 977 |
+
|----------|----------------|
|
| 978 |
+
| **Long thinking tasks** | Increase `TIMEOUT_READ_STREAMING` to 300-360s |
|
| 979 |
+
| **Unstable network** | Increase `TIMEOUT_CONNECT` to 60s |
|
| 980 |
+
| **High concurrency** | Increase `TIMEOUT_POOL` if seeing pool exhaustion |
|
| 981 |
+
| **Large context/output** | Increase `TIMEOUT_READ_NON_STREAMING` to 900s+ |
|
| 982 |
+
|
| 983 |
+
#### Example Configuration
|
| 984 |
+
|
| 985 |
+
```env
|
| 986 |
+
# For environments with complex reasoning tasks
|
| 987 |
+
TIMEOUT_READ_STREAMING=300
|
| 988 |
+
TIMEOUT_READ_NON_STREAMING=900
|
| 989 |
+
|
| 990 |
+
# For unstable network conditions
|
| 991 |
+
TIMEOUT_CONNECT=60
|
| 992 |
+
TIMEOUT_POOL=120
|
| 993 |
+
```
|
| 994 |
+
|
| 995 |
---
|
| 996 |
|
| 997 |
|
|
|
|
| 1013 |
|
| 1014 |
#### Authentication (`gemini_auth_base.py`)
|
| 1015 |
|
| 1016 |
+
* **Device Flow**: Uses a standard OAuth 2.0 flow. The `credential_tool` spins up a local web server (default: `localhost:8085`, configurable via `GEMINI_CLI_OAUTH_PORT`) to capture the callback from Google's auth page.
|
| 1017 |
+
* **Token Lifecycle**:
|
| 1018 |
* **Proactive Refresh**: Tokens are refreshed 5 minutes before expiry.
|
| 1019 |
* **Atomic Writes**: Credential files are updated using a temp-file-and-move strategy to prevent corruption during writes.
|
| 1020 |
* **Revocation Handling**: If a `400` or `401` occurs during refresh, the token is marked as revoked, preventing infinite retry loops.
|
|
|
|
| 1043 |
### 3.3. iFlow (`iflow_provider.py`)
|
| 1044 |
|
| 1045 |
* **Hybrid Auth**: Uses a custom OAuth flow (Authorization Code) to obtain an `access_token`. However, the *actual* API calls use a separate `apiKey` that is retrieved from the user's profile (`/api/oauth/getUserInfo`) using the access token.
|
| 1046 |
+
* **Callback Server**: The auth flow spins up a local server (default: port `11451`, configurable via `IFLOW_OAUTH_PORT`) to capture the redirect.
|
| 1047 |
* **Token Management**: Automatically refreshes the OAuth token and re-fetches the API key if needed.
|
| 1048 |
* **Schema Cleaning**: Similar to Qwen, it aggressively sanitizes tool schemas to prevent 400 errors.
|
| 1049 |
* **Dedicated Logging**: Implements `_IFlowFileLogger` to capture raw chunks for debugging proprietary API behaviors.
|
|
|
|
| 1071 |
|
| 1072 |
This level of detail allows developers to trace exactly why a request failed or why a specific key was rotated.
|
| 1073 |
|
| 1074 |
+
---
|
| 1075 |
+
|
| 1076 |
+
## 5. Runtime Resilience
|
| 1077 |
+
|
| 1078 |
+
The proxy is engineered to maintain high availability even in the face of runtime filesystem disruptions. This "Runtime Resilience" capability ensures that the service continues to process API requests even if data files or directories are deleted while the application is running.
|
| 1079 |
+
|
| 1080 |
+
### 5.1. Centralized Resilient I/O (`resilient_io.py`)
|
| 1081 |
+
|
| 1082 |
+
All file operations are centralized in a single utility module that provides consistent error handling, graceful degradation, and automatic retry with shutdown flush:
|
| 1083 |
+
|
| 1084 |
+
#### `BufferedWriteRegistry` (Singleton)
|
| 1085 |
+
|
| 1086 |
+
Global registry for buffered writes with periodic retry and shutdown flush. Ensures critical data is saved even if disk writes fail temporarily:
|
| 1087 |
+
|
| 1088 |
+
- **Per-file buffering**: Each file path has its own pending write (latest data always wins)
|
| 1089 |
+
- **Periodic retries**: Background thread retries failed writes every 30 seconds
|
| 1090 |
+
- **Shutdown flush**: `atexit` hook ensures final write attempt on app exit (Ctrl+C)
|
| 1091 |
+
- **Thread-safe**: Safe for concurrent access from multiple threads
|
| 1092 |
+
|
| 1093 |
+
```python
|
| 1094 |
+
# Get the singleton instance
|
| 1095 |
+
registry = BufferedWriteRegistry.get_instance()
|
| 1096 |
+
|
| 1097 |
+
# Check pending writes (for monitoring)
|
| 1098 |
+
pending_count = registry.get_pending_count()
|
| 1099 |
+
pending_files = registry.get_pending_paths()
|
| 1100 |
+
|
| 1101 |
+
# Manual flush (optional - atexit handles this automatically)
|
| 1102 |
+
results = registry.flush_all() # Returns {path: success_bool}
|
| 1103 |
+
|
| 1104 |
+
# Manual shutdown (if needed before atexit)
|
| 1105 |
+
results = registry.shutdown()
|
| 1106 |
+
```
|
| 1107 |
+
|
| 1108 |
+
#### `ResilientStateWriter`
|
| 1109 |
+
|
| 1110 |
+
For stateful files that must persist (usage stats):
|
| 1111 |
+
- **Memory-first**: Always updates in-memory state before attempting disk write
|
| 1112 |
+
- **Atomic writes**: Uses tempfile + move pattern to prevent corruption
|
| 1113 |
+
- **Automatic retry with backoff**: If disk fails, waits `retry_interval` seconds before trying again
|
| 1114 |
+
- **Shutdown integration**: Registers with `BufferedWriteRegistry` on failure for final flush
|
| 1115 |
+
- **Health monitoring**: Exposes `is_healthy` property for monitoring
|
| 1116 |
+
|
| 1117 |
+
```python
|
| 1118 |
+
writer = ResilientStateWriter("data.json", logger, retry_interval=30.0)
|
| 1119 |
+
writer.write({"key": "value"}) # Always succeeds (memory update)
|
| 1120 |
+
if not writer.is_healthy:
|
| 1121 |
+
logger.warning("Disk writes failing, data in memory only")
|
| 1122 |
+
# On next write() call after retry_interval, disk write is attempted again
|
| 1123 |
+
# On app exit (Ctrl+C), BufferedWriteRegistry attempts final save
|
| 1124 |
+
```
|
| 1125 |
+
|
| 1126 |
+
#### `safe_write_json()`
|
| 1127 |
+
|
| 1128 |
+
For JSON writes with configurable options (credentials, cache):
|
| 1129 |
+
|
| 1130 |
+
| Parameter | Default | Description |
|
| 1131 |
+
|-----------|---------|-------------|
|
| 1132 |
+
| `path` | required | File path to write to |
|
| 1133 |
+
| `data` | required | JSON-serializable data |
|
| 1134 |
+
| `logger` | required | Logger for warnings |
|
| 1135 |
+
| `atomic` | `True` | Use atomic write pattern (tempfile + move) |
|
| 1136 |
+
| `indent` | `2` | JSON indentation level |
|
| 1137 |
+
| `ensure_ascii` | `True` | Escape non-ASCII characters |
|
| 1138 |
+
| `secure_permissions` | `False` | Set file permissions to 0o600 |
|
| 1139 |
+
| `buffer_on_failure` | `False` | Register with BufferedWriteRegistry on failure |
|
| 1140 |
+
|
| 1141 |
+
When `buffer_on_failure=True`:
|
| 1142 |
+
- Failed writes are registered with `BufferedWriteRegistry`
|
| 1143 |
+
- Data is retried every 30 seconds in background
|
| 1144 |
+
- On app exit, final write attempt is made automatically
|
| 1145 |
+
- Success unregisters the pending write
|
| 1146 |
+
|
| 1147 |
+
```python
|
| 1148 |
+
# For critical data (auth tokens) - use buffer_on_failure
|
| 1149 |
+
safe_write_json(path, creds, logger, secure_permissions=True, buffer_on_failure=True)
|
| 1150 |
+
|
| 1151 |
+
# For non-critical data (logs) - no buffering needed
|
| 1152 |
+
safe_write_json(path, data, logger)
|
| 1153 |
+
```
|
| 1154 |
+
|
| 1155 |
+
#### `safe_log_write()`
|
| 1156 |
+
|
| 1157 |
+
For log files where occasional loss is acceptable:
|
| 1158 |
+
- Fire-and-forget pattern
|
| 1159 |
+
- Creates parent directories if needed
|
| 1160 |
+
- Returns `True`/`False`, never raises
|
| 1161 |
+
- **No buffering** - logs are dropped on failure
|
| 1162 |
+
|
| 1163 |
+
#### `safe_mkdir()`
|
| 1164 |
+
|
| 1165 |
+
For directory creation with error handling.
|
| 1166 |
+
|
| 1167 |
+
### 5.2. Resilience Hierarchy
|
| 1168 |
+
|
| 1169 |
+
The system follows a strict hierarchy of survival:
|
| 1170 |
+
|
| 1171 |
+
1. **Core API Handling (Level 1)**: The Python runtime keeps all necessary code in memory. Deleting source code files while the proxy is running will **not** crash active requests.
|
| 1172 |
+
|
| 1173 |
+
2. **Credential Management (Level 2)**: OAuth tokens are cached in memory first. If credential files are deleted, the proxy continues using cached tokens. If a token refresh succeeds but the file cannot be written, the new token is buffered for retry and saved on shutdown.
|
| 1174 |
+
|
| 1175 |
+
3. **Usage Tracking (Level 3)**: Usage statistics (`key_usage.json`) are maintained in memory via `ResilientStateWriter`. If the file is deleted, the system tracks usage internally and attempts to recreate the file on the next save interval. Pending writes are flushed on shutdown.
|
| 1176 |
+
|
| 1177 |
+
4. **Provider Cache (Level 4)**: The provider cache tracks disk health and continues operating in memory-only mode if disk writes fail. Has its own shutdown mechanism.
|
| 1178 |
+
|
| 1179 |
+
5. **Logging (Level 5)**: Logging is treated as non-critical. If the `logs/` directory is removed, the system attempts to recreate it. If creation fails, logging degrades gracefully without interrupting the request flow. **No buffering or retry**.
|
| 1180 |
+
|
| 1181 |
+
### 5.3. Component Integration
|
| 1182 |
+
|
| 1183 |
+
| Component | Utility Used | Behavior on Disk Failure | Shutdown Flush |
|
| 1184 |
+
|-----------|--------------|--------------------------|----------------|
|
| 1185 |
+
| `UsageManager` | `ResilientStateWriter` | Continues in memory, retries after 30s | Yes (via registry) |
|
| 1186 |
+
| `GoogleOAuthBase` | `safe_write_json(buffer_on_failure=True)` | Memory cache preserved, buffered for retry | Yes (via registry) |
|
| 1187 |
+
| `QwenAuthBase` | `safe_write_json(buffer_on_failure=True)` | Memory cache preserved, buffered for retry | Yes (via registry) |
|
| 1188 |
+
| `IFlowAuthBase` | `safe_write_json(buffer_on_failure=True)` | Memory cache preserved, buffered for retry | Yes (via registry) |
|
| 1189 |
+
| `ProviderCache` | `safe_write_json` + own shutdown | Retries via own background loop | Yes (own mechanism) |
|
| 1190 |
+
| `DetailedLogger` | `safe_write_json` | Logs dropped, no crash | No |
|
| 1191 |
+
| `failure_logger` | Python `logging.RotatingFileHandler` | Falls back to NullHandler | No |
|
| 1192 |
+
|
| 1193 |
+
### 5.4. Shutdown Behavior
|
| 1194 |
+
|
| 1195 |
+
When the application exits (including Ctrl+C):
|
| 1196 |
+
|
| 1197 |
+
1. **atexit handler fires**: `BufferedWriteRegistry._atexit_handler()` is called
|
| 1198 |
+
2. **Pending writes counted**: Registry checks how many files have pending writes
|
| 1199 |
+
3. **Flush attempted**: Each pending file gets a final write attempt
|
| 1200 |
+
4. **Results logged**:
|
| 1201 |
+
- Success: `"Shutdown flush: all N write(s) succeeded"`
|
| 1202 |
+
- Partial: `"Shutdown flush: X succeeded, Y failed"` with failed file names
|
| 1203 |
+
|
| 1204 |
+
**Console output example:**
|
| 1205 |
+
```
|
| 1206 |
+
INFO:rotator_library.resilient_io:Flushing 2 pending write(s) on shutdown...
|
| 1207 |
+
INFO:rotator_library.resilient_io:Shutdown flush: all 2 write(s) succeeded
|
| 1208 |
+
```
|
| 1209 |
+
|
| 1210 |
+
### 5.5. "Develop While Running"
|
| 1211 |
+
|
| 1212 |
+
This architecture supports a robust development workflow:
|
| 1213 |
+
|
| 1214 |
+
- **Log Cleanup**: You can safely run `rm -rf logs/` while the proxy is serving traffic. The system will recreate the directory structure on the next request.
|
| 1215 |
+
- **Config Reset**: Deleting `key_usage.json` resets the persistence layer, but the running instance preserves its current in-memory counts for load balancing consistency.
|
| 1216 |
+
- **File Recovery**: If you delete a critical file, the system attempts directory auto-recreation before every write operation.
|
| 1217 |
+
- **Safe Exit**: Ctrl+C triggers graceful shutdown with final data flush attempt.
|
| 1218 |
+
|
| 1219 |
+
### 5.6. Graceful Degradation & Data Loss
|
| 1220 |
+
|
| 1221 |
+
While functionality is preserved, persistence may be compromised during filesystem failures:
|
| 1222 |
+
|
| 1223 |
+
- **Logs**: If disk writes fail, detailed request logs may be lost (no buffering).
|
| 1224 |
+
- **Usage Stats**: Buffered in memory and flushed on shutdown. Data loss only if shutdown flush also fails.
|
| 1225 |
+
- **Credentials**: Buffered in memory and flushed on shutdown. Re-authentication only needed if shutdown flush fails.
|
| 1226 |
+
- **Cache**: Provider cache entries may need to be regenerated after restart if its own shutdown mechanism fails.
|
| 1227 |
+
|
| 1228 |
+
### 5.7. Monitoring Disk Health
|
| 1229 |
+
|
| 1230 |
+
Components expose health information for monitoring:
|
| 1231 |
+
|
| 1232 |
+
```python
|
| 1233 |
+
# BufferedWriteRegistry
|
| 1234 |
+
registry = BufferedWriteRegistry.get_instance()
|
| 1235 |
+
pending = registry.get_pending_count() # Number of files with pending writes
|
| 1236 |
+
files = registry.get_pending_paths() # List of pending file names
|
| 1237 |
+
|
| 1238 |
+
# UsageManager
|
| 1239 |
+
writer = usage_manager._state_writer
|
| 1240 |
+
health = writer.get_health_info()
|
| 1241 |
+
# Returns: {"healthy": True, "failure_count": 0, "last_success": 1234567890.0, ...}
|
| 1242 |
+
|
| 1243 |
+
# ProviderCache
|
| 1244 |
+
stats = cache.get_stats()
|
| 1245 |
+
# Includes: {"disk_available": True, "disk_errors": 0, ...}
|
| 1246 |
+
```
|
| 1247 |
|
README.md
CHANGED
|
@@ -1,755 +1,763 @@
|
|
| 1 |
-
# Universal LLM API Proxy & Resilience Library
|
|
|
|
| 2 |
[](https://deepwiki.com/Mirrowel/LLM-API-Key-Proxy) [](https://zread.ai/Mirrowel/LLM-API-Key-Proxy)
|
| 3 |
|
|
|
|
| 4 |
|
| 5 |
-
|
| 6 |
|
| 7 |
-
This project
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
|
| 10 |
-
2. **A Resilience & Key Management Library**: The core engine that powers the proxy. This reusable Python library intelligently manages a pool of API keys to ensure your application is highly available and resilient to transient provider errors or performance issues.
|
| 11 |
-
|
| 12 |
-
## Features
|
| 13 |
|
| 14 |
-
|
| 15 |
-
- **High Availability**: The underlying library ensures your application remains operational by gracefully handling transient provider errors and API key-specific issues.
|
| 16 |
-
- **Resilient Performance**: A global timeout on all requests prevents your application from hanging on unresponsive provider APIs.
|
| 17 |
-
- **Advanced Concurrency Control**: A single API key can be used for multiple concurrent requests. By default, it supports concurrent requests to *different* models. With configuration (`MAX_CONCURRENT_REQUESTS_PER_KEY_<PROVIDER>`), it can also support multiple concurrent requests to the *same* model using the same key.
|
| 18 |
-
- **Intelligent Key Management**: Optimizes request distribution across your pool of keys by selecting the best available one for each call.
|
| 19 |
-
- **Automated OAuth Discovery**: Automatically discovers, validates, and manages OAuth credentials from standard provider directories (e.g., `~/.gemini/`, `~/.qwen/`, `~/.iflow/`).
|
| 20 |
-
- **Stateless Deployment Support**: Deploy easily to platforms like Railway, Render, or Vercel. The new export tool converts complex OAuth credentials (Gemini CLI, Qwen, iFlow) into simple environment variables, removing the need for persistent storage or file uploads.
|
| 21 |
-
- **Batch Request Processing**: Efficiently aggregates multiple embedding requests into single batch API calls, improving throughput and reducing rate limit hits.
|
| 22 |
-
- **New Provider Support**: Full support for **iFlow** (API Key & OAuth), **Qwen Code** (API Key & OAuth), and **NVIDIA NIM** with DeepSeek thinking support, including special handling for their API quirks (tool schema cleaning, reasoning support, dedicated logging).
|
| 23 |
-
- **Duplicate Credential Detection**: Intelligently detects if multiple local credential files belong to the same user account and logs a warning, preventing redundancy in your key pool.
|
| 24 |
-
- **Escalating Per-Model Cooldowns**: If a key fails for a specific model, it's placed on a temporary, escalating cooldown for that model, allowing it to be used with others.
|
| 25 |
-
- **Automatic Daily Resets**: Cooldowns and usage statistics are automatically reset daily, making the system self-maintaining.
|
| 26 |
-
- **Detailed Request Logging**: Enable comprehensive logging for debugging. Each request gets its own directory with full request/response details, streaming chunks, and performance metadata.
|
| 27 |
-
- **Provider Agnostic**: Compatible with any provider supported by `litellm`.
|
| 28 |
-
- **OpenAI-Compatible Proxy**: Offers a familiar API interface with additional endpoints for model and provider discovery.
|
| 29 |
-
- **Advanced Model Filtering**: Supports both blacklists and whitelists to give you fine-grained control over which models are available through the proxy.
|
| 30 |
-
|
| 31 |
-
- **🆕 Antigravity Provider**: Full support for Google's internal Antigravity API, providing access to Gemini 3 and Claude models with advanced features:
|
| 32 |
-
- **🚀 Claude Opus 4.5** - Anthropic's most powerful model (thinking mode only)
|
| 33 |
-
- **Claude Sonnet 4.5** - Supports both thinking and non-thinking modes
|
| 34 |
-
- **Gemini 3 Pro** - With thinkingLevel support (low/high)
|
| 35 |
-
- Credential prioritization with automatic paid/free tier detection
|
| 36 |
-
- Thought signature caching for multi-turn conversations
|
| 37 |
-
- Tool hallucination prevention via parameter signature injection
|
| 38 |
-
- Automatic thinking block sanitization for Claude models (with recovery strategies)
|
| 39 |
-
- Note: Claude thinking mode requires careful conversation state management (see [Antigravity documentation](DOCUMENTATION.md#antigravity-claude-extended-thinking-sanitization) for details)
|
| 40 |
-
- **🆕 Credential Prioritization**: Automatic tier detection and priority-based credential selection ensures paid-tier credentials are used for premium models that require them.
|
| 41 |
-
- **🆕 Sequential Rotation Mode**: Choose between balanced (distribute load evenly) or sequential (use until exhausted) credential rotation strategies. Sequential mode maximizes cache hit rates for providers like Antigravity.
|
| 42 |
-
- **🆕 Per-Model Quota Tracking**: Granular per-model usage tracking with authoritative quota reset timestamps from provider error responses. Each model maintains its own window with `window_start_ts` and `quota_reset_ts`.
|
| 43 |
-
- **🆕 Model Quota Groups**: Group models that share quota limits (e.g., Claude Sonnet and Opus). When one model in a group hits quota, all receive the same cooldown timestamp.
|
| 44 |
-
- **🆕 Priority-Based Concurrency**: Assign credentials to priority tiers (1=highest) with configurable concurrency multipliers. Paid-tier credentials can handle more concurrent requests than free-tier ones.
|
| 45 |
-
- **🆕 Provider-Specific Quota Parsing**: Extended provider interface with `parse_quota_error()` method to extract precise retry-after times from provider-specific error formats (e.g., Google RPC format).
|
| 46 |
-
- **🆕 Flexible Rolling Windows**: Support for provider-specific quota reset configurations (5-hour, 7-day, etc.) replacing hardcoded daily resets.
|
| 47 |
-
- **🆕 Weighted Random Rotation**: Configurable credential rotation strategy - choose between deterministic (perfect balance) or weighted random (unpredictable, harder to fingerprint) selection.
|
| 48 |
-
- **🆕 Enhanced Gemini CLI**: Improved project discovery, paid vs free tier detection, and Gemini 3 support with thoughtSignature caching.
|
| 49 |
-
- **🆕 Temperature Override**: Global temperature=0 override option to prevent tool hallucination issues with low-temperature settings.
|
| 50 |
-
- **🆕 Provider Cache System**: Modular caching system for preserving conversation state (thought signatures, thinking content) across requests.
|
| 51 |
-
- **🆕 Refactored OAuth Base**: Shared [`GoogleOAuthBase`](src/rotator_library/providers/google_oauth_base.py) class eliminates code duplication across OAuth providers.
|
| 52 |
-
|
| 53 |
-
- **🆕 Interactive Launcher TUI**: Beautiful, cross-platform TUI for configuration and management with an integrated settings tool for advanced configuration.
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
---
|
| 57 |
|
| 58 |
-
##
|
| 59 |
|
| 60 |
-
### Windows
|
| 61 |
|
| 62 |
-
1.
|
| 63 |
-
2.
|
| 64 |
-
3.
|
| 65 |
-
- 🚀 Run the proxy server with your configured settings
|
| 66 |
-
- ⚙️ Configure proxy settings (Host, Port, PROXY_API_KEY, Request Logging)
|
| 67 |
-
- 🔑 Manage credentials (add/edit API keys & OAuth credentials)
|
| 68 |
-
- 📊 View provider status and advanced settings
|
| 69 |
-
- 🔧 Configure advanced settings interactively (custom API bases, model definitions, concurrency limits)
|
| 70 |
-
- 🔄 Reload configuration without restarting
|
| 71 |
|
| 72 |
-
|
| 73 |
|
| 74 |
### macOS / Linux
|
| 75 |
|
| 76 |
-
**Option A: Using the Executable (Recommended)**
|
| 77 |
-
If you downloaded the pre-compiled binary for your platform, no Python installation is required.
|
| 78 |
-
|
| 79 |
-
1. **Download the latest release** from the GitHub Releases page.
|
| 80 |
-
2. Open a terminal and make the binary executable:
|
| 81 |
-
```bash
|
| 82 |
-
chmod +x proxy_app
|
| 83 |
-
```
|
| 84 |
-
3. **Run the Interactive Launcher**:
|
| 85 |
-
```bash
|
| 86 |
-
./proxy_app
|
| 87 |
-
```
|
| 88 |
-
This launches the TUI where you can configure and run the proxy.
|
| 89 |
-
|
| 90 |
-
4. **Or run directly with arguments** to bypass the launcher:
|
| 91 |
-
```bash
|
| 92 |
-
./proxy_app --host 0.0.0.0 --port 8000
|
| 93 |
-
```
|
| 94 |
-
|
| 95 |
-
**Option B: Manual Setup (Source Code)**
|
| 96 |
-
If you are running from source, use these commands:
|
| 97 |
-
|
| 98 |
-
**1. Install Dependencies**
|
| 99 |
```bash
|
| 100 |
-
#
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
pip install -r requirements.txt
|
| 104 |
```
|
| 105 |
|
| 106 |
-
|
|
|
|
| 107 |
```bash
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
python src/proxy_app/main.py
|
| 110 |
```
|
| 111 |
|
| 112 |
-
**
|
| 113 |
-
```bash
|
| 114 |
-
export PYTHONPATH=$PYTHONPATH:$(pwd)/src
|
| 115 |
-
python src/proxy_app/main.py --host 0.0.0.0 --port 8000
|
| 116 |
-
```
|
| 117 |
-
*To enable logging, add `--enable-request-logging` to the command.*
|
| 118 |
|
| 119 |
---
|
| 120 |
|
| 121 |
-
##
|
| 122 |
-
|
| 123 |
-
The proxy now includes a powerful **interactive Text User Interface (TUI)** that makes configuration and management effortless.
|
| 124 |
-
|
| 125 |
-
### Features
|
| 126 |
-
|
| 127 |
-
- **🎯 Main Menu**:
|
| 128 |
-
- Run proxy server with saved settings
|
| 129 |
-
- Configure proxy settings (host, port, API key, logging)
|
| 130 |
-
- Manage credentials (API keys & OAuth)
|
| 131 |
-
- View provider & advanced settings status
|
| 132 |
-
- Reload configuration
|
| 133 |
-
|
| 134 |
-
- **🔧 Advanced Settings Tool**:
|
| 135 |
-
- Configure custom OpenAI-compatible providers
|
| 136 |
-
- Define provider models (simple or advanced JSON format)
|
| 137 |
-
- Set concurrency limits per provider
|
| 138 |
-
- Configure rotation modes (balanced vs sequential)
|
| 139 |
-
- Manage priority-based concurrency multipliers
|
| 140 |
-
- Interactive numbered menus for easy selection
|
| 141 |
-
- Pending changes system with save/discard options
|
| 142 |
-
|
| 143 |
-
- **📊 Status Dashboard**:
|
| 144 |
-
- Shows configured providers and credential counts
|
| 145 |
-
- Displays custom providers and API bases
|
| 146 |
-
- Shows active advanced settings
|
| 147 |
-
- Real-time configuration status
|
| 148 |
-
|
| 149 |
-
### How to Use
|
| 150 |
-
|
| 151 |
-
**Running without arguments launches the TUI:**
|
| 152 |
-
```bash
|
| 153 |
-
# Windows
|
| 154 |
-
proxy_app.exe
|
| 155 |
|
| 156 |
-
|
| 157 |
-
./proxy_app
|
| 158 |
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
```
|
| 162 |
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
```bash
|
| 165 |
-
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
```
|
| 168 |
|
| 169 |
-
|
| 170 |
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
- **`.env`**: Stores all credentials and advanced settings (PROXY_API_KEY, provider credentials, custom settings)
|
| 174 |
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
-
|
| 178 |
|
| 179 |
-
|
|
|
|
| 180 |
|
| 181 |
-
|
| 182 |
|
| 183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
-
|
| 186 |
|
| 187 |
-
|
| 188 |
-
```bash
|
| 189 |
-
# Clone the repository
|
| 190 |
-
git clone https://github.com/Mirrowel/LLM-API-Key-Proxy.git
|
| 191 |
-
cd LLM-API-Key-Proxy
|
| 192 |
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
-
|
| 198 |
-
pip install -r requirements.txt
|
| 199 |
-
```
|
| 200 |
|
| 201 |
-
|
| 202 |
-
```powershell
|
| 203 |
-
# Clone the repository
|
| 204 |
-
git clone https://github.com/Mirrowel/LLM-API-Key-Proxy.git
|
| 205 |
-
cd LLM-API-Key-Proxy
|
| 206 |
|
| 207 |
-
#
|
| 208 |
-
python -m venv venv
|
| 209 |
-
.\venv\Scripts\Activate.ps1
|
| 210 |
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
|
|
|
|
|
|
| 214 |
|
| 215 |
-
|
|
|
|
|
|
|
| 216 |
|
| 217 |
-
|
| 218 |
|
| 219 |
-
**Linux/macOS:**
|
| 220 |
```bash
|
| 221 |
-
|
| 222 |
```
|
| 223 |
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
```
|
| 228 |
|
| 229 |
-
|
| 230 |
|
| 231 |
-
|
| 232 |
|
| 233 |
-
The
|
| 234 |
|
| 235 |
-
|
| 236 |
-
2. **OAuth Credentials**: For services that use OAuth 2.0, like the Gemini CLI.
|
| 237 |
|
| 238 |
-
###
|
| 239 |
|
| 240 |
-
|
| 241 |
-
-
|
| 242 |
-
-
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
-
|
| 245 |
|
| 246 |
-
|
|
|
|
| 247 |
|
| 248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
|
| 250 |
-
|
| 251 |
-
|
|
|
|
|
|
|
|
|
|
| 252 |
```
|
| 253 |
|
| 254 |
-
|
| 255 |
-
```bash
|
| 256 |
-
python src/proxy_app/main.py
|
| 257 |
-
# Then select "3. 🔑 Manage Credentials"
|
| 258 |
-
```
|
| 259 |
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
-
|
| 263 |
-
- Automatically opens your browser for authentication
|
| 264 |
-
- Handles the entire OAuth flow including callbacks
|
| 265 |
-
- Saves credentials to the local `oauth_creds/` directory
|
| 266 |
-
- For Gemini CLI: Automatically discovers or creates a Google Cloud project
|
| 267 |
-
- For Antigravity: Similar to Gemini CLI with Antigravity-specific scopes
|
| 268 |
-
- For Qwen Code: Uses Device Code flow (you'll enter a code in your browser)
|
| 269 |
-
- For iFlow: Starts a local callback server on port 11451
|
| 270 |
|
| 271 |
-
|
| 272 |
-
- Interactive prompts guide you through the process
|
| 273 |
-
- Automatically saves to your `.env` file
|
| 274 |
-
- Supports multiple keys per provider (numbered automatically)
|
| 275 |
|
| 276 |
-
|
| 277 |
-
- Converts file-based OAuth credentials into environment variables
|
| 278 |
-
- Essential for platforms without persistent file storage
|
| 279 |
-
- Generates a ready-to-paste `.env` block for each credential
|
| 280 |
|
| 281 |
-
|
| 282 |
|
| 283 |
-
|
| 284 |
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
|
| 291 |
-
|
| 292 |
-
```bash
|
| 293 |
-
python -m rotator_library.credential_tool
|
| 294 |
-
# Select "Export Gemini CLI to .env" (or Qwen/iFlow)
|
| 295 |
-
# Choose your credential file
|
| 296 |
-
```
|
| 297 |
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
|
|
|
| 301 |
|
| 302 |
-
|
| 303 |
-
- Add each variable to your platform's environment settings
|
| 304 |
-
- Set `SKIP_OAUTH_INIT_CHECK=true` to skip interactive validation
|
| 305 |
-
- No credential files needed; everything loads from environment variables
|
| 306 |
|
| 307 |
-
|
| 308 |
|
| 309 |
-
|
| 310 |
|
| 311 |
-
- **
|
| 312 |
-
- **
|
| 313 |
-
- **
|
| 314 |
-
- **
|
| 315 |
-
- **
|
| 316 |
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
# A secret key for your proxy server to authenticate requests.
|
| 320 |
-
# This can be any secret string you choose.
|
| 321 |
-
PROXY_API_KEY="a-very-secret-and-unique-key"
|
| 322 |
-
|
| 323 |
-
# --- Provider API Keys (Optional) ---
|
| 324 |
-
# The proxy automatically finds keys in your environment variables.
|
| 325 |
-
# You can also define them here. Add multiple keys by numbering them (_1, _2).
|
| 326 |
-
GEMINI_API_KEY_1="YOUR_GEMINI_API_KEY_1"
|
| 327 |
-
GEMINI_API_KEY_2="YOUR_GEMINI_API_KEY_2"
|
| 328 |
-
OPENROUTER_API_KEY_1="YOUR_OPENROUTER_API_KEY_1"
|
| 329 |
-
|
| 330 |
-
# --- OAuth Credentials (Optional) ---
|
| 331 |
-
# The proxy automatically finds credentials in standard system paths.
|
| 332 |
-
# You can override this by specifying a path to your credential file.
|
| 333 |
-
GEMINI_CLI_OAUTH_1="/path/to/your/specific/gemini_creds.json"
|
| 334 |
-
|
| 335 |
-
# --- Gemini CLI: Stateless Deployment Support ---
|
| 336 |
-
# For hosts without file persistence (Railway, Render, etc.), you can provide
|
| 337 |
-
# Gemini CLI credentials directly via environment variables:
|
| 338 |
-
GEMINI_CLI_ACCESS_TOKEN="ya29.your-access-token"
|
| 339 |
-
GEMINI_CLI_REFRESH_TOKEN="1//your-refresh-token"
|
| 340 |
-
GEMINI_CLI_EXPIRY_DATE="1234567890000"
|
| 341 |
-
GEMINI_CLI_EMAIL="your-email@gmail.com"
|
| 342 |
-
# Optional: GEMINI_CLI_PROJECT_ID, GEMINI_CLI_CLIENT_ID, etc.
|
| 343 |
-
# See IMPLEMENTATION_SUMMARY.md for full list of supported variables
|
| 344 |
-
|
| 345 |
-
# --- Dual Authentication Support ---
|
| 346 |
-
# Some providers (qwen_code, iflow) support BOTH OAuth and direct API keys.
|
| 347 |
-
# You can use either method, or mix both for credential rotation:
|
| 348 |
-
QWEN_CODE_API_KEY_1="your-qwen-api-key" # Direct API key
|
| 349 |
-
# AND/OR use OAuth: oauth_creds/qwen_code_oauth_1.json
|
| 350 |
-
IFLOW_API_KEY_1="sk-your-iflow-key" # Direct API key
|
| 351 |
-
# AND/OR use OAuth: oauth_creds/iflow_oauth_1.json
|
| 352 |
-
```
|
| 353 |
|
| 354 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
|
| 356 |
-
|
| 357 |
|
| 358 |
-
|
|
|
|
| 359 |
|
| 360 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
|
| 362 |
-
|
| 363 |
|
| 364 |
-
|
|
|
|
| 365 |
|
| 366 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
|
| 368 |
-
|
| 369 |
-
python src/proxy_app/main.py
|
| 370 |
-
```
|
| 371 |
-
This launches the interactive TUI launcher by default. To run the proxy directly, use:
|
| 372 |
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
```
|
| 376 |
|
| 377 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
|
| 379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
|
| 381 |
-
|
|
|
|
|
|
|
|
|
|
| 382 |
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
|
|
|
| 386 |
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
``
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
|
| 398 |
---
|
| 399 |
|
| 400 |
-
## Advanced
|
| 401 |
|
| 402 |
-
|
|
|
|
| 403 |
|
| 404 |
-
|
| 405 |
|
| 406 |
-
|
| 407 |
-
|
|
|
|
|
|
|
|
|
|
| 408 |
|
| 409 |
-
#
|
| 410 |
-
client = openai.OpenAI(
|
| 411 |
-
base_url="http://127.0.0.1:8000/v1",
|
| 412 |
-
api_key="a-very-secret-and-unique-key" # Use your PROXY_API_KEY here
|
| 413 |
-
)
|
| 414 |
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
)
|
| 422 |
|
| 423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
```
|
| 425 |
|
| 426 |
-
###
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
|
| 428 |
-
|
| 429 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
```
|
| 440 |
|
| 441 |
-
###
|
| 442 |
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
- `GET /v1/providers`: Returns a list of all configured providers.
|
| 447 |
-
- `POST /v1/token-count`: Calculates the token count for a given message payload.
|
| 448 |
|
| 449 |
-
-
|
|
|
|
|
|
|
| 450 |
|
| 451 |
-
##
|
| 452 |
|
| 453 |
-
|
| 454 |
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
- **Benefits**: Significantly reduces the number of HTTP requests to providers, helping you stay within rate limits while improving throughput.
|
| 459 |
|
| 460 |
-
#
|
|
|
|
|
|
|
| 461 |
|
| 462 |
-
|
| 463 |
|
| 464 |
-
|
| 465 |
-
2. **Resilience & Deadlines**: Every request has a strict deadline (`global_timeout`). If a provider is slow or fails, the proxy retries with a different key immediately, ensuring your application never hangs.
|
| 466 |
-
3. **Batching**: High-volume embedding requests are automatically aggregated into optimized batches, reducing API calls and staying within rate limits.
|
| 467 |
-
4. **Deep Observability**: (Optional) Detailed logs capture every byte of the transaction, including raw streaming chunks, for precise debugging of complex agentic interactions.
|
| 468 |
|
| 469 |
-
|
|
|
|
|
|
|
|
|
|
| 470 |
|
| 471 |
-
|
| 472 |
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
- `--enable-request-logging`: A flag to enable detailed, per-request logging. When active, the proxy creates a unique directory for each transaction in the `logs/detailed_logs/` folder, containing the full request, response, streaming chunks, and performance metadata. This is highly recommended for debugging.
|
| 476 |
|
| 477 |
-
|
| 478 |
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
|
|
|
|
|
|
| 484 |
|
| 485 |
-
|
| 486 |
-
-
|
| 487 |
-
-
|
| 488 |
-
-
|
| 489 |
-
- **Reasoning Support**: Parses `<think>` tags in responses and exposes them as `reasoning_content` (similar to OpenAI's o1 format).
|
| 490 |
-
- **Dedicated Logging**: Optional per-request file logging to `logs/qwen_code_logs/` for debugging.
|
| 491 |
-
- **Custom Models**: Define additional models via `QWEN_CODE_MODELS` environment variable (JSON array format).
|
| 492 |
|
| 493 |
-
|
| 494 |
-
- **Dual Authentication**: Use either standard API keys or OAuth 2.0 Authorization Code Flow.
|
| 495 |
-
- **Hybrid Auth**: OAuth flow provides an access token, but actual API calls use a separate `apiKey` retrieved from user profile.
|
| 496 |
-
- **Local Callback Server**: OAuth flow runs a temporary server on port 11451 to capture the redirect.
|
| 497 |
-
- **Schema Cleaning**: Same as Qwen Code - removes unsupported properties from tool schemas.
|
| 498 |
-
- **Stream Stability**: Injects placeholder tools to stabilize streaming for empty tool lists.
|
| 499 |
-
- **Dedicated Logging**: Optional per-request file logging to `logs/iflow_logs/` for debugging proprietary API behaviors.
|
| 500 |
-
- **Custom Models**: Define additional models via `IFLOW_MODELS` environment variable (JSON array format).
|
| 501 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
|
| 503 |
-
|
| 504 |
|
| 505 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
|
| 507 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
|
| 514 |
-
|
| 515 |
-
```env
|
| 516 |
-
SKIP_OAUTH_INIT_CHECK=true
|
| 517 |
|
|
|
|
|
|
|
| 518 |
|
| 519 |
-
|
| 520 |
-
The newest and most sophisticated provider, offering access to cutting-edge models via Google's internal Antigravity API.
|
| 521 |
|
| 522 |
**Supported Models:**
|
| 523 |
-
-
|
| 524 |
-
-
|
| 525 |
-
-
|
| 526 |
-
-
|
| 527 |
|
| 528 |
-
**
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
- **Automatic Fallback**: Tries sandbox endpoints before falling back to production
|
| 533 |
-
- **Schema Cleaning**: Handles Claude-specific tool schema requirements
|
| 534 |
|
| 535 |
-
**
|
| 536 |
-
-
|
| 537 |
-
-
|
| 538 |
-
-
|
|
|
|
| 539 |
|
| 540 |
**Environment Variables:**
|
| 541 |
```env
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
ANTIGRAVITY_EMAIL="user@gmail.com"
|
| 547 |
|
| 548 |
# Feature toggles
|
| 549 |
-
ANTIGRAVITY_ENABLE_SIGNATURE_CACHE=true
|
| 550 |
-
ANTIGRAVITY_GEMINI3_TOOL_FIX=true
|
| 551 |
```
|
| 552 |
|
|
|
|
| 553 |
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
#### Credential Rotation Modes
|
| 557 |
-
|
| 558 |
-
- **`ROTATION_MODE_<PROVIDER>`**: Controls how credentials are rotated when multiple are available. Default: `balanced` (except Antigravity which defaults to `sequential`).
|
| 559 |
-
- `balanced`: Rotate credentials evenly across requests to distribute load. Best for per-minute rate limits.
|
| 560 |
-
- `sequential`: Use one credential until exhausted (429 error), then switch to next. Best for daily/weekly quotas.
|
| 561 |
-
```env
|
| 562 |
-
ROTATION_MODE_GEMINI=sequential # Use Gemini keys until quota exhausted
|
| 563 |
-
ROTATION_MODE_OPENAI=balanced # Distribute load across OpenAI keys (default)
|
| 564 |
-
ROTATION_MODE_ANTIGRAVITY=balanced # Override Antigravity's sequential default
|
| 565 |
-
```
|
| 566 |
-
|
| 567 |
-
#### Priority-Based Concurrency Multipliers
|
| 568 |
-
|
| 569 |
-
- **`CONCURRENCY_MULTIPLIER_<PROVIDER>_PRIORITY_<N>`**: Assign concurrency multipliers to priority tiers. Higher-tier credentials handle more concurrent requests.
|
| 570 |
-
```env
|
| 571 |
-
# Universal multipliers (apply to all rotation modes)
|
| 572 |
-
CONCURRENCY_MULTIPLIER_ANTIGRAVITY_PRIORITY_1=10 # 10x for paid ultra tier
|
| 573 |
-
CONCURRENCY_MULTIPLIER_ANTIGRAVITY_PRIORITY_3=1 # 1x for lower tiers
|
| 574 |
-
|
| 575 |
-
# Mode-specific overrides
|
| 576 |
-
CONCURRENCY_MULTIPLIER_ANTIGRAVITY_PRIORITY_2_BALANCED=1 # P2 = 1x in balanced mode only
|
| 577 |
-
```
|
| 578 |
-
|
| 579 |
-
**Provider Defaults** (built into provider classes):
|
| 580 |
-
- **Antigravity**: Priority 1: 5x, Priority 2: 3x, Priority 3+: 2x (sequential) or 1x (balanced)
|
| 581 |
-
- **Gemini CLI**: Priority 1: 5x, Priority 2: 3x, Others: 1x
|
| 582 |
-
|
| 583 |
-
#### Model Quota Groups
|
| 584 |
-
|
| 585 |
-
- **`QUOTA_GROUPS_<PROVIDER>_<GROUP>`**: Define models that share quota/cooldown timing. When one model hits quota, all in the group receive the same cooldown timestamp.
|
| 586 |
-
```env
|
| 587 |
-
QUOTA_GROUPS_ANTIGRAVITY_CLAUDE="claude-sonnet-4-5,claude-opus-4-5"
|
| 588 |
-
QUOTA_GROUPS_ANTIGRAVITY_GEMINI="gemini-3-pro-preview,gemini-3-pro-image-preview"
|
| 589 |
-
|
| 590 |
-
# To disable a default group:
|
| 591 |
-
QUOTA_GROUPS_ANTIGRAVITY_CLAUDE=""
|
| 592 |
-
```
|
| 593 |
-
|
| 594 |
-
**Default Groups**:
|
| 595 |
-
- **Antigravity**: Claude group (Sonnet 4.5 + Opus 4.5) with Opus counting 2x vs Sonnet
|
| 596 |
-
|
| 597 |
-
#### Concurrency Control
|
| 598 |
-
|
| 599 |
-
- **`MAX_CONCURRENT_REQUESTS_PER_KEY_<PROVIDER>`**: Set the maximum number of simultaneous requests allowed per API key for a specific provider. Default is `1` (no concurrency). Useful for high-throughput providers.
|
| 600 |
-
```env
|
| 601 |
-
MAX_CONCURRENT_REQUESTS_PER_KEY_OPENAI=3
|
| 602 |
-
MAX_CONCURRENT_REQUESTS_PER_KEY_ANTHROPIC=2
|
| 603 |
-
MAX_CONCURRENT_REQUESTS_PER_KEY_GEMINI=1
|
| 604 |
-
```
|
| 605 |
-
|
| 606 |
-
#### Custom Model Lists
|
| 607 |
-
|
| 608 |
-
For providers that support custom model definitions (Qwen Code, iFlow), you can override the default model list:
|
| 609 |
-
|
| 610 |
-
- **`QWEN_CODE_MODELS`**: JSON array of custom Qwen Code models. These models take priority over hardcoded defaults.
|
| 611 |
-
```env
|
| 612 |
-
QWEN_CODE_MODELS='["qwen3-coder-plus", "qwen3-coder-flash", "custom-model-id"]'
|
| 613 |
-
```
|
| 614 |
-
|
| 615 |
-
- **`IFLOW_MODELS`**: JSON array of custom iFlow models. These models take priority over hardcoded defaults.
|
| 616 |
-
```env
|
| 617 |
-
IFLOW_MODELS='["glm-4.6", "qwen3-coder-plus", "deepseek-v3.2"]'
|
| 618 |
-
```
|
| 619 |
-
|
| 620 |
-
#### Provider-Specific Settings
|
| 621 |
-
|
| 622 |
-
- **`GEMINI_CLI_PROJECT_ID`**: Manually specify a Google Cloud Project ID for Gemini CLI OAuth. Only needed if automatic discovery fails.
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
#### Antigravity Provider
|
| 626 |
-
|
| 627 |
-
- **`ANTIGRAVITY_OAUTH_1`**: Path to Antigravity OAuth credential file (auto-discovered from `~/.antigravity/` or use the credential tool).
|
| 628 |
-
```env
|
| 629 |
-
ANTIGRAVITY_OAUTH_1="/path/to/your/antigravity_creds.json"
|
| 630 |
-
```
|
| 631 |
-
|
| 632 |
-
- **Stateless Deployment** (Environment Variables):
|
| 633 |
-
```env
|
| 634 |
-
ANTIGRAVITY_ACCESS_TOKEN="ya29.your-access-token"
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
#### Credential Rotation Strategy
|
| 638 |
-
|
| 639 |
-
- **`ROTATION_TOLERANCE`**: Controls how credentials are selected for requests. Set via environment variable or programmatically.
|
| 640 |
-
- `0.0`: **Deterministic** - Always selects the least-used credential for perfect load balance
|
| 641 |
-
- `3.0` (default, recommended): **Weighted Random** - Randomly selects with bias toward less-used credentials. Provides unpredictability (harder to fingerprint/detect) while maintaining good balance
|
| 642 |
-
- `5.0+`: **High Randomness** - Maximum unpredictability, even heavily-used credentials can be selected
|
| 643 |
-
|
| 644 |
-
```env
|
| 645 |
-
# For maximum security/unpredictability (recommended for production)
|
| 646 |
-
ROTATION_TOLERANCE=3.0
|
| 647 |
-
|
| 648 |
-
# For perfect load balancing (default)
|
| 649 |
-
ROTATION_TOLERANCE=0.0
|
| 650 |
-
```
|
| 651 |
-
|
| 652 |
-
**Why use weighted random?**
|
| 653 |
-
- Makes traffic patterns less predictable
|
| 654 |
-
- Still maintains good load distribution across keys
|
| 655 |
-
- Recommended for production environments with multiple credentials
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
ANTIGRAVITY_REFRESH_TOKEN="1//your-refresh-token"
|
| 659 |
-
ANTIGRAVITY_EXPIRY_DATE="1234567890000"
|
| 660 |
-
ANTIGRAVITY_EMAIL="your-email@gmail.com"
|
| 661 |
-
```
|
| 662 |
-
|
| 663 |
-
- **`ANTIGRAVITY_ENABLE_SIGNATURE_CACHE`**: Enable/disable thought signature caching for Gemini 3 multi-turn conversations. Default: `true`.
|
| 664 |
-
```env
|
| 665 |
-
ANTIGRAVITY_ENABLE_SIGNATURE_CACHE=true
|
| 666 |
-
```
|
| 667 |
-
|
| 668 |
-
- **`ANTIGRAVITY_GEMINI3_TOOL_FIX`**: Enable/disable tool hallucination prevention for Gemini 3 models. Default: `true`.
|
| 669 |
-
```env
|
| 670 |
-
ANTIGRAVITY_GEMINI3_TOOL_FIX=true
|
| 671 |
-
```
|
| 672 |
-
|
| 673 |
-
#### Temperature Override (Global)
|
| 674 |
-
|
| 675 |
-
- **`OVERRIDE_TEMPERATURE_ZERO`**: Prevents tool hallucination caused by temperature=0 settings. Modes:
|
| 676 |
-
- `"remove"`: Deletes temperature=0 from requests (lets provider use default)
|
| 677 |
-
- `"set"`: Changes temperature=0 to temperature=1.0
|
| 678 |
-
- `"false"` or unset: Disabled (default)
|
| 679 |
|
| 680 |
-
|
|
|
|
| 681 |
|
| 682 |
-
|
| 683 |
-
```env
|
| 684 |
-
GEMINI_CLI_PROJECT_ID="your-gcp-project-id"
|
| 685 |
-
```
|
| 686 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 687 |
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
|
|
|
|
|
|
| 691 |
|
| 692 |
-
|
| 693 |
-
```bash
|
| 694 |
-
python src/proxy_app/main.py --host 127.0.0.1 --port 9999 --enable-request-logging
|
| 695 |
-
```
|
| 696 |
|
|
|
|
|
|
|
| 697 |
|
| 698 |
-
|
| 699 |
|
| 700 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 701 |
|
| 702 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 703 |
|
| 704 |
-
|
| 705 |
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
- **All keys on cooldown**: If you see a message that all keys are on cooldown, it means all your keys for a specific provider have recently failed. If you have logging enabled (`--enable-request-logging`), check the `logs/detailed_logs/` directory to find the logs for the failed requests and inspect the `final_response.json` to see the underlying error from the provider.
|
| 709 |
|
| 710 |
-
|
| 711 |
|
| 712 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 713 |
|
| 714 |
-
|
| 715 |
-
|
|
|
|
|
|
|
|
|
|
| 716 |
|
| 717 |
-
|
|
|
|
| 718 |
|
| 719 |
-
|
| 720 |
|
| 721 |
-
|
| 722 |
|
| 723 |
-
|
|
|
|
| 724 |
|
| 725 |
-
|
| 726 |
-
2. **Blacklist Check**: For any model *not* on the whitelist, the proxy checks the blacklist (`IGNORE_MODELS_<PROVIDER>`). If the model is on the blacklist, it will be hidden.
|
| 727 |
-
3. **Default**: If a model is on neither list, it will be available.
|
| 728 |
|
| 729 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 730 |
|
| 731 |
-
|
| 732 |
|
| 733 |
-
|
| 734 |
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
|
| 738 |
-
|
| 739 |
|
| 740 |
-
|
| 741 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 742 |
```
|
| 743 |
|
| 744 |
-
|
|
|
|
|
|
|
|
|
|
| 745 |
|
| 746 |
-
|
|
|
|
| 747 |
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
IGNORE_MODELS_OPENAI="*-preview*"
|
| 752 |
|
| 753 |
-
|
| 754 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 755 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Universal LLM API Proxy & Resilience Library
|
| 2 |
+
[](https://ko-fi.com/C0C0UZS4P)
|
| 3 |
[](https://deepwiki.com/Mirrowel/LLM-API-Key-Proxy) [](https://zread.ai/Mirrowel/LLM-API-Key-Proxy)
|
| 4 |
|
| 5 |
+
**One proxy. Any LLM provider. Zero code changes.**
|
| 6 |
|
| 7 |
+
A self-hosted proxy that provides a single, OpenAI-compatible API endpoint for all your LLM providers. Works with any application that supports custom OpenAI base URLs—no code changes required in your existing tools.
|
| 8 |
|
| 9 |
+
This project consists of two components:
|
| 10 |
+
1. **The API Proxy** — A FastAPI application providing a universal `/v1/chat/completions` endpoint
|
| 11 |
+
2. **The Resilience Library** — A reusable Python library for intelligent API key management, rotation, and failover
|
| 12 |
|
| 13 |
+
---
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
## Why Use This?
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
+
- **Universal Compatibility** — Works with any app supporting OpenAI-compatible APIs: Opencode, Continue, Roo/Kilo Code, JanitorAI, SillyTavern, custom applications, and more
|
| 18 |
+
- **One Endpoint, Many Providers** — Configure Gemini, OpenAI, Anthropic, and [any LiteLLM-supported provider](https://docs.litellm.ai/docs/providers) once. Access them all through a single API key
|
| 19 |
+
- **Built-in Resilience** — Automatic key rotation, failover on errors, rate limit handling, and intelligent cooldowns
|
| 20 |
+
- **Exclusive Provider Support** — Includes custom providers not available elsewhere: **Antigravity** (Gemini 3 + Claude Sonnet/Opus 4.5), **Gemini CLI**, **Qwen Code**, and **iFlow**
|
| 21 |
|
| 22 |
---
|
| 23 |
|
| 24 |
+
## Quick Start
|
| 25 |
|
| 26 |
+
### Windows
|
| 27 |
|
| 28 |
+
1. **Download** the latest release from [GitHub Releases](https://github.com/Mirrowel/LLM-API-Key-Proxy/releases/latest)
|
| 29 |
+
2. **Unzip** the downloaded file
|
| 30 |
+
3. **Run** `proxy_app.exe` — the interactive TUI launcher opens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
+
<!-- TODO: Add TUI main menu screenshot here -->
|
| 33 |
|
| 34 |
### macOS / Linux
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
```bash
|
| 37 |
+
# Download and extract the release for your platform
|
| 38 |
+
chmod +x proxy_app
|
| 39 |
+
./proxy_app
|
|
|
|
| 40 |
```
|
| 41 |
|
| 42 |
+
### From Source
|
| 43 |
+
|
| 44 |
```bash
|
| 45 |
+
git clone https://github.com/Mirrowel/LLM-API-Key-Proxy.git
|
| 46 |
+
cd LLM-API-Key-Proxy
|
| 47 |
+
python3 -m venv venv
|
| 48 |
+
source venv/bin/activate # Windows: venv\Scripts\activate
|
| 49 |
+
pip install -r requirements.txt
|
| 50 |
python src/proxy_app/main.py
|
| 51 |
```
|
| 52 |
|
| 53 |
+
> **Tip:** Running with command-line arguments (e.g., `--host 0.0.0.0 --port 8000`) bypasses the TUI and starts the proxy directly.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
---
|
| 56 |
|
| 57 |
+
## Connecting to the Proxy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
+
Once the proxy is running, configure your application with these settings:
|
|
|
|
| 60 |
|
| 61 |
+
| Setting | Value |
|
| 62 |
+
|---------|-------|
|
| 63 |
+
| **Base URL / API Endpoint** | `http://127.0.0.1:8000/v1` |
|
| 64 |
+
| **API Key** | Your `PROXY_API_KEY` |
|
| 65 |
+
|
| 66 |
+
### Model Format: `provider/model_name`
|
| 67 |
+
|
| 68 |
+
**Important:** Models must be specified in the format `provider/model_name`. The `provider/` prefix tells the proxy which backend to route the request to.
|
| 69 |
+
|
| 70 |
+
```
|
| 71 |
+
gemini/gemini-2.5-flash ← Gemini API
|
| 72 |
+
openai/gpt-4o ← OpenAI API
|
| 73 |
+
anthropic/claude-3-5-sonnet ← Anthropic API
|
| 74 |
+
openrouter/anthropic/claude-3-opus ← OpenRouter
|
| 75 |
+
gemini_cli/gemini-2.5-pro ← Gemini CLI (OAuth)
|
| 76 |
+
antigravity/gemini-3-pro-preview ← Antigravity (Gemini 3, Claude Opus 4.5)
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
### Usage Examples
|
| 80 |
+
|
| 81 |
+
<details>
|
| 82 |
+
<summary><b>Python (OpenAI Library)</b></summary>
|
| 83 |
+
|
| 84 |
+
```python
|
| 85 |
+
from openai import OpenAI
|
| 86 |
+
|
| 87 |
+
client = OpenAI(
|
| 88 |
+
base_url="http://127.0.0.1:8000/v1",
|
| 89 |
+
api_key="your-proxy-api-key"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
response = client.chat.completions.create(
|
| 93 |
+
model="gemini/gemini-2.5-flash", # provider/model format
|
| 94 |
+
messages=[{"role": "user", "content": "Hello!"}]
|
| 95 |
+
)
|
| 96 |
+
print(response.choices[0].message.content)
|
| 97 |
```
|
| 98 |
|
| 99 |
+
</details>
|
| 100 |
+
|
| 101 |
+
<details>
|
| 102 |
+
<summary><b>curl</b></summary>
|
| 103 |
+
|
| 104 |
```bash
|
| 105 |
+
curl -X POST http://127.0.0.1:8000/v1/chat/completions \
|
| 106 |
+
-H "Content-Type: application/json" \
|
| 107 |
+
-H "Authorization: Bearer your-proxy-api-key" \
|
| 108 |
+
-d '{
|
| 109 |
+
"model": "gemini/gemini-2.5-flash",
|
| 110 |
+
"messages": [{"role": "user", "content": "What is the capital of France?"}]
|
| 111 |
+
}'
|
| 112 |
```
|
| 113 |
|
| 114 |
+
</details>
|
| 115 |
|
| 116 |
+
<details>
|
| 117 |
+
<summary><b>JanitorAI / SillyTavern / Other Chat UIs</b></summary>
|
|
|
|
| 118 |
|
| 119 |
+
1. Go to **API Settings**
|
| 120 |
+
2. Select **"Proxy"** or **"Custom OpenAI"** mode
|
| 121 |
+
3. Configure:
|
| 122 |
+
- **API URL:** `http://127.0.0.1:8000/v1`
|
| 123 |
+
- **API Key:** Your `PROXY_API_KEY`
|
| 124 |
+
- **Model:** `provider/model_name` (e.g., `gemini/gemini-2.5-flash`)
|
| 125 |
+
4. Save and start chatting
|
| 126 |
|
| 127 |
+
</details>
|
| 128 |
|
| 129 |
+
<details>
|
| 130 |
+
<summary><b>Continue / Cursor / IDE Extensions</b></summary>
|
| 131 |
|
| 132 |
+
In your configuration file (e.g., `config.json`):
|
| 133 |
|
| 134 |
+
```json
|
| 135 |
+
{
|
| 136 |
+
"models": [{
|
| 137 |
+
"title": "Gemini via Proxy",
|
| 138 |
+
"provider": "openai",
|
| 139 |
+
"model": "gemini/gemini-2.5-flash",
|
| 140 |
+
"apiBase": "http://127.0.0.1:8000/v1",
|
| 141 |
+
"apiKey": "your-proxy-api-key"
|
| 142 |
+
}]
|
| 143 |
+
}
|
| 144 |
+
```
|
| 145 |
|
| 146 |
+
</details>
|
| 147 |
|
| 148 |
+
### API Endpoints
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
+
| Endpoint | Description |
|
| 151 |
+
|----------|-------------|
|
| 152 |
+
| `GET /` | Status check — confirms proxy is running |
|
| 153 |
+
| `POST /v1/chat/completions` | Chat completions (main endpoint) |
|
| 154 |
+
| `POST /v1/embeddings` | Text embeddings |
|
| 155 |
+
| `GET /v1/models` | List all available models with pricing & capabilities |
|
| 156 |
+
| `GET /v1/models/{model_id}` | Get details for a specific model |
|
| 157 |
+
| `GET /v1/providers` | List configured providers |
|
| 158 |
+
| `POST /v1/token-count` | Calculate token count for a payload |
|
| 159 |
+
| `POST /v1/cost-estimate` | Estimate cost based on token counts |
|
| 160 |
|
| 161 |
+
> **Tip:** The `/v1/models` endpoint is useful for discovering available models in your client. Many apps can fetch this list automatically. Add `?enriched=false` for a minimal response without pricing data.
|
|
|
|
|
|
|
| 162 |
|
| 163 |
+
---
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
+
## Managing Credentials
|
|
|
|
|
|
|
| 166 |
|
| 167 |
+
The proxy includes an interactive tool for managing all your API keys and OAuth credentials.
|
| 168 |
+
|
| 169 |
+
### Using the TUI
|
| 170 |
+
|
| 171 |
+
<!-- TODO: Add TUI credentials menu screenshot here -->
|
| 172 |
|
| 173 |
+
1. Run the proxy without arguments to open the TUI
|
| 174 |
+
2. Select **"🔑 Manage Credentials"**
|
| 175 |
+
3. Choose to add API keys or OAuth credentials
|
| 176 |
|
| 177 |
+
### Using the Command Line
|
| 178 |
|
|
|
|
| 179 |
```bash
|
| 180 |
+
python -m rotator_library.credential_tool
|
| 181 |
```
|
| 182 |
|
| 183 |
+
### Credential Types
|
| 184 |
+
|
| 185 |
+
| Type | Providers | How to Add |
|
| 186 |
+
|------|-----------|------------|
|
| 187 |
+
| **API Keys** | Gemini, OpenAI, Anthropic, OpenRouter, Groq, Mistral, NVIDIA, Cohere, Chutes | Enter key in TUI or add to `.env` |
|
| 188 |
+
| **OAuth** | Gemini CLI, Antigravity, Qwen Code, iFlow | Interactive browser login via credential tool |
|
| 189 |
+
|
| 190 |
+
### The `.env` File
|
| 191 |
+
|
| 192 |
+
Credentials are stored in a `.env` file. You can edit it directly or use the TUI:
|
| 193 |
+
|
| 194 |
+
```env
|
| 195 |
+
# Required: Authentication key for YOUR proxy
|
| 196 |
+
PROXY_API_KEY="your-secret-proxy-key"
|
| 197 |
+
|
| 198 |
+
# Provider API Keys (add multiple with _1, _2, etc.)
|
| 199 |
+
GEMINI_API_KEY_1="your-gemini-key"
|
| 200 |
+
GEMINI_API_KEY_2="another-gemini-key"
|
| 201 |
+
OPENAI_API_KEY_1="your-openai-key"
|
| 202 |
+
ANTHROPIC_API_KEY_1="your-anthropic-key"
|
| 203 |
```
|
| 204 |
|
| 205 |
+
> Copy `.env.example` to `.env` as a starting point.
|
| 206 |
|
| 207 |
+
---
|
| 208 |
|
| 209 |
+
## The Resilience Library
|
| 210 |
|
| 211 |
+
The proxy is powered by a standalone Python library that you can use directly in your own applications.
|
|
|
|
| 212 |
|
| 213 |
+
### Key Features
|
| 214 |
|
| 215 |
+
- **Async-native** with `asyncio` and `httpx`
|
| 216 |
+
- **Intelligent key selection** with tiered, model-aware locking
|
| 217 |
+
- **Deadline-driven requests** with configurable global timeout
|
| 218 |
+
- **Automatic failover** between keys on errors
|
| 219 |
+
- **OAuth support** for Gemini CLI, Antigravity, Qwen, iFlow
|
| 220 |
+
- **Stateless deployment ready** — load credentials from environment variables
|
| 221 |
|
| 222 |
+
### Basic Usage
|
| 223 |
|
| 224 |
+
```python
|
| 225 |
+
from rotator_library import RotatingClient
|
| 226 |
|
| 227 |
+
client = RotatingClient(
|
| 228 |
+
api_keys={"gemini": ["key1", "key2"], "openai": ["key3"]},
|
| 229 |
+
global_timeout=30,
|
| 230 |
+
max_retries=2
|
| 231 |
+
)
|
| 232 |
|
| 233 |
+
async with client:
|
| 234 |
+
response = await client.acompletion(
|
| 235 |
+
model="gemini/gemini-2.5-flash",
|
| 236 |
+
messages=[{"role": "user", "content": "Hello!"}]
|
| 237 |
+
)
|
| 238 |
```
|
| 239 |
|
| 240 |
+
### Library Documentation
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
+
See the [Library README](src/rotator_library/README.md) for complete documentation including:
|
| 243 |
+
- All initialization parameters
|
| 244 |
+
- Streaming support
|
| 245 |
+
- Error handling and cooldown strategies
|
| 246 |
+
- Provider plugin system
|
| 247 |
+
- Credential prioritization
|
| 248 |
|
| 249 |
+
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
+
## Interactive TUI
|
|
|
|
|
|
|
|
|
|
| 252 |
|
| 253 |
+
The proxy includes a powerful text-based UI for configuration and management.
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
+
<!-- TODO: Add TUI main menu screenshot here -->
|
| 256 |
|
| 257 |
+
### TUI Features
|
| 258 |
|
| 259 |
+
- **🚀 Run Proxy** — Start the server with saved settings
|
| 260 |
+
- **⚙️ Configure Settings** — Host, port, API key, request logging
|
| 261 |
+
- **🔑 Manage Credentials** — Add/edit API keys and OAuth credentials
|
| 262 |
+
- **📊 View Status** — See configured providers and credential counts
|
| 263 |
+
- **🔧 Advanced Settings** — Custom providers, model definitions, concurrency
|
| 264 |
|
| 265 |
+
### Configuration Files
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
|
| 267 |
+
| File | Contents |
|
| 268 |
+
|------|----------|
|
| 269 |
+
| `.env` | All credentials and advanced settings |
|
| 270 |
+
| `launcher_config.json` | TUI-specific settings (host, port, logging) |
|
| 271 |
|
| 272 |
+
---
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
+
## Features
|
| 275 |
|
| 276 |
+
### Core Capabilities
|
| 277 |
|
| 278 |
+
- **Universal OpenAI-compatible endpoint** for all providers
|
| 279 |
+
- **Multi-provider support** via [LiteLLM](https://docs.litellm.ai/docs/providers) fallback
|
| 280 |
+
- **Automatic key rotation** and load balancing
|
| 281 |
+
- **Interactive TUI** for easy configuration
|
| 282 |
+
- **Detailed request logging** for debugging
|
| 283 |
|
| 284 |
+
<details>
|
| 285 |
+
<summary><b>🛡️ Resilience & High Availability</b></summary>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
+
- **Global timeout** with deadline-driven retries
|
| 288 |
+
- **Escalating cooldowns** per model (10s → 30s → 60s → 120s)
|
| 289 |
+
- **Key-level lockouts** for consistently failing keys
|
| 290 |
+
- **Stream error detection** and graceful recovery
|
| 291 |
+
- **Batch embedding aggregation** for improved throughput
|
| 292 |
+
- **Automatic daily resets** for cooldowns and usage stats
|
| 293 |
|
| 294 |
+
</details>
|
| 295 |
|
| 296 |
+
<details>
|
| 297 |
+
<summary><b>🔑 Credential Management</b></summary>
|
| 298 |
|
| 299 |
+
- **Auto-discovery** of API keys from environment variables
|
| 300 |
+
- **OAuth discovery** from standard paths (`~/.gemini/`, `~/.qwen/`, `~/.iflow/`)
|
| 301 |
+
- **Duplicate detection** warns when same account added multiple times
|
| 302 |
+
- **Credential prioritization** — paid tier used before free tier
|
| 303 |
+
- **Stateless deployment** — export OAuth to environment variables
|
| 304 |
+
- **Local-first storage** — credentials isolated in `oauth_creds/` directory
|
| 305 |
|
| 306 |
+
</details>
|
| 307 |
|
| 308 |
+
<details>
|
| 309 |
+
<summary><b>⚙️ Advanced Configuration</b></summary>
|
| 310 |
|
| 311 |
+
- **Model whitelists/blacklists** with wildcard support
|
| 312 |
+
- **Per-provider concurrency limits** (`MAX_CONCURRENT_REQUESTS_PER_KEY_<PROVIDER>`)
|
| 313 |
+
- **Rotation modes** — balanced (distribute load) or sequential (use until exhausted)
|
| 314 |
+
- **Priority multipliers** — higher concurrency for paid credentials
|
| 315 |
+
- **Model quota groups** — shared cooldowns for related models
|
| 316 |
+
- **Temperature override** — prevent tool hallucination issues
|
| 317 |
+
- **Weighted random rotation** — unpredictable selection patterns
|
| 318 |
|
| 319 |
+
</details>
|
|
|
|
|
|
|
|
|
|
| 320 |
|
| 321 |
+
<details>
|
| 322 |
+
<summary><b>🔌 Provider-Specific Features</b></summary>
|
|
|
|
| 323 |
|
| 324 |
+
**Gemini CLI:**
|
| 325 |
+
- Zero-config Google Cloud project discovery
|
| 326 |
+
- Internal API access with higher rate limits
|
| 327 |
+
- Automatic fallback to preview models on rate limit
|
| 328 |
+
- Paid vs free tier detection
|
| 329 |
|
| 330 |
+
**Antigravity:**
|
| 331 |
+
- Gemini 3 Pro with `thinkingLevel` support
|
| 332 |
+
- Claude Opus 4.5 (thinking mode)
|
| 333 |
+
- Claude Sonnet 4.5 (thinking and non-thinking)
|
| 334 |
+
- Thought signature caching for multi-turn conversations
|
| 335 |
+
- Tool hallucination prevention
|
| 336 |
|
| 337 |
+
**Qwen Code:**
|
| 338 |
+
- Dual auth (API key + OAuth Device Flow)
|
| 339 |
+
- `<think>` tag parsing as `reasoning_content`
|
| 340 |
+
- Tool schema cleaning
|
| 341 |
|
| 342 |
+
**iFlow:**
|
| 343 |
+
- Dual auth (API key + OAuth Authorization Code)
|
| 344 |
+
- Hybrid auth with separate API key fetch
|
| 345 |
+
- Tool schema cleaning
|
| 346 |
|
| 347 |
+
**NVIDIA NIM:**
|
| 348 |
+
- Dynamic model discovery
|
| 349 |
+
- DeepSeek thinking support
|
| 350 |
+
|
| 351 |
+
</details>
|
| 352 |
+
|
| 353 |
+
<details>
|
| 354 |
+
<summary><b>📝 Logging & Debugging</b></summary>
|
| 355 |
+
|
| 356 |
+
- **Per-request file logging** with `--enable-request-logging`
|
| 357 |
+
- **Unique request directories** with full transaction details
|
| 358 |
+
- **Streaming chunk capture** for debugging
|
| 359 |
+
- **Performance metadata** (duration, tokens, model used)
|
| 360 |
+
- **Provider-specific logs** for Qwen, iFlow, Antigravity
|
| 361 |
+
|
| 362 |
+
</details>
|
| 363 |
|
| 364 |
---
|
| 365 |
|
| 366 |
+
## Advanced Configuration
|
| 367 |
|
| 368 |
+
<details>
|
| 369 |
+
<summary><b>Environment Variables Reference</b></summary>
|
| 370 |
|
| 371 |
+
### Proxy Settings
|
| 372 |
|
| 373 |
+
| Variable | Description | Default |
|
| 374 |
+
|----------|-------------|---------|
|
| 375 |
+
| `PROXY_API_KEY` | Authentication key for your proxy | Required |
|
| 376 |
+
| `OAUTH_REFRESH_INTERVAL` | Token refresh check interval (seconds) | `600` |
|
| 377 |
+
| `SKIP_OAUTH_INIT_CHECK` | Skip interactive OAuth setup on startup | `false` |
|
| 378 |
|
| 379 |
+
### Per-Provider Settings
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
|
| 381 |
+
| Pattern | Description | Example |
|
| 382 |
+
|---------|-------------|---------|
|
| 383 |
+
| `<PROVIDER>_API_KEY_<N>` | API key for provider | `GEMINI_API_KEY_1` |
|
| 384 |
+
| `MAX_CONCURRENT_REQUESTS_PER_KEY_<PROVIDER>` | Concurrent request limit | `MAX_CONCURRENT_REQUESTS_PER_KEY_OPENAI=3` |
|
| 385 |
+
| `ROTATION_MODE_<PROVIDER>` | `balanced` or `sequential` | `ROTATION_MODE_GEMINI=sequential` |
|
| 386 |
+
| `IGNORE_MODELS_<PROVIDER>` | Blacklist (comma-separated, supports `*`) | `IGNORE_MODELS_OPENAI=*-preview*` |
|
| 387 |
+
| `WHITELIST_MODELS_<PROVIDER>` | Whitelist (overrides blacklist) | `WHITELIST_MODELS_GEMINI=gemini-2.5-pro` |
|
| 388 |
|
| 389 |
+
### Advanced Features
|
| 390 |
+
|
| 391 |
+
| Variable | Description |
|
| 392 |
+
|----------|-------------|
|
| 393 |
+
| `ROTATION_TOLERANCE` | `0.0`=deterministic, `3.0`=weighted random (default) |
|
| 394 |
+
| `CONCURRENCY_MULTIPLIER_<PROVIDER>_PRIORITY_<N>` | Concurrency multiplier per priority tier |
|
| 395 |
+
| `QUOTA_GROUPS_<PROVIDER>_<GROUP>` | Models sharing quota limits |
|
| 396 |
+
| `OVERRIDE_TEMPERATURE_ZERO` | `remove` or `set` to prevent tool hallucination |
|
| 397 |
+
|
| 398 |
+
</details>
|
| 399 |
+
|
| 400 |
+
<details>
|
| 401 |
+
<summary><b>Model Filtering (Whitelists & Blacklists)</b></summary>
|
| 402 |
+
|
| 403 |
+
Control which models are exposed through your proxy.
|
| 404 |
+
|
| 405 |
+
### Blacklist Only
|
| 406 |
+
```env
|
| 407 |
+
# Hide all preview models
|
| 408 |
+
IGNORE_MODELS_OPENAI="*-preview*"
|
| 409 |
```
|
| 410 |
|
| 411 |
+
### Pure Whitelist Mode
|
| 412 |
+
```env
|
| 413 |
+
# Block all, then allow specific models
|
| 414 |
+
IGNORE_MODELS_GEMINI="*"
|
| 415 |
+
WHITELIST_MODELS_GEMINI="gemini-2.5-pro,gemini-2.5-flash"
|
| 416 |
+
```
|
| 417 |
|
| 418 |
+
### Exemption Mode
|
| 419 |
+
```env
|
| 420 |
+
# Block preview models, but allow one specific preview
|
| 421 |
+
IGNORE_MODELS_OPENAI="*-preview*"
|
| 422 |
+
WHITELIST_MODELS_OPENAI="gpt-4o-2024-08-06-preview"
|
| 423 |
+
```
|
| 424 |
|
| 425 |
+
**Logic order:** Whitelist check → Blacklist check → Default allow
|
| 426 |
+
|
| 427 |
+
</details>
|
| 428 |
+
|
| 429 |
+
<details>
|
| 430 |
+
<summary><b>Concurrency & Rotation Settings</b></summary>
|
| 431 |
+
|
| 432 |
+
### Concurrency Limits
|
| 433 |
+
|
| 434 |
+
```env
|
| 435 |
+
# Allow 3 concurrent requests per OpenAI key
|
| 436 |
+
MAX_CONCURRENT_REQUESTS_PER_KEY_OPENAI=3
|
| 437 |
+
|
| 438 |
+
# Default is 1 (no concurrency)
|
| 439 |
+
MAX_CONCURRENT_REQUESTS_PER_KEY_GEMINI=1
|
| 440 |
```
|
| 441 |
|
| 442 |
+
### Rotation Modes
|
| 443 |
|
| 444 |
+
```env
|
| 445 |
+
# balanced (default): Distribute load evenly - best for per-minute rate limits
|
| 446 |
+
ROTATION_MODE_OPENAI=balanced
|
|
|
|
|
|
|
| 447 |
|
| 448 |
+
# sequential: Use until exhausted - best for daily/weekly quotas
|
| 449 |
+
ROTATION_MODE_GEMINI=sequential
|
| 450 |
+
```
|
| 451 |
|
| 452 |
+
### Priority Multipliers
|
| 453 |
|
| 454 |
+
Paid credentials can handle more concurrent requests:
|
| 455 |
|
| 456 |
+
```env
|
| 457 |
+
# Priority 1 (paid ultra): 10x concurrency
|
| 458 |
+
CONCURRENCY_MULTIPLIER_ANTIGRAVITY_PRIORITY_1=10
|
|
|
|
| 459 |
|
| 460 |
+
# Priority 2 (standard paid): 3x
|
| 461 |
+
CONCURRENCY_MULTIPLIER_ANTIGRAVITY_PRIORITY_2=3
|
| 462 |
+
```
|
| 463 |
|
| 464 |
+
### Model Quota Groups
|
| 465 |
|
| 466 |
+
Models sharing quota limits:
|
|
|
|
|
|
|
|
|
|
| 467 |
|
| 468 |
+
```env
|
| 469 |
+
# Claude models share quota - when one hits limit, both cool down
|
| 470 |
+
QUOTA_GROUPS_ANTIGRAVITY_CLAUDE="claude-sonnet-4-5,claude-opus-4-5"
|
| 471 |
+
```
|
| 472 |
|
| 473 |
+
</details>
|
| 474 |
|
| 475 |
+
<details>
|
| 476 |
+
<summary><b>Timeout Configuration</b></summary>
|
|
|
|
| 477 |
|
| 478 |
+
Fine-grained control over HTTP timeouts:
|
| 479 |
|
| 480 |
+
```env
|
| 481 |
+
TIMEOUT_CONNECT=30 # Connection establishment
|
| 482 |
+
TIMEOUT_WRITE=30 # Request body send
|
| 483 |
+
TIMEOUT_POOL=60 # Connection pool acquisition
|
| 484 |
+
TIMEOUT_READ_STREAMING=180 # Between streaming chunks (3 min)
|
| 485 |
+
TIMEOUT_READ_NON_STREAMING=600 # Full response wait (10 min)
|
| 486 |
+
```
|
| 487 |
|
| 488 |
+
**Recommendations:**
|
| 489 |
+
- Long thinking tasks: Increase `TIMEOUT_READ_STREAMING` to 300-360s
|
| 490 |
+
- Unstable network: Increase `TIMEOUT_CONNECT` to 60s
|
| 491 |
+
- Large outputs: Increase `TIMEOUT_READ_NON_STREAMING` to 900s+
|
|
|
|
|
|
|
|
|
|
| 492 |
|
| 493 |
+
</details>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
|
| 495 |
+
---
|
| 496 |
+
|
| 497 |
+
## OAuth Providers
|
| 498 |
+
|
| 499 |
+
<details>
|
| 500 |
+
<summary><b>Gemini CLI</b></summary>
|
| 501 |
|
| 502 |
+
Uses Google OAuth to access internal Gemini endpoints with higher rate limits.
|
| 503 |
|
| 504 |
+
**Setup:**
|
| 505 |
+
1. Run `python -m rotator_library.credential_tool`
|
| 506 |
+
2. Select "Add OAuth Credential" → "Gemini CLI"
|
| 507 |
+
3. Complete browser authentication
|
| 508 |
+
4. Credentials saved to `oauth_creds/gemini_cli_oauth_1.json`
|
| 509 |
|
| 510 |
+
**Features:**
|
| 511 |
+
- Zero-config project discovery
|
| 512 |
+
- Automatic free-tier project onboarding
|
| 513 |
+
- Paid vs free tier detection
|
| 514 |
+
- Smart fallback on rate limits
|
| 515 |
|
| 516 |
+
**Environment Variables (for stateless deployment):**
|
| 517 |
+
```env
|
| 518 |
+
GEMINI_CLI_ACCESS_TOKEN="ya29.your-access-token"
|
| 519 |
+
GEMINI_CLI_REFRESH_TOKEN="1//your-refresh-token"
|
| 520 |
+
GEMINI_CLI_EXPIRY_DATE="1234567890000"
|
| 521 |
+
GEMINI_CLI_EMAIL="your-email@gmail.com"
|
| 522 |
+
GEMINI_CLI_PROJECT_ID="your-gcp-project-id" # Optional
|
| 523 |
+
```
|
| 524 |
|
| 525 |
+
</details>
|
|
|
|
|
|
|
| 526 |
|
| 527 |
+
<details>
|
| 528 |
+
<summary><b>Antigravity (Gemini 3 + Claude Opus 4.5)</b></summary>
|
| 529 |
|
| 530 |
+
Access Google's internal Antigravity API for cutting-edge models.
|
|
|
|
| 531 |
|
| 532 |
**Supported Models:**
|
| 533 |
+
- **Gemini 3 Pro** — with `thinkingLevel` support (low/high)
|
| 534 |
+
- **Claude Opus 4.5** — Anthropic's most powerful model (thinking mode only)
|
| 535 |
+
- **Claude Sonnet 4.5** — supports both thinking and non-thinking modes
|
| 536 |
+
- Gemini 2.5 Pro/Flash
|
| 537 |
|
| 538 |
+
**Setup:**
|
| 539 |
+
1. Run `python -m rotator_library.credential_tool`
|
| 540 |
+
2. Select "Add OAuth Credential" → "Antigravity"
|
| 541 |
+
3. Complete browser authentication
|
|
|
|
|
|
|
| 542 |
|
| 543 |
+
**Advanced Features:**
|
| 544 |
+
- Thought signature caching for multi-turn conversations
|
| 545 |
+
- Tool hallucination prevention via parameter signature injection
|
| 546 |
+
- Automatic thinking block sanitization for Claude
|
| 547 |
+
- Credential prioritization (paid resets every 5 hours, free weekly)
|
| 548 |
|
| 549 |
**Environment Variables:**
|
| 550 |
```env
|
| 551 |
+
ANTIGRAVITY_ACCESS_TOKEN="ya29.your-access-token"
|
| 552 |
+
ANTIGRAVITY_REFRESH_TOKEN="1//your-refresh-token"
|
| 553 |
+
ANTIGRAVITY_EXPIRY_DATE="1234567890000"
|
| 554 |
+
ANTIGRAVITY_EMAIL="your-email@gmail.com"
|
|
|
|
| 555 |
|
| 556 |
# Feature toggles
|
| 557 |
+
ANTIGRAVITY_ENABLE_SIGNATURE_CACHE=true
|
| 558 |
+
ANTIGRAVITY_GEMINI3_TOOL_FIX=true
|
| 559 |
```
|
| 560 |
|
| 561 |
+
> **Note:** Gemini 3 models require a paid-tier Google Cloud project.
|
| 562 |
|
| 563 |
+
</details>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 564 |
|
| 565 |
+
<details>
|
| 566 |
+
<summary><b>Qwen Code</b></summary>
|
| 567 |
|
| 568 |
+
Uses OAuth Device Flow for Qwen/Dashscope APIs.
|
|
|
|
|
|
|
|
|
|
| 569 |
|
| 570 |
+
**Setup:**
|
| 571 |
+
1. Run the credential tool
|
| 572 |
+
2. Select "Add OAuth Credential" → "Qwen Code"
|
| 573 |
+
3. Enter the code displayed in your browser
|
| 574 |
+
4. Or add API key directly: `QWEN_CODE_API_KEY_1="your-key"`
|
| 575 |
|
| 576 |
+
**Features:**
|
| 577 |
+
- Dual auth (API key or OAuth)
|
| 578 |
+
- `<think>` tag parsing as `reasoning_content`
|
| 579 |
+
- Automatic tool schema cleaning
|
| 580 |
+
- Custom models via `QWEN_CODE_MODELS` env var
|
| 581 |
|
| 582 |
+
</details>
|
|
|
|
|
|
|
|
|
|
| 583 |
|
| 584 |
+
<details>
|
| 585 |
+
<summary><b>iFlow</b></summary>
|
| 586 |
|
| 587 |
+
Uses OAuth Authorization Code flow with local callback server.
|
| 588 |
|
| 589 |
+
**Setup:**
|
| 590 |
+
1. Run the credential tool
|
| 591 |
+
2. Select "Add OAuth Credential" → "iFlow"
|
| 592 |
+
3. Complete browser authentication (callback on port 11451)
|
| 593 |
+
4. Or add API key directly: `IFLOW_API_KEY_1="sk-your-key"`
|
| 594 |
|
| 595 |
+
**Features:**
|
| 596 |
+
- Dual auth (API key or OAuth)
|
| 597 |
+
- Hybrid auth (OAuth token fetches separate API key)
|
| 598 |
+
- Automatic tool schema cleaning
|
| 599 |
+
- Custom models via `IFLOW_MODELS` env var
|
| 600 |
|
| 601 |
+
</details>
|
| 602 |
|
| 603 |
+
<details>
|
| 604 |
+
<summary><b>Stateless Deployment (Export to Environment Variables)</b></summary>
|
|
|
|
| 605 |
|
| 606 |
+
For platforms without file persistence (Railway, Render, Vercel):
|
| 607 |
|
| 608 |
+
1. **Set up credentials locally:**
|
| 609 |
+
```bash
|
| 610 |
+
python -m rotator_library.credential_tool
|
| 611 |
+
# Complete OAuth flows
|
| 612 |
+
```
|
| 613 |
|
| 614 |
+
2. **Export to environment variables:**
|
| 615 |
+
```bash
|
| 616 |
+
python -m rotator_library.credential_tool
|
| 617 |
+
# Select "Export [Provider] to .env"
|
| 618 |
+
```
|
| 619 |
|
| 620 |
+
3. **Copy generated variables to your platform:**
|
| 621 |
+
The tool creates files like `gemini_cli_credential_1.env` containing all necessary variables.
|
| 622 |
|
| 623 |
+
4. **Set `SKIP_OAUTH_INIT_CHECK=true`** to skip interactive validation on startup.
|
| 624 |
|
| 625 |
+
</details>
|
| 626 |
|
| 627 |
+
<details>
|
| 628 |
+
<summary><b>OAuth Callback Port Configuration</b></summary>
|
| 629 |
|
| 630 |
+
Customize OAuth callback ports if defaults conflict:
|
|
|
|
|
|
|
| 631 |
|
| 632 |
+
| Provider | Default Port | Environment Variable |
|
| 633 |
+
|----------|-------------|---------------------|
|
| 634 |
+
| Gemini CLI | 8085 | `GEMINI_CLI_OAUTH_PORT` |
|
| 635 |
+
| Antigravity | 51121 | `ANTIGRAVITY_OAUTH_PORT` |
|
| 636 |
+
| iFlow | 11451 | `IFLOW_OAUTH_PORT` |
|
| 637 |
|
| 638 |
+
</details>
|
| 639 |
|
| 640 |
+
---
|
| 641 |
|
| 642 |
+
## Deployment
|
| 643 |
+
|
| 644 |
+
<details>
|
| 645 |
+
<summary><b>Command-Line Arguments</b></summary>
|
| 646 |
|
| 647 |
+
```bash
|
| 648 |
+
python src/proxy_app/main.py [OPTIONS]
|
| 649 |
+
|
| 650 |
+
Options:
|
| 651 |
+
--host TEXT Host to bind (default: 0.0.0.0)
|
| 652 |
+
--port INTEGER Port to run on (default: 8000)
|
| 653 |
+
--enable-request-logging Enable detailed per-request logging
|
| 654 |
+
--add-credential Launch interactive credential setup tool
|
| 655 |
```
|
| 656 |
|
| 657 |
+
**Examples:**
|
| 658 |
+
```bash
|
| 659 |
+
# Run on custom port
|
| 660 |
+
python src/proxy_app/main.py --host 127.0.0.1 --port 9000
|
| 661 |
|
| 662 |
+
# Run with logging
|
| 663 |
+
python src/proxy_app/main.py --enable-request-logging
|
| 664 |
|
| 665 |
+
# Add credentials without starting proxy
|
| 666 |
+
python src/proxy_app/main.py --add-credential
|
| 667 |
+
```
|
|
|
|
| 668 |
|
| 669 |
+
</details>
|
| 670 |
+
|
| 671 |
+
<details>
|
| 672 |
+
<summary><b>Render / Railway / Vercel</b></summary>
|
| 673 |
+
|
| 674 |
+
See the [Deployment Guide](Deployment%20guide.md) for complete instructions.
|
| 675 |
+
|
| 676 |
+
**Quick Setup:**
|
| 677 |
+
1. Fork the repository
|
| 678 |
+
2. Create a `.env` file with your credentials
|
| 679 |
+
3. Create a new Web Service pointing to your repo
|
| 680 |
+
4. Set build command: `pip install -r requirements.txt`
|
| 681 |
+
5. Set start command: `uvicorn src.proxy_app.main:app --host 0.0.0.0 --port $PORT`
|
| 682 |
+
6. Upload `.env` as a secret file
|
| 683 |
+
|
| 684 |
+
**OAuth Credentials:**
|
| 685 |
+
Export OAuth credentials to environment variables using the credential tool, then add them to your platform's environment settings.
|
| 686 |
+
|
| 687 |
+
</details>
|
| 688 |
+
|
| 689 |
+
<details>
|
| 690 |
+
<summary><b>Custom VPS / Docker</b></summary>
|
| 691 |
+
|
| 692 |
+
**Option 1: Authenticate locally, deploy credentials**
|
| 693 |
+
1. Complete OAuth flows on your local machine
|
| 694 |
+
2. Export to environment variables
|
| 695 |
+
3. Deploy `.env` to your server
|
| 696 |
+
|
| 697 |
+
**Option 2: SSH Port Forwarding**
|
| 698 |
+
```bash
|
| 699 |
+
# Forward callback ports through SSH
|
| 700 |
+
ssh -L 51121:localhost:51121 -L 8085:localhost:8085 user@your-vps
|
| 701 |
+
|
| 702 |
+
# Then run credential tool on the VPS
|
| 703 |
+
```
|
| 704 |
+
|
| 705 |
+
**Systemd Service:**
|
| 706 |
+
```ini
|
| 707 |
+
[Unit]
|
| 708 |
+
Description=LLM API Key Proxy
|
| 709 |
+
After=network.target
|
| 710 |
+
|
| 711 |
+
[Service]
|
| 712 |
+
Type=simple
|
| 713 |
+
WorkingDirectory=/path/to/LLM-API-Key-Proxy
|
| 714 |
+
ExecStart=/path/to/python -m uvicorn src.proxy_app.main:app --host 0.0.0.0 --port 8000
|
| 715 |
+
Restart=always
|
| 716 |
+
|
| 717 |
+
[Install]
|
| 718 |
+
WantedBy=multi-user.target
|
| 719 |
```
|
| 720 |
+
|
| 721 |
+
See [VPS Deployment](Deployment%20guide.md#appendix-deploying-to-a-custom-vps) for complete guide.
|
| 722 |
+
|
| 723 |
+
</details>
|
| 724 |
+
|
| 725 |
+
---
|
| 726 |
+
|
| 727 |
+
## Troubleshooting
|
| 728 |
+
|
| 729 |
+
| Issue | Solution |
|
| 730 |
+
|-------|----------|
|
| 731 |
+
| `401 Unauthorized` | Verify `PROXY_API_KEY` matches your `Authorization: Bearer` header exactly |
|
| 732 |
+
| `500 Internal Server Error` | Check provider key validity; enable `--enable-request-logging` for details |
|
| 733 |
+
| All keys on cooldown | All keys failed recently; check `logs/detailed_logs/` for upstream errors |
|
| 734 |
+
| Model not found | Verify format is `provider/model_name` (e.g., `gemini/gemini-2.5-flash`) |
|
| 735 |
+
| OAuth callback failed | Ensure callback port (8085, 51121, 11451) isn't blocked by firewall |
|
| 736 |
+
| Streaming hangs | Increase `TIMEOUT_READ_STREAMING`; check provider status |
|
| 737 |
+
|
| 738 |
+
**Detailed Logs:**
|
| 739 |
+
|
| 740 |
+
When `--enable-request-logging` is enabled, check `logs/detailed_logs/` for:
|
| 741 |
+
- `request.json` — Exact request payload
|
| 742 |
+
- `final_response.json` — Complete response or error
|
| 743 |
+
- `streaming_chunks.jsonl` — All SSE chunks received
|
| 744 |
+
- `metadata.json` — Performance metrics
|
| 745 |
+
|
| 746 |
+
---
|
| 747 |
+
|
| 748 |
+
## Documentation
|
| 749 |
+
|
| 750 |
+
| Document | Description |
|
| 751 |
+
|----------|-------------|
|
| 752 |
+
| [Technical Documentation](DOCUMENTATION.md) | Architecture, internals, provider implementations |
|
| 753 |
+
| [Library README](src/rotator_library/README.md) | Using the resilience library directly |
|
| 754 |
+
| [Deployment Guide](Deployment%20guide.md) | Hosting on Render, Railway, VPS |
|
| 755 |
+
| [.env.example](.env.example) | Complete environment variable reference |
|
| 756 |
+
|
| 757 |
+
---
|
| 758 |
+
|
| 759 |
+
## License
|
| 760 |
+
|
| 761 |
+
This project is dual-licensed:
|
| 762 |
+
- **Proxy Application** (`src/proxy_app/`) — [MIT License](src/proxy_app/LICENSE)
|
| 763 |
+
- **Resilience Library** (`src/rotator_library/`) — [LGPL-3.0](src/rotator_library/COPYING.LESSER)
|
src/proxy_app/detailed_logger.py
CHANGED
|
@@ -3,16 +3,33 @@ import time
|
|
| 3 |
import uuid
|
| 4 |
from datetime import datetime
|
| 5 |
from pathlib import Path
|
| 6 |
-
from typing import Any, Dict, Optional
|
| 7 |
import logging
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
class DetailedLogger:
|
| 13 |
"""
|
| 14 |
Logs comprehensive details of each API transaction to a unique, timestamped directory.
|
|
|
|
|
|
|
|
|
|
| 15 |
"""
|
|
|
|
| 16 |
def __init__(self):
|
| 17 |
"""
|
| 18 |
Initializes the logger for a single request, creating a unique directory to store all related log files.
|
|
@@ -20,17 +37,26 @@ class DetailedLogger:
|
|
| 20 |
self.start_time = time.time()
|
| 21 |
self.request_id = str(uuid.uuid4())
|
| 22 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 23 |
-
self.log_dir =
|
| 24 |
-
self.log_dir.mkdir(parents=True, exist_ok=True)
|
| 25 |
self.streaming = False
|
|
|
|
| 26 |
|
| 27 |
def _write_json(self, filename: str, data: Dict[str, Any]):
|
| 28 |
"""Helper to write data to a JSON file in the log directory."""
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
def log_request(self, headers: Dict[str, Any], body: Dict[str, Any]):
|
| 36 |
"""Logs the initial request details."""
|
|
@@ -39,23 +65,22 @@ class DetailedLogger:
|
|
| 39 |
"request_id": self.request_id,
|
| 40 |
"timestamp_utc": datetime.utcnow().isoformat(),
|
| 41 |
"headers": dict(headers),
|
| 42 |
-
"body": body
|
| 43 |
}
|
| 44 |
self._write_json("request.json", request_data)
|
| 45 |
|
| 46 |
def log_stream_chunk(self, chunk: Dict[str, Any]):
|
| 47 |
"""Logs an individual chunk from a streaming response to a JSON Lines file."""
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def log_final_response(self, status_code: int, headers: Optional[Dict[str, Any]], body: Dict[str, Any]):
|
| 59 |
"""Logs the complete final response, either from a non-streaming call or after reassembling a stream."""
|
| 60 |
end_time = time.time()
|
| 61 |
duration_ms = (end_time - self.start_time) * 1000
|
|
@@ -66,7 +91,7 @@ class DetailedLogger:
|
|
| 66 |
"status_code": status_code,
|
| 67 |
"duration_ms": round(duration_ms),
|
| 68 |
"headers": dict(headers) if headers else None,
|
| 69 |
-
"body": body
|
| 70 |
}
|
| 71 |
self._write_json("final_response.json", response_data)
|
| 72 |
self._log_metadata(response_data)
|
|
@@ -75,10 +100,10 @@ class DetailedLogger:
|
|
| 75 |
"""Recursively searches for and extracts 'reasoning' fields from the response body."""
|
| 76 |
if not isinstance(response_body, dict):
|
| 77 |
return None
|
| 78 |
-
|
| 79 |
if "reasoning" in response_body:
|
| 80 |
return response_body["reasoning"]
|
| 81 |
-
|
| 82 |
if "choices" in response_body and response_body["choices"]:
|
| 83 |
message = response_body["choices"][0].get("message", {})
|
| 84 |
if "reasoning" in message:
|
|
@@ -93,8 +118,13 @@ class DetailedLogger:
|
|
| 93 |
usage = response_data.get("body", {}).get("usage") or {}
|
| 94 |
model = response_data.get("body", {}).get("model", "N/A")
|
| 95 |
finish_reason = "N/A"
|
| 96 |
-
if
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
metadata = {
|
| 100 |
"request_id": self.request_id,
|
|
@@ -110,12 +140,12 @@ class DetailedLogger:
|
|
| 110 |
},
|
| 111 |
"finish_reason": finish_reason,
|
| 112 |
"reasoning_found": False,
|
| 113 |
-
"reasoning_content": None
|
| 114 |
}
|
| 115 |
|
| 116 |
reasoning = self._extract_reasoning(response_data.get("body", {}))
|
| 117 |
if reasoning:
|
| 118 |
metadata["reasoning_found"] = True
|
| 119 |
metadata["reasoning_content"] = reasoning
|
| 120 |
-
|
| 121 |
-
self._write_json("metadata.json", metadata)
|
|
|
|
| 3 |
import uuid
|
| 4 |
from datetime import datetime
|
| 5 |
from pathlib import Path
|
| 6 |
+
from typing import Any, Dict, Optional
|
| 7 |
import logging
|
| 8 |
|
| 9 |
+
from rotator_library.utils.resilient_io import (
|
| 10 |
+
safe_write_json,
|
| 11 |
+
safe_log_write,
|
| 12 |
+
safe_mkdir,
|
| 13 |
+
)
|
| 14 |
+
from rotator_library.utils.paths import get_logs_dir
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _get_detailed_logs_dir() -> Path:
|
| 18 |
+
"""Get the detailed logs directory, creating it if needed."""
|
| 19 |
+
logs_dir = get_logs_dir()
|
| 20 |
+
detailed_dir = logs_dir / "detailed_logs"
|
| 21 |
+
detailed_dir.mkdir(parents=True, exist_ok=True)
|
| 22 |
+
return detailed_dir
|
| 23 |
+
|
| 24 |
|
| 25 |
class DetailedLogger:
|
| 26 |
"""
|
| 27 |
Logs comprehensive details of each API transaction to a unique, timestamped directory.
|
| 28 |
+
|
| 29 |
+
Uses fire-and-forget logging - if disk writes fail, logs are dropped (not buffered)
|
| 30 |
+
to prevent memory issues, especially with streaming responses.
|
| 31 |
"""
|
| 32 |
+
|
| 33 |
def __init__(self):
|
| 34 |
"""
|
| 35 |
Initializes the logger for a single request, creating a unique directory to store all related log files.
|
|
|
|
| 37 |
self.start_time = time.time()
|
| 38 |
self.request_id = str(uuid.uuid4())
|
| 39 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 40 |
+
self.log_dir = _get_detailed_logs_dir() / f"{timestamp}_{self.request_id}"
|
|
|
|
| 41 |
self.streaming = False
|
| 42 |
+
self._dir_available = safe_mkdir(self.log_dir, logging)
|
| 43 |
|
| 44 |
def _write_json(self, filename: str, data: Dict[str, Any]):
|
| 45 |
"""Helper to write data to a JSON file in the log directory."""
|
| 46 |
+
if not self._dir_available:
|
| 47 |
+
# Try to create directory again in case it was recreated
|
| 48 |
+
self._dir_available = safe_mkdir(self.log_dir, logging)
|
| 49 |
+
if not self._dir_available:
|
| 50 |
+
return
|
| 51 |
+
|
| 52 |
+
safe_write_json(
|
| 53 |
+
self.log_dir / filename,
|
| 54 |
+
data,
|
| 55 |
+
logging,
|
| 56 |
+
atomic=False,
|
| 57 |
+
indent=4,
|
| 58 |
+
ensure_ascii=False,
|
| 59 |
+
)
|
| 60 |
|
| 61 |
def log_request(self, headers: Dict[str, Any], body: Dict[str, Any]):
|
| 62 |
"""Logs the initial request details."""
|
|
|
|
| 65 |
"request_id": self.request_id,
|
| 66 |
"timestamp_utc": datetime.utcnow().isoformat(),
|
| 67 |
"headers": dict(headers),
|
| 68 |
+
"body": body,
|
| 69 |
}
|
| 70 |
self._write_json("request.json", request_data)
|
| 71 |
|
| 72 |
def log_stream_chunk(self, chunk: Dict[str, Any]):
|
| 73 |
"""Logs an individual chunk from a streaming response to a JSON Lines file."""
|
| 74 |
+
if not self._dir_available:
|
| 75 |
+
return
|
| 76 |
+
|
| 77 |
+
log_entry = {"timestamp_utc": datetime.utcnow().isoformat(), "chunk": chunk}
|
| 78 |
+
content = json.dumps(log_entry, ensure_ascii=False) + "\n"
|
| 79 |
+
safe_log_write(self.log_dir / "streaming_chunks.jsonl", content, logging)
|
| 80 |
+
|
| 81 |
+
def log_final_response(
|
| 82 |
+
self, status_code: int, headers: Optional[Dict[str, Any]], body: Dict[str, Any]
|
| 83 |
+
):
|
|
|
|
| 84 |
"""Logs the complete final response, either from a non-streaming call or after reassembling a stream."""
|
| 85 |
end_time = time.time()
|
| 86 |
duration_ms = (end_time - self.start_time) * 1000
|
|
|
|
| 91 |
"status_code": status_code,
|
| 92 |
"duration_ms": round(duration_ms),
|
| 93 |
"headers": dict(headers) if headers else None,
|
| 94 |
+
"body": body,
|
| 95 |
}
|
| 96 |
self._write_json("final_response.json", response_data)
|
| 97 |
self._log_metadata(response_data)
|
|
|
|
| 100 |
"""Recursively searches for and extracts 'reasoning' fields from the response body."""
|
| 101 |
if not isinstance(response_body, dict):
|
| 102 |
return None
|
| 103 |
+
|
| 104 |
if "reasoning" in response_body:
|
| 105 |
return response_body["reasoning"]
|
| 106 |
+
|
| 107 |
if "choices" in response_body and response_body["choices"]:
|
| 108 |
message = response_body["choices"][0].get("message", {})
|
| 109 |
if "reasoning" in message:
|
|
|
|
| 118 |
usage = response_data.get("body", {}).get("usage") or {}
|
| 119 |
model = response_data.get("body", {}).get("model", "N/A")
|
| 120 |
finish_reason = "N/A"
|
| 121 |
+
if (
|
| 122 |
+
"choices" in response_data.get("body", {})
|
| 123 |
+
and response_data["body"]["choices"]
|
| 124 |
+
):
|
| 125 |
+
finish_reason = response_data["body"]["choices"][0].get(
|
| 126 |
+
"finish_reason", "N/A"
|
| 127 |
+
)
|
| 128 |
|
| 129 |
metadata = {
|
| 130 |
"request_id": self.request_id,
|
|
|
|
| 140 |
},
|
| 141 |
"finish_reason": finish_reason,
|
| 142 |
"reasoning_found": False,
|
| 143 |
+
"reasoning_content": None,
|
| 144 |
}
|
| 145 |
|
| 146 |
reasoning = self._extract_reasoning(response_data.get("body", {}))
|
| 147 |
if reasoning:
|
| 148 |
metadata["reasoning_found"] = True
|
| 149 |
metadata["reasoning_content"] = reasoning
|
| 150 |
+
|
| 151 |
+
self._write_json("metadata.json", metadata)
|
src/proxy_app/launcher_tui.py
CHANGED
|
@@ -16,6 +16,20 @@ from dotenv import load_dotenv, set_key
|
|
| 16 |
console = Console()
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
def clear_screen():
|
| 20 |
"""
|
| 21 |
Cross-platform terminal clear that works robustly on both
|
|
@@ -74,7 +88,7 @@ class LauncherConfig:
|
|
| 74 |
@staticmethod
|
| 75 |
def update_proxy_api_key(new_key: str):
|
| 76 |
"""Update PROXY_API_KEY in .env only"""
|
| 77 |
-
env_file =
|
| 78 |
set_key(str(env_file), "PROXY_API_KEY", new_key)
|
| 79 |
load_dotenv(dotenv_path=env_file, override=True)
|
| 80 |
|
|
@@ -85,7 +99,7 @@ class SettingsDetector:
|
|
| 85 |
@staticmethod
|
| 86 |
def _load_local_env() -> dict:
|
| 87 |
"""Load environment variables from local .env file only"""
|
| 88 |
-
env_file =
|
| 89 |
env_dict = {}
|
| 90 |
if not env_file.exists():
|
| 91 |
return env_dict
|
|
@@ -107,7 +121,7 @@ class SettingsDetector:
|
|
| 107 |
|
| 108 |
@staticmethod
|
| 109 |
def get_all_settings() -> dict:
|
| 110 |
-
"""Returns comprehensive settings overview"""
|
| 111 |
return {
|
| 112 |
"credentials": SettingsDetector.detect_credentials(),
|
| 113 |
"custom_bases": SettingsDetector.detect_custom_api_bases(),
|
|
@@ -117,6 +131,17 @@ class SettingsDetector:
|
|
| 117 |
"provider_settings": SettingsDetector.detect_provider_settings(),
|
| 118 |
}
|
| 119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
@staticmethod
|
| 121 |
def detect_credentials() -> dict:
|
| 122 |
"""Detect API keys and OAuth credentials"""
|
|
@@ -260,7 +285,7 @@ class LauncherTUI:
|
|
| 260 |
self.console = Console()
|
| 261 |
self.config = LauncherConfig()
|
| 262 |
self.running = True
|
| 263 |
-
self.env_file =
|
| 264 |
# Load .env file to ensure environment variables are available
|
| 265 |
load_dotenv(dotenv_path=self.env_file, override=True)
|
| 266 |
|
|
@@ -277,8 +302,8 @@ class LauncherTUI:
|
|
| 277 |
"""Display main menu and handle selection"""
|
| 278 |
clear_screen()
|
| 279 |
|
| 280 |
-
# Detect
|
| 281 |
-
settings = SettingsDetector.
|
| 282 |
credentials = settings["credentials"]
|
| 283 |
custom_bases = settings["custom_bases"]
|
| 284 |
|
|
@@ -363,18 +388,17 @@ class LauncherTUI:
|
|
| 363 |
self.console.print("━" * 70)
|
| 364 |
provider_count = len(credentials)
|
| 365 |
custom_count = len(custom_bases)
|
| 366 |
-
|
|
|
|
|
|
|
|
|
|
| 367 |
has_advanced = bool(
|
| 368 |
settings["model_definitions"]
|
| 369 |
or settings["concurrency_limits"]
|
| 370 |
or settings["model_filters"]
|
| 371 |
-
or provider_settings
|
| 372 |
)
|
| 373 |
-
|
| 374 |
-
self.console.print(f" Providers: {provider_count} configured")
|
| 375 |
-
self.console.print(f" Custom Providers: {custom_count} configured")
|
| 376 |
self.console.print(
|
| 377 |
-
f" Advanced Settings: {'Active (view in menu 4)' if has_advanced else 'None'}"
|
| 378 |
)
|
| 379 |
|
| 380 |
# Show menu
|
|
@@ -418,7 +442,7 @@ class LauncherTUI:
|
|
| 418 |
elif choice == "4":
|
| 419 |
self.show_provider_settings_menu()
|
| 420 |
elif choice == "5":
|
| 421 |
-
load_dotenv(dotenv_path=
|
| 422 |
self.config = LauncherConfig() # Reload config
|
| 423 |
self.console.print("\n[green]✅ Configuration reloaded![/green]")
|
| 424 |
elif choice == "6":
|
|
@@ -659,13 +683,14 @@ class LauncherTUI:
|
|
| 659 |
"""Display provider/advanced settings (read-only + launch tool)"""
|
| 660 |
clear_screen()
|
| 661 |
|
| 662 |
-
settings
|
|
|
|
|
|
|
| 663 |
credentials = settings["credentials"]
|
| 664 |
custom_bases = settings["custom_bases"]
|
| 665 |
model_defs = settings["model_definitions"]
|
| 666 |
concurrency = settings["concurrency_limits"]
|
| 667 |
filters = settings["model_filters"]
|
| 668 |
-
provider_settings = settings.get("provider_settings", {})
|
| 669 |
|
| 670 |
self.console.print(
|
| 671 |
Panel.fit(
|
|
@@ -740,23 +765,13 @@ class LauncherTUI:
|
|
| 740 |
status = " + ".join(status_parts) if status_parts else "None"
|
| 741 |
self.console.print(f" • {provider:15} ✅ {status}")
|
| 742 |
|
| 743 |
-
# Provider-Specific Settings
|
| 744 |
self.console.print()
|
| 745 |
self.console.print("[bold]🔬 Provider-Specific Settings[/bold]")
|
| 746 |
self.console.print("━" * 70)
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
from .settings_tool import PROVIDER_SETTINGS_MAP
|
| 751 |
-
for provider in PROVIDER_SETTINGS_MAP.keys():
|
| 752 |
-
display_name = provider.replace("_", " ").title()
|
| 753 |
-
modified = provider_settings.get(provider, 0)
|
| 754 |
-
if modified > 0:
|
| 755 |
-
self.console.print(
|
| 756 |
-
f" • {display_name:20} [yellow]{modified} setting{'s' if modified > 1 else ''} modified[/yellow]"
|
| 757 |
-
)
|
| 758 |
-
else:
|
| 759 |
-
self.console.print(f" • {display_name:20} [dim]using defaults[/dim]")
|
| 760 |
|
| 761 |
# Actions
|
| 762 |
self.console.print()
|
|
@@ -823,15 +838,31 @@ class LauncherTUI:
|
|
| 823 |
# Run the tool with from_launcher=True to skip duplicate loading screen
|
| 824 |
run_credential_tool(from_launcher=True)
|
| 825 |
# Reload environment after credential tool
|
| 826 |
-
load_dotenv(dotenv_path=
|
| 827 |
|
| 828 |
def launch_settings_tool(self):
|
| 829 |
"""Launch settings configuration tool"""
|
| 830 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 831 |
|
| 832 |
run_settings_tool()
|
| 833 |
# Reload environment after settings tool
|
| 834 |
-
load_dotenv(dotenv_path=
|
| 835 |
|
| 836 |
def show_about(self):
|
| 837 |
"""Display About page with project information"""
|
|
@@ -919,9 +950,9 @@ class LauncherTUI:
|
|
| 919 |
)
|
| 920 |
|
| 921 |
ensure_env_defaults()
|
| 922 |
-
load_dotenv(dotenv_path=
|
| 923 |
run_credential_tool()
|
| 924 |
-
load_dotenv(dotenv_path=
|
| 925 |
|
| 926 |
# Check again after credential tool
|
| 927 |
if not os.getenv("PROXY_API_KEY"):
|
|
|
|
| 16 |
console = Console()
|
| 17 |
|
| 18 |
|
| 19 |
+
def _get_env_file() -> Path:
|
| 20 |
+
"""
|
| 21 |
+
Get .env file path (lightweight - no heavy imports).
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Path to .env file - EXE directory if frozen, else current working directory
|
| 25 |
+
"""
|
| 26 |
+
if getattr(sys, "frozen", False):
|
| 27 |
+
# Running as PyInstaller EXE - use EXE's directory
|
| 28 |
+
return Path(sys.executable).parent / ".env"
|
| 29 |
+
# Running as script - use current working directory
|
| 30 |
+
return Path.cwd() / ".env"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
def clear_screen():
|
| 34 |
"""
|
| 35 |
Cross-platform terminal clear that works robustly on both
|
|
|
|
| 88 |
@staticmethod
|
| 89 |
def update_proxy_api_key(new_key: str):
|
| 90 |
"""Update PROXY_API_KEY in .env only"""
|
| 91 |
+
env_file = _get_env_file()
|
| 92 |
set_key(str(env_file), "PROXY_API_KEY", new_key)
|
| 93 |
load_dotenv(dotenv_path=env_file, override=True)
|
| 94 |
|
|
|
|
| 99 |
@staticmethod
|
| 100 |
def _load_local_env() -> dict:
|
| 101 |
"""Load environment variables from local .env file only"""
|
| 102 |
+
env_file = _get_env_file()
|
| 103 |
env_dict = {}
|
| 104 |
if not env_file.exists():
|
| 105 |
return env_dict
|
|
|
|
| 121 |
|
| 122 |
@staticmethod
|
| 123 |
def get_all_settings() -> dict:
|
| 124 |
+
"""Returns comprehensive settings overview (includes provider_settings which triggers heavy imports)"""
|
| 125 |
return {
|
| 126 |
"credentials": SettingsDetector.detect_credentials(),
|
| 127 |
"custom_bases": SettingsDetector.detect_custom_api_bases(),
|
|
|
|
| 131 |
"provider_settings": SettingsDetector.detect_provider_settings(),
|
| 132 |
}
|
| 133 |
|
| 134 |
+
@staticmethod
|
| 135 |
+
def get_basic_settings() -> dict:
|
| 136 |
+
"""Returns basic settings overview without provider_settings (avoids heavy imports)"""
|
| 137 |
+
return {
|
| 138 |
+
"credentials": SettingsDetector.detect_credentials(),
|
| 139 |
+
"custom_bases": SettingsDetector.detect_custom_api_bases(),
|
| 140 |
+
"model_definitions": SettingsDetector.detect_model_definitions(),
|
| 141 |
+
"concurrency_limits": SettingsDetector.detect_concurrency_limits(),
|
| 142 |
+
"model_filters": SettingsDetector.detect_model_filters(),
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
@staticmethod
|
| 146 |
def detect_credentials() -> dict:
|
| 147 |
"""Detect API keys and OAuth credentials"""
|
|
|
|
| 285 |
self.console = Console()
|
| 286 |
self.config = LauncherConfig()
|
| 287 |
self.running = True
|
| 288 |
+
self.env_file = _get_env_file()
|
| 289 |
# Load .env file to ensure environment variables are available
|
| 290 |
load_dotenv(dotenv_path=self.env_file, override=True)
|
| 291 |
|
|
|
|
| 302 |
"""Display main menu and handle selection"""
|
| 303 |
clear_screen()
|
| 304 |
|
| 305 |
+
# Detect basic settings (excludes provider_settings to avoid heavy imports)
|
| 306 |
+
settings = SettingsDetector.get_basic_settings()
|
| 307 |
credentials = settings["credentials"]
|
| 308 |
custom_bases = settings["custom_bases"]
|
| 309 |
|
|
|
|
| 388 |
self.console.print("━" * 70)
|
| 389 |
provider_count = len(credentials)
|
| 390 |
custom_count = len(custom_bases)
|
| 391 |
+
|
| 392 |
+
self.console.print(f" Providers: {provider_count} configured")
|
| 393 |
+
self.console.print(f" Custom Providers: {custom_count} configured")
|
| 394 |
+
# Note: provider_settings detection is deferred to avoid heavy imports on startup
|
| 395 |
has_advanced = bool(
|
| 396 |
settings["model_definitions"]
|
| 397 |
or settings["concurrency_limits"]
|
| 398 |
or settings["model_filters"]
|
|
|
|
| 399 |
)
|
|
|
|
|
|
|
|
|
|
| 400 |
self.console.print(
|
| 401 |
+
f" Advanced Settings: {'Active (view in menu 4)' if has_advanced else 'None (view menu 4 for details)'}"
|
| 402 |
)
|
| 403 |
|
| 404 |
# Show menu
|
|
|
|
| 442 |
elif choice == "4":
|
| 443 |
self.show_provider_settings_menu()
|
| 444 |
elif choice == "5":
|
| 445 |
+
load_dotenv(dotenv_path=_get_env_file(), override=True)
|
| 446 |
self.config = LauncherConfig() # Reload config
|
| 447 |
self.console.print("\n[green]✅ Configuration reloaded![/green]")
|
| 448 |
elif choice == "6":
|
|
|
|
| 683 |
"""Display provider/advanced settings (read-only + launch tool)"""
|
| 684 |
clear_screen()
|
| 685 |
|
| 686 |
+
# Use basic settings to avoid heavy imports - provider_settings deferred to Settings Tool
|
| 687 |
+
settings = SettingsDetector.get_basic_settings()
|
| 688 |
+
|
| 689 |
credentials = settings["credentials"]
|
| 690 |
custom_bases = settings["custom_bases"]
|
| 691 |
model_defs = settings["model_definitions"]
|
| 692 |
concurrency = settings["concurrency_limits"]
|
| 693 |
filters = settings["model_filters"]
|
|
|
|
| 694 |
|
| 695 |
self.console.print(
|
| 696 |
Panel.fit(
|
|
|
|
| 765 |
status = " + ".join(status_parts) if status_parts else "None"
|
| 766 |
self.console.print(f" • {provider:15} ✅ {status}")
|
| 767 |
|
| 768 |
+
# Provider-Specific Settings (deferred to Settings Tool to avoid heavy imports)
|
| 769 |
self.console.print()
|
| 770 |
self.console.print("[bold]🔬 Provider-Specific Settings[/bold]")
|
| 771 |
self.console.print("━" * 70)
|
| 772 |
+
self.console.print(
|
| 773 |
+
" [dim]Launch Settings Tool to view/configure provider-specific settings[/dim]"
|
| 774 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 775 |
|
| 776 |
# Actions
|
| 777 |
self.console.print()
|
|
|
|
| 838 |
# Run the tool with from_launcher=True to skip duplicate loading screen
|
| 839 |
run_credential_tool(from_launcher=True)
|
| 840 |
# Reload environment after credential tool
|
| 841 |
+
load_dotenv(dotenv_path=_get_env_file(), override=True)
|
| 842 |
|
| 843 |
def launch_settings_tool(self):
|
| 844 |
"""Launch settings configuration tool"""
|
| 845 |
+
import time
|
| 846 |
+
|
| 847 |
+
clear_screen()
|
| 848 |
+
|
| 849 |
+
self.console.print("━" * 70)
|
| 850 |
+
self.console.print("Advanced Settings Configuration Tool")
|
| 851 |
+
self.console.print("━" * 70)
|
| 852 |
+
|
| 853 |
+
_start_time = time.time()
|
| 854 |
+
|
| 855 |
+
with self.console.status("Initializing settings tool...", spinner="dots"):
|
| 856 |
+
from proxy_app.settings_tool import run_settings_tool
|
| 857 |
+
|
| 858 |
+
_elapsed = time.time() - _start_time
|
| 859 |
+
self.console.print(f"✓ Settings tool ready in {_elapsed:.2f}s")
|
| 860 |
+
|
| 861 |
+
time.sleep(0.3)
|
| 862 |
|
| 863 |
run_settings_tool()
|
| 864 |
# Reload environment after settings tool
|
| 865 |
+
load_dotenv(dotenv_path=_get_env_file(), override=True)
|
| 866 |
|
| 867 |
def show_about(self):
|
| 868 |
"""Display About page with project information"""
|
|
|
|
| 950 |
)
|
| 951 |
|
| 952 |
ensure_env_defaults()
|
| 953 |
+
load_dotenv(dotenv_path=_get_env_file(), override=True)
|
| 954 |
run_credential_tool()
|
| 955 |
+
load_dotenv(dotenv_path=_get_env_file(), override=True)
|
| 956 |
|
| 957 |
# Check again after credential tool
|
| 958 |
if not os.getenv("PROXY_API_KEY"):
|
src/proxy_app/main.py
CHANGED
|
@@ -52,11 +52,17 @@ _start_time = time.time()
|
|
| 52 |
from dotenv import load_dotenv
|
| 53 |
from glob import glob
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
# Load main .env first
|
| 56 |
-
load_dotenv()
|
| 57 |
|
| 58 |
# Load any additional .env files (e.g., antigravity_all_combined.env, gemini_cli_all_combined.env)
|
| 59 |
-
_root_dir = Path.cwd()
|
| 60 |
_env_files_found = list(_root_dir.glob("*.env"))
|
| 61 |
for _env_file in sorted(_root_dir.glob("*.env")):
|
| 62 |
if _env_file.name != ".env": # Skip main .env (already loaded)
|
|
@@ -234,8 +240,10 @@ print(
|
|
| 234 |
# Note: Debug logging will be added after logging configuration below
|
| 235 |
|
| 236 |
# --- Logging Configuration ---
|
| 237 |
-
|
| 238 |
-
|
|
|
|
|
|
|
| 239 |
|
| 240 |
# Configure a console handler with color (INFO and above only, no DEBUG)
|
| 241 |
console_handler = colorlog.StreamHandler(sys.stdout)
|
|
@@ -324,7 +332,7 @@ litellm_logger.propagate = False
|
|
| 324 |
logging.debug(f"Modules loaded in {_elapsed:.2f}s")
|
| 325 |
|
| 326 |
# Load environment variables from .env file
|
| 327 |
-
load_dotenv()
|
| 328 |
|
| 329 |
# --- Configuration ---
|
| 330 |
USE_EMBEDDING_BATCHER = False
|
|
@@ -570,11 +578,11 @@ async def lifespan(app: FastAPI):
|
|
| 570 |
)
|
| 571 |
|
| 572 |
# Log loaded credentials summary (compact, always visible for deployment verification)
|
| 573 |
-
#_api_summary = ', '.join([f"{p}:{len(c)}" for p, c in api_keys.items()]) if api_keys else "none"
|
| 574 |
-
#_oauth_summary = ', '.join([f"{p}:{len(c)}" for p, c in oauth_credentials.items()]) if oauth_credentials else "none"
|
| 575 |
-
#_total_summary = ', '.join([f"{p}:{len(c)}" for p, c in client.all_credentials.items()])
|
| 576 |
-
#print(f"🔑 Credentials loaded: {_total_summary} (API: {_api_summary} | OAuth: {_oauth_summary})")
|
| 577 |
-
client.background_refresher.start()
|
| 578 |
app.state.rotating_client = client
|
| 579 |
|
| 580 |
# Warn if no provider credentials are configured
|
|
@@ -1263,8 +1271,8 @@ async def cost_estimate(request: Request, _=Depends(verify_api_key)):
|
|
| 1263 |
|
| 1264 |
|
| 1265 |
if __name__ == "__main__":
|
| 1266 |
-
# Define ENV_FILE for onboarding checks
|
| 1267 |
-
ENV_FILE =
|
| 1268 |
|
| 1269 |
# Check if launcher TUI should be shown (no arguments provided)
|
| 1270 |
if len(sys.argv) == 1:
|
|
@@ -1331,7 +1339,7 @@ if __name__ == "__main__":
|
|
| 1331 |
|
| 1332 |
ensure_env_defaults()
|
| 1333 |
# Reload environment variables after ensure_env_defaults creates/updates .env
|
| 1334 |
-
load_dotenv(override=True)
|
| 1335 |
run_credential_tool()
|
| 1336 |
else:
|
| 1337 |
# Check if onboarding is needed
|
|
@@ -1349,11 +1357,11 @@ if __name__ == "__main__":
|
|
| 1349 |
from rotator_library.credential_tool import ensure_env_defaults
|
| 1350 |
|
| 1351 |
ensure_env_defaults()
|
| 1352 |
-
load_dotenv(override=True)
|
| 1353 |
run_credential_tool()
|
| 1354 |
|
| 1355 |
# After credential tool exits, reload and re-check
|
| 1356 |
-
load_dotenv(override=True)
|
| 1357 |
# Re-read PROXY_API_KEY from environment
|
| 1358 |
PROXY_API_KEY = os.getenv("PROXY_API_KEY")
|
| 1359 |
|
|
|
|
| 52 |
from dotenv import load_dotenv
|
| 53 |
from glob import glob
|
| 54 |
|
| 55 |
+
# Get the application root directory (EXE dir if frozen, else CWD)
|
| 56 |
+
# Inlined here to avoid triggering heavy rotator_library imports before loading screen
|
| 57 |
+
if getattr(sys, "frozen", False):
|
| 58 |
+
_root_dir = Path(sys.executable).parent
|
| 59 |
+
else:
|
| 60 |
+
_root_dir = Path.cwd()
|
| 61 |
+
|
| 62 |
# Load main .env first
|
| 63 |
+
load_dotenv(_root_dir / ".env")
|
| 64 |
|
| 65 |
# Load any additional .env files (e.g., antigravity_all_combined.env, gemini_cli_all_combined.env)
|
|
|
|
| 66 |
_env_files_found = list(_root_dir.glob("*.env"))
|
| 67 |
for _env_file in sorted(_root_dir.glob("*.env")):
|
| 68 |
if _env_file.name != ".env": # Skip main .env (already loaded)
|
|
|
|
| 240 |
# Note: Debug logging will be added after logging configuration below
|
| 241 |
|
| 242 |
# --- Logging Configuration ---
|
| 243 |
+
# Import path utilities here (after loading screen) to avoid triggering heavy imports early
|
| 244 |
+
from rotator_library.utils.paths import get_logs_dir, get_data_file
|
| 245 |
+
|
| 246 |
+
LOG_DIR = get_logs_dir(_root_dir)
|
| 247 |
|
| 248 |
# Configure a console handler with color (INFO and above only, no DEBUG)
|
| 249 |
console_handler = colorlog.StreamHandler(sys.stdout)
|
|
|
|
| 332 |
logging.debug(f"Modules loaded in {_elapsed:.2f}s")
|
| 333 |
|
| 334 |
# Load environment variables from .env file
|
| 335 |
+
load_dotenv(_root_dir / ".env")
|
| 336 |
|
| 337 |
# --- Configuration ---
|
| 338 |
USE_EMBEDDING_BATCHER = False
|
|
|
|
| 578 |
)
|
| 579 |
|
| 580 |
# Log loaded credentials summary (compact, always visible for deployment verification)
|
| 581 |
+
# _api_summary = ', '.join([f"{p}:{len(c)}" for p, c in api_keys.items()]) if api_keys else "none"
|
| 582 |
+
# _oauth_summary = ', '.join([f"{p}:{len(c)}" for p, c in oauth_credentials.items()]) if oauth_credentials else "none"
|
| 583 |
+
# _total_summary = ', '.join([f"{p}:{len(c)}" for p, c in client.all_credentials.items()])
|
| 584 |
+
# print(f"🔑 Credentials loaded: {_total_summary} (API: {_api_summary} | OAuth: {_oauth_summary})")
|
| 585 |
+
client.background_refresher.start() # Start the background task
|
| 586 |
app.state.rotating_client = client
|
| 587 |
|
| 588 |
# Warn if no provider credentials are configured
|
|
|
|
| 1271 |
|
| 1272 |
|
| 1273 |
if __name__ == "__main__":
|
| 1274 |
+
# Define ENV_FILE for onboarding checks using centralized path
|
| 1275 |
+
ENV_FILE = get_data_file(".env")
|
| 1276 |
|
| 1277 |
# Check if launcher TUI should be shown (no arguments provided)
|
| 1278 |
if len(sys.argv) == 1:
|
|
|
|
| 1339 |
|
| 1340 |
ensure_env_defaults()
|
| 1341 |
# Reload environment variables after ensure_env_defaults creates/updates .env
|
| 1342 |
+
load_dotenv(ENV_FILE, override=True)
|
| 1343 |
run_credential_tool()
|
| 1344 |
else:
|
| 1345 |
# Check if onboarding is needed
|
|
|
|
| 1357 |
from rotator_library.credential_tool import ensure_env_defaults
|
| 1358 |
|
| 1359 |
ensure_env_defaults()
|
| 1360 |
+
load_dotenv(ENV_FILE, override=True)
|
| 1361 |
run_credential_tool()
|
| 1362 |
|
| 1363 |
# After credential tool exits, reload and re-check
|
| 1364 |
+
load_dotenv(ENV_FILE, override=True)
|
| 1365 |
# Re-read PROXY_API_KEY from environment
|
| 1366 |
PROXY_API_KEY = os.getenv("PROXY_API_KEY")
|
| 1367 |
|
src/proxy_app/settings_tool.py
CHANGED
|
@@ -12,8 +12,36 @@ from rich.prompt import Prompt, IntPrompt, Confirm
|
|
| 12 |
from rich.panel import Panel
|
| 13 |
from dotenv import set_key, unset_key
|
| 14 |
|
|
|
|
|
|
|
| 15 |
console = Console()
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def clear_screen():
|
| 19 |
"""
|
|
@@ -31,7 +59,7 @@ class AdvancedSettings:
|
|
| 31 |
"""Manages pending changes to .env"""
|
| 32 |
|
| 33 |
def __init__(self):
|
| 34 |
-
self.env_file =
|
| 35 |
self.pending_changes = {} # key -> value (None means delete)
|
| 36 |
self.load_current_settings()
|
| 37 |
|
|
@@ -39,7 +67,7 @@ class AdvancedSettings:
|
|
| 39 |
"""Load current .env values into env vars"""
|
| 40 |
from dotenv import load_dotenv
|
| 41 |
|
| 42 |
-
load_dotenv(override=True)
|
| 43 |
|
| 44 |
def set(self, key: str, value: str):
|
| 45 |
"""Stage a change"""
|
|
@@ -70,6 +98,70 @@ class AdvancedSettings:
|
|
| 70 |
"""Check if there are pending changes"""
|
| 71 |
return bool(self.pending_changes)
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
class CustomProviderManager:
|
| 75 |
"""Manages custom provider API bases"""
|
|
@@ -383,6 +475,11 @@ ANTIGRAVITY_SETTINGS = {
|
|
| 383 |
"default": "\n\nSTRICT PARAMETERS: {params}.",
|
| 384 |
"description": "Template for Claude strict parameter hints in tool descriptions",
|
| 385 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
}
|
| 387 |
|
| 388 |
# Gemini CLI provider environment variables
|
|
@@ -427,12 +524,27 @@ GEMINI_CLI_SETTINGS = {
|
|
| 427 |
"default": "",
|
| 428 |
"description": "GCP Project ID for paid tier users (required for paid tiers)",
|
| 429 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
}
|
| 431 |
|
| 432 |
# Map provider names to their settings definitions
|
| 433 |
PROVIDER_SETTINGS_MAP = {
|
| 434 |
"antigravity": ANTIGRAVITY_SETTINGS,
|
| 435 |
"gemini_cli": GEMINI_CLI_SETTINGS,
|
|
|
|
| 436 |
}
|
| 437 |
|
| 438 |
|
|
@@ -516,9 +628,61 @@ class SettingsTool:
|
|
| 516 |
self.provider_settings_mgr = ProviderSettingsManager(self.settings)
|
| 517 |
self.running = True
|
| 518 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
def get_available_providers(self) -> List[str]:
|
| 520 |
"""Get list of providers that have credentials configured"""
|
| 521 |
-
env_file =
|
| 522 |
providers = set()
|
| 523 |
|
| 524 |
# Scan for providers with API keys from local .env
|
|
@@ -541,7 +705,9 @@ class SettingsTool:
|
|
| 541 |
pass
|
| 542 |
|
| 543 |
# Also check for OAuth providers from files
|
| 544 |
-
|
|
|
|
|
|
|
| 545 |
if oauth_dir.exists():
|
| 546 |
for file in oauth_dir.glob("*_oauth_*.json"):
|
| 547 |
provider = file.name.split("_oauth_")[0]
|
|
@@ -579,12 +745,7 @@ class SettingsTool:
|
|
| 579 |
self.console.print()
|
| 580 |
self.console.print("━" * 70)
|
| 581 |
|
| 582 |
-
|
| 583 |
-
self.console.print(
|
| 584 |
-
'[yellow]ℹ️ Changes are pending until you select "Save & Exit"[/yellow]'
|
| 585 |
-
)
|
| 586 |
-
else:
|
| 587 |
-
self.console.print("[dim]ℹ️ No pending changes[/dim]")
|
| 588 |
|
| 589 |
self.console.print()
|
| 590 |
self.console.print(
|
|
@@ -618,6 +779,7 @@ class SettingsTool:
|
|
| 618 |
while True:
|
| 619 |
clear_screen()
|
| 620 |
|
|
|
|
| 621 |
providers = self.provider_mgr.get_current_providers()
|
| 622 |
|
| 623 |
self.console.print(
|
|
@@ -631,9 +793,48 @@ class SettingsTool:
|
|
| 631 |
self.console.print("[bold]📋 Configured Custom Providers[/bold]")
|
| 632 |
self.console.print("━" * 70)
|
| 633 |
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 637 |
else:
|
| 638 |
self.console.print(" [dim]No custom providers configured[/dim]")
|
| 639 |
|
|
@@ -662,7 +863,7 @@ class SettingsTool:
|
|
| 662 |
if api_base:
|
| 663 |
self.provider_mgr.add_provider(name, api_base)
|
| 664 |
self.console.print(
|
| 665 |
-
f"\n[green]✅ Custom provider '{name}'
|
| 666 |
)
|
| 667 |
self.console.print(
|
| 668 |
f" To use: set {name.upper()}_API_KEY in credentials"
|
|
@@ -670,14 +871,18 @@ class SettingsTool:
|
|
| 670 |
input("\nPress Enter to continue...")
|
| 671 |
|
| 672 |
elif choice == "2":
|
| 673 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 674 |
self.console.print("\n[yellow]No providers to edit[/yellow]")
|
| 675 |
input("\nPress Enter to continue...")
|
| 676 |
continue
|
| 677 |
|
| 678 |
# Show numbered list
|
| 679 |
self.console.print("\n[bold]Select provider to edit:[/bold]")
|
| 680 |
-
providers_list =
|
| 681 |
for idx, prov in enumerate(providers_list, 1):
|
| 682 |
self.console.print(f" {idx}. {prov}")
|
| 683 |
|
|
@@ -686,7 +891,9 @@ class SettingsTool:
|
|
| 686 |
choices=[str(i) for i in range(1, len(providers_list) + 1)],
|
| 687 |
)
|
| 688 |
name = providers_list[choice_idx - 1]
|
| 689 |
-
|
|
|
|
|
|
|
| 690 |
|
| 691 |
self.console.print(f"\nCurrent API Base: {current_base}")
|
| 692 |
new_base = Prompt.ask(
|
|
@@ -703,16 +910,33 @@ class SettingsTool:
|
|
| 703 |
input("\nPress Enter to continue...")
|
| 704 |
|
| 705 |
elif choice == "3":
|
| 706 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 707 |
self.console.print("\n[yellow]No providers to remove[/yellow]")
|
| 708 |
input("\nPress Enter to continue...")
|
| 709 |
continue
|
| 710 |
|
| 711 |
# Show numbered list
|
| 712 |
self.console.print("\n[bold]Select provider to remove:[/bold]")
|
| 713 |
-
|
|
|
|
| 714 |
for idx, prov in enumerate(providers_list, 1):
|
| 715 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 716 |
|
| 717 |
choice_idx = IntPrompt.ask(
|
| 718 |
"Select option",
|
|
@@ -721,10 +945,18 @@ class SettingsTool:
|
|
| 721 |
name = providers_list[choice_idx - 1]
|
| 722 |
|
| 723 |
if Confirm.ask(f"Remove '{name}'?"):
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 728 |
input("\nPress Enter to continue...")
|
| 729 |
|
| 730 |
elif choice == "4":
|
|
@@ -735,7 +967,8 @@ class SettingsTool:
|
|
| 735 |
while True:
|
| 736 |
clear_screen()
|
| 737 |
|
| 738 |
-
|
|
|
|
| 739 |
|
| 740 |
self.console.print(
|
| 741 |
Panel.fit(
|
|
@@ -748,10 +981,69 @@ class SettingsTool:
|
|
| 748 |
self.console.print("[bold]📋 Configured Provider Models[/bold]")
|
| 749 |
self.console.print("━" * 70)
|
| 750 |
|
| 751 |
-
|
| 752 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 753 |
self.console.print(
|
| 754 |
-
|
|
|
|
|
|
|
| 755 |
)
|
| 756 |
else:
|
| 757 |
self.console.print(" [dim]No model definitions configured[/dim]")
|
|
@@ -778,19 +1070,36 @@ class SettingsTool:
|
|
| 778 |
if choice == "1":
|
| 779 |
self.add_model_definitions()
|
| 780 |
elif choice == "2":
|
| 781 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 782 |
self.console.print("\n[yellow]No providers to edit[/yellow]")
|
| 783 |
input("\nPress Enter to continue...")
|
| 784 |
continue
|
| 785 |
-
self.edit_model_definitions(
|
| 786 |
elif choice == "3":
|
| 787 |
-
|
|
|
|
|
|
|
|
|
|
| 788 |
self.console.print("\n[yellow]No providers to view[/yellow]")
|
| 789 |
input("\nPress Enter to continue...")
|
| 790 |
continue
|
| 791 |
-
self.view_model_definitions(
|
| 792 |
elif choice == "4":
|
| 793 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 794 |
self.console.print("\n[yellow]No providers to remove[/yellow]")
|
| 795 |
input("\nPress Enter to continue...")
|
| 796 |
continue
|
|
@@ -799,9 +1108,14 @@ class SettingsTool:
|
|
| 799 |
self.console.print(
|
| 800 |
"\n[bold]Select provider to remove models from:[/bold]"
|
| 801 |
)
|
| 802 |
-
providers_list =
|
| 803 |
for idx, prov in enumerate(providers_list, 1):
|
| 804 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 805 |
|
| 806 |
choice_idx = IntPrompt.ask(
|
| 807 |
"Select option",
|
|
@@ -810,10 +1124,18 @@ class SettingsTool:
|
|
| 810 |
provider = providers_list[choice_idx - 1]
|
| 811 |
|
| 812 |
if Confirm.ask(f"Remove all model definitions for '{provider}'?"):
|
| 813 |
-
|
| 814 |
-
|
| 815 |
-
|
| 816 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 817 |
input("\nPress Enter to continue...")
|
| 818 |
elif choice == "5":
|
| 819 |
break
|
|
@@ -1140,7 +1462,7 @@ class SettingsTool:
|
|
| 1140 |
self.console.print("[bold]📋 Current Settings[/bold]")
|
| 1141 |
self.console.print("━" * 70)
|
| 1142 |
|
| 1143 |
-
# Display all settings with current values
|
| 1144 |
settings_list = list(definitions.keys())
|
| 1145 |
for idx, key in enumerate(settings_list, 1):
|
| 1146 |
definition = definitions[key]
|
|
@@ -1149,37 +1471,88 @@ class SettingsTool:
|
|
| 1149 |
setting_type = definition.get("type", "str")
|
| 1150 |
description = definition.get("description", "")
|
| 1151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1152 |
# Format value display
|
| 1153 |
if setting_type == "bool":
|
| 1154 |
value_display = (
|
| 1155 |
"[green]✓ Enabled[/green]"
|
| 1156 |
-
if
|
| 1157 |
else "[red]✗ Disabled[/red]"
|
| 1158 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1159 |
elif setting_type == "int":
|
| 1160 |
-
value_display = f"[cyan]{
|
|
|
|
| 1161 |
else:
|
| 1162 |
value_display = (
|
| 1163 |
-
f"[cyan]{
|
| 1164 |
-
if
|
| 1165 |
else "[dim](not set)[/dim]"
|
| 1166 |
)
|
| 1167 |
-
|
| 1168 |
-
|
| 1169 |
-
|
| 1170 |
-
mod_marker = "[yellow]*[/yellow]" if modified else " "
|
| 1171 |
|
| 1172 |
# Short key name for display (strip provider prefix)
|
| 1173 |
short_key = key.replace(f"{provider.upper()}_", "")
|
| 1174 |
|
| 1175 |
-
|
| 1176 |
-
|
| 1177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1178 |
self.console.print(f" [dim]{description}[/dim]")
|
| 1179 |
|
| 1180 |
self.console.print()
|
| 1181 |
self.console.print("━" * 70)
|
| 1182 |
-
self.console.print(
|
|
|
|
|
|
|
| 1183 |
self.console.print()
|
| 1184 |
self.console.print("[bold]⚙️ Actions[/bold]")
|
| 1185 |
self.console.print()
|
|
@@ -1299,6 +1672,7 @@ class SettingsTool:
|
|
| 1299 |
while True:
|
| 1300 |
clear_screen()
|
| 1301 |
|
|
|
|
| 1302 |
modes = self.rotation_mgr.get_current_modes()
|
| 1303 |
available_providers = self.get_available_providers()
|
| 1304 |
|
|
@@ -1322,20 +1696,78 @@ class SettingsTool:
|
|
| 1322 |
self.console.print("[bold]📋 Current Rotation Mode Settings[/bold]")
|
| 1323 |
self.console.print("━" * 70)
|
| 1324 |
|
| 1325 |
-
|
| 1326 |
-
|
| 1327 |
-
|
| 1328 |
-
|
| 1329 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1330 |
mode_display = (
|
| 1331 |
f"[green]{mode}[/green]"
|
| 1332 |
if mode == "sequential"
|
| 1333 |
else f"[blue]{mode}[/blue]"
|
| 1334 |
)
|
| 1335 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1336 |
|
| 1337 |
# Show providers with default modes
|
| 1338 |
-
providers_with_defaults = [
|
|
|
|
|
|
|
| 1339 |
if providers_with_defaults:
|
| 1340 |
self.console.print()
|
| 1341 |
self.console.print("[dim]Providers using default modes:[/dim]")
|
|
@@ -1423,12 +1855,16 @@ class SettingsTool:
|
|
| 1423 |
|
| 1424 |
self.rotation_mgr.set_mode(provider, new_mode)
|
| 1425 |
self.console.print(
|
| 1426 |
-
f"\n[green]✅ Rotation mode for '{provider}'
|
| 1427 |
)
|
| 1428 |
input("\nPress Enter to continue...")
|
| 1429 |
|
| 1430 |
elif choice == "2":
|
| 1431 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1432 |
self.console.print(
|
| 1433 |
"\n[yellow]No custom rotation modes to reset[/yellow]"
|
| 1434 |
)
|
|
@@ -1439,12 +1875,18 @@ class SettingsTool:
|
|
| 1439 |
self.console.print(
|
| 1440 |
"\n[bold]Select provider to reset to default:[/bold]"
|
| 1441 |
)
|
| 1442 |
-
modes_list =
|
| 1443 |
for idx, prov in enumerate(modes_list, 1):
|
| 1444 |
default_mode = self.rotation_mgr.get_default_mode(prov)
|
| 1445 |
-
|
| 1446 |
-
|
| 1447 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1448 |
|
| 1449 |
choice_idx = IntPrompt.ask(
|
| 1450 |
"Select option",
|
|
@@ -1452,12 +1894,21 @@ class SettingsTool:
|
|
| 1452 |
)
|
| 1453 |
provider = modes_list[choice_idx - 1]
|
| 1454 |
default_mode = self.rotation_mgr.get_default_mode(provider)
|
|
|
|
| 1455 |
|
| 1456 |
if Confirm.ask(f"Reset '{provider}' to default mode ({default_mode})?"):
|
| 1457 |
-
|
| 1458 |
-
|
| 1459 |
-
|
| 1460 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1461 |
input("\nPress Enter to continue...")
|
| 1462 |
|
| 1463 |
elif choice == "3":
|
|
@@ -1630,6 +2081,7 @@ class SettingsTool:
|
|
| 1630 |
while True:
|
| 1631 |
clear_screen()
|
| 1632 |
|
|
|
|
| 1633 |
limits = self.concurrency_mgr.get_current_limits()
|
| 1634 |
|
| 1635 |
self.console.print(
|
|
@@ -1643,10 +2095,57 @@ class SettingsTool:
|
|
| 1643 |
self.console.print("[bold]📋 Current Concurrency Settings[/bold]")
|
| 1644 |
self.console.print("━" * 70)
|
| 1645 |
|
| 1646 |
-
|
| 1647 |
-
|
| 1648 |
-
|
| 1649 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1650 |
else:
|
| 1651 |
self.console.print(" • Default: 1 request/key (all providers)")
|
| 1652 |
|
|
@@ -1704,7 +2203,7 @@ class SettingsTool:
|
|
| 1704 |
if 1 <= limit <= 100:
|
| 1705 |
self.concurrency_mgr.set_limit(provider, limit)
|
| 1706 |
self.console.print(
|
| 1707 |
-
f"\n[green]✅ Concurrency limit
|
| 1708 |
)
|
| 1709 |
else:
|
| 1710 |
self.console.print(
|
|
@@ -1713,14 +2212,18 @@ class SettingsTool:
|
|
| 1713 |
input("\nPress Enter to continue...")
|
| 1714 |
|
| 1715 |
elif choice == "2":
|
| 1716 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1717 |
self.console.print("\n[yellow]No limits to edit[/yellow]")
|
| 1718 |
input("\nPress Enter to continue...")
|
| 1719 |
continue
|
| 1720 |
|
| 1721 |
# Show numbered list
|
| 1722 |
self.console.print("\n[bold]Select provider to edit:[/bold]")
|
| 1723 |
-
limits_list =
|
| 1724 |
for idx, prov in enumerate(limits_list, 1):
|
| 1725 |
self.console.print(f" {idx}. {prov}")
|
| 1726 |
|
|
@@ -1729,7 +2232,8 @@ class SettingsTool:
|
|
| 1729 |
choices=[str(i) for i in range(1, len(limits_list) + 1)],
|
| 1730 |
)
|
| 1731 |
provider = limits_list[choice_idx - 1]
|
| 1732 |
-
|
|
|
|
| 1733 |
|
| 1734 |
self.console.print(f"\nCurrent limit: {current_limit} requests/key")
|
| 1735 |
new_limit = IntPrompt.ask(
|
|
@@ -1750,7 +2254,18 @@ class SettingsTool:
|
|
| 1750 |
input("\nPress Enter to continue...")
|
| 1751 |
|
| 1752 |
elif choice == "3":
|
| 1753 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1754 |
self.console.print("\n[yellow]No limits to remove[/yellow]")
|
| 1755 |
input("\nPress Enter to continue...")
|
| 1756 |
continue
|
|
@@ -1759,9 +2274,14 @@ class SettingsTool:
|
|
| 1759 |
self.console.print(
|
| 1760 |
"\n[bold]Select provider to remove limit from:[/bold]"
|
| 1761 |
)
|
| 1762 |
-
limits_list =
|
| 1763 |
for idx, prov in enumerate(limits_list, 1):
|
| 1764 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1765 |
|
| 1766 |
choice_idx = IntPrompt.ask(
|
| 1767 |
"Select option",
|
|
@@ -1772,18 +2292,118 @@ class SettingsTool:
|
|
| 1772 |
if Confirm.ask(
|
| 1773 |
f"Remove concurrency limit for '{provider}' (reset to default 1)?"
|
| 1774 |
):
|
| 1775 |
-
|
| 1776 |
-
|
| 1777 |
-
|
| 1778 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1779 |
input("\nPress Enter to continue...")
|
| 1780 |
|
| 1781 |
elif choice == "4":
|
| 1782 |
break
|
| 1783 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1784 |
def save_and_exit(self):
|
| 1785 |
"""Save pending changes and exit"""
|
| 1786 |
if self.settings.has_pending():
|
|
|
|
|
|
|
|
|
|
| 1787 |
if Confirm.ask("\n[bold yellow]Save all pending changes?[/bold yellow]"):
|
| 1788 |
self.settings.save()
|
| 1789 |
self.console.print("\n[green]✅ All changes saved to .env![/green]")
|
|
@@ -1801,6 +2421,9 @@ class SettingsTool:
|
|
| 1801 |
def exit_without_saving(self):
|
| 1802 |
"""Exit without saving"""
|
| 1803 |
if self.settings.has_pending():
|
|
|
|
|
|
|
|
|
|
| 1804 |
if Confirm.ask("\n[bold red]Discard all pending changes?[/bold red]"):
|
| 1805 |
self.settings.discard()
|
| 1806 |
self.console.print("\n[yellow]Changes discarded[/yellow]")
|
|
|
|
| 12 |
from rich.panel import Panel
|
| 13 |
from dotenv import set_key, unset_key
|
| 14 |
|
| 15 |
+
from rotator_library.utils.paths import get_data_file
|
| 16 |
+
|
| 17 |
console = Console()
|
| 18 |
|
| 19 |
+
# Sentinel value for distinguishing "no pending change" from "pending change to None"
|
| 20 |
+
_NOT_FOUND = object()
|
| 21 |
+
|
| 22 |
+
# Import default OAuth port values from provider modules
|
| 23 |
+
# These serve as the source of truth for default port values
|
| 24 |
+
try:
|
| 25 |
+
from rotator_library.providers.gemini_auth_base import GeminiAuthBase
|
| 26 |
+
|
| 27 |
+
GEMINI_CLI_DEFAULT_OAUTH_PORT = GeminiAuthBase.CALLBACK_PORT
|
| 28 |
+
except ImportError:
|
| 29 |
+
GEMINI_CLI_DEFAULT_OAUTH_PORT = 8085
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
from rotator_library.providers.antigravity_auth_base import AntigravityAuthBase
|
| 33 |
+
|
| 34 |
+
ANTIGRAVITY_DEFAULT_OAUTH_PORT = AntigravityAuthBase.CALLBACK_PORT
|
| 35 |
+
except ImportError:
|
| 36 |
+
ANTIGRAVITY_DEFAULT_OAUTH_PORT = 51121
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
from rotator_library.providers.iflow_auth_base import (
|
| 40 |
+
CALLBACK_PORT as IFLOW_DEFAULT_OAUTH_PORT,
|
| 41 |
+
)
|
| 42 |
+
except ImportError:
|
| 43 |
+
IFLOW_DEFAULT_OAUTH_PORT = 11451
|
| 44 |
+
|
| 45 |
|
| 46 |
def clear_screen():
|
| 47 |
"""
|
|
|
|
| 59 |
"""Manages pending changes to .env"""
|
| 60 |
|
| 61 |
def __init__(self):
|
| 62 |
+
self.env_file = get_data_file(".env")
|
| 63 |
self.pending_changes = {} # key -> value (None means delete)
|
| 64 |
self.load_current_settings()
|
| 65 |
|
|
|
|
| 67 |
"""Load current .env values into env vars"""
|
| 68 |
from dotenv import load_dotenv
|
| 69 |
|
| 70 |
+
load_dotenv(self.env_file, override=True)
|
| 71 |
|
| 72 |
def set(self, key: str, value: str):
|
| 73 |
"""Stage a change"""
|
|
|
|
| 98 |
"""Check if there are pending changes"""
|
| 99 |
return bool(self.pending_changes)
|
| 100 |
|
| 101 |
+
def get_pending_value(self, key: str):
|
| 102 |
+
"""Get pending value for a key. Returns sentinel _NOT_FOUND if no pending change."""
|
| 103 |
+
return self.pending_changes.get(key, _NOT_FOUND)
|
| 104 |
+
|
| 105 |
+
def get_original_value(self, key: str) -> Optional[str]:
|
| 106 |
+
"""Get the current .env value (before pending changes)"""
|
| 107 |
+
return os.getenv(key)
|
| 108 |
+
|
| 109 |
+
def get_change_type(self, key: str) -> Optional[str]:
|
| 110 |
+
"""Returns 'add', 'edit', 'remove', or None if no pending change"""
|
| 111 |
+
if key not in self.pending_changes:
|
| 112 |
+
return None
|
| 113 |
+
if self.pending_changes[key] is None:
|
| 114 |
+
return "remove"
|
| 115 |
+
elif os.getenv(key) is not None:
|
| 116 |
+
return "edit"
|
| 117 |
+
else:
|
| 118 |
+
return "add"
|
| 119 |
+
|
| 120 |
+
def get_pending_keys_by_pattern(
|
| 121 |
+
self, prefix: str = "", suffix: str = ""
|
| 122 |
+
) -> List[str]:
|
| 123 |
+
"""Get all pending change keys that match prefix and/or suffix"""
|
| 124 |
+
return [
|
| 125 |
+
k
|
| 126 |
+
for k in self.pending_changes.keys()
|
| 127 |
+
if k.startswith(prefix) and k.endswith(suffix)
|
| 128 |
+
]
|
| 129 |
+
|
| 130 |
+
def get_changes_summary(self) -> Dict[str, List[tuple]]:
|
| 131 |
+
"""Get categorized summary of all pending changes.
|
| 132 |
+
Returns dict with 'add', 'edit', 'remove' keys,
|
| 133 |
+
each containing list of (key, old_val, new_val) tuples.
|
| 134 |
+
"""
|
| 135 |
+
summary: Dict[str, List[tuple]] = {"add": [], "edit": [], "remove": []}
|
| 136 |
+
for key, new_val in self.pending_changes.items():
|
| 137 |
+
old_val = os.getenv(key)
|
| 138 |
+
change_type = self.get_change_type(key)
|
| 139 |
+
if change_type:
|
| 140 |
+
summary[change_type].append((key, old_val, new_val))
|
| 141 |
+
# Sort each list alphabetically by key
|
| 142 |
+
for change_type in summary:
|
| 143 |
+
summary[change_type].sort(key=lambda x: x[0])
|
| 144 |
+
return summary
|
| 145 |
+
|
| 146 |
+
def get_pending_counts(self) -> Dict[str, int]:
|
| 147 |
+
"""Get counts of pending changes by type"""
|
| 148 |
+
adds = len(
|
| 149 |
+
[
|
| 150 |
+
k
|
| 151 |
+
for k, v in self.pending_changes.items()
|
| 152 |
+
if v is not None and os.getenv(k) is None
|
| 153 |
+
]
|
| 154 |
+
)
|
| 155 |
+
edits = len(
|
| 156 |
+
[
|
| 157 |
+
k
|
| 158 |
+
for k, v in self.pending_changes.items()
|
| 159 |
+
if v is not None and os.getenv(k) is not None
|
| 160 |
+
]
|
| 161 |
+
)
|
| 162 |
+
removes = len([k for k, v in self.pending_changes.items() if v is None])
|
| 163 |
+
return {"add": adds, "edit": edits, "remove": removes}
|
| 164 |
+
|
| 165 |
|
| 166 |
class CustomProviderManager:
|
| 167 |
"""Manages custom provider API bases"""
|
|
|
|
| 475 |
"default": "\n\nSTRICT PARAMETERS: {params}.",
|
| 476 |
"description": "Template for Claude strict parameter hints in tool descriptions",
|
| 477 |
},
|
| 478 |
+
"ANTIGRAVITY_OAUTH_PORT": {
|
| 479 |
+
"type": "int",
|
| 480 |
+
"default": ANTIGRAVITY_DEFAULT_OAUTH_PORT,
|
| 481 |
+
"description": "Local port for OAuth callback server during authentication",
|
| 482 |
+
},
|
| 483 |
}
|
| 484 |
|
| 485 |
# Gemini CLI provider environment variables
|
|
|
|
| 524 |
"default": "",
|
| 525 |
"description": "GCP Project ID for paid tier users (required for paid tiers)",
|
| 526 |
},
|
| 527 |
+
"GEMINI_CLI_OAUTH_PORT": {
|
| 528 |
+
"type": "int",
|
| 529 |
+
"default": GEMINI_CLI_DEFAULT_OAUTH_PORT,
|
| 530 |
+
"description": "Local port for OAuth callback server during authentication",
|
| 531 |
+
},
|
| 532 |
+
}
|
| 533 |
+
|
| 534 |
+
# iFlow provider environment variables
|
| 535 |
+
IFLOW_SETTINGS = {
|
| 536 |
+
"IFLOW_OAUTH_PORT": {
|
| 537 |
+
"type": "int",
|
| 538 |
+
"default": IFLOW_DEFAULT_OAUTH_PORT,
|
| 539 |
+
"description": "Local port for OAuth callback server during authentication",
|
| 540 |
+
},
|
| 541 |
}
|
| 542 |
|
| 543 |
# Map provider names to their settings definitions
|
| 544 |
PROVIDER_SETTINGS_MAP = {
|
| 545 |
"antigravity": ANTIGRAVITY_SETTINGS,
|
| 546 |
"gemini_cli": GEMINI_CLI_SETTINGS,
|
| 547 |
+
"iflow": IFLOW_SETTINGS,
|
| 548 |
}
|
| 549 |
|
| 550 |
|
|
|
|
| 628 |
self.provider_settings_mgr = ProviderSettingsManager(self.settings)
|
| 629 |
self.running = True
|
| 630 |
|
| 631 |
+
def _format_item(
|
| 632 |
+
self,
|
| 633 |
+
name: str,
|
| 634 |
+
value: str,
|
| 635 |
+
change_type: Optional[str],
|
| 636 |
+
old_value: Optional[str] = None,
|
| 637 |
+
width: int = 15,
|
| 638 |
+
) -> str:
|
| 639 |
+
"""Format a list item with change indicator.
|
| 640 |
+
|
| 641 |
+
change_type: None, 'add', 'edit', 'remove'
|
| 642 |
+
Returns formatted string like:
|
| 643 |
+
" + myapi https://api.example.com" (green)
|
| 644 |
+
" ~ openai 1 → 5 requests/key" (yellow)
|
| 645 |
+
" - oldapi https://old.api.com" (red)
|
| 646 |
+
" • groq 3 requests/key" (normal)
|
| 647 |
+
"""
|
| 648 |
+
if change_type == "add":
|
| 649 |
+
return f" [green]+ {name:{width}} {value}[/green]"
|
| 650 |
+
elif change_type == "edit":
|
| 651 |
+
if old_value is not None:
|
| 652 |
+
return f" [yellow]~ {name:{width}} {old_value} → {value}[/yellow]"
|
| 653 |
+
else:
|
| 654 |
+
return f" [yellow]~ {name:{width}} {value}[/yellow]"
|
| 655 |
+
elif change_type == "remove":
|
| 656 |
+
return f" [red]- {name:{width}} {value}[/red]"
|
| 657 |
+
else:
|
| 658 |
+
return f" • {name:{width}} {value}"
|
| 659 |
+
|
| 660 |
+
def _get_pending_status_text(self) -> str:
|
| 661 |
+
"""Get formatted pending changes status text for main menu."""
|
| 662 |
+
if not self.settings.has_pending():
|
| 663 |
+
return "[dim]ℹ️ No pending changes[/dim]"
|
| 664 |
+
|
| 665 |
+
counts = self.settings.get_pending_counts()
|
| 666 |
+
parts = []
|
| 667 |
+
if counts["add"]:
|
| 668 |
+
parts.append(
|
| 669 |
+
f"[green]{counts['add']} addition{'s' if counts['add'] > 1 else ''}[/green]"
|
| 670 |
+
)
|
| 671 |
+
if counts["edit"]:
|
| 672 |
+
parts.append(
|
| 673 |
+
f"[yellow]{counts['edit']} modification{'s' if counts['edit'] > 1 else ''}[/yellow]"
|
| 674 |
+
)
|
| 675 |
+
if counts["remove"]:
|
| 676 |
+
parts.append(
|
| 677 |
+
f"[red]{counts['remove']} removal{'s' if counts['remove'] > 1 else ''}[/red]"
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
return f"[bold]ℹ️ Pending changes: {', '.join(parts)}[/bold]"
|
| 681 |
+
self.running = True
|
| 682 |
+
|
| 683 |
def get_available_providers(self) -> List[str]:
|
| 684 |
"""Get list of providers that have credentials configured"""
|
| 685 |
+
env_file = get_data_file(".env")
|
| 686 |
providers = set()
|
| 687 |
|
| 688 |
# Scan for providers with API keys from local .env
|
|
|
|
| 705 |
pass
|
| 706 |
|
| 707 |
# Also check for OAuth providers from files
|
| 708 |
+
from rotator_library.utils.paths import get_oauth_dir
|
| 709 |
+
|
| 710 |
+
oauth_dir = get_oauth_dir()
|
| 711 |
if oauth_dir.exists():
|
| 712 |
for file in oauth_dir.glob("*_oauth_*.json"):
|
| 713 |
provider = file.name.split("_oauth_")[0]
|
|
|
|
| 745 |
self.console.print()
|
| 746 |
self.console.print("━" * 70)
|
| 747 |
|
| 748 |
+
self.console.print(self._get_pending_status_text())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 749 |
|
| 750 |
self.console.print()
|
| 751 |
self.console.print(
|
|
|
|
| 779 |
while True:
|
| 780 |
clear_screen()
|
| 781 |
|
| 782 |
+
# Get current providers from env
|
| 783 |
providers = self.provider_mgr.get_current_providers()
|
| 784 |
|
| 785 |
self.console.print(
|
|
|
|
| 793 |
self.console.print("[bold]📋 Configured Custom Providers[/bold]")
|
| 794 |
self.console.print("━" * 70)
|
| 795 |
|
| 796 |
+
# Build combined view with pending changes
|
| 797 |
+
all_providers: Dict[str, Dict[str, Any]] = {}
|
| 798 |
+
|
| 799 |
+
# Add current providers (from env)
|
| 800 |
+
for name, base in providers.items():
|
| 801 |
+
key = f"{name.upper()}_API_BASE"
|
| 802 |
+
change_type = self.settings.get_change_type(key)
|
| 803 |
+
if change_type == "remove":
|
| 804 |
+
all_providers[name] = {"value": base, "type": "remove", "old": None}
|
| 805 |
+
elif change_type == "edit":
|
| 806 |
+
new_val = self.settings.pending_changes[key]
|
| 807 |
+
all_providers[name] = {
|
| 808 |
+
"value": new_val,
|
| 809 |
+
"type": "edit",
|
| 810 |
+
"old": base,
|
| 811 |
+
}
|
| 812 |
+
else:
|
| 813 |
+
all_providers[name] = {"value": base, "type": None, "old": None}
|
| 814 |
+
|
| 815 |
+
# Add pending new providers (additions)
|
| 816 |
+
for key in self.settings.get_pending_keys_by_pattern(suffix="_API_BASE"):
|
| 817 |
+
if self.settings.get_change_type(key) == "add":
|
| 818 |
+
name = key.replace("_API_BASE", "").lower()
|
| 819 |
+
if name not in all_providers:
|
| 820 |
+
all_providers[name] = {
|
| 821 |
+
"value": self.settings.pending_changes[key],
|
| 822 |
+
"type": "add",
|
| 823 |
+
"old": None,
|
| 824 |
+
}
|
| 825 |
+
|
| 826 |
+
if all_providers:
|
| 827 |
+
# Sort alphabetically
|
| 828 |
+
for name in sorted(all_providers.keys()):
|
| 829 |
+
info = all_providers[name]
|
| 830 |
+
self.console.print(
|
| 831 |
+
self._format_item(
|
| 832 |
+
name,
|
| 833 |
+
info["value"],
|
| 834 |
+
info["type"],
|
| 835 |
+
info["old"],
|
| 836 |
+
)
|
| 837 |
+
)
|
| 838 |
else:
|
| 839 |
self.console.print(" [dim]No custom providers configured[/dim]")
|
| 840 |
|
|
|
|
| 863 |
if api_base:
|
| 864 |
self.provider_mgr.add_provider(name, api_base)
|
| 865 |
self.console.print(
|
| 866 |
+
f"\n[green]✅ Custom provider '{name}' staged![/green]"
|
| 867 |
)
|
| 868 |
self.console.print(
|
| 869 |
f" To use: set {name.upper()}_API_KEY in credentials"
|
|
|
|
| 871 |
input("\nPress Enter to continue...")
|
| 872 |
|
| 873 |
elif choice == "2":
|
| 874 |
+
# Get editable providers (existing + pending additions, excluding pending removals)
|
| 875 |
+
editable = {
|
| 876 |
+
k: v for k, v in all_providers.items() if v["type"] != "remove"
|
| 877 |
+
}
|
| 878 |
+
if not editable:
|
| 879 |
self.console.print("\n[yellow]No providers to edit[/yellow]")
|
| 880 |
input("\nPress Enter to continue...")
|
| 881 |
continue
|
| 882 |
|
| 883 |
# Show numbered list
|
| 884 |
self.console.print("\n[bold]Select provider to edit:[/bold]")
|
| 885 |
+
providers_list = sorted(editable.keys())
|
| 886 |
for idx, prov in enumerate(providers_list, 1):
|
| 887 |
self.console.print(f" {idx}. {prov}")
|
| 888 |
|
|
|
|
| 891 |
choices=[str(i) for i in range(1, len(providers_list) + 1)],
|
| 892 |
)
|
| 893 |
name = providers_list[choice_idx - 1]
|
| 894 |
+
info = editable[name]
|
| 895 |
+
# Get effective current value (could be pending or from env)
|
| 896 |
+
current_base = info["value"]
|
| 897 |
|
| 898 |
self.console.print(f"\nCurrent API Base: {current_base}")
|
| 899 |
new_base = Prompt.ask(
|
|
|
|
| 910 |
input("\nPress Enter to continue...")
|
| 911 |
|
| 912 |
elif choice == "3":
|
| 913 |
+
# Get removable providers (existing ones not already pending removal)
|
| 914 |
+
removable = {
|
| 915 |
+
k: v
|
| 916 |
+
for k, v in all_providers.items()
|
| 917 |
+
if v["type"] != "remove" and v["type"] != "add"
|
| 918 |
+
}
|
| 919 |
+
# For pending additions, we can "undo" by removing from pending
|
| 920 |
+
pending_adds = {
|
| 921 |
+
k: v for k, v in all_providers.items() if v["type"] == "add"
|
| 922 |
+
}
|
| 923 |
+
|
| 924 |
+
if not removable and not pending_adds:
|
| 925 |
self.console.print("\n[yellow]No providers to remove[/yellow]")
|
| 926 |
input("\nPress Enter to continue...")
|
| 927 |
continue
|
| 928 |
|
| 929 |
# Show numbered list
|
| 930 |
self.console.print("\n[bold]Select provider to remove:[/bold]")
|
| 931 |
+
# Show existing providers first, then pending additions
|
| 932 |
+
providers_list = sorted(removable.keys()) + sorted(pending_adds.keys())
|
| 933 |
for idx, prov in enumerate(providers_list, 1):
|
| 934 |
+
if prov in pending_adds:
|
| 935 |
+
self.console.print(
|
| 936 |
+
f" {idx}. {prov} [green](pending add)[/green]"
|
| 937 |
+
)
|
| 938 |
+
else:
|
| 939 |
+
self.console.print(f" {idx}. {prov}")
|
| 940 |
|
| 941 |
choice_idx = IntPrompt.ask(
|
| 942 |
"Select option",
|
|
|
|
| 945 |
name = providers_list[choice_idx - 1]
|
| 946 |
|
| 947 |
if Confirm.ask(f"Remove '{name}'?"):
|
| 948 |
+
if name in pending_adds:
|
| 949 |
+
# Undo pending addition - remove from pending_changes
|
| 950 |
+
key = f"{name.upper()}_API_BASE"
|
| 951 |
+
del self.settings.pending_changes[key]
|
| 952 |
+
self.console.print(
|
| 953 |
+
f"\n[green]✅ Pending addition of '{name}' cancelled![/green]"
|
| 954 |
+
)
|
| 955 |
+
else:
|
| 956 |
+
self.provider_mgr.remove_provider(name)
|
| 957 |
+
self.console.print(
|
| 958 |
+
f"\n[green]✅ Provider '{name}' marked for removal![/green]"
|
| 959 |
+
)
|
| 960 |
input("\nPress Enter to continue...")
|
| 961 |
|
| 962 |
elif choice == "4":
|
|
|
|
| 967 |
while True:
|
| 968 |
clear_screen()
|
| 969 |
|
| 970 |
+
# Get current providers with models from env
|
| 971 |
+
all_providers_env = self.model_mgr.get_all_providers_with_models()
|
| 972 |
|
| 973 |
self.console.print(
|
| 974 |
Panel.fit(
|
|
|
|
| 981 |
self.console.print("[bold]📋 Configured Provider Models[/bold]")
|
| 982 |
self.console.print("━" * 70)
|
| 983 |
|
| 984 |
+
# Build combined view with pending changes
|
| 985 |
+
all_models: Dict[str, Dict[str, Any]] = {}
|
| 986 |
+
suffix = "_MODELS"
|
| 987 |
+
|
| 988 |
+
# Add current providers (from env)
|
| 989 |
+
for provider, count in all_providers_env.items():
|
| 990 |
+
key = f"{provider.upper()}{suffix}"
|
| 991 |
+
change_type = self.settings.get_change_type(key)
|
| 992 |
+
if change_type == "remove":
|
| 993 |
+
all_models[provider] = {
|
| 994 |
+
"value": f"{count} model{'s' if count > 1 else ''}",
|
| 995 |
+
"type": "remove",
|
| 996 |
+
"old": None,
|
| 997 |
+
}
|
| 998 |
+
elif change_type == "edit":
|
| 999 |
+
# Get new model count from pending
|
| 1000 |
+
new_val = self.settings.pending_changes[key]
|
| 1001 |
+
try:
|
| 1002 |
+
parsed = json.loads(new_val)
|
| 1003 |
+
new_count = (
|
| 1004 |
+
len(parsed) if isinstance(parsed, (dict, list)) else 0
|
| 1005 |
+
)
|
| 1006 |
+
except (json.JSONDecodeError, ValueError):
|
| 1007 |
+
new_count = 0
|
| 1008 |
+
all_models[provider] = {
|
| 1009 |
+
"value": f"{new_count} model{'s' if new_count > 1 else ''}",
|
| 1010 |
+
"type": "edit",
|
| 1011 |
+
"old": f"{count} model{'s' if count > 1 else ''}",
|
| 1012 |
+
}
|
| 1013 |
+
else:
|
| 1014 |
+
all_models[provider] = {
|
| 1015 |
+
"value": f"{count} model{'s' if count > 1 else ''}",
|
| 1016 |
+
"type": None,
|
| 1017 |
+
"old": None,
|
| 1018 |
+
}
|
| 1019 |
+
|
| 1020 |
+
# Add pending new model definitions (additions)
|
| 1021 |
+
for key in self.settings.get_pending_keys_by_pattern(suffix=suffix):
|
| 1022 |
+
if self.settings.get_change_type(key) == "add":
|
| 1023 |
+
provider = key.replace(suffix, "").lower()
|
| 1024 |
+
if provider not in all_models:
|
| 1025 |
+
new_val = self.settings.pending_changes[key]
|
| 1026 |
+
try:
|
| 1027 |
+
parsed = json.loads(new_val)
|
| 1028 |
+
new_count = (
|
| 1029 |
+
len(parsed) if isinstance(parsed, (dict, list)) else 0
|
| 1030 |
+
)
|
| 1031 |
+
except (json.JSONDecodeError, ValueError):
|
| 1032 |
+
new_count = 0
|
| 1033 |
+
all_models[provider] = {
|
| 1034 |
+
"value": f"{new_count} model{'s' if new_count > 1 else ''}",
|
| 1035 |
+
"type": "add",
|
| 1036 |
+
"old": None,
|
| 1037 |
+
}
|
| 1038 |
+
|
| 1039 |
+
if all_models:
|
| 1040 |
+
# Sort alphabetically
|
| 1041 |
+
for provider in sorted(all_models.keys()):
|
| 1042 |
+
info = all_models[provider]
|
| 1043 |
self.console.print(
|
| 1044 |
+
self._format_item(
|
| 1045 |
+
provider, info["value"], info["type"], info["old"]
|
| 1046 |
+
)
|
| 1047 |
)
|
| 1048 |
else:
|
| 1049 |
self.console.print(" [dim]No model definitions configured[/dim]")
|
|
|
|
| 1070 |
if choice == "1":
|
| 1071 |
self.add_model_definitions()
|
| 1072 |
elif choice == "2":
|
| 1073 |
+
# Get editable models (existing + pending additions, excluding pending removals)
|
| 1074 |
+
editable = {
|
| 1075 |
+
k: v for k, v in all_models.items() if v["type"] != "remove"
|
| 1076 |
+
}
|
| 1077 |
+
if not editable:
|
| 1078 |
self.console.print("\n[yellow]No providers to edit[/yellow]")
|
| 1079 |
input("\nPress Enter to continue...")
|
| 1080 |
continue
|
| 1081 |
+
self.edit_model_definitions(sorted(editable.keys()))
|
| 1082 |
elif choice == "3":
|
| 1083 |
+
viewable = {
|
| 1084 |
+
k: v for k, v in all_models.items() if v["type"] != "remove"
|
| 1085 |
+
}
|
| 1086 |
+
if not viewable:
|
| 1087 |
self.console.print("\n[yellow]No providers to view[/yellow]")
|
| 1088 |
input("\nPress Enter to continue...")
|
| 1089 |
continue
|
| 1090 |
+
self.view_model_definitions(sorted(viewable.keys()))
|
| 1091 |
elif choice == "4":
|
| 1092 |
+
# Get removable models (existing ones not already pending removal)
|
| 1093 |
+
removable = {
|
| 1094 |
+
k: v
|
| 1095 |
+
for k, v in all_models.items()
|
| 1096 |
+
if v["type"] != "remove" and v["type"] != "add"
|
| 1097 |
+
}
|
| 1098 |
+
pending_adds = {
|
| 1099 |
+
k: v for k, v in all_models.items() if v["type"] == "add"
|
| 1100 |
+
}
|
| 1101 |
+
|
| 1102 |
+
if not removable and not pending_adds:
|
| 1103 |
self.console.print("\n[yellow]No providers to remove[/yellow]")
|
| 1104 |
input("\nPress Enter to continue...")
|
| 1105 |
continue
|
|
|
|
| 1108 |
self.console.print(
|
| 1109 |
"\n[bold]Select provider to remove models from:[/bold]"
|
| 1110 |
)
|
| 1111 |
+
providers_list = sorted(removable.keys()) + sorted(pending_adds.keys())
|
| 1112 |
for idx, prov in enumerate(providers_list, 1):
|
| 1113 |
+
if prov in pending_adds:
|
| 1114 |
+
self.console.print(
|
| 1115 |
+
f" {idx}. {prov} [green](pending add)[/green]"
|
| 1116 |
+
)
|
| 1117 |
+
else:
|
| 1118 |
+
self.console.print(f" {idx}. {prov}")
|
| 1119 |
|
| 1120 |
choice_idx = IntPrompt.ask(
|
| 1121 |
"Select option",
|
|
|
|
| 1124 |
provider = providers_list[choice_idx - 1]
|
| 1125 |
|
| 1126 |
if Confirm.ask(f"Remove all model definitions for '{provider}'?"):
|
| 1127 |
+
if provider in pending_adds:
|
| 1128 |
+
# Undo pending addition
|
| 1129 |
+
key = f"{provider.upper()}{suffix}"
|
| 1130 |
+
del self.settings.pending_changes[key]
|
| 1131 |
+
self.console.print(
|
| 1132 |
+
f"\n[green]✅ Pending models for '{provider}' cancelled![/green]"
|
| 1133 |
+
)
|
| 1134 |
+
else:
|
| 1135 |
+
self.model_mgr.remove_models(provider)
|
| 1136 |
+
self.console.print(
|
| 1137 |
+
f"\n[green]✅ Model definitions marked for removal for '{provider}'![/green]"
|
| 1138 |
+
)
|
| 1139 |
input("\nPress Enter to continue...")
|
| 1140 |
elif choice == "5":
|
| 1141 |
break
|
|
|
|
| 1462 |
self.console.print("[bold]📋 Current Settings[/bold]")
|
| 1463 |
self.console.print("━" * 70)
|
| 1464 |
|
| 1465 |
+
# Display all settings with current values and pending changes
|
| 1466 |
settings_list = list(definitions.keys())
|
| 1467 |
for idx, key in enumerate(settings_list, 1):
|
| 1468 |
definition = definitions[key]
|
|
|
|
| 1471 |
setting_type = definition.get("type", "str")
|
| 1472 |
description = definition.get("description", "")
|
| 1473 |
|
| 1474 |
+
# Check for pending changes
|
| 1475 |
+
change_type = self.settings.get_change_type(key)
|
| 1476 |
+
pending_val = self.settings.get_pending_value(key)
|
| 1477 |
+
|
| 1478 |
+
# Determine effective value to display
|
| 1479 |
+
if pending_val is not _NOT_FOUND and pending_val is not None:
|
| 1480 |
+
# Has pending change - convert to proper type for display
|
| 1481 |
+
if setting_type == "bool":
|
| 1482 |
+
effective = pending_val.lower() in ("true", "1", "yes")
|
| 1483 |
+
elif setting_type == "int":
|
| 1484 |
+
try:
|
| 1485 |
+
effective = int(pending_val)
|
| 1486 |
+
except (ValueError, TypeError):
|
| 1487 |
+
effective = pending_val
|
| 1488 |
+
else:
|
| 1489 |
+
effective = pending_val
|
| 1490 |
+
elif pending_val is None and change_type == "remove":
|
| 1491 |
+
# Pending removal - will revert to default
|
| 1492 |
+
effective = default
|
| 1493 |
+
else:
|
| 1494 |
+
effective = current
|
| 1495 |
+
|
| 1496 |
# Format value display
|
| 1497 |
if setting_type == "bool":
|
| 1498 |
value_display = (
|
| 1499 |
"[green]✓ Enabled[/green]"
|
| 1500 |
+
if effective
|
| 1501 |
else "[red]✗ Disabled[/red]"
|
| 1502 |
)
|
| 1503 |
+
old_display = (
|
| 1504 |
+
(
|
| 1505 |
+
"[green]✓ Enabled[/green]"
|
| 1506 |
+
if current
|
| 1507 |
+
else "[red]✗ Disabled[/red]"
|
| 1508 |
+
)
|
| 1509 |
+
if change_type
|
| 1510 |
+
else None
|
| 1511 |
+
)
|
| 1512 |
elif setting_type == "int":
|
| 1513 |
+
value_display = f"[cyan]{effective}[/cyan]"
|
| 1514 |
+
old_display = f"[cyan]{current}[/cyan]" if change_type else None
|
| 1515 |
else:
|
| 1516 |
value_display = (
|
| 1517 |
+
f"[cyan]{effective or '(not set)'}[/cyan]"
|
| 1518 |
+
if effective
|
| 1519 |
else "[dim](not set)[/dim]"
|
| 1520 |
)
|
| 1521 |
+
old_display = (
|
| 1522 |
+
f"[cyan]{current}[/cyan]" if change_type and current else None
|
| 1523 |
+
)
|
|
|
|
| 1524 |
|
| 1525 |
# Short key name for display (strip provider prefix)
|
| 1526 |
short_key = key.replace(f"{provider.upper()}_", "")
|
| 1527 |
|
| 1528 |
+
# Determine display marker based on pending change type
|
| 1529 |
+
if change_type == "add":
|
| 1530 |
+
self.console.print(
|
| 1531 |
+
f" [green]+{idx:2}. {short_key:35} {value_display}[/green]"
|
| 1532 |
+
)
|
| 1533 |
+
elif change_type == "edit":
|
| 1534 |
+
self.console.print(
|
| 1535 |
+
f" [yellow]~{idx:2}. {short_key:35} {old_display} → {value_display}[/yellow]"
|
| 1536 |
+
)
|
| 1537 |
+
elif change_type == "remove":
|
| 1538 |
+
self.console.print(
|
| 1539 |
+
f" [red]-{idx:2}. {short_key:35} {old_display} → [dim](default: {default})[/dim][/red]"
|
| 1540 |
+
)
|
| 1541 |
+
else:
|
| 1542 |
+
# Check if modified from default (in env, not pending)
|
| 1543 |
+
modified = current != default
|
| 1544 |
+
mod_marker = "[yellow]*[/yellow]" if modified else " "
|
| 1545 |
+
self.console.print(
|
| 1546 |
+
f" {mod_marker}{idx:2}. {short_key:35} {value_display}"
|
| 1547 |
+
)
|
| 1548 |
+
|
| 1549 |
self.console.print(f" [dim]{description}[/dim]")
|
| 1550 |
|
| 1551 |
self.console.print()
|
| 1552 |
self.console.print("━" * 70)
|
| 1553 |
+
self.console.print(
|
| 1554 |
+
"[dim]* = modified from default, + = pending add, ~ = pending edit, - = pending reset[/dim]"
|
| 1555 |
+
)
|
| 1556 |
self.console.print()
|
| 1557 |
self.console.print("[bold]⚙️ Actions[/bold]")
|
| 1558 |
self.console.print()
|
|
|
|
| 1672 |
while True:
|
| 1673 |
clear_screen()
|
| 1674 |
|
| 1675 |
+
# Get current modes from env
|
| 1676 |
modes = self.rotation_mgr.get_current_modes()
|
| 1677 |
available_providers = self.get_available_providers()
|
| 1678 |
|
|
|
|
| 1696 |
self.console.print("[bold]📋 Current Rotation Mode Settings[/bold]")
|
| 1697 |
self.console.print("━" * 70)
|
| 1698 |
|
| 1699 |
+
# Build combined view with pending changes
|
| 1700 |
+
all_modes: Dict[str, Dict[str, Any]] = {}
|
| 1701 |
+
prefix = "ROTATION_MODE_"
|
| 1702 |
+
|
| 1703 |
+
# Add current modes (from env)
|
| 1704 |
+
for provider, mode in modes.items():
|
| 1705 |
+
key = f"{prefix}{provider.upper()}"
|
| 1706 |
+
change_type = self.settings.get_change_type(key)
|
| 1707 |
+
default_mode = self.rotation_mgr.get_default_mode(provider)
|
| 1708 |
+
if change_type == "remove":
|
| 1709 |
+
all_modes[provider] = {"value": mode, "type": "remove", "old": None}
|
| 1710 |
+
elif change_type == "edit":
|
| 1711 |
+
new_val = self.settings.pending_changes[key]
|
| 1712 |
+
all_modes[provider] = {
|
| 1713 |
+
"value": new_val,
|
| 1714 |
+
"type": "edit",
|
| 1715 |
+
"old": mode,
|
| 1716 |
+
}
|
| 1717 |
+
else:
|
| 1718 |
+
all_modes[provider] = {"value": mode, "type": None, "old": None}
|
| 1719 |
+
|
| 1720 |
+
# Add pending new modes (additions)
|
| 1721 |
+
for key in self.settings.get_pending_keys_by_pattern(prefix=prefix):
|
| 1722 |
+
if self.settings.get_change_type(key) == "add":
|
| 1723 |
+
provider = key.replace(prefix, "").lower()
|
| 1724 |
+
if provider not in all_modes:
|
| 1725 |
+
all_modes[provider] = {
|
| 1726 |
+
"value": self.settings.pending_changes[key],
|
| 1727 |
+
"type": "add",
|
| 1728 |
+
"old": None,
|
| 1729 |
+
}
|
| 1730 |
+
|
| 1731 |
+
if all_modes:
|
| 1732 |
+
# Sort alphabetically
|
| 1733 |
+
for provider in sorted(all_modes.keys()):
|
| 1734 |
+
info = all_modes[provider]
|
| 1735 |
+
mode = info["value"]
|
| 1736 |
mode_display = (
|
| 1737 |
f"[green]{mode}[/green]"
|
| 1738 |
if mode == "sequential"
|
| 1739 |
else f"[blue]{mode}[/blue]"
|
| 1740 |
)
|
| 1741 |
+
old_display = None
|
| 1742 |
+
if info["old"]:
|
| 1743 |
+
old_display = (
|
| 1744 |
+
f"[green]{info['old']}[/green]"
|
| 1745 |
+
if info["old"] == "sequential"
|
| 1746 |
+
else f"[blue]{info['old']}[/blue]"
|
| 1747 |
+
)
|
| 1748 |
+
|
| 1749 |
+
if info["type"] == "add":
|
| 1750 |
+
self.console.print(
|
| 1751 |
+
f" [green]+ {provider:20} {mode_display}[/green]"
|
| 1752 |
+
)
|
| 1753 |
+
elif info["type"] == "edit":
|
| 1754 |
+
self.console.print(
|
| 1755 |
+
f" [yellow]~ {provider:20} {old_display} → {mode_display}[/yellow]"
|
| 1756 |
+
)
|
| 1757 |
+
elif info["type"] == "remove":
|
| 1758 |
+
self.console.print(
|
| 1759 |
+
f" [red]- {provider:20} {mode_display}[/red]"
|
| 1760 |
+
)
|
| 1761 |
+
else:
|
| 1762 |
+
default_mode = self.rotation_mgr.get_default_mode(provider)
|
| 1763 |
+
is_custom = mode != default_mode
|
| 1764 |
+
marker = "[yellow]*[/yellow]" if is_custom else " "
|
| 1765 |
+
self.console.print(f" {marker}• {provider:20} {mode_display}")
|
| 1766 |
|
| 1767 |
# Show providers with default modes
|
| 1768 |
+
providers_with_defaults = [
|
| 1769 |
+
p for p in available_providers if p not in modes and p not in all_modes
|
| 1770 |
+
]
|
| 1771 |
if providers_with_defaults:
|
| 1772 |
self.console.print()
|
| 1773 |
self.console.print("[dim]Providers using default modes:[/dim]")
|
|
|
|
| 1855 |
|
| 1856 |
self.rotation_mgr.set_mode(provider, new_mode)
|
| 1857 |
self.console.print(
|
| 1858 |
+
f"\n[green]✅ Rotation mode for '{provider}' staged as {new_mode}![/green]"
|
| 1859 |
)
|
| 1860 |
input("\nPress Enter to continue...")
|
| 1861 |
|
| 1862 |
elif choice == "2":
|
| 1863 |
+
# Get resettable modes (existing + pending adds, excluding pending removes)
|
| 1864 |
+
resettable = {
|
| 1865 |
+
k: v for k, v in all_modes.items() if v["type"] != "remove"
|
| 1866 |
+
}
|
| 1867 |
+
if not resettable:
|
| 1868 |
self.console.print(
|
| 1869 |
"\n[yellow]No custom rotation modes to reset[/yellow]"
|
| 1870 |
)
|
|
|
|
| 1875 |
self.console.print(
|
| 1876 |
"\n[bold]Select provider to reset to default:[/bold]"
|
| 1877 |
)
|
| 1878 |
+
modes_list = sorted(resettable.keys())
|
| 1879 |
for idx, prov in enumerate(modes_list, 1):
|
| 1880 |
default_mode = self.rotation_mgr.get_default_mode(prov)
|
| 1881 |
+
info = resettable[prov]
|
| 1882 |
+
if info["type"] == "add":
|
| 1883 |
+
self.console.print(
|
| 1884 |
+
f" {idx}. {prov} [green](pending add)[/green] - will cancel"
|
| 1885 |
+
)
|
| 1886 |
+
else:
|
| 1887 |
+
self.console.print(
|
| 1888 |
+
f" {idx}. {prov} (will reset to: {default_mode})"
|
| 1889 |
+
)
|
| 1890 |
|
| 1891 |
choice_idx = IntPrompt.ask(
|
| 1892 |
"Select option",
|
|
|
|
| 1894 |
)
|
| 1895 |
provider = modes_list[choice_idx - 1]
|
| 1896 |
default_mode = self.rotation_mgr.get_default_mode(provider)
|
| 1897 |
+
info = resettable[provider]
|
| 1898 |
|
| 1899 |
if Confirm.ask(f"Reset '{provider}' to default mode ({default_mode})?"):
|
| 1900 |
+
if info["type"] == "add":
|
| 1901 |
+
# Undo pending addition
|
| 1902 |
+
key = f"{prefix}{provider.upper()}"
|
| 1903 |
+
del self.settings.pending_changes[key]
|
| 1904 |
+
self.console.print(
|
| 1905 |
+
f"\n[green]✅ Pending mode for '{provider}' cancelled![/green]"
|
| 1906 |
+
)
|
| 1907 |
+
else:
|
| 1908 |
+
self.rotation_mgr.remove_mode(provider)
|
| 1909 |
+
self.console.print(
|
| 1910 |
+
f"\n[green]✅ Rotation mode for '{provider}' marked for reset to default ({default_mode})![/green]"
|
| 1911 |
+
)
|
| 1912 |
input("\nPress Enter to continue...")
|
| 1913 |
|
| 1914 |
elif choice == "3":
|
|
|
|
| 2081 |
while True:
|
| 2082 |
clear_screen()
|
| 2083 |
|
| 2084 |
+
# Get current limits from env
|
| 2085 |
limits = self.concurrency_mgr.get_current_limits()
|
| 2086 |
|
| 2087 |
self.console.print(
|
|
|
|
| 2095 |
self.console.print("[bold]📋 Current Concurrency Settings[/bold]")
|
| 2096 |
self.console.print("━" * 70)
|
| 2097 |
|
| 2098 |
+
# Build combined view with pending changes
|
| 2099 |
+
all_limits: Dict[str, Dict[str, Any]] = {}
|
| 2100 |
+
prefix = "MAX_CONCURRENT_REQUESTS_PER_KEY_"
|
| 2101 |
+
|
| 2102 |
+
# Add current limits (from env)
|
| 2103 |
+
for provider, limit in limits.items():
|
| 2104 |
+
key = f"{prefix}{provider.upper()}"
|
| 2105 |
+
change_type = self.settings.get_change_type(key)
|
| 2106 |
+
if change_type == "remove":
|
| 2107 |
+
all_limits[provider] = {
|
| 2108 |
+
"value": str(limit),
|
| 2109 |
+
"type": "remove",
|
| 2110 |
+
"old": None,
|
| 2111 |
+
}
|
| 2112 |
+
elif change_type == "edit":
|
| 2113 |
+
new_val = self.settings.pending_changes[key]
|
| 2114 |
+
all_limits[provider] = {
|
| 2115 |
+
"value": new_val,
|
| 2116 |
+
"type": "edit",
|
| 2117 |
+
"old": str(limit),
|
| 2118 |
+
}
|
| 2119 |
+
else:
|
| 2120 |
+
all_limits[provider] = {
|
| 2121 |
+
"value": str(limit),
|
| 2122 |
+
"type": None,
|
| 2123 |
+
"old": None,
|
| 2124 |
+
}
|
| 2125 |
+
|
| 2126 |
+
# Add pending new limits (additions)
|
| 2127 |
+
for key in self.settings.get_pending_keys_by_pattern(prefix=prefix):
|
| 2128 |
+
if self.settings.get_change_type(key) == "add":
|
| 2129 |
+
provider = key.replace(prefix, "").lower()
|
| 2130 |
+
if provider not in all_limits:
|
| 2131 |
+
all_limits[provider] = {
|
| 2132 |
+
"value": self.settings.pending_changes[key],
|
| 2133 |
+
"type": "add",
|
| 2134 |
+
"old": None,
|
| 2135 |
+
}
|
| 2136 |
+
|
| 2137 |
+
if all_limits:
|
| 2138 |
+
# Sort alphabetically
|
| 2139 |
+
for provider in sorted(all_limits.keys()):
|
| 2140 |
+
info = all_limits[provider]
|
| 2141 |
+
value_display = f"{info['value']} requests/key"
|
| 2142 |
+
old_display = f"{info['old']} requests/key" if info["old"] else None
|
| 2143 |
+
self.console.print(
|
| 2144 |
+
self._format_item(
|
| 2145 |
+
provider, value_display, info["type"], old_display
|
| 2146 |
+
)
|
| 2147 |
+
)
|
| 2148 |
+
self.console.print(" • Default: 1 request/key (all others)")
|
| 2149 |
else:
|
| 2150 |
self.console.print(" • Default: 1 request/key (all providers)")
|
| 2151 |
|
|
|
|
| 2203 |
if 1 <= limit <= 100:
|
| 2204 |
self.concurrency_mgr.set_limit(provider, limit)
|
| 2205 |
self.console.print(
|
| 2206 |
+
f"\n[green]✅ Concurrency limit staged for '{provider}': {limit} requests/key[/green]"
|
| 2207 |
)
|
| 2208 |
else:
|
| 2209 |
self.console.print(
|
|
|
|
| 2212 |
input("\nPress Enter to continue...")
|
| 2213 |
|
| 2214 |
elif choice == "2":
|
| 2215 |
+
# Get editable limits (existing + pending additions, excluding pending removals)
|
| 2216 |
+
editable = {
|
| 2217 |
+
k: v for k, v in all_limits.items() if v["type"] != "remove"
|
| 2218 |
+
}
|
| 2219 |
+
if not editable:
|
| 2220 |
self.console.print("\n[yellow]No limits to edit[/yellow]")
|
| 2221 |
input("\nPress Enter to continue...")
|
| 2222 |
continue
|
| 2223 |
|
| 2224 |
# Show numbered list
|
| 2225 |
self.console.print("\n[bold]Select provider to edit:[/bold]")
|
| 2226 |
+
limits_list = sorted(editable.keys())
|
| 2227 |
for idx, prov in enumerate(limits_list, 1):
|
| 2228 |
self.console.print(f" {idx}. {prov}")
|
| 2229 |
|
|
|
|
| 2232 |
choices=[str(i) for i in range(1, len(limits_list) + 1)],
|
| 2233 |
)
|
| 2234 |
provider = limits_list[choice_idx - 1]
|
| 2235 |
+
info = editable[provider]
|
| 2236 |
+
current_limit = int(info["value"])
|
| 2237 |
|
| 2238 |
self.console.print(f"\nCurrent limit: {current_limit} requests/key")
|
| 2239 |
new_limit = IntPrompt.ask(
|
|
|
|
| 2254 |
input("\nPress Enter to continue...")
|
| 2255 |
|
| 2256 |
elif choice == "3":
|
| 2257 |
+
# Get removable limits (existing ones not already pending removal)
|
| 2258 |
+
removable = {
|
| 2259 |
+
k: v
|
| 2260 |
+
for k, v in all_limits.items()
|
| 2261 |
+
if v["type"] != "remove" and v["type"] != "add"
|
| 2262 |
+
}
|
| 2263 |
+
# For pending additions, we can "undo" by removing from pending
|
| 2264 |
+
pending_adds = {
|
| 2265 |
+
k: v for k, v in all_limits.items() if v["type"] == "add"
|
| 2266 |
+
}
|
| 2267 |
+
|
| 2268 |
+
if not removable and not pending_adds:
|
| 2269 |
self.console.print("\n[yellow]No limits to remove[/yellow]")
|
| 2270 |
input("\nPress Enter to continue...")
|
| 2271 |
continue
|
|
|
|
| 2274 |
self.console.print(
|
| 2275 |
"\n[bold]Select provider to remove limit from:[/bold]"
|
| 2276 |
)
|
| 2277 |
+
limits_list = sorted(removable.keys()) + sorted(pending_adds.keys())
|
| 2278 |
for idx, prov in enumerate(limits_list, 1):
|
| 2279 |
+
if prov in pending_adds:
|
| 2280 |
+
self.console.print(
|
| 2281 |
+
f" {idx}. {prov} [green](pending add)[/green]"
|
| 2282 |
+
)
|
| 2283 |
+
else:
|
| 2284 |
+
self.console.print(f" {idx}. {prov}")
|
| 2285 |
|
| 2286 |
choice_idx = IntPrompt.ask(
|
| 2287 |
"Select option",
|
|
|
|
| 2292 |
if Confirm.ask(
|
| 2293 |
f"Remove concurrency limit for '{provider}' (reset to default 1)?"
|
| 2294 |
):
|
| 2295 |
+
if provider in pending_adds:
|
| 2296 |
+
# Undo pending addition
|
| 2297 |
+
key = f"{prefix}{provider.upper()}"
|
| 2298 |
+
del self.settings.pending_changes[key]
|
| 2299 |
+
self.console.print(
|
| 2300 |
+
f"\n[green]✅ Pending limit for '{provider}' cancelled![/green]"
|
| 2301 |
+
)
|
| 2302 |
+
else:
|
| 2303 |
+
self.concurrency_mgr.remove_limit(provider)
|
| 2304 |
+
self.console.print(
|
| 2305 |
+
f"\n[green]✅ Limit marked for removal for '{provider}'[/green]"
|
| 2306 |
+
)
|
| 2307 |
input("\nPress Enter to continue...")
|
| 2308 |
|
| 2309 |
elif choice == "4":
|
| 2310 |
break
|
| 2311 |
|
| 2312 |
+
def _show_changes_summary(self):
|
| 2313 |
+
"""Display categorized summary of all pending changes."""
|
| 2314 |
+
self.console.print(
|
| 2315 |
+
Panel.fit(
|
| 2316 |
+
"[bold cyan]📋 Pending Changes Summary[/bold cyan]",
|
| 2317 |
+
border_style="cyan",
|
| 2318 |
+
)
|
| 2319 |
+
)
|
| 2320 |
+
self.console.print()
|
| 2321 |
+
|
| 2322 |
+
# Define categories with their key patterns
|
| 2323 |
+
categories = [
|
| 2324 |
+
("Custom Provider API Bases", "_API_BASE", "suffix"),
|
| 2325 |
+
("Model Definitions", "_MODELS", "suffix"),
|
| 2326 |
+
("Concurrency Limits", "MAX_CONCURRENT_REQUESTS_PER_KEY_", "prefix"),
|
| 2327 |
+
("Rotation Modes", "ROTATION_MODE_", "prefix"),
|
| 2328 |
+
("Priority Multipliers", "CONCURRENCY_MULTIPLIER_", "prefix"),
|
| 2329 |
+
]
|
| 2330 |
+
|
| 2331 |
+
# Get provider-specific settings keys
|
| 2332 |
+
provider_settings_keys = set()
|
| 2333 |
+
for provider_settings in PROVIDER_SETTINGS_MAP.values():
|
| 2334 |
+
provider_settings_keys.update(provider_settings.keys())
|
| 2335 |
+
|
| 2336 |
+
changes = self.settings.get_changes_summary()
|
| 2337 |
+
displayed_keys = set()
|
| 2338 |
+
|
| 2339 |
+
for category_name, pattern, pattern_type in categories:
|
| 2340 |
+
category_changes = {"add": [], "edit": [], "remove": []}
|
| 2341 |
+
|
| 2342 |
+
for change_type in ["add", "edit", "remove"]:
|
| 2343 |
+
for key, old_val, new_val in changes[change_type]:
|
| 2344 |
+
matches = False
|
| 2345 |
+
if pattern_type == "suffix" and key.endswith(pattern):
|
| 2346 |
+
matches = True
|
| 2347 |
+
elif pattern_type == "prefix" and key.startswith(pattern):
|
| 2348 |
+
matches = True
|
| 2349 |
+
|
| 2350 |
+
if matches:
|
| 2351 |
+
category_changes[change_type].append((key, old_val, new_val))
|
| 2352 |
+
displayed_keys.add(key)
|
| 2353 |
+
|
| 2354 |
+
# Check if this category has any changes
|
| 2355 |
+
has_changes = any(category_changes[t] for t in ["add", "edit", "remove"])
|
| 2356 |
+
if has_changes:
|
| 2357 |
+
self.console.print(f"[bold]{category_name}:[/bold]")
|
| 2358 |
+
# Sort: additions, modifications, removals (alphabetically within each)
|
| 2359 |
+
for change_type in ["add", "edit", "remove"]:
|
| 2360 |
+
for key, old_val, new_val in sorted(
|
| 2361 |
+
category_changes[change_type], key=lambda x: x[0]
|
| 2362 |
+
):
|
| 2363 |
+
if change_type == "add":
|
| 2364 |
+
self.console.print(f" [green]+ {key} = {new_val}[/green]")
|
| 2365 |
+
elif change_type == "edit":
|
| 2366 |
+
self.console.print(
|
| 2367 |
+
f" [yellow]~ {key}: {old_val} → {new_val}[/yellow]"
|
| 2368 |
+
)
|
| 2369 |
+
else:
|
| 2370 |
+
self.console.print(f" [red]- {key}[/red]")
|
| 2371 |
+
self.console.print()
|
| 2372 |
+
|
| 2373 |
+
# Handle provider-specific settings that don't match the patterns above
|
| 2374 |
+
provider_changes = {"add": [], "edit": [], "remove": []}
|
| 2375 |
+
for change_type in ["add", "edit", "remove"]:
|
| 2376 |
+
for key, old_val, new_val in changes[change_type]:
|
| 2377 |
+
if key not in displayed_keys and key in provider_settings_keys:
|
| 2378 |
+
provider_changes[change_type].append((key, old_val, new_val))
|
| 2379 |
+
|
| 2380 |
+
has_provider_changes = any(
|
| 2381 |
+
provider_changes[t] for t in ["add", "edit", "remove"]
|
| 2382 |
+
)
|
| 2383 |
+
if has_provider_changes:
|
| 2384 |
+
self.console.print("[bold]Provider-Specific Settings:[/bold]")
|
| 2385 |
+
for change_type in ["add", "edit", "remove"]:
|
| 2386 |
+
for key, old_val, new_val in sorted(
|
| 2387 |
+
provider_changes[change_type], key=lambda x: x[0]
|
| 2388 |
+
):
|
| 2389 |
+
if change_type == "add":
|
| 2390 |
+
self.console.print(f" [green]+ {key} = {new_val}[/green]")
|
| 2391 |
+
elif change_type == "edit":
|
| 2392 |
+
self.console.print(
|
| 2393 |
+
f" [yellow]~ {key}: {old_val} → {new_val}[/yellow]"
|
| 2394 |
+
)
|
| 2395 |
+
else:
|
| 2396 |
+
self.console.print(f" [red]- {key}[/red]")
|
| 2397 |
+
self.console.print()
|
| 2398 |
+
|
| 2399 |
+
self.console.print("━" * 70)
|
| 2400 |
+
|
| 2401 |
def save_and_exit(self):
|
| 2402 |
"""Save pending changes and exit"""
|
| 2403 |
if self.settings.has_pending():
|
| 2404 |
+
clear_screen()
|
| 2405 |
+
self._show_changes_summary()
|
| 2406 |
+
|
| 2407 |
if Confirm.ask("\n[bold yellow]Save all pending changes?[/bold yellow]"):
|
| 2408 |
self.settings.save()
|
| 2409 |
self.console.print("\n[green]✅ All changes saved to .env![/green]")
|
|
|
|
| 2421 |
def exit_without_saving(self):
|
| 2422 |
"""Exit without saving"""
|
| 2423 |
if self.settings.has_pending():
|
| 2424 |
+
clear_screen()
|
| 2425 |
+
self._show_changes_summary()
|
| 2426 |
+
|
| 2427 |
if Confirm.ask("\n[bold red]Discard all pending changes?[/bold red]"):
|
| 2428 |
self.settings.discard()
|
| 2429 |
self.console.print("\n[yellow]Changes discarded[/yellow]")
|
src/rotator_library/client.py
CHANGED
|
@@ -10,6 +10,7 @@ import litellm
|
|
| 10 |
from litellm.exceptions import APIConnectionError
|
| 11 |
from litellm.litellm_core_utils.token_counter import token_counter
|
| 12 |
import logging
|
|
|
|
| 13 |
from typing import List, Dict, Any, AsyncGenerator, Optional, Union
|
| 14 |
|
| 15 |
lib_logger = logging.getLogger("rotator_library")
|
|
@@ -19,7 +20,7 @@ lib_logger = logging.getLogger("rotator_library")
|
|
| 19 |
lib_logger.propagate = False
|
| 20 |
|
| 21 |
from .usage_manager import UsageManager
|
| 22 |
-
from .failure_logger import log_failure
|
| 23 |
from .error_handler import (
|
| 24 |
PreRequestCallbackError,
|
| 25 |
classify_error,
|
|
@@ -37,6 +38,7 @@ from .cooldown_manager import CooldownManager
|
|
| 37 |
from .credential_manager import CredentialManager
|
| 38 |
from .background_refresher import BackgroundRefresher
|
| 39 |
from .model_definitions import ModelDefinitions
|
|
|
|
| 40 |
|
| 41 |
|
| 42 |
class StreamedAPIError(Exception):
|
|
@@ -58,7 +60,7 @@ class RotatingClient:
|
|
| 58 |
api_keys: Optional[Dict[str, List[str]]] = None,
|
| 59 |
oauth_credentials: Optional[Dict[str, List[str]]] = None,
|
| 60 |
max_retries: int = 2,
|
| 61 |
-
usage_file_path: str =
|
| 62 |
configure_logging: bool = True,
|
| 63 |
global_timeout: int = 30,
|
| 64 |
abort_on_callback_error: bool = True,
|
|
@@ -68,6 +70,7 @@ class RotatingClient:
|
|
| 68 |
enable_request_logging: bool = False,
|
| 69 |
max_concurrent_requests_per_key: Optional[Dict[str, int]] = None,
|
| 70 |
rotation_tolerance: float = 3.0,
|
|
|
|
| 71 |
):
|
| 72 |
"""
|
| 73 |
Initialize the RotatingClient with intelligent credential rotation.
|
|
@@ -76,7 +79,7 @@ class RotatingClient:
|
|
| 76 |
api_keys: Dictionary mapping provider names to lists of API keys
|
| 77 |
oauth_credentials: Dictionary mapping provider names to OAuth credential paths
|
| 78 |
max_retries: Maximum number of retry attempts per credential
|
| 79 |
-
usage_file_path: Path to store usage statistics
|
| 80 |
configure_logging: Whether to configure library logging
|
| 81 |
global_timeout: Global timeout for requests in seconds
|
| 82 |
abort_on_callback_error: Whether to abort on pre-request callback errors
|
|
@@ -89,7 +92,18 @@ class RotatingClient:
|
|
| 89 |
- 0.0: Deterministic, least-used credential always selected
|
| 90 |
- 2.0 - 4.0 (default, recommended): Balanced randomness, can pick credentials within 2 uses of max
|
| 91 |
- 5.0+: High randomness, more unpredictable selection patterns
|
|
|
|
|
|
|
| 92 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
os.environ["LITELLM_LOG"] = "ERROR"
|
| 94 |
litellm.set_verbose = False
|
| 95 |
litellm.drop_params = True
|
|
@@ -124,7 +138,9 @@ class RotatingClient:
|
|
| 124 |
if oauth_credentials:
|
| 125 |
self.oauth_credentials = oauth_credentials
|
| 126 |
else:
|
| 127 |
-
self.credential_manager = CredentialManager(
|
|
|
|
|
|
|
| 128 |
self.oauth_credentials = self.credential_manager.discover_and_prepare()
|
| 129 |
self.background_refresher = BackgroundRefresher(self)
|
| 130 |
self.oauth_providers = set(self.oauth_credentials.keys())
|
|
@@ -242,8 +258,14 @@ class RotatingClient:
|
|
| 242 |
f"Provider '{provider}' sequential fallback multiplier: {fallback}x"
|
| 243 |
)
|
| 244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
self.usage_manager = UsageManager(
|
| 246 |
-
file_path=
|
| 247 |
rotation_tolerance=rotation_tolerance,
|
| 248 |
provider_rotation_modes=provider_rotation_modes,
|
| 249 |
provider_plugins=PROVIDER_PLUGINS,
|
|
|
|
| 10 |
from litellm.exceptions import APIConnectionError
|
| 11 |
from litellm.litellm_core_utils.token_counter import token_counter
|
| 12 |
import logging
|
| 13 |
+
from pathlib import Path
|
| 14 |
from typing import List, Dict, Any, AsyncGenerator, Optional, Union
|
| 15 |
|
| 16 |
lib_logger = logging.getLogger("rotator_library")
|
|
|
|
| 20 |
lib_logger.propagate = False
|
| 21 |
|
| 22 |
from .usage_manager import UsageManager
|
| 23 |
+
from .failure_logger import log_failure, configure_failure_logger
|
| 24 |
from .error_handler import (
|
| 25 |
PreRequestCallbackError,
|
| 26 |
classify_error,
|
|
|
|
| 38 |
from .credential_manager import CredentialManager
|
| 39 |
from .background_refresher import BackgroundRefresher
|
| 40 |
from .model_definitions import ModelDefinitions
|
| 41 |
+
from .utils.paths import get_default_root, get_logs_dir, get_oauth_dir, get_data_file
|
| 42 |
|
| 43 |
|
| 44 |
class StreamedAPIError(Exception):
|
|
|
|
| 60 |
api_keys: Optional[Dict[str, List[str]]] = None,
|
| 61 |
oauth_credentials: Optional[Dict[str, List[str]]] = None,
|
| 62 |
max_retries: int = 2,
|
| 63 |
+
usage_file_path: Optional[Union[str, Path]] = None,
|
| 64 |
configure_logging: bool = True,
|
| 65 |
global_timeout: int = 30,
|
| 66 |
abort_on_callback_error: bool = True,
|
|
|
|
| 70 |
enable_request_logging: bool = False,
|
| 71 |
max_concurrent_requests_per_key: Optional[Dict[str, int]] = None,
|
| 72 |
rotation_tolerance: float = 3.0,
|
| 73 |
+
data_dir: Optional[Union[str, Path]] = None,
|
| 74 |
):
|
| 75 |
"""
|
| 76 |
Initialize the RotatingClient with intelligent credential rotation.
|
|
|
|
| 79 |
api_keys: Dictionary mapping provider names to lists of API keys
|
| 80 |
oauth_credentials: Dictionary mapping provider names to OAuth credential paths
|
| 81 |
max_retries: Maximum number of retry attempts per credential
|
| 82 |
+
usage_file_path: Path to store usage statistics. If None, uses data_dir/key_usage.json
|
| 83 |
configure_logging: Whether to configure library logging
|
| 84 |
global_timeout: Global timeout for requests in seconds
|
| 85 |
abort_on_callback_error: Whether to abort on pre-request callback errors
|
|
|
|
| 92 |
- 0.0: Deterministic, least-used credential always selected
|
| 93 |
- 2.0 - 4.0 (default, recommended): Balanced randomness, can pick credentials within 2 uses of max
|
| 94 |
- 5.0+: High randomness, more unpredictable selection patterns
|
| 95 |
+
data_dir: Root directory for all data files (logs, cache, oauth_creds, key_usage.json).
|
| 96 |
+
If None, auto-detects: EXE directory if frozen, else current working directory.
|
| 97 |
"""
|
| 98 |
+
# Resolve data_dir early - this becomes the root for all file operations
|
| 99 |
+
if data_dir is not None:
|
| 100 |
+
self.data_dir = Path(data_dir).resolve()
|
| 101 |
+
else:
|
| 102 |
+
self.data_dir = get_default_root()
|
| 103 |
+
|
| 104 |
+
# Configure failure logger to use correct logs directory
|
| 105 |
+
configure_failure_logger(get_logs_dir(self.data_dir))
|
| 106 |
+
|
| 107 |
os.environ["LITELLM_LOG"] = "ERROR"
|
| 108 |
litellm.set_verbose = False
|
| 109 |
litellm.drop_params = True
|
|
|
|
| 138 |
if oauth_credentials:
|
| 139 |
self.oauth_credentials = oauth_credentials
|
| 140 |
else:
|
| 141 |
+
self.credential_manager = CredentialManager(
|
| 142 |
+
os.environ, oauth_dir=get_oauth_dir(self.data_dir)
|
| 143 |
+
)
|
| 144 |
self.oauth_credentials = self.credential_manager.discover_and_prepare()
|
| 145 |
self.background_refresher = BackgroundRefresher(self)
|
| 146 |
self.oauth_providers = set(self.oauth_credentials.keys())
|
|
|
|
| 258 |
f"Provider '{provider}' sequential fallback multiplier: {fallback}x"
|
| 259 |
)
|
| 260 |
|
| 261 |
+
# Resolve usage file path - use provided path or default to data_dir
|
| 262 |
+
if usage_file_path is not None:
|
| 263 |
+
resolved_usage_path = Path(usage_file_path)
|
| 264 |
+
else:
|
| 265 |
+
resolved_usage_path = self.data_dir / "key_usage.json"
|
| 266 |
+
|
| 267 |
self.usage_manager = UsageManager(
|
| 268 |
+
file_path=resolved_usage_path,
|
| 269 |
rotation_tolerance=rotation_tolerance,
|
| 270 |
provider_rotation_modes=provider_rotation_modes,
|
| 271 |
provider_plugins=PROVIDER_PLUGINS,
|
src/rotator_library/credential_manager.py
CHANGED
|
@@ -3,12 +3,11 @@ import re
|
|
| 3 |
import shutil
|
| 4 |
import logging
|
| 5 |
from pathlib import Path
|
| 6 |
-
from typing import Dict, List, Optional, Set
|
| 7 |
|
| 8 |
-
|
| 9 |
|
| 10 |
-
|
| 11 |
-
OAUTH_BASE_DIR.mkdir(exist_ok=True)
|
| 12 |
|
| 13 |
# Standard directories where tools like `gemini login` store credentials.
|
| 14 |
DEFAULT_OAUTH_DIRS = {
|
|
@@ -33,38 +32,53 @@ class CredentialManager:
|
|
| 33 |
"""
|
| 34 |
Discovers OAuth credential files from standard locations, copies them locally,
|
| 35 |
and updates the configuration to use the local paths.
|
| 36 |
-
|
| 37 |
Also discovers environment variable-based OAuth credentials for stateless deployments.
|
| 38 |
Supports two env var formats:
|
| 39 |
-
|
| 40 |
1. Single credential (legacy): PROVIDER_ACCESS_TOKEN, PROVIDER_REFRESH_TOKEN
|
| 41 |
2. Multiple credentials (numbered): PROVIDER_1_ACCESS_TOKEN, PROVIDER_2_ACCESS_TOKEN, etc.
|
| 42 |
-
|
| 43 |
When env-based credentials are detected, virtual paths like "env://provider/1" are created.
|
| 44 |
"""
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
self.env_vars = env_vars
|
|
|
|
|
|
|
| 47 |
|
| 48 |
def _discover_env_oauth_credentials(self) -> Dict[str, List[str]]:
|
| 49 |
"""
|
| 50 |
Discover OAuth credentials defined via environment variables.
|
| 51 |
-
|
| 52 |
Supports two formats:
|
| 53 |
1. Single credential: ANTIGRAVITY_ACCESS_TOKEN + ANTIGRAVITY_REFRESH_TOKEN
|
| 54 |
2. Multiple credentials: ANTIGRAVITY_1_ACCESS_TOKEN + ANTIGRAVITY_1_REFRESH_TOKEN, etc.
|
| 55 |
-
|
| 56 |
Returns:
|
| 57 |
Dict mapping provider name to list of virtual paths (e.g., "env://antigravity/1")
|
| 58 |
"""
|
| 59 |
env_credentials: Dict[str, Set[str]] = {}
|
| 60 |
-
|
| 61 |
for provider, env_prefix in ENV_OAUTH_PROVIDERS.items():
|
| 62 |
found_indices: Set[str] = set()
|
| 63 |
-
|
| 64 |
# Check for numbered credentials (PROVIDER_N_ACCESS_TOKEN pattern)
|
| 65 |
# Pattern: ANTIGRAVITY_1_ACCESS_TOKEN, ANTIGRAVITY_2_ACCESS_TOKEN, etc.
|
| 66 |
numbered_pattern = re.compile(rf"^{env_prefix}_(\d+)_ACCESS_TOKEN$")
|
| 67 |
-
|
| 68 |
for key in self.env_vars.keys():
|
| 69 |
match = numbered_pattern.match(key)
|
| 70 |
if match:
|
|
@@ -73,28 +87,34 @@ class CredentialManager:
|
|
| 73 |
refresh_key = f"{env_prefix}_{index}_REFRESH_TOKEN"
|
| 74 |
if refresh_key in self.env_vars and self.env_vars[refresh_key]:
|
| 75 |
found_indices.add(index)
|
| 76 |
-
|
| 77 |
# Check for legacy single credential (PROVIDER_ACCESS_TOKEN pattern)
|
| 78 |
# Only use this if no numbered credentials exist
|
| 79 |
if not found_indices:
|
| 80 |
access_key = f"{env_prefix}_ACCESS_TOKEN"
|
| 81 |
refresh_key = f"{env_prefix}_REFRESH_TOKEN"
|
| 82 |
-
if (
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
# Use "0" as the index for legacy single credential
|
| 85 |
found_indices.add("0")
|
| 86 |
-
|
| 87 |
if found_indices:
|
| 88 |
env_credentials[provider] = found_indices
|
| 89 |
-
lib_logger.info(
|
| 90 |
-
|
|
|
|
|
|
|
| 91 |
# Convert to virtual paths
|
| 92 |
result: Dict[str, List[str]] = {}
|
| 93 |
for provider, indices in env_credentials.items():
|
| 94 |
# Sort indices numerically for consistent ordering
|
| 95 |
sorted_indices = sorted(indices, key=lambda x: int(x))
|
| 96 |
result[provider] = [f"env://{provider}/{idx}" for idx in sorted_indices]
|
| 97 |
-
|
| 98 |
return result
|
| 99 |
|
| 100 |
def discover_and_prepare(self) -> Dict[str, List[str]]:
|
|
@@ -105,7 +125,9 @@ class CredentialManager:
|
|
| 105 |
# These take priority for stateless deployments
|
| 106 |
env_oauth_creds = self._discover_env_oauth_credentials()
|
| 107 |
for provider, virtual_paths in env_oauth_creds.items():
|
| 108 |
-
lib_logger.info(
|
|
|
|
|
|
|
| 109 |
final_config[provider] = virtual_paths
|
| 110 |
|
| 111 |
# Extract OAuth file paths from environment variables
|
|
@@ -115,21 +137,29 @@ class CredentialManager:
|
|
| 115 |
provider = key.split("_OAUTH_")[0].lower()
|
| 116 |
if provider not in env_oauth_paths:
|
| 117 |
env_oauth_paths[provider] = []
|
| 118 |
-
if value:
|
| 119 |
env_oauth_paths[provider].append(value)
|
| 120 |
|
| 121 |
# PHASE 2: Discover file-based OAuth credentials
|
| 122 |
for provider, default_dir in DEFAULT_OAUTH_DIRS.items():
|
| 123 |
# Skip if already discovered from environment variables
|
| 124 |
if provider in final_config:
|
| 125 |
-
lib_logger.debug(
|
|
|
|
|
|
|
| 126 |
continue
|
| 127 |
-
|
| 128 |
# Check for existing local credentials first. If found, use them and skip discovery.
|
| 129 |
-
local_provider_creds = sorted(
|
|
|
|
|
|
|
| 130 |
if local_provider_creds:
|
| 131 |
-
lib_logger.info(
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
continue
|
| 134 |
|
| 135 |
# If no local credentials exist, proceed with a one-time discovery and copy.
|
|
@@ -140,13 +170,13 @@ class CredentialManager:
|
|
| 140 |
path = Path(path_str).expanduser()
|
| 141 |
if path.exists():
|
| 142 |
discovered_paths.add(path)
|
| 143 |
-
|
| 144 |
# 2. If no overrides are provided via .env, scan the default directory
|
| 145 |
# [MODIFIED] This logic is now disabled to prefer local-first credential management.
|
| 146 |
# if not discovered_paths and default_dir.exists():
|
| 147 |
# for json_file in default_dir.glob('*.json'):
|
| 148 |
# discovered_paths.add(json_file)
|
| 149 |
-
|
| 150 |
if not discovered_paths:
|
| 151 |
lib_logger.debug(f"No credential files found for provider: {provider}")
|
| 152 |
continue
|
|
@@ -156,18 +186,24 @@ class CredentialManager:
|
|
| 156 |
for i, source_path in enumerate(sorted(list(discovered_paths))):
|
| 157 |
account_id = i + 1
|
| 158 |
local_filename = f"{provider}_oauth_{account_id}.json"
|
| 159 |
-
local_path =
|
| 160 |
|
| 161 |
try:
|
| 162 |
# Since we've established no local files exist, we can copy directly.
|
| 163 |
shutil.copy(source_path, local_path)
|
| 164 |
-
lib_logger.info(
|
|
|
|
|
|
|
| 165 |
prepared_paths.append(str(local_path.resolve()))
|
| 166 |
except Exception as e:
|
| 167 |
-
lib_logger.error(
|
| 168 |
-
|
|
|
|
|
|
|
| 169 |
if prepared_paths:
|
| 170 |
-
lib_logger.info(
|
|
|
|
|
|
|
| 171 |
final_config[provider] = prepared_paths
|
| 172 |
|
| 173 |
lib_logger.info("OAuth credential discovery complete.")
|
|
|
|
| 3 |
import shutil
|
| 4 |
import logging
|
| 5 |
from pathlib import Path
|
| 6 |
+
from typing import Dict, List, Optional, Set, Union
|
| 7 |
|
| 8 |
+
from .utils.paths import get_oauth_dir
|
| 9 |
|
| 10 |
+
lib_logger = logging.getLogger("rotator_library")
|
|
|
|
| 11 |
|
| 12 |
# Standard directories where tools like `gemini login` store credentials.
|
| 13 |
DEFAULT_OAUTH_DIRS = {
|
|
|
|
| 32 |
"""
|
| 33 |
Discovers OAuth credential files from standard locations, copies them locally,
|
| 34 |
and updates the configuration to use the local paths.
|
| 35 |
+
|
| 36 |
Also discovers environment variable-based OAuth credentials for stateless deployments.
|
| 37 |
Supports two env var formats:
|
| 38 |
+
|
| 39 |
1. Single credential (legacy): PROVIDER_ACCESS_TOKEN, PROVIDER_REFRESH_TOKEN
|
| 40 |
2. Multiple credentials (numbered): PROVIDER_1_ACCESS_TOKEN, PROVIDER_2_ACCESS_TOKEN, etc.
|
| 41 |
+
|
| 42 |
When env-based credentials are detected, virtual paths like "env://provider/1" are created.
|
| 43 |
"""
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
env_vars: Dict[str, str],
|
| 48 |
+
oauth_dir: Optional[Union[Path, str]] = None,
|
| 49 |
+
):
|
| 50 |
+
"""
|
| 51 |
+
Initialize the CredentialManager.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
env_vars: Dictionary of environment variables (typically os.environ).
|
| 55 |
+
oauth_dir: Directory for storing OAuth credentials.
|
| 56 |
+
If None, uses get_oauth_dir() which respects EXE vs script mode.
|
| 57 |
+
"""
|
| 58 |
self.env_vars = env_vars
|
| 59 |
+
self.oauth_base_dir = Path(oauth_dir) if oauth_dir else get_oauth_dir()
|
| 60 |
+
self.oauth_base_dir.mkdir(parents=True, exist_ok=True)
|
| 61 |
|
| 62 |
def _discover_env_oauth_credentials(self) -> Dict[str, List[str]]:
|
| 63 |
"""
|
| 64 |
Discover OAuth credentials defined via environment variables.
|
| 65 |
+
|
| 66 |
Supports two formats:
|
| 67 |
1. Single credential: ANTIGRAVITY_ACCESS_TOKEN + ANTIGRAVITY_REFRESH_TOKEN
|
| 68 |
2. Multiple credentials: ANTIGRAVITY_1_ACCESS_TOKEN + ANTIGRAVITY_1_REFRESH_TOKEN, etc.
|
| 69 |
+
|
| 70 |
Returns:
|
| 71 |
Dict mapping provider name to list of virtual paths (e.g., "env://antigravity/1")
|
| 72 |
"""
|
| 73 |
env_credentials: Dict[str, Set[str]] = {}
|
| 74 |
+
|
| 75 |
for provider, env_prefix in ENV_OAUTH_PROVIDERS.items():
|
| 76 |
found_indices: Set[str] = set()
|
| 77 |
+
|
| 78 |
# Check for numbered credentials (PROVIDER_N_ACCESS_TOKEN pattern)
|
| 79 |
# Pattern: ANTIGRAVITY_1_ACCESS_TOKEN, ANTIGRAVITY_2_ACCESS_TOKEN, etc.
|
| 80 |
numbered_pattern = re.compile(rf"^{env_prefix}_(\d+)_ACCESS_TOKEN$")
|
| 81 |
+
|
| 82 |
for key in self.env_vars.keys():
|
| 83 |
match = numbered_pattern.match(key)
|
| 84 |
if match:
|
|
|
|
| 87 |
refresh_key = f"{env_prefix}_{index}_REFRESH_TOKEN"
|
| 88 |
if refresh_key in self.env_vars and self.env_vars[refresh_key]:
|
| 89 |
found_indices.add(index)
|
| 90 |
+
|
| 91 |
# Check for legacy single credential (PROVIDER_ACCESS_TOKEN pattern)
|
| 92 |
# Only use this if no numbered credentials exist
|
| 93 |
if not found_indices:
|
| 94 |
access_key = f"{env_prefix}_ACCESS_TOKEN"
|
| 95 |
refresh_key = f"{env_prefix}_REFRESH_TOKEN"
|
| 96 |
+
if (
|
| 97 |
+
access_key in self.env_vars
|
| 98 |
+
and self.env_vars[access_key]
|
| 99 |
+
and refresh_key in self.env_vars
|
| 100 |
+
and self.env_vars[refresh_key]
|
| 101 |
+
):
|
| 102 |
# Use "0" as the index for legacy single credential
|
| 103 |
found_indices.add("0")
|
| 104 |
+
|
| 105 |
if found_indices:
|
| 106 |
env_credentials[provider] = found_indices
|
| 107 |
+
lib_logger.info(
|
| 108 |
+
f"Found {len(found_indices)} env-based credential(s) for {provider}"
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
# Convert to virtual paths
|
| 112 |
result: Dict[str, List[str]] = {}
|
| 113 |
for provider, indices in env_credentials.items():
|
| 114 |
# Sort indices numerically for consistent ordering
|
| 115 |
sorted_indices = sorted(indices, key=lambda x: int(x))
|
| 116 |
result[provider] = [f"env://{provider}/{idx}" for idx in sorted_indices]
|
| 117 |
+
|
| 118 |
return result
|
| 119 |
|
| 120 |
def discover_and_prepare(self) -> Dict[str, List[str]]:
|
|
|
|
| 125 |
# These take priority for stateless deployments
|
| 126 |
env_oauth_creds = self._discover_env_oauth_credentials()
|
| 127 |
for provider, virtual_paths in env_oauth_creds.items():
|
| 128 |
+
lib_logger.info(
|
| 129 |
+
f"Using {len(virtual_paths)} env-based credential(s) for {provider}"
|
| 130 |
+
)
|
| 131 |
final_config[provider] = virtual_paths
|
| 132 |
|
| 133 |
# Extract OAuth file paths from environment variables
|
|
|
|
| 137 |
provider = key.split("_OAUTH_")[0].lower()
|
| 138 |
if provider not in env_oauth_paths:
|
| 139 |
env_oauth_paths[provider] = []
|
| 140 |
+
if value: # Only consider non-empty values
|
| 141 |
env_oauth_paths[provider].append(value)
|
| 142 |
|
| 143 |
# PHASE 2: Discover file-based OAuth credentials
|
| 144 |
for provider, default_dir in DEFAULT_OAUTH_DIRS.items():
|
| 145 |
# Skip if already discovered from environment variables
|
| 146 |
if provider in final_config:
|
| 147 |
+
lib_logger.debug(
|
| 148 |
+
f"Skipping file discovery for {provider} - using env-based credentials"
|
| 149 |
+
)
|
| 150 |
continue
|
| 151 |
+
|
| 152 |
# Check for existing local credentials first. If found, use them and skip discovery.
|
| 153 |
+
local_provider_creds = sorted(
|
| 154 |
+
list(self.oauth_base_dir.glob(f"{provider}_oauth_*.json"))
|
| 155 |
+
)
|
| 156 |
if local_provider_creds:
|
| 157 |
+
lib_logger.info(
|
| 158 |
+
f"Found {len(local_provider_creds)} existing local credential(s) for {provider}. Skipping discovery."
|
| 159 |
+
)
|
| 160 |
+
final_config[provider] = [
|
| 161 |
+
str(p.resolve()) for p in local_provider_creds
|
| 162 |
+
]
|
| 163 |
continue
|
| 164 |
|
| 165 |
# If no local credentials exist, proceed with a one-time discovery and copy.
|
|
|
|
| 170 |
path = Path(path_str).expanduser()
|
| 171 |
if path.exists():
|
| 172 |
discovered_paths.add(path)
|
| 173 |
+
|
| 174 |
# 2. If no overrides are provided via .env, scan the default directory
|
| 175 |
# [MODIFIED] This logic is now disabled to prefer local-first credential management.
|
| 176 |
# if not discovered_paths and default_dir.exists():
|
| 177 |
# for json_file in default_dir.glob('*.json'):
|
| 178 |
# discovered_paths.add(json_file)
|
| 179 |
+
|
| 180 |
if not discovered_paths:
|
| 181 |
lib_logger.debug(f"No credential files found for provider: {provider}")
|
| 182 |
continue
|
|
|
|
| 186 |
for i, source_path in enumerate(sorted(list(discovered_paths))):
|
| 187 |
account_id = i + 1
|
| 188 |
local_filename = f"{provider}_oauth_{account_id}.json"
|
| 189 |
+
local_path = self.oauth_base_dir / local_filename
|
| 190 |
|
| 191 |
try:
|
| 192 |
# Since we've established no local files exist, we can copy directly.
|
| 193 |
shutil.copy(source_path, local_path)
|
| 194 |
+
lib_logger.info(
|
| 195 |
+
f"Copied '{source_path.name}' to local pool at '{local_path}'."
|
| 196 |
+
)
|
| 197 |
prepared_paths.append(str(local_path.resolve()))
|
| 198 |
except Exception as e:
|
| 199 |
+
lib_logger.error(
|
| 200 |
+
f"Failed to process OAuth file from '{source_path}': {e}"
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
if prepared_paths:
|
| 204 |
+
lib_logger.info(
|
| 205 |
+
f"Discovered and prepared {len(prepared_paths)} credential(s) for provider: {provider}"
|
| 206 |
+
)
|
| 207 |
final_config[provider] = prepared_paths
|
| 208 |
|
| 209 |
lib_logger.info("OAuth credential discovery complete.")
|
src/rotator_library/credential_tool.py
CHANGED
|
@@ -3,22 +3,31 @@
|
|
| 3 |
import asyncio
|
| 4 |
import json
|
| 5 |
import os
|
| 6 |
-
import re
|
| 7 |
import time
|
| 8 |
from pathlib import Path
|
| 9 |
from dotenv import set_key, get_key
|
| 10 |
|
| 11 |
-
# NOTE: Heavy imports (provider_factory, PROVIDER_PLUGINS) are deferred
|
| 12 |
# to avoid 6-7 second delay before showing loading screen
|
| 13 |
from rich.console import Console
|
| 14 |
from rich.panel import Panel
|
| 15 |
from rich.prompt import Prompt
|
| 16 |
from rich.text import Text
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
console = Console()
|
| 24 |
|
|
@@ -26,12 +35,14 @@ console = Console()
|
|
| 26 |
_provider_factory = None
|
| 27 |
_provider_plugins = None
|
| 28 |
|
|
|
|
| 29 |
def _ensure_providers_loaded():
|
| 30 |
"""Lazy load provider modules only when needed"""
|
| 31 |
global _provider_factory, _provider_plugins
|
| 32 |
if _provider_factory is None:
|
| 33 |
from . import provider_factory as pf
|
| 34 |
from .providers import PROVIDER_PLUGINS as pp
|
|
|
|
| 35 |
_provider_factory = pf
|
| 36 |
_provider_plugins = pp
|
| 37 |
return _provider_factory, _provider_plugins
|
|
@@ -39,99 +50,34 @@ def _ensure_providers_loaded():
|
|
| 39 |
|
| 40 |
def clear_screen():
|
| 41 |
"""
|
| 42 |
-
Cross-platform terminal clear that works robustly on both
|
| 43 |
classic Windows conhost and modern terminals (Windows Terminal, Linux, Mac).
|
| 44 |
-
|
| 45 |
Uses native OS commands instead of ANSI escape sequences:
|
| 46 |
- Windows (conhost & Windows Terminal): cls
|
| 47 |
- Unix-like systems (Linux, Mac): clear
|
| 48 |
"""
|
| 49 |
-
os.system(
|
| 50 |
-
|
| 51 |
|
| 52 |
-
def _get_credential_number_from_filename(filename: str) -> int:
|
| 53 |
-
"""
|
| 54 |
-
Extract credential number from filename like 'provider_oauth_1.json' -> 1
|
| 55 |
-
"""
|
| 56 |
-
match = re.search(r'_oauth_(\d+)\.json$', filename)
|
| 57 |
-
if match:
|
| 58 |
-
return int(match.group(1))
|
| 59 |
-
return 1
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def _build_env_export_content(
|
| 63 |
-
provider_prefix: str,
|
| 64 |
-
cred_number: int,
|
| 65 |
-
creds: dict,
|
| 66 |
-
email: str,
|
| 67 |
-
extra_fields: dict = None,
|
| 68 |
-
include_client_creds: bool = True
|
| 69 |
-
) -> tuple[list[str], str]:
|
| 70 |
-
"""
|
| 71 |
-
Build .env content for OAuth credential export with numbered format.
|
| 72 |
-
Exports all fields from the JSON file as a 1-to-1 mirror.
|
| 73 |
-
|
| 74 |
-
Args:
|
| 75 |
-
provider_prefix: Environment variable prefix (e.g., "ANTIGRAVITY", "GEMINI_CLI")
|
| 76 |
-
cred_number: Credential number for this export (1, 2, 3, etc.)
|
| 77 |
-
creds: The credential dictionary loaded from JSON
|
| 78 |
-
email: User email for comments
|
| 79 |
-
extra_fields: Optional dict of additional fields to include
|
| 80 |
-
include_client_creds: Whether to include client_id/secret (Google OAuth providers)
|
| 81 |
-
|
| 82 |
-
Returns:
|
| 83 |
-
Tuple of (env_lines list, numbered_prefix string for display)
|
| 84 |
-
"""
|
| 85 |
-
# Use numbered format: PROVIDER_N_ACCESS_TOKEN
|
| 86 |
-
numbered_prefix = f"{provider_prefix}_{cred_number}"
|
| 87 |
-
|
| 88 |
-
env_lines = [
|
| 89 |
-
f"# {provider_prefix} Credential #{cred_number} for: {email}",
|
| 90 |
-
f"# Exported from: {provider_prefix.lower()}_oauth_{cred_number}.json",
|
| 91 |
-
f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
|
| 92 |
-
f"# ",
|
| 93 |
-
f"# To combine multiple credentials into one .env file, copy these lines",
|
| 94 |
-
f"# and ensure each credential has a unique number (1, 2, 3, etc.)",
|
| 95 |
-
"",
|
| 96 |
-
f"{numbered_prefix}_ACCESS_TOKEN={creds.get('access_token', '')}",
|
| 97 |
-
f"{numbered_prefix}_REFRESH_TOKEN={creds.get('refresh_token', '')}",
|
| 98 |
-
f"{numbered_prefix}_SCOPE={creds.get('scope', '')}",
|
| 99 |
-
f"{numbered_prefix}_TOKEN_TYPE={creds.get('token_type', 'Bearer')}",
|
| 100 |
-
f"{numbered_prefix}_ID_TOKEN={creds.get('id_token', '')}",
|
| 101 |
-
f"{numbered_prefix}_EXPIRY_DATE={creds.get('expiry_date', 0)}",
|
| 102 |
-
]
|
| 103 |
-
|
| 104 |
-
if include_client_creds:
|
| 105 |
-
env_lines.extend([
|
| 106 |
-
f"{numbered_prefix}_CLIENT_ID={creds.get('client_id', '')}",
|
| 107 |
-
f"{numbered_prefix}_CLIENT_SECRET={creds.get('client_secret', '')}",
|
| 108 |
-
f"{numbered_prefix}_TOKEN_URI={creds.get('token_uri', 'https://oauth2.googleapis.com/token')}",
|
| 109 |
-
f"{numbered_prefix}_UNIVERSE_DOMAIN={creds.get('universe_domain', 'googleapis.com')}",
|
| 110 |
-
])
|
| 111 |
-
|
| 112 |
-
env_lines.append(f"{numbered_prefix}_EMAIL={email}")
|
| 113 |
-
|
| 114 |
-
# Add extra provider-specific fields
|
| 115 |
-
if extra_fields:
|
| 116 |
-
for key, value in extra_fields.items():
|
| 117 |
-
if value: # Only add non-empty values
|
| 118 |
-
env_lines.append(f"{numbered_prefix}_{key}={value}")
|
| 119 |
-
|
| 120 |
-
return env_lines, numbered_prefix
|
| 121 |
|
| 122 |
def ensure_env_defaults():
|
| 123 |
"""
|
| 124 |
Ensures the .env file exists and contains essential default values like PROXY_API_KEY.
|
| 125 |
"""
|
| 126 |
-
if not
|
| 127 |
-
|
| 128 |
-
console.print(
|
|
|
|
|
|
|
| 129 |
|
| 130 |
# Check for PROXY_API_KEY, similar to setup_env.bat
|
| 131 |
-
if get_key(str(
|
| 132 |
default_key = "VerysecretKey"
|
| 133 |
-
console.print(
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
async def setup_api_key():
|
| 137 |
"""
|
|
@@ -144,41 +90,74 @@ async def setup_api_key():
|
|
| 144 |
|
| 145 |
# Verified list of LiteLLM providers with their friendly names and API key variables
|
| 146 |
LITELLM_PROVIDERS = {
|
| 147 |
-
"OpenAI": "OPENAI_API_KEY",
|
| 148 |
-
"
|
| 149 |
-
"
|
| 150 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
"Mistral AI": "MISTRAL_API_KEY",
|
| 152 |
-
"Codestral (Mistral)": "CODESTRAL_API_KEY",
|
| 153 |
-
"
|
| 154 |
-
"
|
| 155 |
-
"
|
| 156 |
-
"
|
| 157 |
-
"
|
| 158 |
-
"
|
| 159 |
-
"
|
| 160 |
-
"
|
| 161 |
-
"
|
| 162 |
-
"
|
| 163 |
-
"
|
| 164 |
-
"
|
| 165 |
-
"
|
| 166 |
-
"
|
| 167 |
-
"
|
| 168 |
-
"
|
| 169 |
-
"
|
| 170 |
-
"
|
| 171 |
-
"
|
| 172 |
-
"
|
| 173 |
-
"
|
| 174 |
-
"
|
| 175 |
-
"
|
| 176 |
-
"
|
| 177 |
-
"
|
| 178 |
-
"
|
| 179 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
"Deepgram": "DEEPGRAM_API_KEY",
|
| 181 |
-
"GitHub Models": "GITHUB_TOKEN",
|
|
|
|
| 182 |
}
|
| 183 |
|
| 184 |
# Discover custom providers and add them to the list
|
|
@@ -186,37 +165,37 @@ async def setup_api_key():
|
|
| 186 |
# qwen_code API key support is a fallback
|
| 187 |
# iflow API key support is a feature
|
| 188 |
_, PROVIDER_PLUGINS = _ensure_providers_loaded()
|
| 189 |
-
|
| 190 |
# Build a set of environment variables already in LITELLM_PROVIDERS
|
| 191 |
# to avoid duplicates based on the actual API key names
|
| 192 |
litellm_env_vars = set(LITELLM_PROVIDERS.values())
|
| 193 |
-
|
| 194 |
# Providers to exclude from API key list
|
| 195 |
exclude_providers = {
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
}
|
| 201 |
-
|
| 202 |
discovered_providers = {}
|
| 203 |
for provider_key in PROVIDER_PLUGINS.keys():
|
| 204 |
if provider_key in exclude_providers:
|
| 205 |
continue
|
| 206 |
-
|
| 207 |
# Create environment variable name
|
| 208 |
env_var = provider_key.upper() + "_API_KEY"
|
| 209 |
-
|
| 210 |
# Check if this env var already exists in LITELLM_PROVIDERS
|
| 211 |
# This catches duplicates like GEMINI_API_KEY, MISTRAL_API_KEY, etc.
|
| 212 |
if env_var in litellm_env_vars:
|
| 213 |
# Already in LITELLM_PROVIDERS with better name, skip this one
|
| 214 |
continue
|
| 215 |
-
|
| 216 |
# Create display name for this custom provider
|
| 217 |
-
display_name = provider_key.replace(
|
| 218 |
discovered_providers[display_name] = env_var
|
| 219 |
-
|
| 220 |
# LITELLM_PROVIDERS takes precedence (comes first in merge)
|
| 221 |
combined_providers = {**LITELLM_PROVIDERS, **discovered_providers}
|
| 222 |
provider_display_list = sorted(combined_providers.keys())
|
|
@@ -231,15 +210,19 @@ async def setup_api_key():
|
|
| 231 |
else:
|
| 232 |
provider_text.append(f" {i + 1}. {provider_name}\n")
|
| 233 |
|
| 234 |
-
console.print(
|
|
|
|
|
|
|
| 235 |
|
| 236 |
choice = Prompt.ask(
|
| 237 |
-
Text.from_markup(
|
|
|
|
|
|
|
| 238 |
choices=[str(i + 1) for i in range(len(provider_display_list))] + ["b"],
|
| 239 |
-
show_choices=False
|
| 240 |
)
|
| 241 |
|
| 242 |
-
if choice.lower() ==
|
| 243 |
return
|
| 244 |
|
| 245 |
try:
|
|
@@ -251,59 +234,88 @@ async def setup_api_key():
|
|
| 251 |
api_key = Prompt.ask(f"Enter the API key for {display_name}")
|
| 252 |
|
| 253 |
# Check for duplicate API key value
|
| 254 |
-
if
|
| 255 |
-
with open(
|
| 256 |
for line in f:
|
| 257 |
line = line.strip()
|
| 258 |
if line.startswith(api_var_base) and "=" in line:
|
| 259 |
-
existing_key_name, _, existing_key_value = line.partition(
|
|
|
|
|
|
|
| 260 |
if existing_key_value == api_key:
|
| 261 |
-
warning_text = Text.from_markup(
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
return
|
| 269 |
|
| 270 |
# Special handling for AWS
|
| 271 |
if display_name in ["AWS Bedrock", "AWS SageMaker"]:
|
| 272 |
-
console.print(
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
|
|
|
|
|
|
| 281 |
|
| 282 |
key_index = 1
|
| 283 |
while True:
|
| 284 |
key_name = f"{api_var_base}_{key_index}"
|
| 285 |
-
if
|
| 286 |
-
|
| 287 |
if not any(line.startswith(f"{key_name}=") for line in f):
|
| 288 |
break
|
| 289 |
else:
|
| 290 |
break
|
| 291 |
key_index += 1
|
| 292 |
-
|
| 293 |
key_name = f"{api_var_base}_{key_index}"
|
| 294 |
-
set_key(str(
|
| 295 |
-
|
| 296 |
-
success_text = Text.from_markup(
|
|
|
|
|
|
|
| 297 |
console.print(Panel(success_text, style="bold green", title="Success"))
|
| 298 |
|
| 299 |
else:
|
| 300 |
console.print("[bold red]Invalid choice. Please try again.[/bold red]")
|
| 301 |
except ValueError:
|
| 302 |
-
console.print(
|
|
|
|
|
|
|
|
|
|
| 303 |
|
| 304 |
async def setup_new_credential(provider_name: str):
|
| 305 |
"""
|
| 306 |
Interactively sets up a new OAuth credential for a given provider.
|
|
|
|
|
|
|
| 307 |
"""
|
| 308 |
try:
|
| 309 |
provider_factory, _ = _ensure_providers_loaded()
|
|
@@ -315,668 +327,602 @@ async def setup_new_credential(provider_name: str):
|
|
| 315 |
"gemini_cli": "Gemini CLI (OAuth)",
|
| 316 |
"qwen_code": "Qwen Code (OAuth - also supports API keys)",
|
| 317 |
"iflow": "iFlow (OAuth - also supports API keys)",
|
| 318 |
-
"antigravity": "Antigravity (OAuth)"
|
| 319 |
}
|
| 320 |
-
display_name = oauth_friendly_names.get(
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
temp_creds = {
|
| 324 |
-
"_proxy_metadata": {
|
| 325 |
-
"provider_name": provider_name,
|
| 326 |
-
"display_name": display_name
|
| 327 |
-
}
|
| 328 |
-
}
|
| 329 |
-
initialized_creds = await auth_instance.initialize_token(temp_creds)
|
| 330 |
-
|
| 331 |
-
user_info = await auth_instance.get_user_info(initialized_creds)
|
| 332 |
-
email = user_info.get("email")
|
| 333 |
|
| 334 |
-
|
| 335 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
return
|
| 337 |
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
|
|
|
|
|
|
|
|
|
| 346 |
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
|
|
|
|
|
|
| 350 |
|
| 351 |
-
success_text = Text.from_markup(f"Successfully updated credential at [bold yellow]'{cred_file.name}'[/bold yellow] for user [bold cyan]'{email}'[/bold cyan].")
|
| 352 |
-
console.print(Panel(success_text, style="bold green", title="Success"))
|
| 353 |
-
return
|
| 354 |
-
|
| 355 |
-
existing_files = list(OAUTH_BASE_DIR.glob(f"{provider_name}_oauth_*.json"))
|
| 356 |
-
next_num = 1
|
| 357 |
-
if existing_files:
|
| 358 |
-
nums = [int(re.search(r'_(\d+)\.json$', f.name).group(1)) for f in existing_files if re.search(r'_(\d+)\.json$', f.name)]
|
| 359 |
-
if nums:
|
| 360 |
-
next_num = max(nums) + 1
|
| 361 |
-
|
| 362 |
-
new_filename = f"{provider_name}_oauth_{next_num}.json"
|
| 363 |
-
new_filepath = OAUTH_BASE_DIR / new_filename
|
| 364 |
-
|
| 365 |
-
with open(new_filepath, 'w') as f:
|
| 366 |
-
json.dump(initialized_creds, f, indent=2)
|
| 367 |
-
|
| 368 |
-
success_text = Text.from_markup(f"Successfully created new credential at [bold yellow]'{new_filepath.name}'[/bold yellow] for user [bold cyan]'{email}'[/bold cyan].")
|
| 369 |
console.print(Panel(success_text, style="bold green", title="Success"))
|
| 370 |
|
| 371 |
except Exception as e:
|
| 372 |
-
console.print(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
|
| 374 |
|
| 375 |
async def export_gemini_cli_to_env():
|
| 376 |
"""
|
| 377 |
Export a Gemini CLI credential JSON file to .env format.
|
| 378 |
-
Uses
|
| 379 |
"""
|
| 380 |
-
console.print(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
|
| 382 |
-
#
|
| 383 |
-
|
|
|
|
|
|
|
| 384 |
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
return
|
| 389 |
|
| 390 |
# Display available credentials
|
| 391 |
cred_text = Text()
|
| 392 |
-
for i,
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
email = creds.get("_proxy_metadata", {}).get("email", "unknown")
|
| 397 |
-
cred_text.append(f" {i + 1}. {cred_file.name} ({email})\n")
|
| 398 |
-
except Exception as e:
|
| 399 |
-
cred_text.append(f" {i + 1}. {cred_file.name} (error reading: {e})\n")
|
| 400 |
|
| 401 |
-
console.print(
|
|
|
|
|
|
|
| 402 |
|
| 403 |
choice = Prompt.ask(
|
| 404 |
-
Text.from_markup(
|
| 405 |
-
|
| 406 |
-
|
|
|
|
|
|
|
| 407 |
)
|
| 408 |
|
| 409 |
-
if choice.lower() ==
|
| 410 |
return
|
| 411 |
|
| 412 |
try:
|
| 413 |
choice_index = int(choice) - 1
|
| 414 |
-
if 0 <= choice_index < len(
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
# Load the credential
|
| 418 |
-
with open(cred_file, 'r') as f:
|
| 419 |
-
creds = json.load(f)
|
| 420 |
|
| 421 |
-
#
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
tier = creds.get("_proxy_metadata", {}).get("tier", "")
|
| 425 |
-
|
| 426 |
-
# Get credential number from filename
|
| 427 |
-
cred_number = _get_credential_number_from_filename(cred_file.name)
|
| 428 |
-
|
| 429 |
-
# Generate .env file name with credential number
|
| 430 |
-
safe_email = email.replace("@", "_at_").replace(".", "_")
|
| 431 |
-
env_filename = f"gemini_cli_{cred_number}_{safe_email}.env"
|
| 432 |
-
env_filepath = OAUTH_BASE_DIR / env_filename
|
| 433 |
-
|
| 434 |
-
# Build extra fields
|
| 435 |
-
extra_fields = {}
|
| 436 |
-
if project_id:
|
| 437 |
-
extra_fields["PROJECT_ID"] = project_id
|
| 438 |
-
if tier:
|
| 439 |
-
extra_fields["TIER"] = tier
|
| 440 |
-
|
| 441 |
-
# Build .env content using helper
|
| 442 |
-
env_lines, numbered_prefix = _build_env_export_content(
|
| 443 |
-
provider_prefix="GEMINI_CLI",
|
| 444 |
-
cred_number=cred_number,
|
| 445 |
-
creds=creds,
|
| 446 |
-
email=email,
|
| 447 |
-
extra_fields=extra_fields,
|
| 448 |
-
include_client_creds=True
|
| 449 |
)
|
| 450 |
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 467 |
else:
|
| 468 |
console.print("[bold red]Invalid choice. Please try again.[/bold red]")
|
| 469 |
except ValueError:
|
| 470 |
-
console.print(
|
|
|
|
|
|
|
| 471 |
except Exception as e:
|
| 472 |
-
console.print(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 473 |
|
| 474 |
|
| 475 |
async def export_qwen_code_to_env():
|
| 476 |
"""
|
| 477 |
Export a Qwen Code credential JSON file to .env format.
|
| 478 |
-
|
| 479 |
"""
|
| 480 |
-
console.print(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
|
| 482 |
-
#
|
| 483 |
-
|
| 484 |
|
| 485 |
-
if not
|
| 486 |
-
console.print(
|
| 487 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
return
|
| 489 |
|
| 490 |
# Display available credentials
|
| 491 |
cred_text = Text()
|
| 492 |
-
for i,
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
email = creds.get("_proxy_metadata", {}).get("email", "unknown")
|
| 497 |
-
cred_text.append(f" {i + 1}. {cred_file.name} ({email})\n")
|
| 498 |
-
except Exception as e:
|
| 499 |
-
cred_text.append(f" {i + 1}. {cred_file.name} (error reading: {e})\n")
|
| 500 |
|
| 501 |
-
console.print(
|
|
|
|
|
|
|
| 502 |
|
| 503 |
choice = Prompt.ask(
|
| 504 |
-
Text.from_markup(
|
| 505 |
-
|
| 506 |
-
|
|
|
|
|
|
|
| 507 |
)
|
| 508 |
|
| 509 |
-
if choice.lower() ==
|
| 510 |
return
|
| 511 |
|
| 512 |
try:
|
| 513 |
choice_index = int(choice) - 1
|
| 514 |
-
if 0 <= choice_index < len(
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
# Load the credential
|
| 518 |
-
with open(cred_file, 'r') as f:
|
| 519 |
-
creds = json.load(f)
|
| 520 |
-
|
| 521 |
-
# Extract metadata
|
| 522 |
-
email = creds.get("_proxy_metadata", {}).get("email", "unknown")
|
| 523 |
-
|
| 524 |
-
# Get credential number from filename
|
| 525 |
-
cred_number = _get_credential_number_from_filename(cred_file.name)
|
| 526 |
-
|
| 527 |
-
# Generate .env file name with credential number
|
| 528 |
-
safe_email = email.replace("@", "_at_").replace(".", "_")
|
| 529 |
-
env_filename = f"qwen_code_{cred_number}_{safe_email}.env"
|
| 530 |
-
env_filepath = OAUTH_BASE_DIR / env_filename
|
| 531 |
-
|
| 532 |
-
# Use numbered format: QWEN_CODE_N_*
|
| 533 |
-
numbered_prefix = f"QWEN_CODE_{cred_number}"
|
| 534 |
-
|
| 535 |
-
# Build .env content (Qwen has different structure)
|
| 536 |
-
env_lines = [
|
| 537 |
-
f"# QWEN_CODE Credential #{cred_number} for: {email}",
|
| 538 |
-
f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
|
| 539 |
-
f"# ",
|
| 540 |
-
f"# To combine multiple credentials into one .env file, copy these lines",
|
| 541 |
-
f"# and ensure each credential has a unique number (1, 2, 3, etc.)",
|
| 542 |
-
"",
|
| 543 |
-
f"{numbered_prefix}_ACCESS_TOKEN={creds.get('access_token', '')}",
|
| 544 |
-
f"{numbered_prefix}_REFRESH_TOKEN={creds.get('refresh_token', '')}",
|
| 545 |
-
f"{numbered_prefix}_EXPIRY_DATE={creds.get('expiry_date', 0)}",
|
| 546 |
-
f"{numbered_prefix}_RESOURCE_URL={creds.get('resource_url', 'https://portal.qwen.ai/v1')}",
|
| 547 |
-
f"{numbered_prefix}_EMAIL={email}",
|
| 548 |
-
]
|
| 549 |
-
|
| 550 |
-
# Write to .env file
|
| 551 |
-
with open(env_filepath, 'w') as f:
|
| 552 |
-
f.write('\n'.join(env_lines))
|
| 553 |
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
f"[bold]To use this credential:[/bold]\n"
|
| 558 |
-
f"1. Copy the contents to your main .env file, OR\n"
|
| 559 |
-
f"2. Source it: [bold cyan]source {env_filepath.name}[/bold cyan] (Linux/Mac)\n\n"
|
| 560 |
-
f"[bold]To combine multiple credentials:[/bold]\n"
|
| 561 |
-
f"Copy lines from multiple .env files into one file.\n"
|
| 562 |
-
f"Each credential uses a unique number ({numbered_prefix}_*)."
|
| 563 |
)
|
| 564 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
else:
|
| 566 |
console.print("[bold red]Invalid choice. Please try again.[/bold red]")
|
| 567 |
except ValueError:
|
| 568 |
-
console.print(
|
|
|
|
|
|
|
| 569 |
except Exception as e:
|
| 570 |
-
console.print(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 571 |
|
| 572 |
|
| 573 |
async def export_iflow_to_env():
|
| 574 |
"""
|
| 575 |
Export an iFlow credential JSON file to .env format.
|
| 576 |
-
Uses
|
| 577 |
"""
|
| 578 |
-
console.print(
|
|
|
|
|
|
|
| 579 |
|
| 580 |
-
#
|
| 581 |
-
|
|
|
|
|
|
|
| 582 |
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 586 |
return
|
| 587 |
|
| 588 |
# Display available credentials
|
| 589 |
cred_text = Text()
|
| 590 |
-
for i,
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
email = creds.get("_proxy_metadata", {}).get("email", "unknown")
|
| 595 |
-
cred_text.append(f" {i + 1}. {cred_file.name} ({email})\n")
|
| 596 |
-
except Exception as e:
|
| 597 |
-
cred_text.append(f" {i + 1}. {cred_file.name} (error reading: {e})\n")
|
| 598 |
|
| 599 |
-
console.print(
|
|
|
|
|
|
|
| 600 |
|
| 601 |
choice = Prompt.ask(
|
| 602 |
-
Text.from_markup(
|
| 603 |
-
|
| 604 |
-
|
|
|
|
|
|
|
| 605 |
)
|
| 606 |
|
| 607 |
-
if choice.lower() ==
|
| 608 |
return
|
| 609 |
|
| 610 |
try:
|
| 611 |
choice_index = int(choice) - 1
|
| 612 |
-
if 0 <= choice_index < len(
|
| 613 |
-
|
| 614 |
|
| 615 |
-
#
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
# Extract metadata
|
| 620 |
-
email = creds.get("_proxy_metadata", {}).get("email", "unknown")
|
| 621 |
-
|
| 622 |
-
# Get credential number from filename
|
| 623 |
-
cred_number = _get_credential_number_from_filename(cred_file.name)
|
| 624 |
-
|
| 625 |
-
# Generate .env file name with credential number
|
| 626 |
-
safe_email = email.replace("@", "_at_").replace(".", "_")
|
| 627 |
-
env_filename = f"iflow_{cred_number}_{safe_email}.env"
|
| 628 |
-
env_filepath = OAUTH_BASE_DIR / env_filename
|
| 629 |
-
|
| 630 |
-
# Use numbered format: IFLOW_N_*
|
| 631 |
-
numbered_prefix = f"IFLOW_{cred_number}"
|
| 632 |
-
|
| 633 |
-
# Build .env content (iFlow has different structure with API key)
|
| 634 |
-
env_lines = [
|
| 635 |
-
f"# IFLOW Credential #{cred_number} for: {email}",
|
| 636 |
-
f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
|
| 637 |
-
f"# ",
|
| 638 |
-
f"# To combine multiple credentials into one .env file, copy these lines",
|
| 639 |
-
f"# and ensure each credential has a unique number (1, 2, 3, etc.)",
|
| 640 |
-
"",
|
| 641 |
-
f"{numbered_prefix}_ACCESS_TOKEN={creds.get('access_token', '')}",
|
| 642 |
-
f"{numbered_prefix}_REFRESH_TOKEN={creds.get('refresh_token', '')}",
|
| 643 |
-
f"{numbered_prefix}_API_KEY={creds.get('api_key', '')}",
|
| 644 |
-
f"{numbered_prefix}_EXPIRY_DATE={creds.get('expiry_date', '')}",
|
| 645 |
-
f"{numbered_prefix}_EMAIL={email}",
|
| 646 |
-
f"{numbered_prefix}_TOKEN_TYPE={creds.get('token_type', 'Bearer')}",
|
| 647 |
-
f"{numbered_prefix}_SCOPE={creds.get('scope', 'read write')}",
|
| 648 |
-
]
|
| 649 |
-
|
| 650 |
-
# Write to .env file
|
| 651 |
-
with open(env_filepath, 'w') as f:
|
| 652 |
-
f.write('\n'.join(env_lines))
|
| 653 |
-
|
| 654 |
-
success_text = Text.from_markup(
|
| 655 |
-
f"Successfully exported credential to [bold yellow]'{env_filepath}'[/bold yellow]\n\n"
|
| 656 |
-
f"[bold]Environment variable prefix:[/bold] [cyan]{numbered_prefix}_*[/cyan]\n\n"
|
| 657 |
-
f"[bold]To use this credential:[/bold]\n"
|
| 658 |
-
f"1. Copy the contents to your main .env file, OR\n"
|
| 659 |
-
f"2. Source it: [bold cyan]source {env_filepath.name}[/bold cyan] (Linux/Mac)\n\n"
|
| 660 |
-
f"[bold]To combine multiple credentials:[/bold]\n"
|
| 661 |
-
f"Copy lines from multiple .env files into one file.\n"
|
| 662 |
-
f"Each credential uses a unique number ({numbered_prefix}_*)."
|
| 663 |
)
|
| 664 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 665 |
else:
|
| 666 |
console.print("[bold red]Invalid choice. Please try again.[/bold red]")
|
| 667 |
except ValueError:
|
| 668 |
-
console.print(
|
|
|
|
|
|
|
| 669 |
except Exception as e:
|
| 670 |
-
console.print(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 671 |
|
| 672 |
|
| 673 |
async def export_antigravity_to_env():
|
| 674 |
"""
|
| 675 |
Export an Antigravity credential JSON file to .env format.
|
| 676 |
-
Uses
|
| 677 |
"""
|
| 678 |
-
console.print(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 679 |
|
| 680 |
-
#
|
| 681 |
-
|
| 682 |
|
| 683 |
-
if not
|
| 684 |
-
console.print(
|
| 685 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 686 |
return
|
| 687 |
|
| 688 |
# Display available credentials
|
| 689 |
cred_text = Text()
|
| 690 |
-
for i,
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
email = creds.get("_proxy_metadata", {}).get("email", "unknown")
|
| 695 |
-
cred_text.append(f" {i + 1}. {cred_file.name} ({email})\n")
|
| 696 |
-
except Exception as e:
|
| 697 |
-
cred_text.append(f" {i + 1}. {cred_file.name} (error reading: {e})\n")
|
| 698 |
|
| 699 |
-
console.print(
|
|
|
|
|
|
|
| 700 |
|
| 701 |
choice = Prompt.ask(
|
| 702 |
-
Text.from_markup(
|
| 703 |
-
|
| 704 |
-
|
|
|
|
|
|
|
| 705 |
)
|
| 706 |
|
| 707 |
-
if choice.lower() ==
|
| 708 |
return
|
| 709 |
|
| 710 |
try:
|
| 711 |
choice_index = int(choice) - 1
|
| 712 |
-
if 0 <= choice_index < len(
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
# Load the credential
|
| 716 |
-
with open(cred_file, 'r') as f:
|
| 717 |
-
creds = json.load(f)
|
| 718 |
|
| 719 |
-
#
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
# Get credential number from filename
|
| 723 |
-
cred_number = _get_credential_number_from_filename(cred_file.name)
|
| 724 |
-
|
| 725 |
-
# Generate .env file name with credential number
|
| 726 |
-
safe_email = email.replace("@", "_at_").replace(".", "_")
|
| 727 |
-
env_filename = f"antigravity_{cred_number}_{safe_email}.env"
|
| 728 |
-
env_filepath = OAUTH_BASE_DIR / env_filename
|
| 729 |
-
|
| 730 |
-
# Build .env content using helper
|
| 731 |
-
env_lines, numbered_prefix = _build_env_export_content(
|
| 732 |
-
provider_prefix="ANTIGRAVITY",
|
| 733 |
-
cred_number=cred_number,
|
| 734 |
-
creds=creds,
|
| 735 |
-
email=email,
|
| 736 |
-
extra_fields=None,
|
| 737 |
-
include_client_creds=True
|
| 738 |
)
|
| 739 |
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 756 |
else:
|
| 757 |
console.print("[bold red]Invalid choice. Please try again.[/bold red]")
|
| 758 |
except ValueError:
|
| 759 |
-
console.print(
|
|
|
|
|
|
|
| 760 |
except Exception as e:
|
| 761 |
-
console.print(
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
email = creds.get("_proxy_metadata", {}).get("email", "unknown")
|
| 767 |
-
project_id = creds.get("_proxy_metadata", {}).get("project_id", "")
|
| 768 |
-
tier = creds.get("_proxy_metadata", {}).get("tier", "")
|
| 769 |
-
|
| 770 |
-
extra_fields = {}
|
| 771 |
-
if project_id:
|
| 772 |
-
extra_fields["PROJECT_ID"] = project_id
|
| 773 |
-
if tier:
|
| 774 |
-
extra_fields["TIER"] = tier
|
| 775 |
-
|
| 776 |
-
env_lines, _ = _build_env_export_content(
|
| 777 |
-
provider_prefix="GEMINI_CLI",
|
| 778 |
-
cred_number=cred_number,
|
| 779 |
-
creds=creds,
|
| 780 |
-
email=email,
|
| 781 |
-
extra_fields=extra_fields,
|
| 782 |
-
include_client_creds=True
|
| 783 |
-
)
|
| 784 |
-
return env_lines
|
| 785 |
-
|
| 786 |
-
|
| 787 |
-
def _build_qwen_code_env_lines(creds: dict, cred_number: int) -> list[str]:
|
| 788 |
-
"""Build .env lines for a Qwen Code credential."""
|
| 789 |
-
email = creds.get("_proxy_metadata", {}).get("email", "unknown")
|
| 790 |
-
numbered_prefix = f"QWEN_CODE_{cred_number}"
|
| 791 |
-
|
| 792 |
-
env_lines = [
|
| 793 |
-
f"# QWEN_CODE Credential #{cred_number} for: {email}",
|
| 794 |
-
f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
|
| 795 |
-
"",
|
| 796 |
-
f"{numbered_prefix}_ACCESS_TOKEN={creds.get('access_token', '')}",
|
| 797 |
-
f"{numbered_prefix}_REFRESH_TOKEN={creds.get('refresh_token', '')}",
|
| 798 |
-
f"{numbered_prefix}_EXPIRY_DATE={creds.get('expiry_date', 0)}",
|
| 799 |
-
f"{numbered_prefix}_RESOURCE_URL={creds.get('resource_url', 'https://portal.qwen.ai/v1')}",
|
| 800 |
-
f"{numbered_prefix}_EMAIL={email}",
|
| 801 |
-
]
|
| 802 |
-
return env_lines
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
def _build_iflow_env_lines(creds: dict, cred_number: int) -> list[str]:
|
| 806 |
-
"""Build .env lines for an iFlow credential."""
|
| 807 |
-
email = creds.get("_proxy_metadata", {}).get("email", "unknown")
|
| 808 |
-
numbered_prefix = f"IFLOW_{cred_number}"
|
| 809 |
-
|
| 810 |
-
env_lines = [
|
| 811 |
-
f"# IFLOW Credential #{cred_number} for: {email}",
|
| 812 |
-
f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
|
| 813 |
-
"",
|
| 814 |
-
f"{numbered_prefix}_ACCESS_TOKEN={creds.get('access_token', '')}",
|
| 815 |
-
f"{numbered_prefix}_REFRESH_TOKEN={creds.get('refresh_token', '')}",
|
| 816 |
-
f"{numbered_prefix}_API_KEY={creds.get('api_key', '')}",
|
| 817 |
-
f"{numbered_prefix}_EXPIRY_DATE={creds.get('expiry_date', '')}",
|
| 818 |
-
f"{numbered_prefix}_EMAIL={email}",
|
| 819 |
-
f"{numbered_prefix}_TOKEN_TYPE={creds.get('token_type', 'Bearer')}",
|
| 820 |
-
f"{numbered_prefix}_SCOPE={creds.get('scope', 'read write')}",
|
| 821 |
-
]
|
| 822 |
-
return env_lines
|
| 823 |
-
|
| 824 |
-
|
| 825 |
-
def _build_antigravity_env_lines(creds: dict, cred_number: int) -> list[str]:
|
| 826 |
-
"""Build .env lines for an Antigravity credential."""
|
| 827 |
-
email = creds.get("_proxy_metadata", {}).get("email", "unknown")
|
| 828 |
-
|
| 829 |
-
env_lines, _ = _build_env_export_content(
|
| 830 |
-
provider_prefix="ANTIGRAVITY",
|
| 831 |
-
cred_number=cred_number,
|
| 832 |
-
creds=creds,
|
| 833 |
-
email=email,
|
| 834 |
-
extra_fields=None,
|
| 835 |
-
include_client_creds=True
|
| 836 |
-
)
|
| 837 |
-
return env_lines
|
| 838 |
|
| 839 |
|
| 840 |
async def export_all_provider_credentials(provider_name: str):
|
| 841 |
"""
|
| 842 |
Export all credentials for a specific provider to individual .env files.
|
|
|
|
| 843 |
"""
|
| 844 |
-
|
| 845 |
-
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
if provider_name not in provider_config:
|
| 852 |
console.print(f"[bold red]Unknown provider: {provider_name}[/bold red]")
|
| 853 |
return
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
|
| 862 |
-
|
| 863 |
-
|
| 864 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 865 |
return
|
| 866 |
-
|
| 867 |
exported_count = 0
|
| 868 |
-
for
|
| 869 |
try:
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
|
| 873 |
-
|
| 874 |
-
|
| 875 |
-
|
| 876 |
-
|
| 877 |
-
|
| 878 |
-
|
| 879 |
-
|
| 880 |
-
|
| 881 |
-
|
| 882 |
-
|
| 883 |
-
|
| 884 |
-
|
| 885 |
-
|
| 886 |
-
console.print(f" ✓ Exported [cyan]{cred_file.name}[/cyan] → [yellow]{env_filename}[/yellow]")
|
| 887 |
-
exported_count += 1
|
| 888 |
-
|
| 889 |
except Exception as e:
|
| 890 |
-
console.print(
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
|
| 894 |
-
|
| 895 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 896 |
|
| 897 |
|
| 898 |
async def combine_provider_credentials(provider_name: str):
|
| 899 |
"""
|
| 900 |
Combine all credentials for a specific provider into a single .env file.
|
|
|
|
| 901 |
"""
|
| 902 |
-
|
| 903 |
-
|
| 904 |
-
|
| 905 |
-
|
| 906 |
-
|
| 907 |
-
|
| 908 |
-
|
| 909 |
-
if provider_name not in provider_config:
|
| 910 |
console.print(f"[bold red]Unknown provider: {provider_name}[/bold red]")
|
| 911 |
return
|
| 912 |
-
|
| 913 |
-
|
| 914 |
-
|
| 915 |
-
|
| 916 |
-
|
| 917 |
-
|
| 918 |
-
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
|
| 922 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 923 |
return
|
| 924 |
-
|
| 925 |
combined_lines = [
|
| 926 |
f"# Combined {display_name} Credentials",
|
| 927 |
f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
|
| 928 |
-
f"# Total credentials: {len(
|
| 929 |
"#",
|
| 930 |
"# Copy all lines below into your main .env file",
|
| 931 |
"",
|
| 932 |
]
|
| 933 |
-
|
| 934 |
combined_count = 0
|
| 935 |
-
for
|
| 936 |
try:
|
| 937 |
-
|
|
|
|
| 938 |
creds = json.load(f)
|
| 939 |
-
|
| 940 |
-
|
| 941 |
-
env_lines =
|
| 942 |
-
|
| 943 |
combined_lines.extend(env_lines)
|
| 944 |
combined_lines.append("") # Blank line between credentials
|
| 945 |
combined_count += 1
|
| 946 |
-
|
| 947 |
except Exception as e:
|
| 948 |
-
console.print(
|
| 949 |
-
|
|
|
|
|
|
|
| 950 |
# Write combined file
|
| 951 |
combined_filename = f"{provider_name}_all_combined.env"
|
| 952 |
-
combined_filepath =
|
| 953 |
-
|
| 954 |
-
with open(combined_filepath,
|
| 955 |
-
f.write(
|
| 956 |
-
|
| 957 |
-
console.print(
|
| 958 |
-
|
| 959 |
-
|
| 960 |
-
|
| 961 |
-
|
| 962 |
-
|
| 963 |
-
|
| 964 |
-
|
|
|
|
|
|
|
|
|
|
| 965 |
|
| 966 |
|
| 967 |
async def combine_all_credentials():
|
| 968 |
"""
|
| 969 |
Combine ALL credentials from ALL providers into a single .env file.
|
|
|
|
| 970 |
"""
|
| 971 |
-
console.print(
|
| 972 |
-
|
| 973 |
-
|
| 974 |
-
|
| 975 |
-
|
| 976 |
-
|
| 977 |
-
|
| 978 |
-
|
| 979 |
-
|
| 980 |
combined_lines = [
|
| 981 |
"# Combined All Provider Credentials",
|
| 982 |
f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
|
|
@@ -984,63 +930,83 @@ async def combine_all_credentials():
|
|
| 984 |
"# Copy all lines below into your main .env file",
|
| 985 |
"",
|
| 986 |
]
|
| 987 |
-
|
| 988 |
total_count = 0
|
| 989 |
provider_counts = {}
|
| 990 |
-
|
| 991 |
-
for provider_name
|
| 992 |
-
|
| 993 |
-
|
| 994 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 995 |
continue
|
| 996 |
-
|
| 997 |
-
display_name =
|
| 998 |
combined_lines.append(f"# ===== {display_name} Credentials =====")
|
| 999 |
combined_lines.append("")
|
| 1000 |
-
|
| 1001 |
provider_count = 0
|
| 1002 |
-
for
|
| 1003 |
try:
|
| 1004 |
-
|
|
|
|
| 1005 |
creds = json.load(f)
|
| 1006 |
-
|
| 1007 |
-
|
| 1008 |
-
env_lines =
|
| 1009 |
-
|
| 1010 |
combined_lines.extend(env_lines)
|
| 1011 |
combined_lines.append("")
|
| 1012 |
provider_count += 1
|
| 1013 |
total_count += 1
|
| 1014 |
-
|
| 1015 |
except Exception as e:
|
| 1016 |
-
console.print(
|
| 1017 |
-
|
|
|
|
|
|
|
| 1018 |
provider_counts[display_name] = provider_count
|
| 1019 |
-
|
| 1020 |
if total_count == 0:
|
| 1021 |
-
console.print(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1022 |
return
|
| 1023 |
-
|
| 1024 |
# Write combined file
|
| 1025 |
combined_filename = "all_providers_combined.env"
|
| 1026 |
-
combined_filepath =
|
| 1027 |
-
|
| 1028 |
-
with open(combined_filepath,
|
| 1029 |
-
f.write(
|
| 1030 |
-
|
| 1031 |
# Build summary
|
| 1032 |
-
summary_lines = [
|
|
|
|
|
|
|
| 1033 |
summary = "\n".join(summary_lines)
|
| 1034 |
-
|
| 1035 |
-
console.print(
|
| 1036 |
-
|
| 1037 |
-
|
| 1038 |
-
|
| 1039 |
-
|
| 1040 |
-
|
| 1041 |
-
|
| 1042 |
-
|
| 1043 |
-
|
|
|
|
|
|
|
|
|
|
| 1044 |
|
| 1045 |
|
| 1046 |
async def export_credentials_submenu():
|
|
@@ -1049,40 +1015,65 @@ async def export_credentials_submenu():
|
|
| 1049 |
"""
|
| 1050 |
while True:
|
| 1051 |
clear_screen()
|
| 1052 |
-
console.print(
|
| 1053 |
-
|
| 1054 |
-
|
| 1055 |
-
|
| 1056 |
-
|
| 1057 |
-
|
| 1058 |
-
|
| 1059 |
-
|
| 1060 |
-
|
| 1061 |
-
|
| 1062 |
-
|
| 1063 |
-
|
| 1064 |
-
|
| 1065 |
-
|
| 1066 |
-
|
| 1067 |
-
|
| 1068 |
-
|
| 1069 |
-
|
| 1070 |
-
|
| 1071 |
-
|
| 1072 |
-
|
| 1073 |
-
|
| 1074 |
-
|
| 1075 |
-
|
| 1076 |
-
|
| 1077 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1078 |
|
| 1079 |
export_choice = Prompt.ask(
|
| 1080 |
-
Text.from_markup(
|
| 1081 |
-
|
| 1082 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1083 |
)
|
| 1084 |
|
| 1085 |
-
if export_choice.lower() ==
|
| 1086 |
break
|
| 1087 |
|
| 1088 |
# Individual exports
|
|
@@ -1146,39 +1137,53 @@ async def export_credentials_submenu():
|
|
| 1146 |
async def main(clear_on_start=True):
|
| 1147 |
"""
|
| 1148 |
An interactive CLI tool to add new credentials.
|
| 1149 |
-
|
| 1150 |
Args:
|
| 1151 |
-
clear_on_start: If False, skip initial screen clear (used when called from launcher
|
| 1152 |
to preserve the loading screen)
|
| 1153 |
"""
|
| 1154 |
ensure_env_defaults()
|
| 1155 |
-
|
| 1156 |
# Only show header if we're clearing (standalone mode)
|
| 1157 |
if clear_on_start:
|
| 1158 |
-
console.print(
|
| 1159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1160 |
while True:
|
| 1161 |
# Clear screen between menu selections for cleaner UX
|
| 1162 |
clear_screen()
|
| 1163 |
-
console.print(
|
| 1164 |
-
|
| 1165 |
-
|
| 1166 |
-
|
| 1167 |
-
|
| 1168 |
-
|
| 1169 |
-
|
| 1170 |
-
|
| 1171 |
-
|
| 1172 |
-
|
| 1173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1174 |
|
| 1175 |
setup_type = Prompt.ask(
|
| 1176 |
-
Text.from_markup(
|
|
|
|
|
|
|
| 1177 |
choices=["1", "2", "3", "q"],
|
| 1178 |
-
show_choices=False
|
| 1179 |
)
|
| 1180 |
|
| 1181 |
-
if setup_type.lower() ==
|
| 1182 |
break
|
| 1183 |
|
| 1184 |
if setup_type == "1":
|
|
@@ -1190,69 +1195,88 @@ async def main(clear_on_start=True):
|
|
| 1190 |
"iflow": "iFlow (OAuth - also supports API keys)",
|
| 1191 |
"antigravity": "Antigravity (OAuth)",
|
| 1192 |
}
|
| 1193 |
-
|
| 1194 |
provider_text = Text()
|
| 1195 |
for i, provider in enumerate(available_providers):
|
| 1196 |
-
display_name = oauth_friendly_names.get(
|
|
|
|
|
|
|
| 1197 |
provider_text.append(f" {i + 1}. {display_name}\n")
|
| 1198 |
-
|
| 1199 |
-
console.print(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1200 |
|
| 1201 |
choice = Prompt.ask(
|
| 1202 |
-
Text.from_markup(
|
|
|
|
|
|
|
| 1203 |
choices=[str(i + 1) for i in range(len(available_providers))] + ["b"],
|
| 1204 |
-
show_choices=False
|
| 1205 |
)
|
| 1206 |
|
| 1207 |
-
if choice.lower() ==
|
| 1208 |
continue
|
| 1209 |
-
|
| 1210 |
try:
|
| 1211 |
choice_index = int(choice) - 1
|
| 1212 |
if 0 <= choice_index < len(available_providers):
|
| 1213 |
provider_name = available_providers[choice_index]
|
| 1214 |
-
display_name = oauth_friendly_names.get(
|
| 1215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1216 |
await setup_new_credential(provider_name)
|
| 1217 |
# Don't clear after OAuth - user needs to see full flow
|
| 1218 |
console.print("\n[dim]Press Enter to return to main menu...[/dim]")
|
| 1219 |
input()
|
| 1220 |
else:
|
| 1221 |
-
console.print(
|
|
|
|
|
|
|
| 1222 |
await asyncio.sleep(1.5)
|
| 1223 |
except ValueError:
|
| 1224 |
-
console.print(
|
|
|
|
|
|
|
| 1225 |
await asyncio.sleep(1.5)
|
| 1226 |
|
| 1227 |
elif setup_type == "2":
|
| 1228 |
await setup_api_key()
|
| 1229 |
-
#console.print("\n[dim]Press Enter to return to main menu...[/dim]")
|
| 1230 |
-
#input()
|
| 1231 |
|
| 1232 |
elif setup_type == "3":
|
| 1233 |
await export_credentials_submenu()
|
| 1234 |
|
|
|
|
| 1235 |
def run_credential_tool(from_launcher=False):
|
| 1236 |
"""
|
| 1237 |
Entry point for credential tool.
|
| 1238 |
-
|
| 1239 |
Args:
|
| 1240 |
from_launcher: If True, skip loading screen (launcher already showed it)
|
| 1241 |
"""
|
| 1242 |
# Check if we need to show loading screen
|
| 1243 |
if not from_launcher:
|
| 1244 |
# Standalone mode - show full loading UI
|
| 1245 |
-
os.system(
|
| 1246 |
-
|
| 1247 |
_start_time = time.time()
|
| 1248 |
-
|
| 1249 |
# Phase 1: Show initial message
|
| 1250 |
print("━" * 70)
|
| 1251 |
print("Interactive Credential Setup Tool")
|
| 1252 |
print("GitHub: https://github.com/Mirrowel/LLM-API-Key-Proxy")
|
| 1253 |
print("━" * 70)
|
| 1254 |
print("Loading credential management components...")
|
| 1255 |
-
|
| 1256 |
# Phase 2: Load dependencies with spinner
|
| 1257 |
with console.status("Loading authentication providers...", spinner="dots"):
|
| 1258 |
_ensure_providers_loaded()
|
|
@@ -1261,14 +1285,16 @@ def run_credential_tool(from_launcher=False):
|
|
| 1261 |
with console.status("Initializing credential tool...", spinner="dots"):
|
| 1262 |
time.sleep(0.2) # Brief pause for UI consistency
|
| 1263 |
console.print("✓ Credential tool initialized")
|
| 1264 |
-
|
| 1265 |
_elapsed = time.time() - _start_time
|
| 1266 |
_, PROVIDER_PLUGINS = _ensure_providers_loaded()
|
| 1267 |
-
print(
|
| 1268 |
-
|
|
|
|
|
|
|
| 1269 |
# Small delay to let user see the ready message
|
| 1270 |
time.sleep(0.5)
|
| 1271 |
-
|
| 1272 |
# Run the main async event loop
|
| 1273 |
# If from launcher, don't clear screen at start to preserve loading messages
|
| 1274 |
try:
|
|
|
|
| 3 |
import asyncio
|
| 4 |
import json
|
| 5 |
import os
|
|
|
|
| 6 |
import time
|
| 7 |
from pathlib import Path
|
| 8 |
from dotenv import set_key, get_key
|
| 9 |
|
| 10 |
+
# NOTE: Heavy imports (provider_factory, PROVIDER_PLUGINS) are deferred
|
| 11 |
# to avoid 6-7 second delay before showing loading screen
|
| 12 |
from rich.console import Console
|
| 13 |
from rich.panel import Panel
|
| 14 |
from rich.prompt import Prompt
|
| 15 |
from rich.text import Text
|
| 16 |
|
| 17 |
+
from .utils.paths import get_oauth_dir, get_data_file
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _get_oauth_base_dir() -> Path:
|
| 21 |
+
"""Get the OAuth base directory (lazy, respects EXE vs script mode)."""
|
| 22 |
+
oauth_dir = get_oauth_dir()
|
| 23 |
+
oauth_dir.mkdir(parents=True, exist_ok=True)
|
| 24 |
+
return oauth_dir
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _get_env_file() -> Path:
|
| 28 |
+
"""Get the .env file path (lazy, respects EXE vs script mode)."""
|
| 29 |
+
return get_data_file(".env")
|
| 30 |
+
|
| 31 |
|
| 32 |
console = Console()
|
| 33 |
|
|
|
|
| 35 |
_provider_factory = None
|
| 36 |
_provider_plugins = None
|
| 37 |
|
| 38 |
+
|
| 39 |
def _ensure_providers_loaded():
|
| 40 |
"""Lazy load provider modules only when needed"""
|
| 41 |
global _provider_factory, _provider_plugins
|
| 42 |
if _provider_factory is None:
|
| 43 |
from . import provider_factory as pf
|
| 44 |
from .providers import PROVIDER_PLUGINS as pp
|
| 45 |
+
|
| 46 |
_provider_factory = pf
|
| 47 |
_provider_plugins = pp
|
| 48 |
return _provider_factory, _provider_plugins
|
|
|
|
| 50 |
|
| 51 |
def clear_screen():
|
| 52 |
"""
|
| 53 |
+
Cross-platform terminal clear that works robustly on both
|
| 54 |
classic Windows conhost and modern terminals (Windows Terminal, Linux, Mac).
|
| 55 |
+
|
| 56 |
Uses native OS commands instead of ANSI escape sequences:
|
| 57 |
- Windows (conhost & Windows Terminal): cls
|
| 58 |
- Unix-like systems (Linux, Mac): clear
|
| 59 |
"""
|
| 60 |
+
os.system("cls" if os.name == "nt" else "clear")
|
|
|
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
def ensure_env_defaults():
|
| 64 |
"""
|
| 65 |
Ensures the .env file exists and contains essential default values like PROXY_API_KEY.
|
| 66 |
"""
|
| 67 |
+
if not _get_env_file().is_file():
|
| 68 |
+
_get_env_file().touch()
|
| 69 |
+
console.print(
|
| 70 |
+
f"Creating a new [bold yellow]{_get_env_file().name}[/bold yellow] file..."
|
| 71 |
+
)
|
| 72 |
|
| 73 |
# Check for PROXY_API_KEY, similar to setup_env.bat
|
| 74 |
+
if get_key(str(_get_env_file()), "PROXY_API_KEY") is None:
|
| 75 |
default_key = "VerysecretKey"
|
| 76 |
+
console.print(
|
| 77 |
+
f"Adding default [bold cyan]PROXY_API_KEY[/bold cyan] to [bold yellow]{_get_env_file().name}[/bold yellow]..."
|
| 78 |
+
)
|
| 79 |
+
set_key(str(_get_env_file()), "PROXY_API_KEY", default_key)
|
| 80 |
+
|
| 81 |
|
| 82 |
async def setup_api_key():
|
| 83 |
"""
|
|
|
|
| 90 |
|
| 91 |
# Verified list of LiteLLM providers with their friendly names and API key variables
|
| 92 |
LITELLM_PROVIDERS = {
|
| 93 |
+
"OpenAI": "OPENAI_API_KEY",
|
| 94 |
+
"Anthropic": "ANTHROPIC_API_KEY",
|
| 95 |
+
"Google AI Studio (Gemini)": "GEMINI_API_KEY",
|
| 96 |
+
"Azure OpenAI": "AZURE_API_KEY",
|
| 97 |
+
"Vertex AI": "GOOGLE_API_KEY",
|
| 98 |
+
"AWS Bedrock": "AWS_ACCESS_KEY_ID",
|
| 99 |
+
"Cohere": "COHERE_API_KEY",
|
| 100 |
+
"Chutes": "CHUTES_API_KEY",
|
| 101 |
"Mistral AI": "MISTRAL_API_KEY",
|
| 102 |
+
"Codestral (Mistral)": "CODESTRAL_API_KEY",
|
| 103 |
+
"Groq": "GROQ_API_KEY",
|
| 104 |
+
"Perplexity": "PERPLEXITYAI_API_KEY",
|
| 105 |
+
"xAI": "XAI_API_KEY",
|
| 106 |
+
"Together AI": "TOGETHERAI_API_KEY",
|
| 107 |
+
"Fireworks AI": "FIREWORKS_AI_API_KEY",
|
| 108 |
+
"Replicate": "REPLICATE_API_KEY",
|
| 109 |
+
"Hugging Face": "HUGGINGFACE_API_KEY",
|
| 110 |
+
"Anyscale": "ANYSCALE_API_KEY",
|
| 111 |
+
"NVIDIA NIM": "NVIDIA_NIM_API_KEY",
|
| 112 |
+
"Deepseek": "DEEPSEEK_API_KEY",
|
| 113 |
+
"AI21": "AI21_API_KEY",
|
| 114 |
+
"Cerebras": "CEREBRAS_API_KEY",
|
| 115 |
+
"Moonshot": "MOONSHOT_API_KEY",
|
| 116 |
+
"Ollama": "OLLAMA_API_KEY",
|
| 117 |
+
"Xinference": "XINFERENCE_API_KEY",
|
| 118 |
+
"Infinity": "INFINITY_API_KEY",
|
| 119 |
+
"OpenRouter": "OPENROUTER_API_KEY",
|
| 120 |
+
"Deepinfra": "DEEPINFRA_API_KEY",
|
| 121 |
+
"Cloudflare": "CLOUDFLARE_API_KEY",
|
| 122 |
+
"Baseten": "BASETEN_API_KEY",
|
| 123 |
+
"Modal": "MODAL_API_KEY",
|
| 124 |
+
"Databricks": "DATABRICKS_API_KEY",
|
| 125 |
+
"AWS SageMaker": "AWS_ACCESS_KEY_ID",
|
| 126 |
+
"IBM watsonx.ai": "WATSONX_APIKEY",
|
| 127 |
+
"Predibase": "PREDIBASE_API_KEY",
|
| 128 |
+
"Clarifai": "CLARIFAI_API_KEY",
|
| 129 |
+
"NLP Cloud": "NLP_CLOUD_API_KEY",
|
| 130 |
+
"Voyage AI": "VOYAGE_API_KEY",
|
| 131 |
+
"Jina AI": "JINA_API_KEY",
|
| 132 |
+
"Hyperbolic": "HYPERBOLIC_API_KEY",
|
| 133 |
+
"Morph": "MORPH_API_KEY",
|
| 134 |
+
"Lambda AI": "LAMBDA_API_KEY",
|
| 135 |
+
"Novita AI": "NOVITA_API_KEY",
|
| 136 |
+
"Aleph Alpha": "ALEPH_ALPHA_API_KEY",
|
| 137 |
+
"SambaNova": "SAMBANOVA_API_KEY",
|
| 138 |
+
"FriendliAI": "FRIENDLI_TOKEN",
|
| 139 |
+
"Galadriel": "GALADRIEL_API_KEY",
|
| 140 |
+
"CompactifAI": "COMPACTIFAI_API_KEY",
|
| 141 |
+
"Lemonade": "LEMONADE_API_KEY",
|
| 142 |
+
"GradientAI": "GRADIENTAI_API_KEY",
|
| 143 |
+
"Featherless AI": "FEATHERLESS_AI_API_KEY",
|
| 144 |
+
"Nebius AI Studio": "NEBIUS_API_KEY",
|
| 145 |
+
"Dashscope (Qwen)": "DASHSCOPE_API_KEY",
|
| 146 |
+
"Bytez": "BYTEZ_API_KEY",
|
| 147 |
+
"Oracle OCI": "OCI_API_KEY",
|
| 148 |
+
"DataRobot": "DATAROBOT_API_KEY",
|
| 149 |
+
"OVHCloud": "OVHCLOUD_API_KEY",
|
| 150 |
+
"Volcengine": "VOLCENGINE_API_KEY",
|
| 151 |
+
"Snowflake": "SNOWFLAKE_API_KEY",
|
| 152 |
+
"Nscale": "NSCALE_API_KEY",
|
| 153 |
+
"Recraft": "RECRAFT_API_KEY",
|
| 154 |
+
"v0": "V0_API_KEY",
|
| 155 |
+
"Vercel": "VERCEL_AI_GATEWAY_API_KEY",
|
| 156 |
+
"Topaz": "TOPAZ_API_KEY",
|
| 157 |
+
"ElevenLabs": "ELEVENLABS_API_KEY",
|
| 158 |
"Deepgram": "DEEPGRAM_API_KEY",
|
| 159 |
+
"GitHub Models": "GITHUB_TOKEN",
|
| 160 |
+
"GitHub Copilot": "GITHUB_COPILOT_API_KEY",
|
| 161 |
}
|
| 162 |
|
| 163 |
# Discover custom providers and add them to the list
|
|
|
|
| 165 |
# qwen_code API key support is a fallback
|
| 166 |
# iflow API key support is a feature
|
| 167 |
_, PROVIDER_PLUGINS = _ensure_providers_loaded()
|
| 168 |
+
|
| 169 |
# Build a set of environment variables already in LITELLM_PROVIDERS
|
| 170 |
# to avoid duplicates based on the actual API key names
|
| 171 |
litellm_env_vars = set(LITELLM_PROVIDERS.values())
|
| 172 |
+
|
| 173 |
# Providers to exclude from API key list
|
| 174 |
exclude_providers = {
|
| 175 |
+
"gemini_cli", # OAuth-only
|
| 176 |
+
"antigravity", # OAuth-only
|
| 177 |
+
"qwen_code", # API key is fallback, OAuth is primary - don't advertise
|
| 178 |
+
"openai_compatible", # Base class, not a real provider
|
| 179 |
}
|
| 180 |
+
|
| 181 |
discovered_providers = {}
|
| 182 |
for provider_key in PROVIDER_PLUGINS.keys():
|
| 183 |
if provider_key in exclude_providers:
|
| 184 |
continue
|
| 185 |
+
|
| 186 |
# Create environment variable name
|
| 187 |
env_var = provider_key.upper() + "_API_KEY"
|
| 188 |
+
|
| 189 |
# Check if this env var already exists in LITELLM_PROVIDERS
|
| 190 |
# This catches duplicates like GEMINI_API_KEY, MISTRAL_API_KEY, etc.
|
| 191 |
if env_var in litellm_env_vars:
|
| 192 |
# Already in LITELLM_PROVIDERS with better name, skip this one
|
| 193 |
continue
|
| 194 |
+
|
| 195 |
# Create display name for this custom provider
|
| 196 |
+
display_name = provider_key.replace("_", " ").title()
|
| 197 |
discovered_providers[display_name] = env_var
|
| 198 |
+
|
| 199 |
# LITELLM_PROVIDERS takes precedence (comes first in merge)
|
| 200 |
combined_providers = {**LITELLM_PROVIDERS, **discovered_providers}
|
| 201 |
provider_display_list = sorted(combined_providers.keys())
|
|
|
|
| 210 |
else:
|
| 211 |
provider_text.append(f" {i + 1}. {provider_name}\n")
|
| 212 |
|
| 213 |
+
console.print(
|
| 214 |
+
Panel(provider_text, title="Available Providers for API Key", style="bold blue")
|
| 215 |
+
)
|
| 216 |
|
| 217 |
choice = Prompt.ask(
|
| 218 |
+
Text.from_markup(
|
| 219 |
+
"[bold]Please select a provider or type [red]'b'[/red] to go back[/bold]"
|
| 220 |
+
),
|
| 221 |
choices=[str(i + 1) for i in range(len(provider_display_list))] + ["b"],
|
| 222 |
+
show_choices=False,
|
| 223 |
)
|
| 224 |
|
| 225 |
+
if choice.lower() == "b":
|
| 226 |
return
|
| 227 |
|
| 228 |
try:
|
|
|
|
| 234 |
api_key = Prompt.ask(f"Enter the API key for {display_name}")
|
| 235 |
|
| 236 |
# Check for duplicate API key value
|
| 237 |
+
if _get_env_file().is_file():
|
| 238 |
+
with open(_get_env_file(), "r") as f:
|
| 239 |
for line in f:
|
| 240 |
line = line.strip()
|
| 241 |
if line.startswith(api_var_base) and "=" in line:
|
| 242 |
+
existing_key_name, _, existing_key_value = line.partition(
|
| 243 |
+
"="
|
| 244 |
+
)
|
| 245 |
if existing_key_value == api_key:
|
| 246 |
+
warning_text = Text.from_markup(
|
| 247 |
+
f"This API key already exists as [bold yellow]'{existing_key_name}'[/bold yellow]. Overwriting..."
|
| 248 |
+
)
|
| 249 |
+
console.print(
|
| 250 |
+
Panel(
|
| 251 |
+
warning_text,
|
| 252 |
+
style="bold yellow",
|
| 253 |
+
title="Updating API Key",
|
| 254 |
+
)
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
set_key(
|
| 258 |
+
str(_get_env_file()), existing_key_name, api_key
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
success_text = Text.from_markup(
|
| 262 |
+
f"Successfully updated existing key [bold yellow]'{existing_key_name}'[/bold yellow]."
|
| 263 |
+
)
|
| 264 |
+
console.print(
|
| 265 |
+
Panel(
|
| 266 |
+
success_text,
|
| 267 |
+
style="bold green",
|
| 268 |
+
title="Success",
|
| 269 |
+
)
|
| 270 |
+
)
|
| 271 |
return
|
| 272 |
|
| 273 |
# Special handling for AWS
|
| 274 |
if display_name in ["AWS Bedrock", "AWS SageMaker"]:
|
| 275 |
+
console.print(
|
| 276 |
+
Panel(
|
| 277 |
+
Text.from_markup(
|
| 278 |
+
"This provider requires both an Access Key ID and a Secret Access Key.\n"
|
| 279 |
+
f"The key you entered will be saved as [bold yellow]{api_var_base}_1[/bold yellow].\n"
|
| 280 |
+
"Please manually add the [bold cyan]AWS_SECRET_ACCESS_KEY_1[/bold cyan] to your .env file."
|
| 281 |
+
),
|
| 282 |
+
title="[bold yellow]Additional Step Required[/bold yellow]",
|
| 283 |
+
border_style="yellow",
|
| 284 |
+
)
|
| 285 |
+
)
|
| 286 |
|
| 287 |
key_index = 1
|
| 288 |
while True:
|
| 289 |
key_name = f"{api_var_base}_{key_index}"
|
| 290 |
+
if _get_env_file().is_file():
|
| 291 |
+
with open(_get_env_file(), "r") as f:
|
| 292 |
if not any(line.startswith(f"{key_name}=") for line in f):
|
| 293 |
break
|
| 294 |
else:
|
| 295 |
break
|
| 296 |
key_index += 1
|
| 297 |
+
|
| 298 |
key_name = f"{api_var_base}_{key_index}"
|
| 299 |
+
set_key(str(_get_env_file()), key_name, api_key)
|
| 300 |
+
|
| 301 |
+
success_text = Text.from_markup(
|
| 302 |
+
f"Successfully added {display_name} API key as [bold yellow]'{key_name}'[/bold yellow]."
|
| 303 |
+
)
|
| 304 |
console.print(Panel(success_text, style="bold green", title="Success"))
|
| 305 |
|
| 306 |
else:
|
| 307 |
console.print("[bold red]Invalid choice. Please try again.[/bold red]")
|
| 308 |
except ValueError:
|
| 309 |
+
console.print(
|
| 310 |
+
"[bold red]Invalid input. Please enter a number or 'b'.[/bold red]"
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
|
| 314 |
async def setup_new_credential(provider_name: str):
|
| 315 |
"""
|
| 316 |
Interactively sets up a new OAuth credential for a given provider.
|
| 317 |
+
|
| 318 |
+
Delegates all credential management logic to the auth class's setup_credential() method.
|
| 319 |
"""
|
| 320 |
try:
|
| 321 |
provider_factory, _ = _ensure_providers_loaded()
|
|
|
|
| 327 |
"gemini_cli": "Gemini CLI (OAuth)",
|
| 328 |
"qwen_code": "Qwen Code (OAuth - also supports API keys)",
|
| 329 |
"iflow": "iFlow (OAuth - also supports API keys)",
|
| 330 |
+
"antigravity": "Antigravity (OAuth)",
|
| 331 |
}
|
| 332 |
+
display_name = oauth_friendly_names.get(
|
| 333 |
+
provider_name, provider_name.replace("_", " ").title()
|
| 334 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
|
| 336 |
+
# Call the auth class's setup_credential() method which handles the entire flow:
|
| 337 |
+
# - OAuth authentication
|
| 338 |
+
# - Email extraction for deduplication
|
| 339 |
+
# - File path determination (new or existing)
|
| 340 |
+
# - Credential file saving
|
| 341 |
+
# - Post-auth discovery (tier/project for Google OAuth providers)
|
| 342 |
+
result = await auth_instance.setup_credential(_get_oauth_base_dir())
|
| 343 |
+
|
| 344 |
+
if not result.success:
|
| 345 |
+
console.print(
|
| 346 |
+
Panel(
|
| 347 |
+
f"Credential setup failed: {result.error}",
|
| 348 |
+
style="bold red",
|
| 349 |
+
title="Error",
|
| 350 |
+
)
|
| 351 |
+
)
|
| 352 |
return
|
| 353 |
|
| 354 |
+
# Display success message with details
|
| 355 |
+
if result.is_update:
|
| 356 |
+
success_text = Text.from_markup(
|
| 357 |
+
f"Successfully updated credential at [bold yellow]'{Path(result.file_path).name}'[/bold yellow] "
|
| 358 |
+
f"for user [bold cyan]'{result.email}'[/bold cyan]."
|
| 359 |
+
)
|
| 360 |
+
else:
|
| 361 |
+
success_text = Text.from_markup(
|
| 362 |
+
f"Successfully created new credential at [bold yellow]'{Path(result.file_path).name}'[/bold yellow] "
|
| 363 |
+
f"for user [bold cyan]'{result.email}'[/bold cyan]."
|
| 364 |
+
)
|
| 365 |
|
| 366 |
+
# Add tier/project info if available (Google OAuth providers)
|
| 367 |
+
if hasattr(result, "tier") and result.tier:
|
| 368 |
+
success_text.append(f"\nTier: {result.tier}")
|
| 369 |
+
if hasattr(result, "project_id") and result.project_id:
|
| 370 |
+
success_text.append(f"\nProject: {result.project_id}")
|
| 371 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
console.print(Panel(success_text, style="bold green", title="Success"))
|
| 373 |
|
| 374 |
except Exception as e:
|
| 375 |
+
console.print(
|
| 376 |
+
Panel(
|
| 377 |
+
f"An error occurred during setup for {provider_name}: {e}",
|
| 378 |
+
style="bold red",
|
| 379 |
+
title="Error",
|
| 380 |
+
)
|
| 381 |
+
)
|
| 382 |
|
| 383 |
|
| 384 |
async def export_gemini_cli_to_env():
|
| 385 |
"""
|
| 386 |
Export a Gemini CLI credential JSON file to .env format.
|
| 387 |
+
Uses the auth class's build_env_lines() and list_credentials() methods.
|
| 388 |
"""
|
| 389 |
+
console.print(
|
| 390 |
+
Panel(
|
| 391 |
+
"[bold cyan]Export Gemini CLI Credential to .env[/bold cyan]", expand=False
|
| 392 |
+
)
|
| 393 |
+
)
|
| 394 |
|
| 395 |
+
# Get auth instance for this provider
|
| 396 |
+
provider_factory, _ = _ensure_providers_loaded()
|
| 397 |
+
auth_class = provider_factory.get_provider_auth_class("gemini_cli")
|
| 398 |
+
auth_instance = auth_class()
|
| 399 |
|
| 400 |
+
# List available credentials using auth class
|
| 401 |
+
credentials = auth_instance.list_credentials(_get_oauth_base_dir())
|
| 402 |
+
|
| 403 |
+
if not credentials:
|
| 404 |
+
console.print(
|
| 405 |
+
Panel(
|
| 406 |
+
"No Gemini CLI credentials found. Please add one first using 'Add OAuth Credential'.",
|
| 407 |
+
style="bold red",
|
| 408 |
+
title="No Credentials",
|
| 409 |
+
)
|
| 410 |
+
)
|
| 411 |
return
|
| 412 |
|
| 413 |
# Display available credentials
|
| 414 |
cred_text = Text()
|
| 415 |
+
for i, cred_info in enumerate(credentials):
|
| 416 |
+
cred_text.append(
|
| 417 |
+
f" {i + 1}. {Path(cred_info['file_path']).name} ({cred_info['email']})\n"
|
| 418 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
|
| 420 |
+
console.print(
|
| 421 |
+
Panel(cred_text, title="Available Gemini CLI Credentials", style="bold blue")
|
| 422 |
+
)
|
| 423 |
|
| 424 |
choice = Prompt.ask(
|
| 425 |
+
Text.from_markup(
|
| 426 |
+
"[bold]Please select a credential to export or type [red]'b'[/red] to go back[/bold]"
|
| 427 |
+
),
|
| 428 |
+
choices=[str(i + 1) for i in range(len(credentials))] + ["b"],
|
| 429 |
+
show_choices=False,
|
| 430 |
)
|
| 431 |
|
| 432 |
+
if choice.lower() == "b":
|
| 433 |
return
|
| 434 |
|
| 435 |
try:
|
| 436 |
choice_index = int(choice) - 1
|
| 437 |
+
if 0 <= choice_index < len(credentials):
|
| 438 |
+
cred_info = credentials[choice_index]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
|
| 440 |
+
# Use auth class to export
|
| 441 |
+
env_path = auth_instance.export_credential_to_env(
|
| 442 |
+
cred_info["file_path"], _get_oauth_base_dir()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
)
|
| 444 |
|
| 445 |
+
if env_path:
|
| 446 |
+
numbered_prefix = f"GEMINI_CLI_{cred_info['number']}"
|
| 447 |
+
success_text = Text.from_markup(
|
| 448 |
+
f"Successfully exported credential to [bold yellow]'{Path(env_path).name}'[/bold yellow]\n\n"
|
| 449 |
+
f"[bold]Environment variable prefix:[/bold] [cyan]{numbered_prefix}_*[/cyan]\n\n"
|
| 450 |
+
f"[bold]To use this credential:[/bold]\n"
|
| 451 |
+
f"1. Copy the contents to your main .env file, OR\n"
|
| 452 |
+
f"2. Source it: [bold cyan]source {Path(env_path).name}[/bold cyan] (Linux/Mac)\n"
|
| 453 |
+
f"3. Or on Windows: [bold cyan]Get-Content {Path(env_path).name} | ForEach-Object {{ $_ -replace '^([^#].*)$', 'set $1' }} | cmd[/bold cyan]\n\n"
|
| 454 |
+
f"[bold]To combine multiple credentials:[/bold]\n"
|
| 455 |
+
f"Copy lines from multiple .env files into one file.\n"
|
| 456 |
+
f"Each credential uses a unique number ({numbered_prefix}_*)."
|
| 457 |
+
)
|
| 458 |
+
console.print(Panel(success_text, style="bold green", title="Success"))
|
| 459 |
+
else:
|
| 460 |
+
console.print(
|
| 461 |
+
Panel(
|
| 462 |
+
"Failed to export credential", style="bold red", title="Error"
|
| 463 |
+
)
|
| 464 |
+
)
|
| 465 |
else:
|
| 466 |
console.print("[bold red]Invalid choice. Please try again.[/bold red]")
|
| 467 |
except ValueError:
|
| 468 |
+
console.print(
|
| 469 |
+
"[bold red]Invalid input. Please enter a number or 'b'.[/bold red]"
|
| 470 |
+
)
|
| 471 |
except Exception as e:
|
| 472 |
+
console.print(
|
| 473 |
+
Panel(
|
| 474 |
+
f"An error occurred during export: {e}", style="bold red", title="Error"
|
| 475 |
+
)
|
| 476 |
+
)
|
| 477 |
|
| 478 |
|
| 479 |
async def export_qwen_code_to_env():
|
| 480 |
"""
|
| 481 |
Export a Qwen Code credential JSON file to .env format.
|
| 482 |
+
Uses the auth class's build_env_lines() and list_credentials() methods.
|
| 483 |
"""
|
| 484 |
+
console.print(
|
| 485 |
+
Panel(
|
| 486 |
+
"[bold cyan]Export Qwen Code Credential to .env[/bold cyan]", expand=False
|
| 487 |
+
)
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
# Get auth instance for this provider
|
| 491 |
+
provider_factory, _ = _ensure_providers_loaded()
|
| 492 |
+
auth_class = provider_factory.get_provider_auth_class("qwen_code")
|
| 493 |
+
auth_instance = auth_class()
|
| 494 |
|
| 495 |
+
# List available credentials using auth class
|
| 496 |
+
credentials = auth_instance.list_credentials(_get_oauth_base_dir())
|
| 497 |
|
| 498 |
+
if not credentials:
|
| 499 |
+
console.print(
|
| 500 |
+
Panel(
|
| 501 |
+
"No Qwen Code credentials found. Please add one first using 'Add OAuth Credential'.",
|
| 502 |
+
style="bold red",
|
| 503 |
+
title="No Credentials",
|
| 504 |
+
)
|
| 505 |
+
)
|
| 506 |
return
|
| 507 |
|
| 508 |
# Display available credentials
|
| 509 |
cred_text = Text()
|
| 510 |
+
for i, cred_info in enumerate(credentials):
|
| 511 |
+
cred_text.append(
|
| 512 |
+
f" {i + 1}. {Path(cred_info['file_path']).name} ({cred_info['email']})\n"
|
| 513 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
|
| 515 |
+
console.print(
|
| 516 |
+
Panel(cred_text, title="Available Qwen Code Credentials", style="bold blue")
|
| 517 |
+
)
|
| 518 |
|
| 519 |
choice = Prompt.ask(
|
| 520 |
+
Text.from_markup(
|
| 521 |
+
"[bold]Please select a credential to export or type [red]'b'[/red] to go back[/bold]"
|
| 522 |
+
),
|
| 523 |
+
choices=[str(i + 1) for i in range(len(credentials))] + ["b"],
|
| 524 |
+
show_choices=False,
|
| 525 |
)
|
| 526 |
|
| 527 |
+
if choice.lower() == "b":
|
| 528 |
return
|
| 529 |
|
| 530 |
try:
|
| 531 |
choice_index = int(choice) - 1
|
| 532 |
+
if 0 <= choice_index < len(credentials):
|
| 533 |
+
cred_info = credentials[choice_index]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 534 |
|
| 535 |
+
# Use auth class to export
|
| 536 |
+
env_path = auth_instance.export_credential_to_env(
|
| 537 |
+
cred_info["file_path"], _get_oauth_base_dir()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 538 |
)
|
| 539 |
+
|
| 540 |
+
if env_path:
|
| 541 |
+
numbered_prefix = f"QWEN_CODE_{cred_info['number']}"
|
| 542 |
+
success_text = Text.from_markup(
|
| 543 |
+
f"Successfully exported credential to [bold yellow]'{Path(env_path).name}'[/bold yellow]\n\n"
|
| 544 |
+
f"[bold]Environment variable prefix:[/bold] [cyan]{numbered_prefix}_*[/cyan]\n\n"
|
| 545 |
+
f"[bold]To use this credential:[/bold]\n"
|
| 546 |
+
f"1. Copy the contents to your main .env file, OR\n"
|
| 547 |
+
f"2. Source it: [bold cyan]source {Path(env_path).name}[/bold cyan] (Linux/Mac)\n\n"
|
| 548 |
+
f"[bold]To combine multiple credentials:[/bold]\n"
|
| 549 |
+
f"Copy lines from multiple .env files into one file.\n"
|
| 550 |
+
f"Each credential uses a unique number ({numbered_prefix}_*)."
|
| 551 |
+
)
|
| 552 |
+
console.print(Panel(success_text, style="bold green", title="Success"))
|
| 553 |
+
else:
|
| 554 |
+
console.print(
|
| 555 |
+
Panel(
|
| 556 |
+
"Failed to export credential", style="bold red", title="Error"
|
| 557 |
+
)
|
| 558 |
+
)
|
| 559 |
else:
|
| 560 |
console.print("[bold red]Invalid choice. Please try again.[/bold red]")
|
| 561 |
except ValueError:
|
| 562 |
+
console.print(
|
| 563 |
+
"[bold red]Invalid input. Please enter a number or 'b'.[/bold red]"
|
| 564 |
+
)
|
| 565 |
except Exception as e:
|
| 566 |
+
console.print(
|
| 567 |
+
Panel(
|
| 568 |
+
f"An error occurred during export: {e}", style="bold red", title="Error"
|
| 569 |
+
)
|
| 570 |
+
)
|
| 571 |
|
| 572 |
|
| 573 |
async def export_iflow_to_env():
|
| 574 |
"""
|
| 575 |
Export an iFlow credential JSON file to .env format.
|
| 576 |
+
Uses the auth class's build_env_lines() and list_credentials() methods.
|
| 577 |
"""
|
| 578 |
+
console.print(
|
| 579 |
+
Panel("[bold cyan]Export iFlow Credential to .env[/bold cyan]", expand=False)
|
| 580 |
+
)
|
| 581 |
|
| 582 |
+
# Get auth instance for this provider
|
| 583 |
+
provider_factory, _ = _ensure_providers_loaded()
|
| 584 |
+
auth_class = provider_factory.get_provider_auth_class("iflow")
|
| 585 |
+
auth_instance = auth_class()
|
| 586 |
|
| 587 |
+
# List available credentials using auth class
|
| 588 |
+
credentials = auth_instance.list_credentials(_get_oauth_base_dir())
|
| 589 |
+
|
| 590 |
+
if not credentials:
|
| 591 |
+
console.print(
|
| 592 |
+
Panel(
|
| 593 |
+
"No iFlow credentials found. Please add one first using 'Add OAuth Credential'.",
|
| 594 |
+
style="bold red",
|
| 595 |
+
title="No Credentials",
|
| 596 |
+
)
|
| 597 |
+
)
|
| 598 |
return
|
| 599 |
|
| 600 |
# Display available credentials
|
| 601 |
cred_text = Text()
|
| 602 |
+
for i, cred_info in enumerate(credentials):
|
| 603 |
+
cred_text.append(
|
| 604 |
+
f" {i + 1}. {Path(cred_info['file_path']).name} ({cred_info['email']})\n"
|
| 605 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 606 |
|
| 607 |
+
console.print(
|
| 608 |
+
Panel(cred_text, title="Available iFlow Credentials", style="bold blue")
|
| 609 |
+
)
|
| 610 |
|
| 611 |
choice = Prompt.ask(
|
| 612 |
+
Text.from_markup(
|
| 613 |
+
"[bold]Please select a credential to export or type [red]'b'[/red] to go back[/bold]"
|
| 614 |
+
),
|
| 615 |
+
choices=[str(i + 1) for i in range(len(credentials))] + ["b"],
|
| 616 |
+
show_choices=False,
|
| 617 |
)
|
| 618 |
|
| 619 |
+
if choice.lower() == "b":
|
| 620 |
return
|
| 621 |
|
| 622 |
try:
|
| 623 |
choice_index = int(choice) - 1
|
| 624 |
+
if 0 <= choice_index < len(credentials):
|
| 625 |
+
cred_info = credentials[choice_index]
|
| 626 |
|
| 627 |
+
# Use auth class to export
|
| 628 |
+
env_path = auth_instance.export_credential_to_env(
|
| 629 |
+
cred_info["file_path"], _get_oauth_base_dir()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 630 |
)
|
| 631 |
+
|
| 632 |
+
if env_path:
|
| 633 |
+
numbered_prefix = f"IFLOW_{cred_info['number']}"
|
| 634 |
+
success_text = Text.from_markup(
|
| 635 |
+
f"Successfully exported credential to [bold yellow]'{Path(env_path).name}'[/bold yellow]\n\n"
|
| 636 |
+
f"[bold]Environment variable prefix:[/bold] [cyan]{numbered_prefix}_*[/cyan]\n\n"
|
| 637 |
+
f"[bold]To use this credential:[/bold]\n"
|
| 638 |
+
f"1. Copy the contents to your main .env file, OR\n"
|
| 639 |
+
f"2. Source it: [bold cyan]source {Path(env_path).name}[/bold cyan] (Linux/Mac)\n\n"
|
| 640 |
+
f"[bold]To combine multiple credentials:[/bold]\n"
|
| 641 |
+
f"Copy lines from multiple .env files into one file.\n"
|
| 642 |
+
f"Each credential uses a unique number ({numbered_prefix}_*)."
|
| 643 |
+
)
|
| 644 |
+
console.print(Panel(success_text, style="bold green", title="Success"))
|
| 645 |
+
else:
|
| 646 |
+
console.print(
|
| 647 |
+
Panel(
|
| 648 |
+
"Failed to export credential", style="bold red", title="Error"
|
| 649 |
+
)
|
| 650 |
+
)
|
| 651 |
else:
|
| 652 |
console.print("[bold red]Invalid choice. Please try again.[/bold red]")
|
| 653 |
except ValueError:
|
| 654 |
+
console.print(
|
| 655 |
+
"[bold red]Invalid input. Please enter a number or 'b'.[/bold red]"
|
| 656 |
+
)
|
| 657 |
except Exception as e:
|
| 658 |
+
console.print(
|
| 659 |
+
Panel(
|
| 660 |
+
f"An error occurred during export: {e}", style="bold red", title="Error"
|
| 661 |
+
)
|
| 662 |
+
)
|
| 663 |
|
| 664 |
|
| 665 |
async def export_antigravity_to_env():
|
| 666 |
"""
|
| 667 |
Export an Antigravity credential JSON file to .env format.
|
| 668 |
+
Uses the auth class's build_env_lines() and list_credentials() methods.
|
| 669 |
"""
|
| 670 |
+
console.print(
|
| 671 |
+
Panel(
|
| 672 |
+
"[bold cyan]Export Antigravity Credential to .env[/bold cyan]", expand=False
|
| 673 |
+
)
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
# Get auth instance for this provider
|
| 677 |
+
provider_factory, _ = _ensure_providers_loaded()
|
| 678 |
+
auth_class = provider_factory.get_provider_auth_class("antigravity")
|
| 679 |
+
auth_instance = auth_class()
|
| 680 |
|
| 681 |
+
# List available credentials using auth class
|
| 682 |
+
credentials = auth_instance.list_credentials(_get_oauth_base_dir())
|
| 683 |
|
| 684 |
+
if not credentials:
|
| 685 |
+
console.print(
|
| 686 |
+
Panel(
|
| 687 |
+
"No Antigravity credentials found. Please add one first using 'Add OAuth Credential'.",
|
| 688 |
+
style="bold red",
|
| 689 |
+
title="No Credentials",
|
| 690 |
+
)
|
| 691 |
+
)
|
| 692 |
return
|
| 693 |
|
| 694 |
# Display available credentials
|
| 695 |
cred_text = Text()
|
| 696 |
+
for i, cred_info in enumerate(credentials):
|
| 697 |
+
cred_text.append(
|
| 698 |
+
f" {i + 1}. {Path(cred_info['file_path']).name} ({cred_info['email']})\n"
|
| 699 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 700 |
|
| 701 |
+
console.print(
|
| 702 |
+
Panel(cred_text, title="Available Antigravity Credentials", style="bold blue")
|
| 703 |
+
)
|
| 704 |
|
| 705 |
choice = Prompt.ask(
|
| 706 |
+
Text.from_markup(
|
| 707 |
+
"[bold]Please select a credential to export or type [red]'b'[/red] to go back[/bold]"
|
| 708 |
+
),
|
| 709 |
+
choices=[str(i + 1) for i in range(len(credentials))] + ["b"],
|
| 710 |
+
show_choices=False,
|
| 711 |
)
|
| 712 |
|
| 713 |
+
if choice.lower() == "b":
|
| 714 |
return
|
| 715 |
|
| 716 |
try:
|
| 717 |
choice_index = int(choice) - 1
|
| 718 |
+
if 0 <= choice_index < len(credentials):
|
| 719 |
+
cred_info = credentials[choice_index]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 720 |
|
| 721 |
+
# Use auth class to export
|
| 722 |
+
env_path = auth_instance.export_credential_to_env(
|
| 723 |
+
cred_info["file_path"], _get_oauth_base_dir()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 724 |
)
|
| 725 |
|
| 726 |
+
if env_path:
|
| 727 |
+
numbered_prefix = f"ANTIGRAVITY_{cred_info['number']}"
|
| 728 |
+
success_text = Text.from_markup(
|
| 729 |
+
f"Successfully exported credential to [bold yellow]'{Path(env_path).name}'[/bold yellow]\n\n"
|
| 730 |
+
f"[bold]Environment variable prefix:[/bold] [cyan]{numbered_prefix}_*[/cyan]\n\n"
|
| 731 |
+
f"[bold]To use this credential:[/bold]\n"
|
| 732 |
+
f"1. Copy the contents to your main .env file, OR\n"
|
| 733 |
+
f"2. Source it: [bold cyan]source {Path(env_path).name}[/bold cyan] (Linux/Mac)\n"
|
| 734 |
+
f"3. Or on Windows: [bold cyan]Get-Content {Path(env_path).name} | ForEach-Object {{ $_ -replace '^([^#].*)$', 'set $1' }} | cmd[/bold cyan]\n\n"
|
| 735 |
+
f"[bold]To combine multiple credentials:[/bold]\n"
|
| 736 |
+
f"Copy lines from multiple .env files into one file.\n"
|
| 737 |
+
f"Each credential uses a unique number ({numbered_prefix}_*)."
|
| 738 |
+
)
|
| 739 |
+
console.print(Panel(success_text, style="bold green", title="Success"))
|
| 740 |
+
else:
|
| 741 |
+
console.print(
|
| 742 |
+
Panel(
|
| 743 |
+
"Failed to export credential", style="bold red", title="Error"
|
| 744 |
+
)
|
| 745 |
+
)
|
| 746 |
else:
|
| 747 |
console.print("[bold red]Invalid choice. Please try again.[/bold red]")
|
| 748 |
except ValueError:
|
| 749 |
+
console.print(
|
| 750 |
+
"[bold red]Invalid input. Please enter a number or 'b'.[/bold red]"
|
| 751 |
+
)
|
| 752 |
except Exception as e:
|
| 753 |
+
console.print(
|
| 754 |
+
Panel(
|
| 755 |
+
f"An error occurred during export: {e}", style="bold red", title="Error"
|
| 756 |
+
)
|
| 757 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 758 |
|
| 759 |
|
| 760 |
async def export_all_provider_credentials(provider_name: str):
|
| 761 |
"""
|
| 762 |
Export all credentials for a specific provider to individual .env files.
|
| 763 |
+
Uses the auth class's list_credentials() and export_credential_to_env() methods.
|
| 764 |
"""
|
| 765 |
+
# Get auth instance for this provider
|
| 766 |
+
provider_factory, _ = _ensure_providers_loaded()
|
| 767 |
+
try:
|
| 768 |
+
auth_class = provider_factory.get_provider_auth_class(provider_name)
|
| 769 |
+
auth_instance = auth_class()
|
| 770 |
+
except Exception:
|
|
|
|
|
|
|
| 771 |
console.print(f"[bold red]Unknown provider: {provider_name}[/bold red]")
|
| 772 |
return
|
| 773 |
+
|
| 774 |
+
display_name = provider_name.replace("_", " ").title()
|
| 775 |
+
|
| 776 |
+
console.print(
|
| 777 |
+
Panel(
|
| 778 |
+
f"[bold cyan]Export All {display_name} Credentials[/bold cyan]",
|
| 779 |
+
expand=False,
|
| 780 |
+
)
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
# List all credentials using auth class
|
| 784 |
+
credentials = auth_instance.list_credentials(_get_oauth_base_dir())
|
| 785 |
+
|
| 786 |
+
if not credentials:
|
| 787 |
+
console.print(
|
| 788 |
+
Panel(
|
| 789 |
+
f"No {display_name} credentials found.",
|
| 790 |
+
style="bold red",
|
| 791 |
+
title="No Credentials",
|
| 792 |
+
)
|
| 793 |
+
)
|
| 794 |
return
|
| 795 |
+
|
| 796 |
exported_count = 0
|
| 797 |
+
for cred_info in credentials:
|
| 798 |
try:
|
| 799 |
+
# Use auth class to export
|
| 800 |
+
env_path = auth_instance.export_credential_to_env(
|
| 801 |
+
cred_info["file_path"], _get_oauth_base_dir()
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
if env_path:
|
| 805 |
+
console.print(
|
| 806 |
+
f" ✓ Exported [cyan]{Path(cred_info['file_path']).name}[/cyan] → [yellow]{Path(env_path).name}[/yellow]"
|
| 807 |
+
)
|
| 808 |
+
exported_count += 1
|
| 809 |
+
else:
|
| 810 |
+
console.print(
|
| 811 |
+
f" ✗ Failed to export {Path(cred_info['file_path']).name}"
|
| 812 |
+
)
|
| 813 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 814 |
except Exception as e:
|
| 815 |
+
console.print(
|
| 816 |
+
f" ✗ Failed to export {Path(cred_info['file_path']).name}: {e}"
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
console.print(
|
| 820 |
+
Panel(
|
| 821 |
+
f"Successfully exported {exported_count}/{len(credentials)} {display_name} credentials to individual .env files.",
|
| 822 |
+
style="bold green",
|
| 823 |
+
title="Export Complete",
|
| 824 |
+
)
|
| 825 |
+
)
|
| 826 |
|
| 827 |
|
| 828 |
async def combine_provider_credentials(provider_name: str):
|
| 829 |
"""
|
| 830 |
Combine all credentials for a specific provider into a single .env file.
|
| 831 |
+
Uses the auth class's list_credentials() and build_env_lines() methods.
|
| 832 |
"""
|
| 833 |
+
# Get auth instance for this provider
|
| 834 |
+
provider_factory, _ = _ensure_providers_loaded()
|
| 835 |
+
try:
|
| 836 |
+
auth_class = provider_factory.get_provider_auth_class(provider_name)
|
| 837 |
+
auth_instance = auth_class()
|
| 838 |
+
except Exception:
|
|
|
|
|
|
|
| 839 |
console.print(f"[bold red]Unknown provider: {provider_name}[/bold red]")
|
| 840 |
return
|
| 841 |
+
|
| 842 |
+
display_name = provider_name.replace("_", " ").title()
|
| 843 |
+
|
| 844 |
+
console.print(
|
| 845 |
+
Panel(
|
| 846 |
+
f"[bold cyan]Combine All {display_name} Credentials[/bold cyan]",
|
| 847 |
+
expand=False,
|
| 848 |
+
)
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
# List all credentials using auth class
|
| 852 |
+
credentials = auth_instance.list_credentials(_get_oauth_base_dir())
|
| 853 |
+
|
| 854 |
+
if not credentials:
|
| 855 |
+
console.print(
|
| 856 |
+
Panel(
|
| 857 |
+
f"No {display_name} credentials found.",
|
| 858 |
+
style="bold red",
|
| 859 |
+
title="No Credentials",
|
| 860 |
+
)
|
| 861 |
+
)
|
| 862 |
return
|
| 863 |
+
|
| 864 |
combined_lines = [
|
| 865 |
f"# Combined {display_name} Credentials",
|
| 866 |
f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
|
| 867 |
+
f"# Total credentials: {len(credentials)}",
|
| 868 |
"#",
|
| 869 |
"# Copy all lines below into your main .env file",
|
| 870 |
"",
|
| 871 |
]
|
| 872 |
+
|
| 873 |
combined_count = 0
|
| 874 |
+
for cred_info in credentials:
|
| 875 |
try:
|
| 876 |
+
# Load credential file
|
| 877 |
+
with open(cred_info["file_path"], "r") as f:
|
| 878 |
creds = json.load(f)
|
| 879 |
+
|
| 880 |
+
# Use auth class to build env lines
|
| 881 |
+
env_lines = auth_instance.build_env_lines(creds, cred_info["number"])
|
| 882 |
+
|
| 883 |
combined_lines.extend(env_lines)
|
| 884 |
combined_lines.append("") # Blank line between credentials
|
| 885 |
combined_count += 1
|
| 886 |
+
|
| 887 |
except Exception as e:
|
| 888 |
+
console.print(
|
| 889 |
+
f" ✗ Failed to process {Path(cred_info['file_path']).name}: {e}"
|
| 890 |
+
)
|
| 891 |
+
|
| 892 |
# Write combined file
|
| 893 |
combined_filename = f"{provider_name}_all_combined.env"
|
| 894 |
+
combined_filepath = _get_oauth_base_dir() / combined_filename
|
| 895 |
+
|
| 896 |
+
with open(combined_filepath, "w") as f:
|
| 897 |
+
f.write("\n".join(combined_lines))
|
| 898 |
+
|
| 899 |
+
console.print(
|
| 900 |
+
Panel(
|
| 901 |
+
Text.from_markup(
|
| 902 |
+
f"Successfully combined {combined_count} {display_name} credentials into:\n"
|
| 903 |
+
f"[bold yellow]{combined_filepath}[/bold yellow]\n\n"
|
| 904 |
+
f"[bold]To use:[/bold] Copy the contents into your main .env file."
|
| 905 |
+
),
|
| 906 |
+
style="bold green",
|
| 907 |
+
title="Combine Complete",
|
| 908 |
+
)
|
| 909 |
+
)
|
| 910 |
|
| 911 |
|
| 912 |
async def combine_all_credentials():
|
| 913 |
"""
|
| 914 |
Combine ALL credentials from ALL providers into a single .env file.
|
| 915 |
+
Uses auth class list_credentials() and build_env_lines() methods.
|
| 916 |
"""
|
| 917 |
+
console.print(
|
| 918 |
+
Panel("[bold cyan]Combine All Provider Credentials[/bold cyan]", expand=False)
|
| 919 |
+
)
|
| 920 |
+
|
| 921 |
+
# List of providers that support OAuth credentials
|
| 922 |
+
oauth_providers = ["gemini_cli", "qwen_code", "iflow", "antigravity"]
|
| 923 |
+
|
| 924 |
+
provider_factory, _ = _ensure_providers_loaded()
|
| 925 |
+
|
| 926 |
combined_lines = [
|
| 927 |
"# Combined All Provider Credentials",
|
| 928 |
f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
|
|
|
|
| 930 |
"# Copy all lines below into your main .env file",
|
| 931 |
"",
|
| 932 |
]
|
| 933 |
+
|
| 934 |
total_count = 0
|
| 935 |
provider_counts = {}
|
| 936 |
+
|
| 937 |
+
for provider_name in oauth_providers:
|
| 938 |
+
try:
|
| 939 |
+
auth_class = provider_factory.get_provider_auth_class(provider_name)
|
| 940 |
+
auth_instance = auth_class()
|
| 941 |
+
except Exception:
|
| 942 |
+
continue # Skip providers that don't have auth classes
|
| 943 |
+
|
| 944 |
+
credentials = auth_instance.list_credentials(_get_oauth_base_dir())
|
| 945 |
+
|
| 946 |
+
if not credentials:
|
| 947 |
continue
|
| 948 |
+
|
| 949 |
+
display_name = provider_name.replace("_", " ").title()
|
| 950 |
combined_lines.append(f"# ===== {display_name} Credentials =====")
|
| 951 |
combined_lines.append("")
|
| 952 |
+
|
| 953 |
provider_count = 0
|
| 954 |
+
for cred_info in credentials:
|
| 955 |
try:
|
| 956 |
+
# Load credential file
|
| 957 |
+
with open(cred_info["file_path"], "r") as f:
|
| 958 |
creds = json.load(f)
|
| 959 |
+
|
| 960 |
+
# Use auth class to build env lines
|
| 961 |
+
env_lines = auth_instance.build_env_lines(creds, cred_info["number"])
|
| 962 |
+
|
| 963 |
combined_lines.extend(env_lines)
|
| 964 |
combined_lines.append("")
|
| 965 |
provider_count += 1
|
| 966 |
total_count += 1
|
| 967 |
+
|
| 968 |
except Exception as e:
|
| 969 |
+
console.print(
|
| 970 |
+
f" ✗ Failed to process {Path(cred_info['file_path']).name}: {e}"
|
| 971 |
+
)
|
| 972 |
+
|
| 973 |
provider_counts[display_name] = provider_count
|
| 974 |
+
|
| 975 |
if total_count == 0:
|
| 976 |
+
console.print(
|
| 977 |
+
Panel(
|
| 978 |
+
"No credentials found to combine.",
|
| 979 |
+
style="bold red",
|
| 980 |
+
title="No Credentials",
|
| 981 |
+
)
|
| 982 |
+
)
|
| 983 |
return
|
| 984 |
+
|
| 985 |
# Write combined file
|
| 986 |
combined_filename = "all_providers_combined.env"
|
| 987 |
+
combined_filepath = _get_oauth_base_dir() / combined_filename
|
| 988 |
+
|
| 989 |
+
with open(combined_filepath, "w") as f:
|
| 990 |
+
f.write("\n".join(combined_lines))
|
| 991 |
+
|
| 992 |
# Build summary
|
| 993 |
+
summary_lines = [
|
| 994 |
+
f" • {name}: {count} credential(s)" for name, count in provider_counts.items()
|
| 995 |
+
]
|
| 996 |
summary = "\n".join(summary_lines)
|
| 997 |
+
|
| 998 |
+
console.print(
|
| 999 |
+
Panel(
|
| 1000 |
+
Text.from_markup(
|
| 1001 |
+
f"Successfully combined {total_count} credentials from {len(provider_counts)} providers:\n"
|
| 1002 |
+
f"{summary}\n\n"
|
| 1003 |
+
f"[bold]Output file:[/bold] [yellow]{combined_filepath}[/yellow]\n\n"
|
| 1004 |
+
f"[bold]To use:[/bold] Copy the contents into your main .env file."
|
| 1005 |
+
),
|
| 1006 |
+
style="bold green",
|
| 1007 |
+
title="Combine Complete",
|
| 1008 |
+
)
|
| 1009 |
+
)
|
| 1010 |
|
| 1011 |
|
| 1012 |
async def export_credentials_submenu():
|
|
|
|
| 1015 |
"""
|
| 1016 |
while True:
|
| 1017 |
clear_screen()
|
| 1018 |
+
console.print(
|
| 1019 |
+
Panel(
|
| 1020 |
+
"[bold cyan]Export Credentials to .env[/bold cyan]",
|
| 1021 |
+
title="--- API Key Proxy ---",
|
| 1022 |
+
expand=False,
|
| 1023 |
+
)
|
| 1024 |
+
)
|
| 1025 |
+
|
| 1026 |
+
console.print(
|
| 1027 |
+
Panel(
|
| 1028 |
+
Text.from_markup(
|
| 1029 |
+
"[bold]Individual Exports:[/bold]\n"
|
| 1030 |
+
"1. Export Gemini CLI credential\n"
|
| 1031 |
+
"2. Export Qwen Code credential\n"
|
| 1032 |
+
"3. Export iFlow credential\n"
|
| 1033 |
+
"4. Export Antigravity credential\n"
|
| 1034 |
+
"\n"
|
| 1035 |
+
"[bold]Bulk Exports (per provider):[/bold]\n"
|
| 1036 |
+
"5. Export ALL Gemini CLI credentials\n"
|
| 1037 |
+
"6. Export ALL Qwen Code credentials\n"
|
| 1038 |
+
"7. Export ALL iFlow credentials\n"
|
| 1039 |
+
"8. Export ALL Antigravity credentials\n"
|
| 1040 |
+
"\n"
|
| 1041 |
+
"[bold]Combine Credentials:[/bold]\n"
|
| 1042 |
+
"9. Combine all Gemini CLI into one file\n"
|
| 1043 |
+
"10. Combine all Qwen Code into one file\n"
|
| 1044 |
+
"11. Combine all iFlow into one file\n"
|
| 1045 |
+
"12. Combine all Antigravity into one file\n"
|
| 1046 |
+
"13. Combine ALL providers into one file"
|
| 1047 |
+
),
|
| 1048 |
+
title="Choose export option",
|
| 1049 |
+
style="bold blue",
|
| 1050 |
+
)
|
| 1051 |
+
)
|
| 1052 |
|
| 1053 |
export_choice = Prompt.ask(
|
| 1054 |
+
Text.from_markup(
|
| 1055 |
+
"[bold]Please select an option or type [red]'b'[/red] to go back[/bold]"
|
| 1056 |
+
),
|
| 1057 |
+
choices=[
|
| 1058 |
+
"1",
|
| 1059 |
+
"2",
|
| 1060 |
+
"3",
|
| 1061 |
+
"4",
|
| 1062 |
+
"5",
|
| 1063 |
+
"6",
|
| 1064 |
+
"7",
|
| 1065 |
+
"8",
|
| 1066 |
+
"9",
|
| 1067 |
+
"10",
|
| 1068 |
+
"11",
|
| 1069 |
+
"12",
|
| 1070 |
+
"13",
|
| 1071 |
+
"b",
|
| 1072 |
+
],
|
| 1073 |
+
show_choices=False,
|
| 1074 |
)
|
| 1075 |
|
| 1076 |
+
if export_choice.lower() == "b":
|
| 1077 |
break
|
| 1078 |
|
| 1079 |
# Individual exports
|
|
|
|
| 1137 |
async def main(clear_on_start=True):
|
| 1138 |
"""
|
| 1139 |
An interactive CLI tool to add new credentials.
|
| 1140 |
+
|
| 1141 |
Args:
|
| 1142 |
+
clear_on_start: If False, skip initial screen clear (used when called from launcher
|
| 1143 |
to preserve the loading screen)
|
| 1144 |
"""
|
| 1145 |
ensure_env_defaults()
|
| 1146 |
+
|
| 1147 |
# Only show header if we're clearing (standalone mode)
|
| 1148 |
if clear_on_start:
|
| 1149 |
+
console.print(
|
| 1150 |
+
Panel(
|
| 1151 |
+
"[bold cyan]Interactive Credential Setup[/bold cyan]",
|
| 1152 |
+
title="--- API Key Proxy ---",
|
| 1153 |
+
expand=False,
|
| 1154 |
+
)
|
| 1155 |
+
)
|
| 1156 |
+
|
| 1157 |
while True:
|
| 1158 |
# Clear screen between menu selections for cleaner UX
|
| 1159 |
clear_screen()
|
| 1160 |
+
console.print(
|
| 1161 |
+
Panel(
|
| 1162 |
+
"[bold cyan]Interactive Credential Setup[/bold cyan]",
|
| 1163 |
+
title="--- API Key Proxy ---",
|
| 1164 |
+
expand=False,
|
| 1165 |
+
)
|
| 1166 |
+
)
|
| 1167 |
+
|
| 1168 |
+
console.print(
|
| 1169 |
+
Panel(
|
| 1170 |
+
Text.from_markup(
|
| 1171 |
+
"1. Add OAuth Credential\n2. Add API Key\n3. Export Credentials"
|
| 1172 |
+
),
|
| 1173 |
+
title="Choose credential type",
|
| 1174 |
+
style="bold blue",
|
| 1175 |
+
)
|
| 1176 |
+
)
|
| 1177 |
|
| 1178 |
setup_type = Prompt.ask(
|
| 1179 |
+
Text.from_markup(
|
| 1180 |
+
"[bold]Please select an option or type [red]'q'[/red] to quit[/bold]"
|
| 1181 |
+
),
|
| 1182 |
choices=["1", "2", "3", "q"],
|
| 1183 |
+
show_choices=False,
|
| 1184 |
)
|
| 1185 |
|
| 1186 |
+
if setup_type.lower() == "q":
|
| 1187 |
break
|
| 1188 |
|
| 1189 |
if setup_type == "1":
|
|
|
|
| 1195 |
"iflow": "iFlow (OAuth - also supports API keys)",
|
| 1196 |
"antigravity": "Antigravity (OAuth)",
|
| 1197 |
}
|
| 1198 |
+
|
| 1199 |
provider_text = Text()
|
| 1200 |
for i, provider in enumerate(available_providers):
|
| 1201 |
+
display_name = oauth_friendly_names.get(
|
| 1202 |
+
provider, provider.replace("_", " ").title()
|
| 1203 |
+
)
|
| 1204 |
provider_text.append(f" {i + 1}. {display_name}\n")
|
| 1205 |
+
|
| 1206 |
+
console.print(
|
| 1207 |
+
Panel(
|
| 1208 |
+
provider_text,
|
| 1209 |
+
title="Available Providers for OAuth",
|
| 1210 |
+
style="bold blue",
|
| 1211 |
+
)
|
| 1212 |
+
)
|
| 1213 |
|
| 1214 |
choice = Prompt.ask(
|
| 1215 |
+
Text.from_markup(
|
| 1216 |
+
"[bold]Please select a provider or type [red]'b'[/red] to go back[/bold]"
|
| 1217 |
+
),
|
| 1218 |
choices=[str(i + 1) for i in range(len(available_providers))] + ["b"],
|
| 1219 |
+
show_choices=False,
|
| 1220 |
)
|
| 1221 |
|
| 1222 |
+
if choice.lower() == "b":
|
| 1223 |
continue
|
| 1224 |
+
|
| 1225 |
try:
|
| 1226 |
choice_index = int(choice) - 1
|
| 1227 |
if 0 <= choice_index < len(available_providers):
|
| 1228 |
provider_name = available_providers[choice_index]
|
| 1229 |
+
display_name = oauth_friendly_names.get(
|
| 1230 |
+
provider_name, provider_name.replace("_", " ").title()
|
| 1231 |
+
)
|
| 1232 |
+
console.print(
|
| 1233 |
+
f"\nStarting OAuth setup for [bold cyan]{display_name}[/bold cyan]..."
|
| 1234 |
+
)
|
| 1235 |
await setup_new_credential(provider_name)
|
| 1236 |
# Don't clear after OAuth - user needs to see full flow
|
| 1237 |
console.print("\n[dim]Press Enter to return to main menu...[/dim]")
|
| 1238 |
input()
|
| 1239 |
else:
|
| 1240 |
+
console.print(
|
| 1241 |
+
"[bold red]Invalid choice. Please try again.[/bold red]"
|
| 1242 |
+
)
|
| 1243 |
await asyncio.sleep(1.5)
|
| 1244 |
except ValueError:
|
| 1245 |
+
console.print(
|
| 1246 |
+
"[bold red]Invalid input. Please enter a number or 'b'.[/bold red]"
|
| 1247 |
+
)
|
| 1248 |
await asyncio.sleep(1.5)
|
| 1249 |
|
| 1250 |
elif setup_type == "2":
|
| 1251 |
await setup_api_key()
|
| 1252 |
+
# console.print("\n[dim]Press Enter to return to main menu...[/dim]")
|
| 1253 |
+
# input()
|
| 1254 |
|
| 1255 |
elif setup_type == "3":
|
| 1256 |
await export_credentials_submenu()
|
| 1257 |
|
| 1258 |
+
|
| 1259 |
def run_credential_tool(from_launcher=False):
|
| 1260 |
"""
|
| 1261 |
Entry point for credential tool.
|
| 1262 |
+
|
| 1263 |
Args:
|
| 1264 |
from_launcher: If True, skip loading screen (launcher already showed it)
|
| 1265 |
"""
|
| 1266 |
# Check if we need to show loading screen
|
| 1267 |
if not from_launcher:
|
| 1268 |
# Standalone mode - show full loading UI
|
| 1269 |
+
os.system("cls" if os.name == "nt" else "clear")
|
| 1270 |
+
|
| 1271 |
_start_time = time.time()
|
| 1272 |
+
|
| 1273 |
# Phase 1: Show initial message
|
| 1274 |
print("━" * 70)
|
| 1275 |
print("Interactive Credential Setup Tool")
|
| 1276 |
print("GitHub: https://github.com/Mirrowel/LLM-API-Key-Proxy")
|
| 1277 |
print("━" * 70)
|
| 1278 |
print("Loading credential management components...")
|
| 1279 |
+
|
| 1280 |
# Phase 2: Load dependencies with spinner
|
| 1281 |
with console.status("Loading authentication providers...", spinner="dots"):
|
| 1282 |
_ensure_providers_loaded()
|
|
|
|
| 1285 |
with console.status("Initializing credential tool...", spinner="dots"):
|
| 1286 |
time.sleep(0.2) # Brief pause for UI consistency
|
| 1287 |
console.print("✓ Credential tool initialized")
|
| 1288 |
+
|
| 1289 |
_elapsed = time.time() - _start_time
|
| 1290 |
_, PROVIDER_PLUGINS = _ensure_providers_loaded()
|
| 1291 |
+
print(
|
| 1292 |
+
f"✓ Tool ready in {_elapsed:.2f}s ({len(PROVIDER_PLUGINS)} providers available)"
|
| 1293 |
+
)
|
| 1294 |
+
|
| 1295 |
# Small delay to let user see the ready message
|
| 1296 |
time.sleep(0.5)
|
| 1297 |
+
|
| 1298 |
# Run the main async event loop
|
| 1299 |
# If from launcher, don't clear screen at start to preserve loading messages
|
| 1300 |
try:
|
src/rotator_library/failure_logger.py
CHANGED
|
@@ -1,47 +1,93 @@
|
|
| 1 |
import logging
|
| 2 |
import json
|
| 3 |
from logging.handlers import RotatingFileHandler
|
| 4 |
-
import
|
| 5 |
from datetime import datetime
|
|
|
|
|
|
|
| 6 |
from .error_handler import mask_credential
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
log_dir = "logs"
|
| 12 |
-
if not os.path.exists(log_dir):
|
| 13 |
-
os.makedirs(log_dir)
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
|
|
|
| 17 |
logger = logging.getLogger("failure_logger")
|
| 18 |
logger.setLevel(logging.INFO)
|
| 19 |
logger.propagate = False
|
| 20 |
|
| 21 |
-
#
|
| 22 |
-
|
| 23 |
-
os.path.join(log_dir, "failures.log"),
|
| 24 |
-
maxBytes=5 * 1024 * 1024, # 5 MB
|
| 25 |
-
backupCount=2,
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
# Custom JSON formatter for structured logs
|
| 29 |
-
class JsonFormatter(logging.Formatter):
|
| 30 |
-
def format(self, record):
|
| 31 |
-
# The message is already a dict, so we just format it as a JSON string
|
| 32 |
-
return json.dumps(record.msg)
|
| 33 |
|
| 34 |
-
|
|
|
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
logger.addHandler(handler)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
return logger
|
| 41 |
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
# Get the main library logger for concise, propagated messages
|
| 47 |
main_lib_logger = logging.getLogger("rotator_library")
|
|
@@ -52,10 +98,27 @@ def _extract_response_body(error: Exception) -> str:
|
|
| 52 |
Extract the full response body from various error types.
|
| 53 |
|
| 54 |
Handles:
|
|
|
|
| 55 |
- httpx.HTTPStatusError: response.text or response.content
|
| 56 |
- litellm exceptions: various response attributes
|
| 57 |
- Other exceptions: str(error)
|
| 58 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
# Try to get response body from httpx errors
|
| 60 |
if hasattr(error, "response") and error.response is not None:
|
| 61 |
response = error.response
|
|
@@ -145,11 +208,19 @@ def log_failure(
|
|
| 145 |
"request_headers": request_headers,
|
| 146 |
"error_chain": error_chain if len(error_chain) > 1 else None,
|
| 147 |
}
|
| 148 |
-
failure_logger.error(detailed_log_data)
|
| 149 |
|
| 150 |
# 2. Log a concise summary to the main library logger, which will propagate
|
| 151 |
summary_message = (
|
| 152 |
f"API call failed for model {model} with key {mask_credential(api_key)}. "
|
| 153 |
f"Error: {type(error).__name__}. See failures.log for details."
|
| 154 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
main_lib_logger.error(summary_message)
|
|
|
|
| 1 |
import logging
|
| 2 |
import json
|
| 3 |
from logging.handlers import RotatingFileHandler
|
| 4 |
+
from pathlib import Path
|
| 5 |
from datetime import datetime
|
| 6 |
+
from typing import Optional, Union
|
| 7 |
+
|
| 8 |
from .error_handler import mask_credential
|
| 9 |
+
from .utils.paths import get_logs_dir
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class JsonFormatter(logging.Formatter):
|
| 13 |
+
"""Custom JSON formatter for structured logs."""
|
| 14 |
+
|
| 15 |
+
def format(self, record):
|
| 16 |
+
# The message is already a dict, so we just format it as a JSON string
|
| 17 |
+
return json.dumps(record.msg)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Module-level state for lazy initialization
|
| 21 |
+
_failure_logger: Optional[logging.Logger] = None
|
| 22 |
+
_configured_logs_dir: Optional[Path] = None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def configure_failure_logger(logs_dir: Optional[Union[Path, str]] = None) -> None:
|
| 26 |
+
"""
|
| 27 |
+
Configure the failure logger to use a specific logs directory.
|
| 28 |
+
|
| 29 |
+
Call this before first use if you want to override the default location.
|
| 30 |
+
If not called, the logger will use get_logs_dir() on first use.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
logs_dir: Path to the logs directory. If None, uses get_logs_dir().
|
| 34 |
+
"""
|
| 35 |
+
global _configured_logs_dir, _failure_logger
|
| 36 |
+
_configured_logs_dir = Path(logs_dir) if logs_dir else None
|
| 37 |
+
# Reset logger so it gets reconfigured on next use
|
| 38 |
+
_failure_logger = None
|
| 39 |
+
|
| 40 |
|
| 41 |
+
def _setup_failure_logger(logs_dir: Path) -> logging.Logger:
|
| 42 |
+
"""
|
| 43 |
+
Sets up a dedicated JSON logger for writing detailed failure logs to a file.
|
| 44 |
|
| 45 |
+
Args:
|
| 46 |
+
logs_dir: Path to the logs directory.
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
Returns:
|
| 49 |
+
Configured logger instance.
|
| 50 |
+
"""
|
| 51 |
logger = logging.getLogger("failure_logger")
|
| 52 |
logger.setLevel(logging.INFO)
|
| 53 |
logger.propagate = False
|
| 54 |
|
| 55 |
+
# Clear existing handlers to prevent duplicates on re-setup
|
| 56 |
+
logger.handlers.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
+
try:
|
| 59 |
+
logs_dir.mkdir(parents=True, exist_ok=True)
|
| 60 |
|
| 61 |
+
handler = RotatingFileHandler(
|
| 62 |
+
logs_dir / "failures.log",
|
| 63 |
+
maxBytes=5 * 1024 * 1024, # 5 MB
|
| 64 |
+
backupCount=2,
|
| 65 |
+
)
|
| 66 |
+
handler.setFormatter(JsonFormatter())
|
| 67 |
logger.addHandler(handler)
|
| 68 |
+
except (OSError, PermissionError, IOError) as e:
|
| 69 |
+
logging.warning(f"Cannot create failure log file handler: {e}")
|
| 70 |
+
# Add NullHandler to prevent "no handlers" warning
|
| 71 |
+
logger.addHandler(logging.NullHandler())
|
| 72 |
|
| 73 |
return logger
|
| 74 |
|
| 75 |
|
| 76 |
+
def get_failure_logger() -> logging.Logger:
|
| 77 |
+
"""
|
| 78 |
+
Get the failure logger, initializing it lazily if needed.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
The configured failure logger.
|
| 82 |
+
"""
|
| 83 |
+
global _failure_logger, _configured_logs_dir
|
| 84 |
+
|
| 85 |
+
if _failure_logger is None:
|
| 86 |
+
logs_dir = _configured_logs_dir if _configured_logs_dir else get_logs_dir()
|
| 87 |
+
_failure_logger = _setup_failure_logger(logs_dir)
|
| 88 |
+
|
| 89 |
+
return _failure_logger
|
| 90 |
+
|
| 91 |
|
| 92 |
# Get the main library logger for concise, propagated messages
|
| 93 |
main_lib_logger = logging.getLogger("rotator_library")
|
|
|
|
| 98 |
Extract the full response body from various error types.
|
| 99 |
|
| 100 |
Handles:
|
| 101 |
+
- StreamedAPIError: wraps original exception in .data attribute
|
| 102 |
- httpx.HTTPStatusError: response.text or response.content
|
| 103 |
- litellm exceptions: various response attributes
|
| 104 |
- Other exceptions: str(error)
|
| 105 |
"""
|
| 106 |
+
# Handle StreamedAPIError which wraps the original exception in .data
|
| 107 |
+
# This is used by our streaming wrapper when catching provider errors
|
| 108 |
+
if hasattr(error, "data") and error.data is not None:
|
| 109 |
+
inner = error.data
|
| 110 |
+
# If data is a dict (parsed JSON error), return it as JSON
|
| 111 |
+
if isinstance(inner, dict):
|
| 112 |
+
try:
|
| 113 |
+
return json.dumps(inner, indent=2)
|
| 114 |
+
except Exception:
|
| 115 |
+
return str(inner)
|
| 116 |
+
# If data is an exception, recurse to extract from it
|
| 117 |
+
if isinstance(inner, Exception):
|
| 118 |
+
result = _extract_response_body(inner)
|
| 119 |
+
if result:
|
| 120 |
+
return result
|
| 121 |
+
|
| 122 |
# Try to get response body from httpx errors
|
| 123 |
if hasattr(error, "response") and error.response is not None:
|
| 124 |
response = error.response
|
|
|
|
| 208 |
"request_headers": request_headers,
|
| 209 |
"error_chain": error_chain if len(error_chain) > 1 else None,
|
| 210 |
}
|
|
|
|
| 211 |
|
| 212 |
# 2. Log a concise summary to the main library logger, which will propagate
|
| 213 |
summary_message = (
|
| 214 |
f"API call failed for model {model} with key {mask_credential(api_key)}. "
|
| 215 |
f"Error: {type(error).__name__}. See failures.log for details."
|
| 216 |
)
|
| 217 |
+
|
| 218 |
+
# Log to failure logger with resilience - if it fails, just continue
|
| 219 |
+
try:
|
| 220 |
+
get_failure_logger().error(detailed_log_data)
|
| 221 |
+
except (OSError, IOError) as e:
|
| 222 |
+
# Log file write failed - log to console instead
|
| 223 |
+
logging.warning(f"Failed to write to failures.log: {e}")
|
| 224 |
+
|
| 225 |
+
# Console log always succeeds
|
| 226 |
main_lib_logger.error(summary_message)
|
src/rotator_library/providers/antigravity_auth_base.py
CHANGED
|
@@ -1,16 +1,36 @@
|
|
| 1 |
# src/rotator_library/providers/antigravity_auth_base.py
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from .google_oauth_base import GoogleOAuthBase
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
class AntigravityAuthBase(GoogleOAuthBase):
|
| 6 |
"""
|
| 7 |
Antigravity OAuth2 authentication implementation.
|
| 8 |
-
|
| 9 |
Inherits all OAuth functionality from GoogleOAuthBase with Antigravity-specific configuration.
|
| 10 |
Uses Antigravity's OAuth credentials and includes additional scopes for cclog and experimentsandconfigs.
|
|
|
|
|
|
|
|
|
|
| 11 |
"""
|
| 12 |
-
|
| 13 |
-
CLIENT_ID =
|
|
|
|
|
|
|
| 14 |
CLIENT_SECRET = "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"
|
| 15 |
OAUTH_SCOPES = [
|
| 16 |
"https://www.googleapis.com/auth/cloud-platform",
|
|
@@ -22,3 +42,600 @@ class AntigravityAuthBase(GoogleOAuthBase):
|
|
| 22 |
ENV_PREFIX = "ANTIGRAVITY"
|
| 23 |
CALLBACK_PORT = 51121
|
| 24 |
CALLBACK_PATH = "/oauthcallback"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# src/rotator_library/providers/antigravity_auth_base.py
|
| 2 |
|
| 3 |
+
import asyncio
|
| 4 |
+
import json
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Dict, Optional, List
|
| 9 |
+
|
| 10 |
+
import httpx
|
| 11 |
+
|
| 12 |
from .google_oauth_base import GoogleOAuthBase
|
| 13 |
|
| 14 |
+
lib_logger = logging.getLogger("rotator_library")
|
| 15 |
+
|
| 16 |
+
# Code Assist endpoint for project discovery
|
| 17 |
+
CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com/v1internal"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
class AntigravityAuthBase(GoogleOAuthBase):
|
| 21 |
"""
|
| 22 |
Antigravity OAuth2 authentication implementation.
|
| 23 |
+
|
| 24 |
Inherits all OAuth functionality from GoogleOAuthBase with Antigravity-specific configuration.
|
| 25 |
Uses Antigravity's OAuth credentials and includes additional scopes for cclog and experimentsandconfigs.
|
| 26 |
+
|
| 27 |
+
Also provides project/tier discovery functionality that runs during authentication,
|
| 28 |
+
ensuring credentials have their tier and project_id cached before any API requests.
|
| 29 |
"""
|
| 30 |
+
|
| 31 |
+
CLIENT_ID = (
|
| 32 |
+
"1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
| 33 |
+
)
|
| 34 |
CLIENT_SECRET = "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"
|
| 35 |
OAUTH_SCOPES = [
|
| 36 |
"https://www.googleapis.com/auth/cloud-platform",
|
|
|
|
| 42 |
ENV_PREFIX = "ANTIGRAVITY"
|
| 43 |
CALLBACK_PORT = 51121
|
| 44 |
CALLBACK_PATH = "/oauthcallback"
|
| 45 |
+
|
| 46 |
+
def __init__(self):
|
| 47 |
+
super().__init__()
|
| 48 |
+
# Project and tier caches - shared between auth base and provider
|
| 49 |
+
self.project_id_cache: Dict[str, str] = {}
|
| 50 |
+
self.project_tier_cache: Dict[str, str] = {}
|
| 51 |
+
|
| 52 |
+
# =========================================================================
|
| 53 |
+
# POST-AUTH DISCOVERY HOOK
|
| 54 |
+
# =========================================================================
|
| 55 |
+
|
| 56 |
+
async def _post_auth_discovery(
|
| 57 |
+
self, credential_path: str, access_token: str
|
| 58 |
+
) -> None:
|
| 59 |
+
"""
|
| 60 |
+
Discover and cache tier/project information immediately after OAuth authentication.
|
| 61 |
+
|
| 62 |
+
This is called by GoogleOAuthBase._perform_interactive_oauth() after successful auth,
|
| 63 |
+
ensuring tier and project_id are cached during the authentication flow rather than
|
| 64 |
+
waiting for the first API request.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
credential_path: Path to the credential file
|
| 68 |
+
access_token: The newly obtained access token
|
| 69 |
+
"""
|
| 70 |
+
lib_logger.debug(
|
| 71 |
+
f"Starting post-auth discovery for Antigravity credential: {Path(credential_path).name}"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# Skip if already discovered (shouldn't happen during fresh auth, but be defensive)
|
| 75 |
+
if (
|
| 76 |
+
credential_path in self.project_id_cache
|
| 77 |
+
and credential_path in self.project_tier_cache
|
| 78 |
+
):
|
| 79 |
+
lib_logger.debug(
|
| 80 |
+
f"Tier and project already cached for {Path(credential_path).name}, skipping discovery"
|
| 81 |
+
)
|
| 82 |
+
return
|
| 83 |
+
|
| 84 |
+
# Call _discover_project_id which handles tier/project discovery and persistence
|
| 85 |
+
# Pass empty litellm_params since we're in auth context (no model-specific overrides)
|
| 86 |
+
project_id = await self._discover_project_id(
|
| 87 |
+
credential_path, access_token, litellm_params={}
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
tier = self.project_tier_cache.get(credential_path, "unknown")
|
| 91 |
+
lib_logger.info(
|
| 92 |
+
f"Post-auth discovery complete for {Path(credential_path).name}: "
|
| 93 |
+
f"tier={tier}, project={project_id}"
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# =========================================================================
|
| 97 |
+
# PROJECT ID DISCOVERY
|
| 98 |
+
# =========================================================================
|
| 99 |
+
|
| 100 |
+
async def _discover_project_id(
|
| 101 |
+
self, credential_path: str, access_token: str, litellm_params: Dict[str, Any]
|
| 102 |
+
) -> str:
|
| 103 |
+
"""
|
| 104 |
+
Discovers the Google Cloud Project ID, with caching and onboarding for new accounts.
|
| 105 |
+
|
| 106 |
+
This follows the official Gemini CLI discovery flow adapted for Antigravity:
|
| 107 |
+
1. Check in-memory cache
|
| 108 |
+
2. Check configured project_id override (litellm_params or env var)
|
| 109 |
+
3. Check persisted project_id in credential file
|
| 110 |
+
4. Call loadCodeAssist to check if user is already known (has currentTier)
|
| 111 |
+
- If currentTier exists AND cloudaicompanionProject returned: use server's project
|
| 112 |
+
- If no currentTier: user needs onboarding
|
| 113 |
+
5. Onboard user (FREE tier: pass cloudaicompanionProject=None for server-managed)
|
| 114 |
+
6. Fallback to GCP Resource Manager project listing
|
| 115 |
+
|
| 116 |
+
Note: Unlike GeminiCli, Antigravity doesn't use tier-based credential prioritization,
|
| 117 |
+
but we still cache tier info for debugging and consistency.
|
| 118 |
+
"""
|
| 119 |
+
lib_logger.debug(
|
| 120 |
+
f"Starting Antigravity project discovery for credential: {credential_path}"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Check in-memory cache first
|
| 124 |
+
if credential_path in self.project_id_cache:
|
| 125 |
+
cached_project = self.project_id_cache[credential_path]
|
| 126 |
+
lib_logger.debug(f"Using cached project ID: {cached_project}")
|
| 127 |
+
return cached_project
|
| 128 |
+
|
| 129 |
+
# Check for configured project ID override (from litellm_params or env var)
|
| 130 |
+
configured_project_id = (
|
| 131 |
+
litellm_params.get("project_id")
|
| 132 |
+
or os.getenv("ANTIGRAVITY_PROJECT_ID")
|
| 133 |
+
or os.getenv("GOOGLE_CLOUD_PROJECT")
|
| 134 |
+
)
|
| 135 |
+
if configured_project_id:
|
| 136 |
+
lib_logger.debug(
|
| 137 |
+
f"Found configured project_id override: {configured_project_id}"
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# Load credentials from file to check for persisted project_id and tier
|
| 141 |
+
# Skip for env:// paths (environment-based credentials don't persist to files)
|
| 142 |
+
credential_index = self._parse_env_credential_path(credential_path)
|
| 143 |
+
if credential_index is None:
|
| 144 |
+
# Only try to load from file if it's not an env:// path
|
| 145 |
+
try:
|
| 146 |
+
with open(credential_path, "r") as f:
|
| 147 |
+
creds = json.load(f)
|
| 148 |
+
|
| 149 |
+
metadata = creds.get("_proxy_metadata", {})
|
| 150 |
+
persisted_project_id = metadata.get("project_id")
|
| 151 |
+
persisted_tier = metadata.get("tier")
|
| 152 |
+
|
| 153 |
+
if persisted_project_id:
|
| 154 |
+
lib_logger.info(
|
| 155 |
+
f"Loaded persisted project ID from credential file: {persisted_project_id}"
|
| 156 |
+
)
|
| 157 |
+
self.project_id_cache[credential_path] = persisted_project_id
|
| 158 |
+
|
| 159 |
+
# Also load tier if available (for debugging/logging purposes)
|
| 160 |
+
if persisted_tier:
|
| 161 |
+
self.project_tier_cache[credential_path] = persisted_tier
|
| 162 |
+
lib_logger.debug(f"Loaded persisted tier: {persisted_tier}")
|
| 163 |
+
|
| 164 |
+
return persisted_project_id
|
| 165 |
+
except (FileNotFoundError, json.JSONDecodeError, KeyError) as e:
|
| 166 |
+
lib_logger.debug(f"Could not load persisted project ID from file: {e}")
|
| 167 |
+
|
| 168 |
+
lib_logger.debug(
|
| 169 |
+
"No cached or configured project ID found, initiating discovery..."
|
| 170 |
+
)
|
| 171 |
+
headers = {
|
| 172 |
+
"Authorization": f"Bearer {access_token}",
|
| 173 |
+
"Content-Type": "application/json",
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
discovered_project_id = None
|
| 177 |
+
discovered_tier = None
|
| 178 |
+
|
| 179 |
+
async with httpx.AsyncClient() as client:
|
| 180 |
+
# 1. Try discovery endpoint with loadCodeAssist
|
| 181 |
+
lib_logger.debug(
|
| 182 |
+
"Attempting project discovery via Code Assist loadCodeAssist endpoint..."
|
| 183 |
+
)
|
| 184 |
+
try:
|
| 185 |
+
# Build metadata - include duetProject only if we have a configured project
|
| 186 |
+
core_client_metadata = {
|
| 187 |
+
"ideType": "IDE_UNSPECIFIED",
|
| 188 |
+
"platform": "PLATFORM_UNSPECIFIED",
|
| 189 |
+
"pluginType": "GEMINI",
|
| 190 |
+
}
|
| 191 |
+
if configured_project_id:
|
| 192 |
+
core_client_metadata["duetProject"] = configured_project_id
|
| 193 |
+
|
| 194 |
+
# Build load request - pass configured_project_id if available, otherwise None
|
| 195 |
+
load_request = {
|
| 196 |
+
"cloudaicompanionProject": configured_project_id, # Can be None
|
| 197 |
+
"metadata": core_client_metadata,
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
lib_logger.debug(
|
| 201 |
+
f"Sending loadCodeAssist request with cloudaicompanionProject={configured_project_id}"
|
| 202 |
+
)
|
| 203 |
+
response = await client.post(
|
| 204 |
+
f"{CODE_ASSIST_ENDPOINT}:loadCodeAssist",
|
| 205 |
+
headers=headers,
|
| 206 |
+
json=load_request,
|
| 207 |
+
timeout=20,
|
| 208 |
+
)
|
| 209 |
+
response.raise_for_status()
|
| 210 |
+
data = response.json()
|
| 211 |
+
|
| 212 |
+
# Log full response for debugging
|
| 213 |
+
lib_logger.debug(
|
| 214 |
+
f"loadCodeAssist full response keys: {list(data.keys())}"
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# Extract tier information
|
| 218 |
+
allowed_tiers = data.get("allowedTiers", [])
|
| 219 |
+
current_tier = data.get("currentTier")
|
| 220 |
+
|
| 221 |
+
lib_logger.debug(f"=== Tier Information ===")
|
| 222 |
+
lib_logger.debug(f"currentTier: {current_tier}")
|
| 223 |
+
lib_logger.debug(f"allowedTiers count: {len(allowed_tiers)}")
|
| 224 |
+
for i, tier in enumerate(allowed_tiers):
|
| 225 |
+
tier_id = tier.get("id", "unknown")
|
| 226 |
+
is_default = tier.get("isDefault", False)
|
| 227 |
+
user_defined = tier.get("userDefinedCloudaicompanionProject", False)
|
| 228 |
+
lib_logger.debug(
|
| 229 |
+
f" Tier {i + 1}: id={tier_id}, isDefault={is_default}, userDefinedProject={user_defined}"
|
| 230 |
+
)
|
| 231 |
+
lib_logger.debug(f"========================")
|
| 232 |
+
|
| 233 |
+
# Determine the current tier ID
|
| 234 |
+
current_tier_id = None
|
| 235 |
+
if current_tier:
|
| 236 |
+
current_tier_id = current_tier.get("id")
|
| 237 |
+
lib_logger.debug(f"User has currentTier: {current_tier_id}")
|
| 238 |
+
|
| 239 |
+
# Check if user is already known to server (has currentTier)
|
| 240 |
+
if current_tier_id:
|
| 241 |
+
# User is already onboarded - check for project from server
|
| 242 |
+
server_project = data.get("cloudaicompanionProject")
|
| 243 |
+
|
| 244 |
+
# Check if this tier requires user-defined project (paid tiers)
|
| 245 |
+
requires_user_project = any(
|
| 246 |
+
t.get("id") == current_tier_id
|
| 247 |
+
and t.get("userDefinedCloudaicompanionProject", False)
|
| 248 |
+
for t in allowed_tiers
|
| 249 |
+
)
|
| 250 |
+
is_free_tier = current_tier_id == "free-tier"
|
| 251 |
+
|
| 252 |
+
if server_project:
|
| 253 |
+
# Server returned a project - use it (server wins)
|
| 254 |
+
project_id = server_project
|
| 255 |
+
lib_logger.debug(f"Server returned project: {project_id}")
|
| 256 |
+
elif configured_project_id:
|
| 257 |
+
# No server project but we have configured one - use it
|
| 258 |
+
project_id = configured_project_id
|
| 259 |
+
lib_logger.debug(
|
| 260 |
+
f"No server project, using configured: {project_id}"
|
| 261 |
+
)
|
| 262 |
+
elif is_free_tier:
|
| 263 |
+
# Free tier user without server project - try onboarding
|
| 264 |
+
lib_logger.debug(
|
| 265 |
+
"Free tier user with currentTier but no project - will try onboarding"
|
| 266 |
+
)
|
| 267 |
+
project_id = None
|
| 268 |
+
elif requires_user_project:
|
| 269 |
+
# Paid tier requires a project ID to be set
|
| 270 |
+
raise ValueError(
|
| 271 |
+
f"Paid tier '{current_tier_id}' requires setting ANTIGRAVITY_PROJECT_ID environment variable."
|
| 272 |
+
)
|
| 273 |
+
else:
|
| 274 |
+
# Unknown tier without project - proceed to onboarding
|
| 275 |
+
lib_logger.warning(
|
| 276 |
+
f"Tier '{current_tier_id}' has no project and none configured - will try onboarding"
|
| 277 |
+
)
|
| 278 |
+
project_id = None
|
| 279 |
+
|
| 280 |
+
if project_id:
|
| 281 |
+
# Cache tier info
|
| 282 |
+
self.project_tier_cache[credential_path] = current_tier_id
|
| 283 |
+
discovered_tier = current_tier_id
|
| 284 |
+
|
| 285 |
+
# Log appropriately based on tier
|
| 286 |
+
is_paid = current_tier_id and current_tier_id not in [
|
| 287 |
+
"free-tier",
|
| 288 |
+
"legacy-tier",
|
| 289 |
+
"unknown",
|
| 290 |
+
]
|
| 291 |
+
if is_paid:
|
| 292 |
+
lib_logger.info(
|
| 293 |
+
f"Using Antigravity paid tier '{current_tier_id}' with project: {project_id}"
|
| 294 |
+
)
|
| 295 |
+
else:
|
| 296 |
+
lib_logger.info(
|
| 297 |
+
f"Discovered Antigravity project ID via loadCodeAssist: {project_id}"
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
self.project_id_cache[credential_path] = project_id
|
| 301 |
+
discovered_project_id = project_id
|
| 302 |
+
|
| 303 |
+
# Persist to credential file
|
| 304 |
+
await self._persist_project_metadata(
|
| 305 |
+
credential_path, project_id, discovered_tier
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
return project_id
|
| 309 |
+
|
| 310 |
+
# 2. User needs onboarding - no currentTier or no project found
|
| 311 |
+
lib_logger.info(
|
| 312 |
+
"No existing Antigravity session found (no currentTier), attempting to onboard user..."
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
# Determine which tier to onboard with
|
| 316 |
+
onboard_tier = None
|
| 317 |
+
for tier in allowed_tiers:
|
| 318 |
+
if tier.get("isDefault"):
|
| 319 |
+
onboard_tier = tier
|
| 320 |
+
break
|
| 321 |
+
|
| 322 |
+
# Fallback to legacy tier if no default
|
| 323 |
+
if not onboard_tier and allowed_tiers:
|
| 324 |
+
for tier in allowed_tiers:
|
| 325 |
+
if tier.get("id") == "legacy-tier":
|
| 326 |
+
onboard_tier = tier
|
| 327 |
+
break
|
| 328 |
+
if not onboard_tier:
|
| 329 |
+
onboard_tier = allowed_tiers[0]
|
| 330 |
+
|
| 331 |
+
if not onboard_tier:
|
| 332 |
+
raise ValueError("No onboarding tiers available from server")
|
| 333 |
+
|
| 334 |
+
tier_id = onboard_tier.get("id", "free-tier")
|
| 335 |
+
requires_user_project = onboard_tier.get(
|
| 336 |
+
"userDefinedCloudaicompanionProject", False
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
lib_logger.debug(
|
| 340 |
+
f"Onboarding with tier: {tier_id}, requiresUserProject: {requires_user_project}"
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
# Build onboard request based on tier type
|
| 344 |
+
# FREE tier: cloudaicompanionProject = None (server-managed)
|
| 345 |
+
# PAID tier: cloudaicompanionProject = configured_project_id
|
| 346 |
+
is_free_tier = tier_id == "free-tier"
|
| 347 |
+
|
| 348 |
+
if is_free_tier:
|
| 349 |
+
# Free tier uses server-managed project
|
| 350 |
+
onboard_request = {
|
| 351 |
+
"tierId": tier_id,
|
| 352 |
+
"cloudaicompanionProject": None, # Server will create/manage
|
| 353 |
+
"metadata": core_client_metadata,
|
| 354 |
+
}
|
| 355 |
+
lib_logger.debug(
|
| 356 |
+
"Free tier onboarding: using server-managed project"
|
| 357 |
+
)
|
| 358 |
+
else:
|
| 359 |
+
# Paid/legacy tier requires user-provided project
|
| 360 |
+
if not configured_project_id and requires_user_project:
|
| 361 |
+
raise ValueError(
|
| 362 |
+
f"Tier '{tier_id}' requires setting ANTIGRAVITY_PROJECT_ID environment variable."
|
| 363 |
+
)
|
| 364 |
+
onboard_request = {
|
| 365 |
+
"tierId": tier_id,
|
| 366 |
+
"cloudaicompanionProject": configured_project_id,
|
| 367 |
+
"metadata": {
|
| 368 |
+
**core_client_metadata,
|
| 369 |
+
"duetProject": configured_project_id,
|
| 370 |
+
}
|
| 371 |
+
if configured_project_id
|
| 372 |
+
else core_client_metadata,
|
| 373 |
+
}
|
| 374 |
+
lib_logger.debug(
|
| 375 |
+
f"Paid tier onboarding: using project {configured_project_id}"
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
lib_logger.debug("Initiating onboardUser request...")
|
| 379 |
+
lro_response = await client.post(
|
| 380 |
+
f"{CODE_ASSIST_ENDPOINT}:onboardUser",
|
| 381 |
+
headers=headers,
|
| 382 |
+
json=onboard_request,
|
| 383 |
+
timeout=30,
|
| 384 |
+
)
|
| 385 |
+
lro_response.raise_for_status()
|
| 386 |
+
lro_data = lro_response.json()
|
| 387 |
+
lib_logger.debug(
|
| 388 |
+
f"Initial onboarding response: done={lro_data.get('done')}"
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
# Poll for onboarding completion (up to 5 minutes)
|
| 392 |
+
for i in range(150): # 150 × 2s = 5 minutes
|
| 393 |
+
if lro_data.get("done"):
|
| 394 |
+
lib_logger.debug(
|
| 395 |
+
f"Onboarding completed after {i} polling attempts"
|
| 396 |
+
)
|
| 397 |
+
break
|
| 398 |
+
await asyncio.sleep(2)
|
| 399 |
+
if (i + 1) % 15 == 0: # Log every 30 seconds
|
| 400 |
+
lib_logger.info(
|
| 401 |
+
f"Still waiting for onboarding completion... ({(i + 1) * 2}s elapsed)"
|
| 402 |
+
)
|
| 403 |
+
lib_logger.debug(
|
| 404 |
+
f"Polling onboarding status... (Attempt {i + 1}/150)"
|
| 405 |
+
)
|
| 406 |
+
lro_response = await client.post(
|
| 407 |
+
f"{CODE_ASSIST_ENDPOINT}:onboardUser",
|
| 408 |
+
headers=headers,
|
| 409 |
+
json=onboard_request,
|
| 410 |
+
timeout=30,
|
| 411 |
+
)
|
| 412 |
+
lro_response.raise_for_status()
|
| 413 |
+
lro_data = lro_response.json()
|
| 414 |
+
|
| 415 |
+
if not lro_data.get("done"):
|
| 416 |
+
lib_logger.error("Onboarding process timed out after 5 minutes")
|
| 417 |
+
raise ValueError(
|
| 418 |
+
"Onboarding process timed out after 5 minutes. Please try again or contact support."
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
# Extract project ID from LRO response
|
| 422 |
+
# Note: onboardUser returns response.cloudaicompanionProject as an object with .id
|
| 423 |
+
lro_response_data = lro_data.get("response", {})
|
| 424 |
+
lro_project_obj = lro_response_data.get("cloudaicompanionProject", {})
|
| 425 |
+
project_id = (
|
| 426 |
+
lro_project_obj.get("id")
|
| 427 |
+
if isinstance(lro_project_obj, dict)
|
| 428 |
+
else None
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
# Fallback to configured project if LRO didn't return one
|
| 432 |
+
if not project_id and configured_project_id:
|
| 433 |
+
project_id = configured_project_id
|
| 434 |
+
lib_logger.debug(
|
| 435 |
+
f"LRO didn't return project, using configured: {project_id}"
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
if not project_id:
|
| 439 |
+
lib_logger.error(
|
| 440 |
+
"Onboarding completed but no project ID in response and none configured"
|
| 441 |
+
)
|
| 442 |
+
raise ValueError(
|
| 443 |
+
"Onboarding completed, but no project ID was returned. "
|
| 444 |
+
"For paid tiers, set ANTIGRAVITY_PROJECT_ID environment variable."
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
lib_logger.debug(
|
| 448 |
+
f"Successfully extracted project ID from onboarding response: {project_id}"
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
# Cache tier info
|
| 452 |
+
self.project_tier_cache[credential_path] = tier_id
|
| 453 |
+
discovered_tier = tier_id
|
| 454 |
+
lib_logger.debug(f"Cached tier information: {tier_id}")
|
| 455 |
+
|
| 456 |
+
# Log concise message based on tier
|
| 457 |
+
is_paid = tier_id and tier_id not in ["free-tier", "legacy-tier"]
|
| 458 |
+
if is_paid:
|
| 459 |
+
lib_logger.info(
|
| 460 |
+
f"Using Antigravity paid tier '{tier_id}' with project: {project_id}"
|
| 461 |
+
)
|
| 462 |
+
else:
|
| 463 |
+
lib_logger.info(
|
| 464 |
+
f"Successfully onboarded user and discovered project ID: {project_id}"
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
self.project_id_cache[credential_path] = project_id
|
| 468 |
+
discovered_project_id = project_id
|
| 469 |
+
|
| 470 |
+
# Persist to credential file
|
| 471 |
+
await self._persist_project_metadata(
|
| 472 |
+
credential_path, project_id, discovered_tier
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
return project_id
|
| 476 |
+
|
| 477 |
+
except httpx.HTTPStatusError as e:
|
| 478 |
+
error_body = ""
|
| 479 |
+
try:
|
| 480 |
+
error_body = e.response.text
|
| 481 |
+
except Exception:
|
| 482 |
+
pass
|
| 483 |
+
if e.response.status_code == 403:
|
| 484 |
+
lib_logger.error(
|
| 485 |
+
f"Antigravity Code Assist API access denied (403). Response: {error_body}"
|
| 486 |
+
)
|
| 487 |
+
lib_logger.error(
|
| 488 |
+
"Possible causes: 1) cloudaicompanion.googleapis.com API not enabled, 2) Wrong project ID for paid tier, 3) Account lacks permissions"
|
| 489 |
+
)
|
| 490 |
+
elif e.response.status_code == 404:
|
| 491 |
+
lib_logger.warning(
|
| 492 |
+
f"Antigravity Code Assist endpoint not found (404). Falling back to project listing."
|
| 493 |
+
)
|
| 494 |
+
elif e.response.status_code == 412:
|
| 495 |
+
# Precondition Failed - often means wrong project for free tier onboarding
|
| 496 |
+
lib_logger.error(
|
| 497 |
+
f"Precondition failed (412): {error_body}. This may mean the project ID is incompatible with the selected tier."
|
| 498 |
+
)
|
| 499 |
+
else:
|
| 500 |
+
lib_logger.warning(
|
| 501 |
+
f"Antigravity onboarding/discovery failed with status {e.response.status_code}: {error_body}. Falling back to project listing."
|
| 502 |
+
)
|
| 503 |
+
except httpx.RequestError as e:
|
| 504 |
+
lib_logger.warning(
|
| 505 |
+
f"Antigravity onboarding/discovery network error: {e}. Falling back to project listing."
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
# 3. Fallback to listing all available GCP projects (last resort)
|
| 509 |
+
lib_logger.debug(
|
| 510 |
+
"Attempting to discover project via GCP Resource Manager API..."
|
| 511 |
+
)
|
| 512 |
+
try:
|
| 513 |
+
async with httpx.AsyncClient() as client:
|
| 514 |
+
lib_logger.debug(
|
| 515 |
+
"Querying Cloud Resource Manager for available projects..."
|
| 516 |
+
)
|
| 517 |
+
response = await client.get(
|
| 518 |
+
"https://cloudresourcemanager.googleapis.com/v1/projects",
|
| 519 |
+
headers=headers,
|
| 520 |
+
timeout=20,
|
| 521 |
+
)
|
| 522 |
+
response.raise_for_status()
|
| 523 |
+
projects = response.json().get("projects", [])
|
| 524 |
+
lib_logger.debug(f"Found {len(projects)} total projects")
|
| 525 |
+
active_projects = [
|
| 526 |
+
p for p in projects if p.get("lifecycleState") == "ACTIVE"
|
| 527 |
+
]
|
| 528 |
+
lib_logger.debug(f"Found {len(active_projects)} active projects")
|
| 529 |
+
|
| 530 |
+
if not projects:
|
| 531 |
+
lib_logger.error(
|
| 532 |
+
"No GCP projects found for this account. Please create a project in Google Cloud Console."
|
| 533 |
+
)
|
| 534 |
+
elif not active_projects:
|
| 535 |
+
lib_logger.error(
|
| 536 |
+
"No active GCP projects found. Please activate a project in Google Cloud Console."
|
| 537 |
+
)
|
| 538 |
+
else:
|
| 539 |
+
project_id = active_projects[0]["projectId"]
|
| 540 |
+
lib_logger.info(
|
| 541 |
+
f"Discovered Antigravity project ID from active projects list: {project_id}"
|
| 542 |
+
)
|
| 543 |
+
lib_logger.debug(
|
| 544 |
+
f"Selected first active project: {project_id} (out of {len(active_projects)} active projects)"
|
| 545 |
+
)
|
| 546 |
+
self.project_id_cache[credential_path] = project_id
|
| 547 |
+
discovered_project_id = project_id
|
| 548 |
+
|
| 549 |
+
# Persist to credential file (no tier info from resource manager)
|
| 550 |
+
await self._persist_project_metadata(
|
| 551 |
+
credential_path, project_id, None
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
return project_id
|
| 555 |
+
except httpx.HTTPStatusError as e:
|
| 556 |
+
if e.response.status_code == 403:
|
| 557 |
+
lib_logger.error(
|
| 558 |
+
"Failed to list GCP projects due to a 403 Forbidden error. The Cloud Resource Manager API may not be enabled, or your account lacks the 'resourcemanager.projects.list' permission."
|
| 559 |
+
)
|
| 560 |
+
else:
|
| 561 |
+
lib_logger.error(
|
| 562 |
+
f"Failed to list GCP projects with status {e.response.status_code}: {e}"
|
| 563 |
+
)
|
| 564 |
+
except httpx.RequestError as e:
|
| 565 |
+
lib_logger.error(f"Network error while listing GCP projects: {e}")
|
| 566 |
+
|
| 567 |
+
raise ValueError(
|
| 568 |
+
"Could not auto-discover Antigravity project ID. Possible causes:\n"
|
| 569 |
+
" 1. The cloudaicompanion.googleapis.com API is not enabled (enable it in Google Cloud Console)\n"
|
| 570 |
+
" 2. No active GCP projects exist for this account (create one in Google Cloud Console)\n"
|
| 571 |
+
" 3. Account lacks necessary permissions\n"
|
| 572 |
+
"To manually specify a project, set ANTIGRAVITY_PROJECT_ID in your .env file."
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
async def _persist_project_metadata(
|
| 576 |
+
self, credential_path: str, project_id: str, tier: Optional[str]
|
| 577 |
+
):
|
| 578 |
+
"""Persists project ID and tier to the credential file for faster future startups."""
|
| 579 |
+
# Skip persistence for env:// paths (environment-based credentials)
|
| 580 |
+
credential_index = self._parse_env_credential_path(credential_path)
|
| 581 |
+
if credential_index is not None:
|
| 582 |
+
lib_logger.debug(
|
| 583 |
+
f"Skipping project metadata persistence for env:// credential path: {credential_path}"
|
| 584 |
+
)
|
| 585 |
+
return
|
| 586 |
+
|
| 587 |
+
try:
|
| 588 |
+
# Load current credentials
|
| 589 |
+
with open(credential_path, "r") as f:
|
| 590 |
+
creds = json.load(f)
|
| 591 |
+
|
| 592 |
+
# Update metadata
|
| 593 |
+
if "_proxy_metadata" not in creds:
|
| 594 |
+
creds["_proxy_metadata"] = {}
|
| 595 |
+
|
| 596 |
+
creds["_proxy_metadata"]["project_id"] = project_id
|
| 597 |
+
if tier:
|
| 598 |
+
creds["_proxy_metadata"]["tier"] = tier
|
| 599 |
+
|
| 600 |
+
# Save back using the existing save method (handles atomic writes and permissions)
|
| 601 |
+
await self._save_credentials(credential_path, creds)
|
| 602 |
+
|
| 603 |
+
lib_logger.debug(
|
| 604 |
+
f"Persisted project_id and tier to credential file: {credential_path}"
|
| 605 |
+
)
|
| 606 |
+
except Exception as e:
|
| 607 |
+
lib_logger.warning(
|
| 608 |
+
f"Failed to persist project metadata to credential file: {e}"
|
| 609 |
+
)
|
| 610 |
+
# Non-fatal - just means slower startup next time
|
| 611 |
+
|
| 612 |
+
# =========================================================================
|
| 613 |
+
# CREDENTIAL MANAGEMENT OVERRIDES
|
| 614 |
+
# =========================================================================
|
| 615 |
+
|
| 616 |
+
def _get_provider_file_prefix(self) -> str:
|
| 617 |
+
"""Return the file prefix for Antigravity credentials."""
|
| 618 |
+
return "antigravity"
|
| 619 |
+
|
| 620 |
+
def build_env_lines(self, creds: Dict[str, Any], cred_number: int) -> List[str]:
|
| 621 |
+
"""
|
| 622 |
+
Generate .env file lines for an Antigravity credential.
|
| 623 |
+
|
| 624 |
+
Includes tier and project_id from _proxy_metadata.
|
| 625 |
+
"""
|
| 626 |
+
# Get base lines from parent class
|
| 627 |
+
lines = super().build_env_lines(creds, cred_number)
|
| 628 |
+
|
| 629 |
+
# Add Antigravity-specific fields (tier and project_id)
|
| 630 |
+
metadata = creds.get("_proxy_metadata", {})
|
| 631 |
+
prefix = f"{self.ENV_PREFIX}_{cred_number}"
|
| 632 |
+
|
| 633 |
+
project_id = metadata.get("project_id", "")
|
| 634 |
+
tier = metadata.get("tier", "")
|
| 635 |
+
|
| 636 |
+
if project_id:
|
| 637 |
+
lines.append(f"{prefix}_PROJECT_ID={project_id}")
|
| 638 |
+
if tier:
|
| 639 |
+
lines.append(f"{prefix}_TIER={tier}")
|
| 640 |
+
|
| 641 |
+
return lines
|
src/rotator_library/providers/antigravity_provider.py
CHANGED
|
@@ -38,6 +38,8 @@ from .provider_interface import ProviderInterface, UsageResetConfigDef, QuotaGro
|
|
| 38 |
from .antigravity_auth_base import AntigravityAuthBase
|
| 39 |
from .provider_cache import ProviderCache
|
| 40 |
from ..model_definitions import ModelDefinitions
|
|
|
|
|
|
|
| 41 |
|
| 42 |
|
| 43 |
# =============================================================================
|
|
@@ -105,12 +107,23 @@ DEFAULT_SAFETY_SETTINGS = [
|
|
| 105 |
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
| 106 |
]
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
# Gemini 3 tool fix system instruction (prevents hallucination)
|
| 116 |
DEFAULT_GEMINI3_SYSTEM_INSTRUCTION = """<CRITICAL_TOOL_USAGE_INSTRUCTIONS>
|
|
@@ -327,6 +340,33 @@ def _recursively_parse_json_strings(obj: Any) -> Any:
|
|
| 327 |
return obj
|
| 328 |
|
| 329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
def _clean_claude_schema(schema: Any) -> Any:
|
| 331 |
"""
|
| 332 |
Recursively clean JSON Schema for Antigravity/Google's Proto-based API.
|
|
@@ -384,7 +424,6 @@ def _clean_claude_schema(schema: Any) -> Any:
|
|
| 384 |
return first_option
|
| 385 |
|
| 386 |
cleaned = {}
|
| 387 |
-
|
| 388 |
# Handle 'const' by converting to 'enum' with single value
|
| 389 |
if "const" in schema:
|
| 390 |
const_value = schema["const"]
|
|
@@ -425,7 +464,9 @@ class AntigravityFileLogger:
|
|
| 425 |
|
| 426 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 427 |
safe_model = model_name.replace("/", "_").replace(":", "_")
|
| 428 |
-
self.log_dir =
|
|
|
|
|
|
|
| 429 |
|
| 430 |
try:
|
| 431 |
self.log_dir.mkdir(parents=True, exist_ok=True)
|
|
@@ -658,9 +699,6 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
|
|
| 658 |
error_obj = data.get("error", data)
|
| 659 |
details = error_obj.get("details", [])
|
| 660 |
|
| 661 |
-
if not details:
|
| 662 |
-
return None
|
| 663 |
-
|
| 664 |
result = {
|
| 665 |
"retry_after": None,
|
| 666 |
"reason": None,
|
|
@@ -711,6 +749,15 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
|
|
| 711 |
|
| 712 |
# Return None if we couldn't extract retry_after
|
| 713 |
if not result["retry_after"]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 714 |
return None
|
| 715 |
|
| 716 |
return result
|
|
@@ -718,12 +765,7 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
|
|
| 718 |
def __init__(self):
|
| 719 |
super().__init__()
|
| 720 |
self.model_definitions = ModelDefinitions()
|
| 721 |
-
|
| 722 |
-
str, str
|
| 723 |
-
] = {} # Cache project ID per credential path
|
| 724 |
-
self.project_tier_cache: Dict[
|
| 725 |
-
str, str
|
| 726 |
-
] = {} # Cache project tier per credential path (for debugging)
|
| 727 |
|
| 728 |
# Base URL management
|
| 729 |
self._base_url_index = 0
|
|
@@ -735,13 +777,13 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
|
|
| 735 |
|
| 736 |
# Initialize caches using shared ProviderCache
|
| 737 |
self._signature_cache = ProviderCache(
|
| 738 |
-
|
| 739 |
memory_ttl,
|
| 740 |
disk_ttl,
|
| 741 |
env_prefix="ANTIGRAVITY_SIGNATURE",
|
| 742 |
)
|
| 743 |
self._thinking_cache = ProviderCache(
|
| 744 |
-
|
| 745 |
memory_ttl,
|
| 746 |
disk_ttl,
|
| 747 |
env_prefix="ANTIGRAVITY_THINKING",
|
|
@@ -871,9 +913,48 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
|
|
| 871 |
|
| 872 |
This ensures all credential priorities are known before any API calls,
|
| 873 |
preventing unknown credentials from getting priority 999.
|
|
|
|
|
|
|
|
|
|
| 874 |
"""
|
|
|
|
| 875 |
await self._load_persisted_tiers(credential_paths)
|
| 876 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 877 |
async def _load_persisted_tiers(
|
| 878 |
self, credential_paths: List[str]
|
| 879 |
) -> Dict[str, str]:
|
|
@@ -931,6 +1012,8 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
|
|
| 931 |
|
| 932 |
return loaded
|
| 933 |
|
|
|
|
|
|
|
| 934 |
# =========================================================================
|
| 935 |
# MODEL UTILITIES
|
| 936 |
# =========================================================================
|
|
@@ -1007,524 +1090,7 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
|
|
| 1007 |
|
| 1008 |
return "thinking_" + "_".join(key_parts) if key_parts else None
|
| 1009 |
|
| 1010 |
-
#
|
| 1011 |
-
# PROJECT ID DISCOVERY
|
| 1012 |
-
# =========================================================================
|
| 1013 |
-
|
| 1014 |
-
async def _discover_project_id(
|
| 1015 |
-
self, credential_path: str, access_token: str, litellm_params: Dict[str, Any]
|
| 1016 |
-
) -> str:
|
| 1017 |
-
"""
|
| 1018 |
-
Discovers the Google Cloud Project ID, with caching and onboarding for new accounts.
|
| 1019 |
-
|
| 1020 |
-
This follows the official Gemini CLI discovery flow adapted for Antigravity:
|
| 1021 |
-
1. Check in-memory cache
|
| 1022 |
-
2. Check configured project_id override (litellm_params or env var)
|
| 1023 |
-
3. Check persisted project_id in credential file
|
| 1024 |
-
4. Call loadCodeAssist to check if user is already known (has currentTier)
|
| 1025 |
-
- If currentTier exists AND cloudaicompanionProject returned: use server's project
|
| 1026 |
-
- If no currentTier: user needs onboarding
|
| 1027 |
-
5. Onboard user (FREE tier: pass cloudaicompanionProject=None for server-managed)
|
| 1028 |
-
6. Fallback to GCP Resource Manager project listing
|
| 1029 |
-
|
| 1030 |
-
Note: Unlike GeminiCli, Antigravity doesn't use tier-based credential prioritization,
|
| 1031 |
-
but we still cache tier info for debugging and consistency.
|
| 1032 |
-
"""
|
| 1033 |
-
lib_logger.debug(
|
| 1034 |
-
f"Starting Antigravity project discovery for credential: {credential_path}"
|
| 1035 |
-
)
|
| 1036 |
-
|
| 1037 |
-
# Check in-memory cache first
|
| 1038 |
-
if credential_path in self.project_id_cache:
|
| 1039 |
-
cached_project = self.project_id_cache[credential_path]
|
| 1040 |
-
lib_logger.debug(f"Using cached project ID: {cached_project}")
|
| 1041 |
-
return cached_project
|
| 1042 |
-
|
| 1043 |
-
# Check for configured project ID override (from litellm_params or env var)
|
| 1044 |
-
configured_project_id = (
|
| 1045 |
-
litellm_params.get("project_id")
|
| 1046 |
-
or os.getenv("ANTIGRAVITY_PROJECT_ID")
|
| 1047 |
-
or os.getenv("GOOGLE_CLOUD_PROJECT")
|
| 1048 |
-
)
|
| 1049 |
-
if configured_project_id:
|
| 1050 |
-
lib_logger.debug(
|
| 1051 |
-
f"Found configured project_id override: {configured_project_id}"
|
| 1052 |
-
)
|
| 1053 |
-
|
| 1054 |
-
# Load credentials from file to check for persisted project_id and tier
|
| 1055 |
-
# Skip for env:// paths (environment-based credentials don't persist to files)
|
| 1056 |
-
credential_index = self._parse_env_credential_path(credential_path)
|
| 1057 |
-
if credential_index is None:
|
| 1058 |
-
# Only try to load from file if it's not an env:// path
|
| 1059 |
-
try:
|
| 1060 |
-
with open(credential_path, "r") as f:
|
| 1061 |
-
creds = json.load(f)
|
| 1062 |
-
|
| 1063 |
-
metadata = creds.get("_proxy_metadata", {})
|
| 1064 |
-
persisted_project_id = metadata.get("project_id")
|
| 1065 |
-
persisted_tier = metadata.get("tier")
|
| 1066 |
-
|
| 1067 |
-
if persisted_project_id:
|
| 1068 |
-
lib_logger.info(
|
| 1069 |
-
f"Loaded persisted project ID from credential file: {persisted_project_id}"
|
| 1070 |
-
)
|
| 1071 |
-
self.project_id_cache[credential_path] = persisted_project_id
|
| 1072 |
-
|
| 1073 |
-
# Also load tier if available (for debugging/logging purposes)
|
| 1074 |
-
if persisted_tier:
|
| 1075 |
-
self.project_tier_cache[credential_path] = persisted_tier
|
| 1076 |
-
lib_logger.debug(f"Loaded persisted tier: {persisted_tier}")
|
| 1077 |
-
|
| 1078 |
-
return persisted_project_id
|
| 1079 |
-
except (FileNotFoundError, json.JSONDecodeError, KeyError) as e:
|
| 1080 |
-
lib_logger.debug(f"Could not load persisted project ID from file: {e}")
|
| 1081 |
-
|
| 1082 |
-
lib_logger.debug(
|
| 1083 |
-
"No cached or configured project ID found, initiating discovery..."
|
| 1084 |
-
)
|
| 1085 |
-
headers = {
|
| 1086 |
-
"Authorization": f"Bearer {access_token}",
|
| 1087 |
-
"Content-Type": "application/json",
|
| 1088 |
-
}
|
| 1089 |
-
|
| 1090 |
-
discovered_project_id = None
|
| 1091 |
-
discovered_tier = None
|
| 1092 |
-
|
| 1093 |
-
# Use production endpoint for loadCodeAssist (more reliable than sandbox URLs)
|
| 1094 |
-
code_assist_endpoint = "https://cloudcode-pa.googleapis.com/v1internal"
|
| 1095 |
-
|
| 1096 |
-
async with httpx.AsyncClient() as client:
|
| 1097 |
-
# 1. Try discovery endpoint with loadCodeAssist
|
| 1098 |
-
lib_logger.debug(
|
| 1099 |
-
"Attempting project discovery via Code Assist loadCodeAssist endpoint..."
|
| 1100 |
-
)
|
| 1101 |
-
try:
|
| 1102 |
-
# Build metadata - include duetProject only if we have a configured project
|
| 1103 |
-
core_client_metadata = {
|
| 1104 |
-
"ideType": "IDE_UNSPECIFIED",
|
| 1105 |
-
"platform": "PLATFORM_UNSPECIFIED",
|
| 1106 |
-
"pluginType": "GEMINI",
|
| 1107 |
-
}
|
| 1108 |
-
if configured_project_id:
|
| 1109 |
-
core_client_metadata["duetProject"] = configured_project_id
|
| 1110 |
-
|
| 1111 |
-
# Build load request - pass configured_project_id if available, otherwise None
|
| 1112 |
-
load_request = {
|
| 1113 |
-
"cloudaicompanionProject": configured_project_id, # Can be None
|
| 1114 |
-
"metadata": core_client_metadata,
|
| 1115 |
-
}
|
| 1116 |
-
|
| 1117 |
-
lib_logger.debug(
|
| 1118 |
-
f"Sending loadCodeAssist request with cloudaicompanionProject={configured_project_id}"
|
| 1119 |
-
)
|
| 1120 |
-
response = await client.post(
|
| 1121 |
-
f"{code_assist_endpoint}:loadCodeAssist",
|
| 1122 |
-
headers=headers,
|
| 1123 |
-
json=load_request,
|
| 1124 |
-
timeout=20,
|
| 1125 |
-
)
|
| 1126 |
-
response.raise_for_status()
|
| 1127 |
-
data = response.json()
|
| 1128 |
-
|
| 1129 |
-
# Log full response for debugging
|
| 1130 |
-
lib_logger.debug(
|
| 1131 |
-
f"loadCodeAssist full response keys: {list(data.keys())}"
|
| 1132 |
-
)
|
| 1133 |
-
|
| 1134 |
-
# Extract tier information
|
| 1135 |
-
allowed_tiers = data.get("allowedTiers", [])
|
| 1136 |
-
current_tier = data.get("currentTier")
|
| 1137 |
-
|
| 1138 |
-
lib_logger.debug(f"=== Tier Information ===")
|
| 1139 |
-
lib_logger.debug(f"currentTier: {current_tier}")
|
| 1140 |
-
lib_logger.debug(f"allowedTiers count: {len(allowed_tiers)}")
|
| 1141 |
-
for i, tier in enumerate(allowed_tiers):
|
| 1142 |
-
tier_id = tier.get("id", "unknown")
|
| 1143 |
-
is_default = tier.get("isDefault", False)
|
| 1144 |
-
user_defined = tier.get("userDefinedCloudaicompanionProject", False)
|
| 1145 |
-
lib_logger.debug(
|
| 1146 |
-
f" Tier {i + 1}: id={tier_id}, isDefault={is_default}, userDefinedProject={user_defined}"
|
| 1147 |
-
)
|
| 1148 |
-
lib_logger.debug(f"========================")
|
| 1149 |
-
|
| 1150 |
-
# Determine the current tier ID
|
| 1151 |
-
current_tier_id = None
|
| 1152 |
-
if current_tier:
|
| 1153 |
-
current_tier_id = current_tier.get("id")
|
| 1154 |
-
lib_logger.debug(f"User has currentTier: {current_tier_id}")
|
| 1155 |
-
|
| 1156 |
-
# Check if user is already known to server (has currentTier)
|
| 1157 |
-
if current_tier_id:
|
| 1158 |
-
# User is already onboarded - check for project from server
|
| 1159 |
-
server_project = data.get("cloudaicompanionProject")
|
| 1160 |
-
|
| 1161 |
-
# Check if this tier requires user-defined project (paid tiers)
|
| 1162 |
-
requires_user_project = any(
|
| 1163 |
-
t.get("id") == current_tier_id
|
| 1164 |
-
and t.get("userDefinedCloudaicompanionProject", False)
|
| 1165 |
-
for t in allowed_tiers
|
| 1166 |
-
)
|
| 1167 |
-
is_free_tier = current_tier_id == "free-tier"
|
| 1168 |
-
|
| 1169 |
-
if server_project:
|
| 1170 |
-
# Server returned a project - use it (server wins)
|
| 1171 |
-
project_id = server_project
|
| 1172 |
-
lib_logger.debug(f"Server returned project: {project_id}")
|
| 1173 |
-
elif configured_project_id:
|
| 1174 |
-
# No server project but we have configured one - use it
|
| 1175 |
-
project_id = configured_project_id
|
| 1176 |
-
lib_logger.debug(
|
| 1177 |
-
f"No server project, using configured: {project_id}"
|
| 1178 |
-
)
|
| 1179 |
-
elif is_free_tier:
|
| 1180 |
-
# Free tier user without server project - try onboarding
|
| 1181 |
-
lib_logger.debug(
|
| 1182 |
-
"Free tier user with currentTier but no project - will try onboarding"
|
| 1183 |
-
)
|
| 1184 |
-
project_id = None
|
| 1185 |
-
elif requires_user_project:
|
| 1186 |
-
# Paid tier requires a project ID to be set
|
| 1187 |
-
raise ValueError(
|
| 1188 |
-
f"Paid tier '{current_tier_id}' requires setting ANTIGRAVITY_PROJECT_ID environment variable."
|
| 1189 |
-
)
|
| 1190 |
-
else:
|
| 1191 |
-
# Unknown tier without project - proceed to onboarding
|
| 1192 |
-
lib_logger.warning(
|
| 1193 |
-
f"Tier '{current_tier_id}' has no project and none configured - will try onboarding"
|
| 1194 |
-
)
|
| 1195 |
-
project_id = None
|
| 1196 |
-
|
| 1197 |
-
if project_id:
|
| 1198 |
-
# Cache tier info
|
| 1199 |
-
self.project_tier_cache[credential_path] = current_tier_id
|
| 1200 |
-
discovered_tier = current_tier_id
|
| 1201 |
-
|
| 1202 |
-
# Log appropriately based on tier
|
| 1203 |
-
is_paid = current_tier_id and current_tier_id not in [
|
| 1204 |
-
"free-tier",
|
| 1205 |
-
"legacy-tier",
|
| 1206 |
-
"unknown",
|
| 1207 |
-
]
|
| 1208 |
-
if is_paid:
|
| 1209 |
-
lib_logger.info(
|
| 1210 |
-
f"Using Antigravity paid tier '{current_tier_id}' with project: {project_id}"
|
| 1211 |
-
)
|
| 1212 |
-
else:
|
| 1213 |
-
lib_logger.info(
|
| 1214 |
-
f"Discovered Antigravity project ID via loadCodeAssist: {project_id}"
|
| 1215 |
-
)
|
| 1216 |
-
|
| 1217 |
-
self.project_id_cache[credential_path] = project_id
|
| 1218 |
-
discovered_project_id = project_id
|
| 1219 |
-
|
| 1220 |
-
# Persist to credential file
|
| 1221 |
-
await self._persist_project_metadata(
|
| 1222 |
-
credential_path, project_id, discovered_tier
|
| 1223 |
-
)
|
| 1224 |
-
|
| 1225 |
-
return project_id
|
| 1226 |
-
|
| 1227 |
-
# 2. User needs onboarding - no currentTier or no project found
|
| 1228 |
-
lib_logger.info(
|
| 1229 |
-
"No existing Antigravity session found (no currentTier), attempting to onboard user..."
|
| 1230 |
-
)
|
| 1231 |
-
|
| 1232 |
-
# Determine which tier to onboard with
|
| 1233 |
-
onboard_tier = None
|
| 1234 |
-
for tier in allowed_tiers:
|
| 1235 |
-
if tier.get("isDefault"):
|
| 1236 |
-
onboard_tier = tier
|
| 1237 |
-
break
|
| 1238 |
-
|
| 1239 |
-
# Fallback to legacy tier if no default
|
| 1240 |
-
if not onboard_tier and allowed_tiers:
|
| 1241 |
-
for tier in allowed_tiers:
|
| 1242 |
-
if tier.get("id") == "legacy-tier":
|
| 1243 |
-
onboard_tier = tier
|
| 1244 |
-
break
|
| 1245 |
-
if not onboard_tier:
|
| 1246 |
-
onboard_tier = allowed_tiers[0]
|
| 1247 |
-
|
| 1248 |
-
if not onboard_tier:
|
| 1249 |
-
raise ValueError("No onboarding tiers available from server")
|
| 1250 |
-
|
| 1251 |
-
tier_id = onboard_tier.get("id", "free-tier")
|
| 1252 |
-
requires_user_project = onboard_tier.get(
|
| 1253 |
-
"userDefinedCloudaicompanionProject", False
|
| 1254 |
-
)
|
| 1255 |
-
|
| 1256 |
-
lib_logger.debug(
|
| 1257 |
-
f"Onboarding with tier: {tier_id}, requiresUserProject: {requires_user_project}"
|
| 1258 |
-
)
|
| 1259 |
-
|
| 1260 |
-
# Build onboard request based on tier type
|
| 1261 |
-
# FREE tier: cloudaicompanionProject = None (server-managed)
|
| 1262 |
-
# PAID tier: cloudaicompanionProject = configured_project_id
|
| 1263 |
-
is_free_tier = tier_id == "free-tier"
|
| 1264 |
-
|
| 1265 |
-
if is_free_tier:
|
| 1266 |
-
# Free tier uses server-managed project
|
| 1267 |
-
onboard_request = {
|
| 1268 |
-
"tierId": tier_id,
|
| 1269 |
-
"cloudaicompanionProject": None, # Server will create/manage
|
| 1270 |
-
"metadata": core_client_metadata,
|
| 1271 |
-
}
|
| 1272 |
-
lib_logger.debug(
|
| 1273 |
-
"Free tier onboarding: using server-managed project"
|
| 1274 |
-
)
|
| 1275 |
-
else:
|
| 1276 |
-
# Paid/legacy tier requires user-provided project
|
| 1277 |
-
if not configured_project_id and requires_user_project:
|
| 1278 |
-
raise ValueError(
|
| 1279 |
-
f"Tier '{tier_id}' requires setting ANTIGRAVITY_PROJECT_ID environment variable."
|
| 1280 |
-
)
|
| 1281 |
-
onboard_request = {
|
| 1282 |
-
"tierId": tier_id,
|
| 1283 |
-
"cloudaicompanionProject": configured_project_id,
|
| 1284 |
-
"metadata": {
|
| 1285 |
-
**core_client_metadata,
|
| 1286 |
-
"duetProject": configured_project_id,
|
| 1287 |
-
}
|
| 1288 |
-
if configured_project_id
|
| 1289 |
-
else core_client_metadata,
|
| 1290 |
-
}
|
| 1291 |
-
lib_logger.debug(
|
| 1292 |
-
f"Paid tier onboarding: using project {configured_project_id}"
|
| 1293 |
-
)
|
| 1294 |
-
|
| 1295 |
-
lib_logger.debug("Initiating onboardUser request...")
|
| 1296 |
-
lro_response = await client.post(
|
| 1297 |
-
f"{code_assist_endpoint}:onboardUser",
|
| 1298 |
-
headers=headers,
|
| 1299 |
-
json=onboard_request,
|
| 1300 |
-
timeout=30,
|
| 1301 |
-
)
|
| 1302 |
-
lro_response.raise_for_status()
|
| 1303 |
-
lro_data = lro_response.json()
|
| 1304 |
-
lib_logger.debug(
|
| 1305 |
-
f"Initial onboarding response: done={lro_data.get('done')}"
|
| 1306 |
-
)
|
| 1307 |
-
|
| 1308 |
-
# Poll for onboarding completion (up to 5 minutes)
|
| 1309 |
-
for i in range(150): # 150 × 2s = 5 minutes
|
| 1310 |
-
if lro_data.get("done"):
|
| 1311 |
-
lib_logger.debug(
|
| 1312 |
-
f"Onboarding completed after {i} polling attempts"
|
| 1313 |
-
)
|
| 1314 |
-
break
|
| 1315 |
-
await asyncio.sleep(2)
|
| 1316 |
-
if (i + 1) % 15 == 0: # Log every 30 seconds
|
| 1317 |
-
lib_logger.info(
|
| 1318 |
-
f"Still waiting for onboarding completion... ({(i + 1) * 2}s elapsed)"
|
| 1319 |
-
)
|
| 1320 |
-
lib_logger.debug(
|
| 1321 |
-
f"Polling onboarding status... (Attempt {i + 1}/150)"
|
| 1322 |
-
)
|
| 1323 |
-
lro_response = await client.post(
|
| 1324 |
-
f"{code_assist_endpoint}:onboardUser",
|
| 1325 |
-
headers=headers,
|
| 1326 |
-
json=onboard_request,
|
| 1327 |
-
timeout=30,
|
| 1328 |
-
)
|
| 1329 |
-
lro_response.raise_for_status()
|
| 1330 |
-
lro_data = lro_response.json()
|
| 1331 |
-
|
| 1332 |
-
if not lro_data.get("done"):
|
| 1333 |
-
lib_logger.error("Onboarding process timed out after 5 minutes")
|
| 1334 |
-
raise ValueError(
|
| 1335 |
-
"Onboarding process timed out after 5 minutes. Please try again or contact support."
|
| 1336 |
-
)
|
| 1337 |
-
|
| 1338 |
-
# Extract project ID from LRO response
|
| 1339 |
-
# Note: onboardUser returns response.cloudaicompanionProject as an object with .id
|
| 1340 |
-
lro_response_data = lro_data.get("response", {})
|
| 1341 |
-
lro_project_obj = lro_response_data.get("cloudaicompanionProject", {})
|
| 1342 |
-
project_id = (
|
| 1343 |
-
lro_project_obj.get("id")
|
| 1344 |
-
if isinstance(lro_project_obj, dict)
|
| 1345 |
-
else None
|
| 1346 |
-
)
|
| 1347 |
-
|
| 1348 |
-
# Fallback to configured project if LRO didn't return one
|
| 1349 |
-
if not project_id and configured_project_id:
|
| 1350 |
-
project_id = configured_project_id
|
| 1351 |
-
lib_logger.debug(
|
| 1352 |
-
f"LRO didn't return project, using configured: {project_id}"
|
| 1353 |
-
)
|
| 1354 |
-
|
| 1355 |
-
if not project_id:
|
| 1356 |
-
lib_logger.error(
|
| 1357 |
-
"Onboarding completed but no project ID in response and none configured"
|
| 1358 |
-
)
|
| 1359 |
-
raise ValueError(
|
| 1360 |
-
"Onboarding completed, but no project ID was returned. "
|
| 1361 |
-
"For paid tiers, set ANTIGRAVITY_PROJECT_ID environment variable."
|
| 1362 |
-
)
|
| 1363 |
-
|
| 1364 |
-
lib_logger.debug(
|
| 1365 |
-
f"Successfully extracted project ID from onboarding response: {project_id}"
|
| 1366 |
-
)
|
| 1367 |
-
|
| 1368 |
-
# Cache tier info
|
| 1369 |
-
self.project_tier_cache[credential_path] = tier_id
|
| 1370 |
-
discovered_tier = tier_id
|
| 1371 |
-
lib_logger.debug(f"Cached tier information: {tier_id}")
|
| 1372 |
-
|
| 1373 |
-
# Log concise message based on tier
|
| 1374 |
-
is_paid = tier_id and tier_id not in ["free-tier", "legacy-tier"]
|
| 1375 |
-
if is_paid:
|
| 1376 |
-
lib_logger.info(
|
| 1377 |
-
f"Using Antigravity paid tier '{tier_id}' with project: {project_id}"
|
| 1378 |
-
)
|
| 1379 |
-
else:
|
| 1380 |
-
lib_logger.info(
|
| 1381 |
-
f"Successfully onboarded user and discovered project ID: {project_id}"
|
| 1382 |
-
)
|
| 1383 |
-
|
| 1384 |
-
self.project_id_cache[credential_path] = project_id
|
| 1385 |
-
discovered_project_id = project_id
|
| 1386 |
-
|
| 1387 |
-
# Persist to credential file
|
| 1388 |
-
await self._persist_project_metadata(
|
| 1389 |
-
credential_path, project_id, discovered_tier
|
| 1390 |
-
)
|
| 1391 |
-
|
| 1392 |
-
return project_id
|
| 1393 |
-
|
| 1394 |
-
except httpx.HTTPStatusError as e:
|
| 1395 |
-
error_body = ""
|
| 1396 |
-
try:
|
| 1397 |
-
error_body = e.response.text
|
| 1398 |
-
except Exception:
|
| 1399 |
-
pass
|
| 1400 |
-
if e.response.status_code == 403:
|
| 1401 |
-
lib_logger.error(
|
| 1402 |
-
f"Antigravity Code Assist API access denied (403). Response: {error_body}"
|
| 1403 |
-
)
|
| 1404 |
-
lib_logger.error(
|
| 1405 |
-
"Possible causes: 1) cloudaicompanion.googleapis.com API not enabled, 2) Wrong project ID for paid tier, 3) Account lacks permissions"
|
| 1406 |
-
)
|
| 1407 |
-
elif e.response.status_code == 404:
|
| 1408 |
-
lib_logger.warning(
|
| 1409 |
-
f"Antigravity Code Assist endpoint not found (404). Falling back to project listing."
|
| 1410 |
-
)
|
| 1411 |
-
elif e.response.status_code == 412:
|
| 1412 |
-
# Precondition Failed - often means wrong project for free tier onboarding
|
| 1413 |
-
lib_logger.error(
|
| 1414 |
-
f"Precondition failed (412): {error_body}. This may mean the project ID is incompatible with the selected tier."
|
| 1415 |
-
)
|
| 1416 |
-
else:
|
| 1417 |
-
lib_logger.warning(
|
| 1418 |
-
f"Antigravity onboarding/discovery failed with status {e.response.status_code}: {error_body}. Falling back to project listing."
|
| 1419 |
-
)
|
| 1420 |
-
except httpx.RequestError as e:
|
| 1421 |
-
lib_logger.warning(
|
| 1422 |
-
f"Antigravity onboarding/discovery network error: {e}. Falling back to project listing."
|
| 1423 |
-
)
|
| 1424 |
-
|
| 1425 |
-
# 3. Fallback to listing all available GCP projects (last resort)
|
| 1426 |
-
lib_logger.debug(
|
| 1427 |
-
"Attempting to discover project via GCP Resource Manager API..."
|
| 1428 |
-
)
|
| 1429 |
-
try:
|
| 1430 |
-
async with httpx.AsyncClient() as client:
|
| 1431 |
-
lib_logger.debug(
|
| 1432 |
-
"Querying Cloud Resource Manager for available projects..."
|
| 1433 |
-
)
|
| 1434 |
-
response = await client.get(
|
| 1435 |
-
"https://cloudresourcemanager.googleapis.com/v1/projects",
|
| 1436 |
-
headers=headers,
|
| 1437 |
-
timeout=20,
|
| 1438 |
-
)
|
| 1439 |
-
response.raise_for_status()
|
| 1440 |
-
projects = response.json().get("projects", [])
|
| 1441 |
-
lib_logger.debug(f"Found {len(projects)} total projects")
|
| 1442 |
-
active_projects = [
|
| 1443 |
-
p for p in projects if p.get("lifecycleState") == "ACTIVE"
|
| 1444 |
-
]
|
| 1445 |
-
lib_logger.debug(f"Found {len(active_projects)} active projects")
|
| 1446 |
-
|
| 1447 |
-
if not projects:
|
| 1448 |
-
lib_logger.error(
|
| 1449 |
-
"No GCP projects found for this account. Please create a project in Google Cloud Console."
|
| 1450 |
-
)
|
| 1451 |
-
elif not active_projects:
|
| 1452 |
-
lib_logger.error(
|
| 1453 |
-
"No active GCP projects found. Please activate a project in Google Cloud Console."
|
| 1454 |
-
)
|
| 1455 |
-
else:
|
| 1456 |
-
project_id = active_projects[0]["projectId"]
|
| 1457 |
-
lib_logger.info(
|
| 1458 |
-
f"Discovered Antigravity project ID from active projects list: {project_id}"
|
| 1459 |
-
)
|
| 1460 |
-
lib_logger.debug(
|
| 1461 |
-
f"Selected first active project: {project_id} (out of {len(active_projects)} active projects)"
|
| 1462 |
-
)
|
| 1463 |
-
self.project_id_cache[credential_path] = project_id
|
| 1464 |
-
discovered_project_id = project_id
|
| 1465 |
-
|
| 1466 |
-
# Persist to credential file (no tier info from resource manager)
|
| 1467 |
-
await self._persist_project_metadata(
|
| 1468 |
-
credential_path, project_id, None
|
| 1469 |
-
)
|
| 1470 |
-
|
| 1471 |
-
return project_id
|
| 1472 |
-
except httpx.HTTPStatusError as e:
|
| 1473 |
-
if e.response.status_code == 403:
|
| 1474 |
-
lib_logger.error(
|
| 1475 |
-
"Failed to list GCP projects due to a 403 Forbidden error. The Cloud Resource Manager API may not be enabled, or your account lacks the 'resourcemanager.projects.list' permission."
|
| 1476 |
-
)
|
| 1477 |
-
else:
|
| 1478 |
-
lib_logger.error(
|
| 1479 |
-
f"Failed to list GCP projects with status {e.response.status_code}: {e}"
|
| 1480 |
-
)
|
| 1481 |
-
except httpx.RequestError as e:
|
| 1482 |
-
lib_logger.error(f"Network error while listing GCP projects: {e}")
|
| 1483 |
-
|
| 1484 |
-
raise ValueError(
|
| 1485 |
-
"Could not auto-discover Antigravity project ID. Possible causes:\n"
|
| 1486 |
-
" 1. The cloudaicompanion.googleapis.com API is not enabled (enable it in Google Cloud Console)\n"
|
| 1487 |
-
" 2. No active GCP projects exist for this account (create one in Google Cloud Console)\n"
|
| 1488 |
-
" 3. Account lacks necessary permissions\n"
|
| 1489 |
-
"To manually specify a project, set ANTIGRAVITY_PROJECT_ID in your .env file."
|
| 1490 |
-
)
|
| 1491 |
-
|
| 1492 |
-
async def _persist_project_metadata(
|
| 1493 |
-
self, credential_path: str, project_id: str, tier: Optional[str]
|
| 1494 |
-
):
|
| 1495 |
-
"""Persists project ID and tier to the credential file for faster future startups."""
|
| 1496 |
-
# Skip persistence for env:// paths (environment-based credentials)
|
| 1497 |
-
credential_index = self._parse_env_credential_path(credential_path)
|
| 1498 |
-
if credential_index is not None:
|
| 1499 |
-
lib_logger.debug(
|
| 1500 |
-
f"Skipping project metadata persistence for env:// credential path: {credential_path}"
|
| 1501 |
-
)
|
| 1502 |
-
return
|
| 1503 |
-
|
| 1504 |
-
try:
|
| 1505 |
-
# Load current credentials
|
| 1506 |
-
with open(credential_path, "r") as f:
|
| 1507 |
-
creds = json.load(f)
|
| 1508 |
-
|
| 1509 |
-
# Update metadata
|
| 1510 |
-
if "_proxy_metadata" not in creds:
|
| 1511 |
-
creds["_proxy_metadata"] = {}
|
| 1512 |
-
|
| 1513 |
-
creds["_proxy_metadata"]["project_id"] = project_id
|
| 1514 |
-
if tier:
|
| 1515 |
-
creds["_proxy_metadata"]["tier"] = tier
|
| 1516 |
-
|
| 1517 |
-
# Save back using the existing save method (handles atomic writes and permissions)
|
| 1518 |
-
await self._save_credentials(credential_path, creds)
|
| 1519 |
-
|
| 1520 |
-
lib_logger.debug(
|
| 1521 |
-
f"Persisted project_id and tier to credential file: {credential_path}"
|
| 1522 |
-
)
|
| 1523 |
-
except Exception as e:
|
| 1524 |
-
lib_logger.warning(
|
| 1525 |
-
f"Failed to persist project metadata to credential file: {e}"
|
| 1526 |
-
)
|
| 1527 |
-
# Non-fatal - just means slower startup next time
|
| 1528 |
|
| 1529 |
# =========================================================================
|
| 1530 |
# THINKING MODE SANITIZATION
|
|
@@ -2424,7 +1990,7 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
|
|
| 2424 |
elif first_func_in_msg:
|
| 2425 |
# Only add bypass to the first function call if no sig available
|
| 2426 |
func_part["thoughtSignature"] = "skip_thought_signature_validator"
|
| 2427 |
-
lib_logger.
|
| 2428 |
f"Missing thoughtSignature for first func call {tool_id}, using bypass"
|
| 2429 |
)
|
| 2430 |
# Subsequent parallel calls: no signature field at all
|
|
@@ -2559,9 +2125,9 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
|
|
| 2559 |
f"Ignoring duplicate - this may indicate malformed conversation history."
|
| 2560 |
)
|
| 2561 |
continue
|
| 2562 |
-
#lib_logger.debug(
|
| 2563 |
# f"[Grouping] Collected response for ID: {resp_id}"
|
| 2564 |
-
#)
|
| 2565 |
collected_responses[resp_id] = resp
|
| 2566 |
|
| 2567 |
# Try to satisfy pending groups (newest first)
|
|
@@ -2576,10 +2142,10 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
|
|
| 2576 |
collected_responses.pop(gid) for gid in group_ids
|
| 2577 |
]
|
| 2578 |
new_contents.append({"parts": group_responses, "role": "user"})
|
| 2579 |
-
#lib_logger.debug(
|
| 2580 |
# f"[Grouping] Satisfied group with {len(group_responses)} responses: "
|
| 2581 |
# f"ids={group_ids}"
|
| 2582 |
-
#)
|
| 2583 |
pending_groups.pop(i)
|
| 2584 |
break
|
| 2585 |
continue
|
|
@@ -2599,10 +2165,10 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
|
|
| 2599 |
]
|
| 2600 |
|
| 2601 |
if call_ids:
|
| 2602 |
-
#lib_logger.debug(
|
| 2603 |
# f"[Grouping] Created pending group expecting {len(call_ids)} responses: "
|
| 2604 |
# f"ids={call_ids}, names={func_names}"
|
| 2605 |
-
#)
|
| 2606 |
pending_groups.append(
|
| 2607 |
{
|
| 2608 |
"ids": call_ids,
|
|
@@ -2967,12 +2533,41 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
|
|
| 2967 |
|
| 2968 |
if params and isinstance(params, dict):
|
| 2969 |
schema = dict(params)
|
| 2970 |
-
schema.pop("$schema", None)
|
| 2971 |
schema.pop("strict", None)
|
|
|
|
|
|
|
|
|
|
| 2972 |
schema = _normalize_type_arrays(schema)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2973 |
func_decl["parametersJsonSchema"] = schema
|
| 2974 |
else:
|
| 2975 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2976 |
|
| 2977 |
gemini_tools.append({"functionDeclarations": [func_decl]})
|
| 2978 |
|
|
@@ -3097,17 +2692,19 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
|
|
| 3097 |
return antigravity_payload
|
| 3098 |
|
| 3099 |
def _apply_claude_tool_transform(self, payload: Dict[str, Any]) -> None:
|
| 3100 |
-
"""Apply Claude-specific tool schema transformations.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3101 |
tools = payload["request"].get("tools", [])
|
| 3102 |
for tool in tools:
|
| 3103 |
for func_decl in tool.get("functionDeclarations", []):
|
| 3104 |
if "parametersJsonSchema" in func_decl:
|
| 3105 |
params = func_decl["parametersJsonSchema"]
|
| 3106 |
-
|
| 3107 |
-
|
| 3108 |
-
|
| 3109 |
-
else params
|
| 3110 |
-
)
|
| 3111 |
func_decl["parameters"] = params
|
| 3112 |
del func_decl["parametersJsonSchema"]
|
| 3113 |
|
|
@@ -3336,6 +2933,13 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
|
|
| 3336 |
raw_args = func_call.get("args", {})
|
| 3337 |
parsed_args = _recursively_parse_json_strings(raw_args)
|
| 3338 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3339 |
tool_call = {
|
| 3340 |
"id": tool_id,
|
| 3341 |
"type": "function",
|
|
@@ -3405,7 +3009,7 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
|
|
| 3405 |
}
|
| 3406 |
|
| 3407 |
self._thinking_cache.store(cache_key, json.dumps(data))
|
| 3408 |
-
lib_logger.
|
| 3409 |
|
| 3410 |
# =========================================================================
|
| 3411 |
# PROVIDER INTERFACE IMPLEMENTATION
|
|
@@ -3703,7 +3307,12 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
|
|
| 3703 |
file_logger: Optional[AntigravityFileLogger] = None,
|
| 3704 |
) -> litellm.ModelResponse:
|
| 3705 |
"""Handle non-streaming completion."""
|
| 3706 |
-
response = await client.post(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3707 |
response.raise_for_status()
|
| 3708 |
|
| 3709 |
data = response.json()
|
|
@@ -3736,11 +3345,15 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
|
|
| 3736 |
}
|
| 3737 |
|
| 3738 |
async with client.stream(
|
| 3739 |
-
"POST",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3740 |
) as response:
|
| 3741 |
if response.status_code >= 400:
|
| 3742 |
-
# Read error body
|
| 3743 |
-
#
|
| 3744 |
try:
|
| 3745 |
await response.aread()
|
| 3746 |
# lib_logger.error(
|
|
|
|
| 38 |
from .antigravity_auth_base import AntigravityAuthBase
|
| 39 |
from .provider_cache import ProviderCache
|
| 40 |
from ..model_definitions import ModelDefinitions
|
| 41 |
+
from ..timeout_config import TimeoutConfig
|
| 42 |
+
from ..utils.paths import get_logs_dir, get_cache_dir
|
| 43 |
|
| 44 |
|
| 45 |
# =============================================================================
|
|
|
|
| 107 |
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
| 108 |
]
|
| 109 |
|
| 110 |
+
|
| 111 |
+
# Directory paths - use centralized path management
|
| 112 |
+
def _get_antigravity_logs_dir():
|
| 113 |
+
return get_logs_dir() / "antigravity_logs"
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _get_antigravity_cache_dir():
|
| 117 |
+
return get_cache_dir(subdir="antigravity")
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _get_gemini3_signature_cache_file():
|
| 121 |
+
return _get_antigravity_cache_dir() / "gemini3_signatures.json"
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _get_claude_thinking_cache_file():
|
| 125 |
+
return _get_antigravity_cache_dir() / "claude_thinking.json"
|
| 126 |
+
|
| 127 |
|
| 128 |
# Gemini 3 tool fix system instruction (prevents hallucination)
|
| 129 |
DEFAULT_GEMINI3_SYSTEM_INSTRUCTION = """<CRITICAL_TOOL_USAGE_INSTRUCTIONS>
|
|
|
|
| 340 |
return obj
|
| 341 |
|
| 342 |
|
| 343 |
+
def _inline_schema_refs(schema: Dict[str, Any]) -> Dict[str, Any]:
|
| 344 |
+
"""Inline local $ref definitions before sanitization."""
|
| 345 |
+
if not isinstance(schema, dict):
|
| 346 |
+
return schema
|
| 347 |
+
|
| 348 |
+
defs = schema.get("$defs", schema.get("definitions", {}))
|
| 349 |
+
if not defs:
|
| 350 |
+
return schema
|
| 351 |
+
|
| 352 |
+
def resolve(node, seen=()):
|
| 353 |
+
if not isinstance(node, dict):
|
| 354 |
+
return [resolve(x, seen) for x in node] if isinstance(node, list) else node
|
| 355 |
+
if "$ref" in node:
|
| 356 |
+
ref = node["$ref"]
|
| 357 |
+
if ref in seen: # Circular - drop it
|
| 358 |
+
return {k: resolve(v, seen) for k, v in node.items() if k != "$ref"}
|
| 359 |
+
for prefix in ("#/$defs/", "#/definitions/"):
|
| 360 |
+
if isinstance(ref, str) and ref.startswith(prefix):
|
| 361 |
+
name = ref[len(prefix) :]
|
| 362 |
+
if name in defs:
|
| 363 |
+
return resolve(copy.deepcopy(defs[name]), seen + (ref,))
|
| 364 |
+
return {k: resolve(v, seen) for k, v in node.items() if k != "$ref"}
|
| 365 |
+
return {k: resolve(v, seen) for k, v in node.items()}
|
| 366 |
+
|
| 367 |
+
return resolve(schema)
|
| 368 |
+
|
| 369 |
+
|
| 370 |
def _clean_claude_schema(schema: Any) -> Any:
|
| 371 |
"""
|
| 372 |
Recursively clean JSON Schema for Antigravity/Google's Proto-based API.
|
|
|
|
| 424 |
return first_option
|
| 425 |
|
| 426 |
cleaned = {}
|
|
|
|
| 427 |
# Handle 'const' by converting to 'enum' with single value
|
| 428 |
if "const" in schema:
|
| 429 |
const_value = schema["const"]
|
|
|
|
| 464 |
|
| 465 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 466 |
safe_model = model_name.replace("/", "_").replace(":", "_")
|
| 467 |
+
self.log_dir = (
|
| 468 |
+
_get_antigravity_logs_dir() / f"{timestamp}_{safe_model}_{uuid.uuid4()}"
|
| 469 |
+
)
|
| 470 |
|
| 471 |
try:
|
| 472 |
self.log_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 699 |
error_obj = data.get("error", data)
|
| 700 |
details = error_obj.get("details", [])
|
| 701 |
|
|
|
|
|
|
|
|
|
|
| 702 |
result = {
|
| 703 |
"retry_after": None,
|
| 704 |
"reason": None,
|
|
|
|
| 749 |
|
| 750 |
# Return None if we couldn't extract retry_after
|
| 751 |
if not result["retry_after"]:
|
| 752 |
+
# Handle bare RESOURCE_EXHAUSTED without timing details
|
| 753 |
+
error_status = error_obj.get("status", "")
|
| 754 |
+
error_code = error_obj.get("code")
|
| 755 |
+
|
| 756 |
+
if error_status == "RESOURCE_EXHAUSTED" or error_code == 429:
|
| 757 |
+
result["retry_after"] = 60 # Default fallback
|
| 758 |
+
result["reason"] = result.get("reason") or "RESOURCE_EXHAUSTED"
|
| 759 |
+
return result
|
| 760 |
+
|
| 761 |
return None
|
| 762 |
|
| 763 |
return result
|
|
|
|
| 765 |
def __init__(self):
|
| 766 |
super().__init__()
|
| 767 |
self.model_definitions = ModelDefinitions()
|
| 768 |
+
# NOTE: project_id_cache and project_tier_cache are inherited from AntigravityAuthBase
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 769 |
|
| 770 |
# Base URL management
|
| 771 |
self._base_url_index = 0
|
|
|
|
| 777 |
|
| 778 |
# Initialize caches using shared ProviderCache
|
| 779 |
self._signature_cache = ProviderCache(
|
| 780 |
+
_get_gemini3_signature_cache_file(),
|
| 781 |
memory_ttl,
|
| 782 |
disk_ttl,
|
| 783 |
env_prefix="ANTIGRAVITY_SIGNATURE",
|
| 784 |
)
|
| 785 |
self._thinking_cache = ProviderCache(
|
| 786 |
+
_get_claude_thinking_cache_file(),
|
| 787 |
memory_ttl,
|
| 788 |
disk_ttl,
|
| 789 |
env_prefix="ANTIGRAVITY_THINKING",
|
|
|
|
| 913 |
|
| 914 |
This ensures all credential priorities are known before any API calls,
|
| 915 |
preventing unknown credentials from getting priority 999.
|
| 916 |
+
|
| 917 |
+
For credentials without persisted tier info (new or corrupted), performs
|
| 918 |
+
full discovery to ensure proper prioritization in sequential rotation mode.
|
| 919 |
"""
|
| 920 |
+
# Step 1: Load persisted tiers from files
|
| 921 |
await self._load_persisted_tiers(credential_paths)
|
| 922 |
|
| 923 |
+
# Step 2: Identify credentials still missing tier info
|
| 924 |
+
credentials_needing_discovery = [
|
| 925 |
+
path
|
| 926 |
+
for path in credential_paths
|
| 927 |
+
if path not in self.project_tier_cache
|
| 928 |
+
and self._parse_env_credential_path(path) is None # Skip env:// paths
|
| 929 |
+
]
|
| 930 |
+
|
| 931 |
+
if not credentials_needing_discovery:
|
| 932 |
+
return # All credentials have tier info
|
| 933 |
+
|
| 934 |
+
lib_logger.info(
|
| 935 |
+
f"Antigravity: Discovering tier info for {len(credentials_needing_discovery)} credential(s)..."
|
| 936 |
+
)
|
| 937 |
+
|
| 938 |
+
# Step 3: Perform discovery for each missing credential (sequential to avoid rate limits)
|
| 939 |
+
for credential_path in credentials_needing_discovery:
|
| 940 |
+
try:
|
| 941 |
+
auth_header = await self.get_auth_header(credential_path)
|
| 942 |
+
access_token = auth_header["Authorization"].split(" ")[1]
|
| 943 |
+
await self._discover_project_id(
|
| 944 |
+
credential_path, access_token, litellm_params={}
|
| 945 |
+
)
|
| 946 |
+
discovered_tier = self.project_tier_cache.get(
|
| 947 |
+
credential_path, "unknown"
|
| 948 |
+
)
|
| 949 |
+
lib_logger.debug(
|
| 950 |
+
f"Discovered tier '{discovered_tier}' for {Path(credential_path).name}"
|
| 951 |
+
)
|
| 952 |
+
except Exception as e:
|
| 953 |
+
lib_logger.warning(
|
| 954 |
+
f"Failed to discover tier for {Path(credential_path).name}: {e}. "
|
| 955 |
+
f"Credential will use default priority."
|
| 956 |
+
)
|
| 957 |
+
|
| 958 |
async def _load_persisted_tiers(
|
| 959 |
self, credential_paths: List[str]
|
| 960 |
) -> Dict[str, str]:
|
|
|
|
| 1012 |
|
| 1013 |
return loaded
|
| 1014 |
|
| 1015 |
+
# NOTE: _post_auth_discovery() is inherited from AntigravityAuthBase
|
| 1016 |
+
|
| 1017 |
# =========================================================================
|
| 1018 |
# MODEL UTILITIES
|
| 1019 |
# =========================================================================
|
|
|
|
| 1090 |
|
| 1091 |
return "thinking_" + "_".join(key_parts) if key_parts else None
|
| 1092 |
|
| 1093 |
+
# NOTE: _discover_project_id() and _persist_project_metadata() are inherited from AntigravityAuthBase
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1094 |
|
| 1095 |
# =========================================================================
|
| 1096 |
# THINKING MODE SANITIZATION
|
|
|
|
| 1990 |
elif first_func_in_msg:
|
| 1991 |
# Only add bypass to the first function call if no sig available
|
| 1992 |
func_part["thoughtSignature"] = "skip_thought_signature_validator"
|
| 1993 |
+
lib_logger.debug(
|
| 1994 |
f"Missing thoughtSignature for first func call {tool_id}, using bypass"
|
| 1995 |
)
|
| 1996 |
# Subsequent parallel calls: no signature field at all
|
|
|
|
| 2125 |
f"Ignoring duplicate - this may indicate malformed conversation history."
|
| 2126 |
)
|
| 2127 |
continue
|
| 2128 |
+
# lib_logger.debug(
|
| 2129 |
# f"[Grouping] Collected response for ID: {resp_id}"
|
| 2130 |
+
# )
|
| 2131 |
collected_responses[resp_id] = resp
|
| 2132 |
|
| 2133 |
# Try to satisfy pending groups (newest first)
|
|
|
|
| 2142 |
collected_responses.pop(gid) for gid in group_ids
|
| 2143 |
]
|
| 2144 |
new_contents.append({"parts": group_responses, "role": "user"})
|
| 2145 |
+
# lib_logger.debug(
|
| 2146 |
# f"[Grouping] Satisfied group with {len(group_responses)} responses: "
|
| 2147 |
# f"ids={group_ids}"
|
| 2148 |
+
# )
|
| 2149 |
pending_groups.pop(i)
|
| 2150 |
break
|
| 2151 |
continue
|
|
|
|
| 2165 |
]
|
| 2166 |
|
| 2167 |
if call_ids:
|
| 2168 |
+
# lib_logger.debug(
|
| 2169 |
# f"[Grouping] Created pending group expecting {len(call_ids)} responses: "
|
| 2170 |
# f"ids={call_ids}, names={func_names}"
|
| 2171 |
+
# )
|
| 2172 |
pending_groups.append(
|
| 2173 |
{
|
| 2174 |
"ids": call_ids,
|
|
|
|
| 2533 |
|
| 2534 |
if params and isinstance(params, dict):
|
| 2535 |
schema = dict(params)
|
|
|
|
| 2536 |
schema.pop("strict", None)
|
| 2537 |
+
# Inline $ref definitions, then strip unsupported keywords
|
| 2538 |
+
schema = _inline_schema_refs(schema)
|
| 2539 |
+
schema = _clean_claude_schema(schema)
|
| 2540 |
schema = _normalize_type_arrays(schema)
|
| 2541 |
+
|
| 2542 |
+
# Workaround: Antigravity/Gemini fails to emit functionCall
|
| 2543 |
+
# when tool has empty properties {}. Inject a dummy optional
|
| 2544 |
+
# parameter to ensure the tool call is emitted.
|
| 2545 |
+
# Using a required confirmation parameter forces the model to
|
| 2546 |
+
# commit to the tool call rather than just thinking about it.
|
| 2547 |
+
props = schema.get("properties", {})
|
| 2548 |
+
if not props:
|
| 2549 |
+
schema["properties"] = {
|
| 2550 |
+
"_confirm": {
|
| 2551 |
+
"type": "string",
|
| 2552 |
+
"description": "Enter 'yes' to proceed",
|
| 2553 |
+
}
|
| 2554 |
+
}
|
| 2555 |
+
schema["required"] = ["_confirm"]
|
| 2556 |
+
|
| 2557 |
func_decl["parametersJsonSchema"] = schema
|
| 2558 |
else:
|
| 2559 |
+
# No parameters provided - use default with required confirm param
|
| 2560 |
+
# to ensure the tool call is emitted properly
|
| 2561 |
+
func_decl["parametersJsonSchema"] = {
|
| 2562 |
+
"type": "object",
|
| 2563 |
+
"properties": {
|
| 2564 |
+
"_confirm": {
|
| 2565 |
+
"type": "string",
|
| 2566 |
+
"description": "Enter 'yes' to proceed",
|
| 2567 |
+
}
|
| 2568 |
+
},
|
| 2569 |
+
"required": ["_confirm"],
|
| 2570 |
+
}
|
| 2571 |
|
| 2572 |
gemini_tools.append({"functionDeclarations": [func_decl]})
|
| 2573 |
|
|
|
|
| 2692 |
return antigravity_payload
|
| 2693 |
|
| 2694 |
def _apply_claude_tool_transform(self, payload: Dict[str, Any]) -> None:
|
| 2695 |
+
"""Apply Claude-specific tool schema transformations.
|
| 2696 |
+
|
| 2697 |
+
Converts parametersJsonSchema to parameters and applies Claude-specific
|
| 2698 |
+
schema sanitization (inlines $ref, removes unsupported JSON Schema fields).
|
| 2699 |
+
"""
|
| 2700 |
tools = payload["request"].get("tools", [])
|
| 2701 |
for tool in tools:
|
| 2702 |
for func_decl in tool.get("functionDeclarations", []):
|
| 2703 |
if "parametersJsonSchema" in func_decl:
|
| 2704 |
params = func_decl["parametersJsonSchema"]
|
| 2705 |
+
if isinstance(params, dict):
|
| 2706 |
+
params = _inline_schema_refs(params)
|
| 2707 |
+
params = _clean_claude_schema(params)
|
|
|
|
|
|
|
| 2708 |
func_decl["parameters"] = params
|
| 2709 |
del func_decl["parametersJsonSchema"]
|
| 2710 |
|
|
|
|
| 2933 |
raw_args = func_call.get("args", {})
|
| 2934 |
parsed_args = _recursively_parse_json_strings(raw_args)
|
| 2935 |
|
| 2936 |
+
# Strip the injected _confirm parameter ONLY if it's the sole parameter
|
| 2937 |
+
# This ensures we only strip our injection, not legitimate user params
|
| 2938 |
+
if isinstance(parsed_args, dict) and "_confirm" in parsed_args:
|
| 2939 |
+
if len(parsed_args) == 1:
|
| 2940 |
+
# _confirm is the only param - this was our injection
|
| 2941 |
+
parsed_args.pop("_confirm")
|
| 2942 |
+
|
| 2943 |
tool_call = {
|
| 2944 |
"id": tool_id,
|
| 2945 |
"type": "function",
|
|
|
|
| 3009 |
}
|
| 3010 |
|
| 3011 |
self._thinking_cache.store(cache_key, json.dumps(data))
|
| 3012 |
+
lib_logger.debug(f"Cached thinking: {cache_key[:50]}...")
|
| 3013 |
|
| 3014 |
# =========================================================================
|
| 3015 |
# PROVIDER INTERFACE IMPLEMENTATION
|
|
|
|
| 3307 |
file_logger: Optional[AntigravityFileLogger] = None,
|
| 3308 |
) -> litellm.ModelResponse:
|
| 3309 |
"""Handle non-streaming completion."""
|
| 3310 |
+
response = await client.post(
|
| 3311 |
+
url,
|
| 3312 |
+
headers=headers,
|
| 3313 |
+
json=payload,
|
| 3314 |
+
timeout=TimeoutConfig.non_streaming(),
|
| 3315 |
+
)
|
| 3316 |
response.raise_for_status()
|
| 3317 |
|
| 3318 |
data = response.json()
|
|
|
|
| 3345 |
}
|
| 3346 |
|
| 3347 |
async with client.stream(
|
| 3348 |
+
"POST",
|
| 3349 |
+
url,
|
| 3350 |
+
headers=headers,
|
| 3351 |
+
json=payload,
|
| 3352 |
+
timeout=TimeoutConfig.streaming(),
|
| 3353 |
) as response:
|
| 3354 |
if response.status_code >= 400:
|
| 3355 |
+
# Read error body so it's available in response.text for logging
|
| 3356 |
+
# The actual logging happens in failure_logger via _extract_response_body
|
| 3357 |
try:
|
| 3358 |
await response.aread()
|
| 3359 |
# lib_logger.error(
|
src/rotator_library/providers/gemini_auth_base.py
CHANGED
|
@@ -1,15 +1,35 @@
|
|
| 1 |
# src/rotator_library/providers/gemini_auth_base.py
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from .google_oauth_base import GoogleOAuthBase
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
class GeminiAuthBase(GoogleOAuthBase):
|
| 6 |
"""
|
| 7 |
Gemini CLI OAuth2 authentication implementation.
|
| 8 |
-
|
| 9 |
Inherits all OAuth functionality from GoogleOAuthBase with Gemini-specific configuration.
|
|
|
|
|
|
|
|
|
|
| 10 |
"""
|
| 11 |
-
|
| 12 |
-
CLIENT_ID =
|
|
|
|
|
|
|
| 13 |
CLIENT_SECRET = "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"
|
| 14 |
OAUTH_SCOPES = [
|
| 15 |
"https://www.googleapis.com/auth/cloud-platform",
|
|
@@ -18,4 +38,606 @@ class GeminiAuthBase(GoogleOAuthBase):
|
|
| 18 |
]
|
| 19 |
ENV_PREFIX = "GEMINI_CLI"
|
| 20 |
CALLBACK_PORT = 8085
|
| 21 |
-
CALLBACK_PATH = "/oauth2callback"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# src/rotator_library/providers/gemini_auth_base.py
|
| 2 |
|
| 3 |
+
import asyncio
|
| 4 |
+
import json
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Dict, Optional, List
|
| 9 |
+
|
| 10 |
+
import httpx
|
| 11 |
+
|
| 12 |
from .google_oauth_base import GoogleOAuthBase
|
| 13 |
|
| 14 |
+
lib_logger = logging.getLogger("rotator_library")
|
| 15 |
+
|
| 16 |
+
# Code Assist endpoint for project discovery
|
| 17 |
+
CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com/v1internal"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
class GeminiAuthBase(GoogleOAuthBase):
|
| 21 |
"""
|
| 22 |
Gemini CLI OAuth2 authentication implementation.
|
| 23 |
+
|
| 24 |
Inherits all OAuth functionality from GoogleOAuthBase with Gemini-specific configuration.
|
| 25 |
+
|
| 26 |
+
Also provides project/tier discovery functionality that runs during authentication,
|
| 27 |
+
ensuring credentials have their tier and project_id cached before any API requests.
|
| 28 |
"""
|
| 29 |
+
|
| 30 |
+
CLIENT_ID = (
|
| 31 |
+
"681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
| 32 |
+
)
|
| 33 |
CLIENT_SECRET = "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"
|
| 34 |
OAUTH_SCOPES = [
|
| 35 |
"https://www.googleapis.com/auth/cloud-platform",
|
|
|
|
| 38 |
]
|
| 39 |
ENV_PREFIX = "GEMINI_CLI"
|
| 40 |
CALLBACK_PORT = 8085
|
| 41 |
+
CALLBACK_PATH = "/oauth2callback"
|
| 42 |
+
|
| 43 |
+
def __init__(self):
|
| 44 |
+
super().__init__()
|
| 45 |
+
# Project and tier caches - shared between auth base and provider
|
| 46 |
+
self.project_id_cache: Dict[str, str] = {}
|
| 47 |
+
self.project_tier_cache: Dict[str, str] = {}
|
| 48 |
+
|
| 49 |
+
# =========================================================================
|
| 50 |
+
# POST-AUTH DISCOVERY HOOK
|
| 51 |
+
# =========================================================================
|
| 52 |
+
|
| 53 |
+
async def _post_auth_discovery(
|
| 54 |
+
self, credential_path: str, access_token: str
|
| 55 |
+
) -> None:
|
| 56 |
+
"""
|
| 57 |
+
Discover and cache tier/project information immediately after OAuth authentication.
|
| 58 |
+
|
| 59 |
+
This is called by GoogleOAuthBase._perform_interactive_oauth() after successful auth,
|
| 60 |
+
ensuring tier and project_id are cached during the authentication flow rather than
|
| 61 |
+
waiting for the first API request.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
credential_path: Path to the credential file
|
| 65 |
+
access_token: The newly obtained access token
|
| 66 |
+
"""
|
| 67 |
+
lib_logger.debug(
|
| 68 |
+
f"Starting post-auth discovery for GeminiCli credential: {Path(credential_path).name}"
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Skip if already discovered (shouldn't happen during fresh auth, but be defensive)
|
| 72 |
+
if (
|
| 73 |
+
credential_path in self.project_id_cache
|
| 74 |
+
and credential_path in self.project_tier_cache
|
| 75 |
+
):
|
| 76 |
+
lib_logger.debug(
|
| 77 |
+
f"Tier and project already cached for {Path(credential_path).name}, skipping discovery"
|
| 78 |
+
)
|
| 79 |
+
return
|
| 80 |
+
|
| 81 |
+
# Call _discover_project_id which handles tier/project discovery and persistence
|
| 82 |
+
# Pass empty litellm_params since we're in auth context (no model-specific overrides)
|
| 83 |
+
project_id = await self._discover_project_id(
|
| 84 |
+
credential_path, access_token, litellm_params={}
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
tier = self.project_tier_cache.get(credential_path, "unknown")
|
| 88 |
+
lib_logger.info(
|
| 89 |
+
f"Post-auth discovery complete for {Path(credential_path).name}: "
|
| 90 |
+
f"tier={tier}, project={project_id}"
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# =========================================================================
|
| 94 |
+
# PROJECT ID DISCOVERY
|
| 95 |
+
# =========================================================================
|
| 96 |
+
|
| 97 |
+
async def _discover_project_id(
|
| 98 |
+
self, credential_path: str, access_token: str, litellm_params: Dict[str, Any]
|
| 99 |
+
) -> str:
|
| 100 |
+
"""
|
| 101 |
+
Discovers the Google Cloud Project ID, with caching and onboarding for new accounts.
|
| 102 |
+
|
| 103 |
+
This follows the official Gemini CLI discovery flow:
|
| 104 |
+
1. Check in-memory cache
|
| 105 |
+
2. Check configured project_id override (litellm_params or env var)
|
| 106 |
+
3. Check persisted project_id in credential file
|
| 107 |
+
4. Call loadCodeAssist to check if user is already known (has currentTier)
|
| 108 |
+
- If currentTier exists AND cloudaicompanionProject returned: use server's project
|
| 109 |
+
- If currentTier exists but NO cloudaicompanionProject: use configured project_id (paid tier requires this)
|
| 110 |
+
- If no currentTier: user needs onboarding
|
| 111 |
+
5. Onboard user based on tier:
|
| 112 |
+
- FREE tier: pass cloudaicompanionProject=None (server-managed)
|
| 113 |
+
- PAID tier: pass cloudaicompanionProject=configured_project_id
|
| 114 |
+
6. Fallback to GCP Resource Manager project listing
|
| 115 |
+
"""
|
| 116 |
+
lib_logger.debug(
|
| 117 |
+
f"Starting project discovery for credential: {credential_path}"
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Check in-memory cache first
|
| 121 |
+
if credential_path in self.project_id_cache:
|
| 122 |
+
cached_project = self.project_id_cache[credential_path]
|
| 123 |
+
lib_logger.debug(f"Using cached project ID: {cached_project}")
|
| 124 |
+
return cached_project
|
| 125 |
+
|
| 126 |
+
# Check for configured project ID override (from litellm_params or env var)
|
| 127 |
+
# This is REQUIRED for paid tier users per the official CLI behavior
|
| 128 |
+
configured_project_id = litellm_params.get("project_id") or os.getenv(
|
| 129 |
+
"GEMINI_CLI_PROJECT_ID"
|
| 130 |
+
)
|
| 131 |
+
if configured_project_id:
|
| 132 |
+
lib_logger.debug(
|
| 133 |
+
f"Found configured project_id override: {configured_project_id}"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# Load credentials from file to check for persisted project_id and tier
|
| 137 |
+
# Skip for env:// paths (environment-based credentials don't persist to files)
|
| 138 |
+
credential_index = self._parse_env_credential_path(credential_path)
|
| 139 |
+
if credential_index is None:
|
| 140 |
+
# Only try to load from file if it's not an env:// path
|
| 141 |
+
try:
|
| 142 |
+
with open(credential_path, "r") as f:
|
| 143 |
+
creds = json.load(f)
|
| 144 |
+
|
| 145 |
+
metadata = creds.get("_proxy_metadata", {})
|
| 146 |
+
persisted_project_id = metadata.get("project_id")
|
| 147 |
+
persisted_tier = metadata.get("tier")
|
| 148 |
+
|
| 149 |
+
if persisted_project_id:
|
| 150 |
+
lib_logger.info(
|
| 151 |
+
f"Loaded persisted project ID from credential file: {persisted_project_id}"
|
| 152 |
+
)
|
| 153 |
+
self.project_id_cache[credential_path] = persisted_project_id
|
| 154 |
+
|
| 155 |
+
# Also load tier if available
|
| 156 |
+
if persisted_tier:
|
| 157 |
+
self.project_tier_cache[credential_path] = persisted_tier
|
| 158 |
+
lib_logger.debug(f"Loaded persisted tier: {persisted_tier}")
|
| 159 |
+
|
| 160 |
+
return persisted_project_id
|
| 161 |
+
except (FileNotFoundError, json.JSONDecodeError, KeyError) as e:
|
| 162 |
+
lib_logger.debug(f"Could not load persisted project ID from file: {e}")
|
| 163 |
+
|
| 164 |
+
lib_logger.debug(
|
| 165 |
+
"No cached or configured project ID found, initiating discovery..."
|
| 166 |
+
)
|
| 167 |
+
headers = {
|
| 168 |
+
"Authorization": f"Bearer {access_token}",
|
| 169 |
+
"Content-Type": "application/json",
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
discovered_project_id = None
|
| 173 |
+
discovered_tier = None
|
| 174 |
+
|
| 175 |
+
async with httpx.AsyncClient() as client:
|
| 176 |
+
# 1. Try discovery endpoint with loadCodeAssist
|
| 177 |
+
lib_logger.debug(
|
| 178 |
+
"Attempting project discovery via Code Assist loadCodeAssist endpoint..."
|
| 179 |
+
)
|
| 180 |
+
try:
|
| 181 |
+
# Build metadata - include duetProject only if we have a configured project
|
| 182 |
+
core_client_metadata = {
|
| 183 |
+
"ideType": "IDE_UNSPECIFIED",
|
| 184 |
+
"platform": "PLATFORM_UNSPECIFIED",
|
| 185 |
+
"pluginType": "GEMINI",
|
| 186 |
+
}
|
| 187 |
+
if configured_project_id:
|
| 188 |
+
core_client_metadata["duetProject"] = configured_project_id
|
| 189 |
+
|
| 190 |
+
# Build load request - pass configured_project_id if available, otherwise None
|
| 191 |
+
load_request = {
|
| 192 |
+
"cloudaicompanionProject": configured_project_id, # Can be None
|
| 193 |
+
"metadata": core_client_metadata,
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
lib_logger.debug(
|
| 197 |
+
f"Sending loadCodeAssist request with cloudaicompanionProject={configured_project_id}"
|
| 198 |
+
)
|
| 199 |
+
response = await client.post(
|
| 200 |
+
f"{CODE_ASSIST_ENDPOINT}:loadCodeAssist",
|
| 201 |
+
headers=headers,
|
| 202 |
+
json=load_request,
|
| 203 |
+
timeout=20,
|
| 204 |
+
)
|
| 205 |
+
response.raise_for_status()
|
| 206 |
+
data = response.json()
|
| 207 |
+
|
| 208 |
+
# Log full response for debugging
|
| 209 |
+
lib_logger.debug(
|
| 210 |
+
f"loadCodeAssist full response keys: {list(data.keys())}"
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Extract and log ALL tier information for debugging
|
| 214 |
+
allowed_tiers = data.get("allowedTiers", [])
|
| 215 |
+
current_tier = data.get("currentTier")
|
| 216 |
+
|
| 217 |
+
lib_logger.debug(f"=== Tier Information ===")
|
| 218 |
+
lib_logger.debug(f"currentTier: {current_tier}")
|
| 219 |
+
lib_logger.debug(f"allowedTiers count: {len(allowed_tiers)}")
|
| 220 |
+
for i, tier in enumerate(allowed_tiers):
|
| 221 |
+
tier_id = tier.get("id", "unknown")
|
| 222 |
+
is_default = tier.get("isDefault", False)
|
| 223 |
+
user_defined = tier.get("userDefinedCloudaicompanionProject", False)
|
| 224 |
+
lib_logger.debug(
|
| 225 |
+
f" Tier {i + 1}: id={tier_id}, isDefault={is_default}, userDefinedProject={user_defined}"
|
| 226 |
+
)
|
| 227 |
+
lib_logger.debug(f"========================")
|
| 228 |
+
|
| 229 |
+
# Determine the current tier ID
|
| 230 |
+
current_tier_id = None
|
| 231 |
+
if current_tier:
|
| 232 |
+
current_tier_id = current_tier.get("id")
|
| 233 |
+
lib_logger.debug(f"User has currentTier: {current_tier_id}")
|
| 234 |
+
|
| 235 |
+
# Check if user is already known to server (has currentTier)
|
| 236 |
+
if current_tier_id:
|
| 237 |
+
# User is already onboarded - check for project from server
|
| 238 |
+
server_project = data.get("cloudaicompanionProject")
|
| 239 |
+
|
| 240 |
+
# Check if this tier requires user-defined project (paid tiers)
|
| 241 |
+
requires_user_project = any(
|
| 242 |
+
t.get("id") == current_tier_id
|
| 243 |
+
and t.get("userDefinedCloudaicompanionProject", False)
|
| 244 |
+
for t in allowed_tiers
|
| 245 |
+
)
|
| 246 |
+
is_free_tier = current_tier_id == "free-tier"
|
| 247 |
+
|
| 248 |
+
if server_project:
|
| 249 |
+
# Server returned a project - use it (server wins)
|
| 250 |
+
# This is the normal case for FREE tier users
|
| 251 |
+
project_id = server_project
|
| 252 |
+
lib_logger.debug(f"Server returned project: {project_id}")
|
| 253 |
+
elif configured_project_id:
|
| 254 |
+
# No server project but we have configured one - use it
|
| 255 |
+
# This is the PAID TIER case where server doesn't return a project
|
| 256 |
+
project_id = configured_project_id
|
| 257 |
+
lib_logger.debug(
|
| 258 |
+
f"No server project, using configured: {project_id}"
|
| 259 |
+
)
|
| 260 |
+
elif is_free_tier:
|
| 261 |
+
# Free tier user without server project - this shouldn't happen normally
|
| 262 |
+
# but let's not fail, just proceed to onboarding
|
| 263 |
+
lib_logger.debug(
|
| 264 |
+
"Free tier user with currentTier but no project - will try onboarding"
|
| 265 |
+
)
|
| 266 |
+
project_id = None
|
| 267 |
+
elif requires_user_project:
|
| 268 |
+
# Paid tier requires a project ID to be set
|
| 269 |
+
raise ValueError(
|
| 270 |
+
f"Paid tier '{current_tier_id}' requires setting GEMINI_CLI_PROJECT_ID environment variable. "
|
| 271 |
+
"See https://goo.gle/gemini-cli-auth-docs#workspace-gca"
|
| 272 |
+
)
|
| 273 |
+
else:
|
| 274 |
+
# Unknown tier without project - proceed carefully
|
| 275 |
+
lib_logger.warning(
|
| 276 |
+
f"Tier '{current_tier_id}' has no project and none configured - will try onboarding"
|
| 277 |
+
)
|
| 278 |
+
project_id = None
|
| 279 |
+
|
| 280 |
+
if project_id:
|
| 281 |
+
# Cache tier info
|
| 282 |
+
self.project_tier_cache[credential_path] = current_tier_id
|
| 283 |
+
discovered_tier = current_tier_id
|
| 284 |
+
|
| 285 |
+
# Log appropriately based on tier
|
| 286 |
+
is_paid = current_tier_id and current_tier_id not in [
|
| 287 |
+
"free-tier",
|
| 288 |
+
"legacy-tier",
|
| 289 |
+
"unknown",
|
| 290 |
+
]
|
| 291 |
+
if is_paid:
|
| 292 |
+
lib_logger.info(
|
| 293 |
+
f"Using Gemini paid tier '{current_tier_id}' with project: {project_id}"
|
| 294 |
+
)
|
| 295 |
+
else:
|
| 296 |
+
lib_logger.info(
|
| 297 |
+
f"Discovered Gemini project ID via loadCodeAssist: {project_id}"
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
self.project_id_cache[credential_path] = project_id
|
| 301 |
+
discovered_project_id = project_id
|
| 302 |
+
|
| 303 |
+
# Persist to credential file
|
| 304 |
+
await self._persist_project_metadata(
|
| 305 |
+
credential_path, project_id, discovered_tier
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
return project_id
|
| 309 |
+
|
| 310 |
+
# 2. User needs onboarding - no currentTier
|
| 311 |
+
lib_logger.info(
|
| 312 |
+
"No existing Gemini session found (no currentTier), attempting to onboard user..."
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
# Determine which tier to onboard with
|
| 316 |
+
onboard_tier = None
|
| 317 |
+
for tier in allowed_tiers:
|
| 318 |
+
if tier.get("isDefault"):
|
| 319 |
+
onboard_tier = tier
|
| 320 |
+
break
|
| 321 |
+
|
| 322 |
+
# Fallback to LEGACY tier if no default (requires user project)
|
| 323 |
+
if not onboard_tier and allowed_tiers:
|
| 324 |
+
# Look for legacy-tier as fallback
|
| 325 |
+
for tier in allowed_tiers:
|
| 326 |
+
if tier.get("id") == "legacy-tier":
|
| 327 |
+
onboard_tier = tier
|
| 328 |
+
break
|
| 329 |
+
# If still no tier, use first available
|
| 330 |
+
if not onboard_tier:
|
| 331 |
+
onboard_tier = allowed_tiers[0]
|
| 332 |
+
|
| 333 |
+
if not onboard_tier:
|
| 334 |
+
raise ValueError("No onboarding tiers available from server")
|
| 335 |
+
|
| 336 |
+
tier_id = onboard_tier.get("id", "free-tier")
|
| 337 |
+
requires_user_project = onboard_tier.get(
|
| 338 |
+
"userDefinedCloudaicompanionProject", False
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
lib_logger.debug(
|
| 342 |
+
f"Onboarding with tier: {tier_id}, requiresUserProject: {requires_user_project}"
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
# Build onboard request based on tier type (following official CLI logic)
|
| 346 |
+
# FREE tier: cloudaicompanionProject = None (server-managed)
|
| 347 |
+
# PAID tier: cloudaicompanionProject = configured_project_id (user must provide)
|
| 348 |
+
is_free_tier = tier_id == "free-tier"
|
| 349 |
+
|
| 350 |
+
if is_free_tier:
|
| 351 |
+
# Free tier uses server-managed project
|
| 352 |
+
onboard_request = {
|
| 353 |
+
"tierId": tier_id,
|
| 354 |
+
"cloudaicompanionProject": None, # Server will create/manage
|
| 355 |
+
"metadata": core_client_metadata,
|
| 356 |
+
}
|
| 357 |
+
lib_logger.debug(
|
| 358 |
+
"Free tier onboarding: using server-managed project"
|
| 359 |
+
)
|
| 360 |
+
else:
|
| 361 |
+
# Paid/legacy tier requires user-provided project
|
| 362 |
+
if not configured_project_id and requires_user_project:
|
| 363 |
+
raise ValueError(
|
| 364 |
+
f"Tier '{tier_id}' requires setting GEMINI_CLI_PROJECT_ID environment variable. "
|
| 365 |
+
"See https://goo.gle/gemini-cli-auth-docs#workspace-gca"
|
| 366 |
+
)
|
| 367 |
+
onboard_request = {
|
| 368 |
+
"tierId": tier_id,
|
| 369 |
+
"cloudaicompanionProject": configured_project_id,
|
| 370 |
+
"metadata": {
|
| 371 |
+
**core_client_metadata,
|
| 372 |
+
"duetProject": configured_project_id,
|
| 373 |
+
}
|
| 374 |
+
if configured_project_id
|
| 375 |
+
else core_client_metadata,
|
| 376 |
+
}
|
| 377 |
+
lib_logger.debug(
|
| 378 |
+
f"Paid tier onboarding: using project {configured_project_id}"
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
lib_logger.debug("Initiating onboardUser request...")
|
| 382 |
+
lro_response = await client.post(
|
| 383 |
+
f"{CODE_ASSIST_ENDPOINT}:onboardUser",
|
| 384 |
+
headers=headers,
|
| 385 |
+
json=onboard_request,
|
| 386 |
+
timeout=30,
|
| 387 |
+
)
|
| 388 |
+
lro_response.raise_for_status()
|
| 389 |
+
lro_data = lro_response.json()
|
| 390 |
+
lib_logger.debug(
|
| 391 |
+
f"Initial onboarding response: done={lro_data.get('done')}"
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
for i in range(150): # Poll for up to 5 minutes (150 × 2s)
|
| 395 |
+
if lro_data.get("done"):
|
| 396 |
+
lib_logger.debug(
|
| 397 |
+
f"Onboarding completed after {i} polling attempts"
|
| 398 |
+
)
|
| 399 |
+
break
|
| 400 |
+
await asyncio.sleep(2)
|
| 401 |
+
if (i + 1) % 15 == 0: # Log every 30 seconds
|
| 402 |
+
lib_logger.info(
|
| 403 |
+
f"Still waiting for onboarding completion... ({(i + 1) * 2}s elapsed)"
|
| 404 |
+
)
|
| 405 |
+
lib_logger.debug(
|
| 406 |
+
f"Polling onboarding status... (Attempt {i + 1}/150)"
|
| 407 |
+
)
|
| 408 |
+
lro_response = await client.post(
|
| 409 |
+
f"{CODE_ASSIST_ENDPOINT}:onboardUser",
|
| 410 |
+
headers=headers,
|
| 411 |
+
json=onboard_request,
|
| 412 |
+
timeout=30,
|
| 413 |
+
)
|
| 414 |
+
lro_response.raise_for_status()
|
| 415 |
+
lro_data = lro_response.json()
|
| 416 |
+
|
| 417 |
+
if not lro_data.get("done"):
|
| 418 |
+
lib_logger.error("Onboarding process timed out after 5 minutes")
|
| 419 |
+
raise ValueError(
|
| 420 |
+
"Onboarding process timed out after 5 minutes. Please try again or contact support."
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
# Extract project ID from LRO response
|
| 424 |
+
# Note: onboardUser returns response.cloudaicompanionProject as an object with .id
|
| 425 |
+
lro_response_data = lro_data.get("response", {})
|
| 426 |
+
lro_project_obj = lro_response_data.get("cloudaicompanionProject", {})
|
| 427 |
+
project_id = (
|
| 428 |
+
lro_project_obj.get("id")
|
| 429 |
+
if isinstance(lro_project_obj, dict)
|
| 430 |
+
else None
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
# Fallback to configured project if LRO didn't return one
|
| 434 |
+
if not project_id and configured_project_id:
|
| 435 |
+
project_id = configured_project_id
|
| 436 |
+
lib_logger.debug(
|
| 437 |
+
f"LRO didn't return project, using configured: {project_id}"
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
if not project_id:
|
| 441 |
+
lib_logger.error(
|
| 442 |
+
"Onboarding completed but no project ID in response and none configured"
|
| 443 |
+
)
|
| 444 |
+
raise ValueError(
|
| 445 |
+
"Onboarding completed, but no project ID was returned. "
|
| 446 |
+
"For paid tiers, set GEMINI_CLI_PROJECT_ID environment variable."
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
lib_logger.debug(
|
| 450 |
+
f"Successfully extracted project ID from onboarding response: {project_id}"
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
# Cache tier info
|
| 454 |
+
self.project_tier_cache[credential_path] = tier_id
|
| 455 |
+
discovered_tier = tier_id
|
| 456 |
+
lib_logger.debug(f"Cached tier information: {tier_id}")
|
| 457 |
+
|
| 458 |
+
# Log concise message for paid projects
|
| 459 |
+
is_paid = tier_id and tier_id not in ["free-tier", "legacy-tier"]
|
| 460 |
+
if is_paid:
|
| 461 |
+
lib_logger.info(
|
| 462 |
+
f"Using Gemini paid tier '{tier_id}' with project: {project_id}"
|
| 463 |
+
)
|
| 464 |
+
else:
|
| 465 |
+
lib_logger.info(
|
| 466 |
+
f"Successfully onboarded user and discovered project ID: {project_id}"
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
self.project_id_cache[credential_path] = project_id
|
| 470 |
+
discovered_project_id = project_id
|
| 471 |
+
|
| 472 |
+
# Persist to credential file
|
| 473 |
+
await self._persist_project_metadata(
|
| 474 |
+
credential_path, project_id, discovered_tier
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
return project_id
|
| 478 |
+
|
| 479 |
+
except httpx.HTTPStatusError as e:
|
| 480 |
+
error_body = ""
|
| 481 |
+
try:
|
| 482 |
+
error_body = e.response.text
|
| 483 |
+
except Exception:
|
| 484 |
+
pass
|
| 485 |
+
if e.response.status_code == 403:
|
| 486 |
+
lib_logger.error(
|
| 487 |
+
f"Gemini Code Assist API access denied (403). Response: {error_body}"
|
| 488 |
+
)
|
| 489 |
+
lib_logger.error(
|
| 490 |
+
"Possible causes: 1) cloudaicompanion.googleapis.com API not enabled, 2) Wrong project ID for paid tier, 3) Account lacks permissions"
|
| 491 |
+
)
|
| 492 |
+
elif e.response.status_code == 404:
|
| 493 |
+
lib_logger.warning(
|
| 494 |
+
f"Gemini Code Assist endpoint not found (404). Falling back to project listing."
|
| 495 |
+
)
|
| 496 |
+
elif e.response.status_code == 412:
|
| 497 |
+
# Precondition Failed - often means wrong project for free tier onboarding
|
| 498 |
+
lib_logger.error(
|
| 499 |
+
f"Precondition failed (412): {error_body}. This may mean the project ID is incompatible with the selected tier."
|
| 500 |
+
)
|
| 501 |
+
else:
|
| 502 |
+
lib_logger.warning(
|
| 503 |
+
f"Gemini onboarding/discovery failed with status {e.response.status_code}: {error_body}. Falling back to project listing."
|
| 504 |
+
)
|
| 505 |
+
except httpx.RequestError as e:
|
| 506 |
+
lib_logger.warning(
|
| 507 |
+
f"Gemini onboarding/discovery network error: {e}. Falling back to project listing."
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
# 3. Fallback to listing all available GCP projects (last resort)
|
| 511 |
+
lib_logger.debug(
|
| 512 |
+
"Attempting to discover project via GCP Resource Manager API..."
|
| 513 |
+
)
|
| 514 |
+
try:
|
| 515 |
+
async with httpx.AsyncClient() as client:
|
| 516 |
+
lib_logger.debug(
|
| 517 |
+
"Querying Cloud Resource Manager for available projects..."
|
| 518 |
+
)
|
| 519 |
+
response = await client.get(
|
| 520 |
+
"https://cloudresourcemanager.googleapis.com/v1/projects",
|
| 521 |
+
headers=headers,
|
| 522 |
+
timeout=20,
|
| 523 |
+
)
|
| 524 |
+
response.raise_for_status()
|
| 525 |
+
projects = response.json().get("projects", [])
|
| 526 |
+
lib_logger.debug(f"Found {len(projects)} total projects")
|
| 527 |
+
active_projects = [
|
| 528 |
+
p for p in projects if p.get("lifecycleState") == "ACTIVE"
|
| 529 |
+
]
|
| 530 |
+
lib_logger.debug(f"Found {len(active_projects)} active projects")
|
| 531 |
+
|
| 532 |
+
if not projects:
|
| 533 |
+
lib_logger.error(
|
| 534 |
+
"No GCP projects found for this account. Please create a project in Google Cloud Console."
|
| 535 |
+
)
|
| 536 |
+
elif not active_projects:
|
| 537 |
+
lib_logger.error(
|
| 538 |
+
"No active GCP projects found. Please activate a project in Google Cloud Console."
|
| 539 |
+
)
|
| 540 |
+
else:
|
| 541 |
+
project_id = active_projects[0]["projectId"]
|
| 542 |
+
lib_logger.info(
|
| 543 |
+
f"Discovered Gemini project ID from active projects list: {project_id}"
|
| 544 |
+
)
|
| 545 |
+
lib_logger.debug(
|
| 546 |
+
f"Selected first active project: {project_id} (out of {len(active_projects)} active projects)"
|
| 547 |
+
)
|
| 548 |
+
self.project_id_cache[credential_path] = project_id
|
| 549 |
+
discovered_project_id = project_id
|
| 550 |
+
|
| 551 |
+
# Persist to credential file (no tier info from resource manager)
|
| 552 |
+
await self._persist_project_metadata(
|
| 553 |
+
credential_path, project_id, None
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
return project_id
|
| 557 |
+
except httpx.HTTPStatusError as e:
|
| 558 |
+
if e.response.status_code == 403:
|
| 559 |
+
lib_logger.error(
|
| 560 |
+
"Failed to list GCP projects due to a 403 Forbidden error. The Cloud Resource Manager API may not be enabled, or your account lacks the 'resourcemanager.projects.list' permission."
|
| 561 |
+
)
|
| 562 |
+
else:
|
| 563 |
+
lib_logger.error(
|
| 564 |
+
f"Failed to list GCP projects with status {e.response.status_code}: {e}"
|
| 565 |
+
)
|
| 566 |
+
except httpx.RequestError as e:
|
| 567 |
+
lib_logger.error(f"Network error while listing GCP projects: {e}")
|
| 568 |
+
|
| 569 |
+
raise ValueError(
|
| 570 |
+
"Could not auto-discover Gemini project ID. Possible causes:\n"
|
| 571 |
+
" 1. The cloudaicompanion.googleapis.com API is not enabled (enable it in Google Cloud Console)\n"
|
| 572 |
+
" 2. No active GCP projects exist for this account (create one in Google Cloud Console)\n"
|
| 573 |
+
" 3. Account lacks necessary permissions\n"
|
| 574 |
+
"To manually specify a project, set GEMINI_CLI_PROJECT_ID in your .env file."
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
async def _persist_project_metadata(
|
| 578 |
+
self, credential_path: str, project_id: str, tier: Optional[str]
|
| 579 |
+
):
|
| 580 |
+
"""Persists project ID and tier to the credential file for faster future startups."""
|
| 581 |
+
# Skip persistence for env:// paths (environment-based credentials)
|
| 582 |
+
credential_index = self._parse_env_credential_path(credential_path)
|
| 583 |
+
if credential_index is not None:
|
| 584 |
+
lib_logger.debug(
|
| 585 |
+
f"Skipping project metadata persistence for env:// credential path: {credential_path}"
|
| 586 |
+
)
|
| 587 |
+
return
|
| 588 |
+
|
| 589 |
+
try:
|
| 590 |
+
# Load current credentials
|
| 591 |
+
with open(credential_path, "r") as f:
|
| 592 |
+
creds = json.load(f)
|
| 593 |
+
|
| 594 |
+
# Update metadata
|
| 595 |
+
if "_proxy_metadata" not in creds:
|
| 596 |
+
creds["_proxy_metadata"] = {}
|
| 597 |
+
|
| 598 |
+
creds["_proxy_metadata"]["project_id"] = project_id
|
| 599 |
+
if tier:
|
| 600 |
+
creds["_proxy_metadata"]["tier"] = tier
|
| 601 |
+
|
| 602 |
+
# Save back using the existing save method (handles atomic writes and permissions)
|
| 603 |
+
await self._save_credentials(credential_path, creds)
|
| 604 |
+
|
| 605 |
+
lib_logger.debug(
|
| 606 |
+
f"Persisted project_id and tier to credential file: {credential_path}"
|
| 607 |
+
)
|
| 608 |
+
except Exception as e:
|
| 609 |
+
lib_logger.warning(
|
| 610 |
+
f"Failed to persist project metadata to credential file: {e}"
|
| 611 |
+
)
|
| 612 |
+
# Non-fatal - just means slower startup next time
|
| 613 |
+
|
| 614 |
+
# =========================================================================
|
| 615 |
+
# CREDENTIAL MANAGEMENT OVERRIDES
|
| 616 |
+
# =========================================================================
|
| 617 |
+
|
| 618 |
+
def _get_provider_file_prefix(self) -> str:
|
| 619 |
+
"""Return the file prefix for Gemini CLI credentials."""
|
| 620 |
+
return "gemini_cli"
|
| 621 |
+
|
| 622 |
+
def build_env_lines(self, creds: Dict[str, Any], cred_number: int) -> List[str]:
|
| 623 |
+
"""
|
| 624 |
+
Generate .env file lines for a Gemini CLI credential.
|
| 625 |
+
|
| 626 |
+
Includes tier and project_id from _proxy_metadata.
|
| 627 |
+
"""
|
| 628 |
+
# Get base lines from parent class
|
| 629 |
+
lines = super().build_env_lines(creds, cred_number)
|
| 630 |
+
|
| 631 |
+
# Add Gemini-specific fields (tier and project_id)
|
| 632 |
+
metadata = creds.get("_proxy_metadata", {})
|
| 633 |
+
prefix = f"{self.ENV_PREFIX}_{cred_number}"
|
| 634 |
+
|
| 635 |
+
project_id = metadata.get("project_id", "")
|
| 636 |
+
tier = metadata.get("tier", "")
|
| 637 |
+
|
| 638 |
+
if project_id:
|
| 639 |
+
lines.append(f"{prefix}_PROJECT_ID={project_id}")
|
| 640 |
+
if tier:
|
| 641 |
+
lines.append(f"{prefix}_TIER={tier}")
|
| 642 |
+
|
| 643 |
+
return lines
|
src/rotator_library/providers/gemini_cli_provider.py
CHANGED
|
@@ -11,6 +11,8 @@ from .provider_interface import ProviderInterface
|
|
| 11 |
from .gemini_auth_base import GeminiAuthBase
|
| 12 |
from .provider_cache import ProviderCache
|
| 13 |
from ..model_definitions import ModelDefinitions
|
|
|
|
|
|
|
| 14 |
import litellm
|
| 15 |
from litellm.exceptions import RateLimitError
|
| 16 |
from ..error_handler import extract_retry_after_from_body
|
|
@@ -21,8 +23,22 @@ from datetime import datetime
|
|
| 21 |
|
| 22 |
lib_logger = logging.getLogger("rotator_library")
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
class _GeminiCliFileLogger:
|
|
@@ -38,7 +54,7 @@ class _GeminiCliFileLogger:
|
|
| 38 |
# Sanitize model name for directory
|
| 39 |
safe_model_name = model_name.replace("/", "_").replace(":", "_")
|
| 40 |
self.log_dir = (
|
| 41 |
-
|
| 42 |
)
|
| 43 |
try:
|
| 44 |
self.log_dir.mkdir(parents=True, exist_ok=True)
|
|
@@ -102,12 +118,6 @@ HARDCODED_MODELS = [
|
|
| 102 |
"gemini-3-pro-preview",
|
| 103 |
]
|
| 104 |
|
| 105 |
-
# Cache directory for Gemini CLI
|
| 106 |
-
CACHE_DIR = (
|
| 107 |
-
Path(__file__).resolve().parent.parent.parent.parent / "cache" / "gemini_cli"
|
| 108 |
-
)
|
| 109 |
-
GEMINI3_SIGNATURE_CACHE_FILE = CACHE_DIR / "gemini3_signatures.json"
|
| 110 |
-
|
| 111 |
# Gemini 3 tool fix system instruction (prevents hallucination)
|
| 112 |
DEFAULT_GEMINI3_SYSTEM_INSTRUCTION = """<CRITICAL_TOOL_USAGE_INSTRUCTIONS>
|
| 113 |
You are operating in a CUSTOM ENVIRONMENT where tool definitions COMPLETELY DIFFER from your training data.
|
|
@@ -173,6 +183,98 @@ FINISH_REASON_MAP = {
|
|
| 173 |
}
|
| 174 |
|
| 175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
def _env_bool(key: str, default: bool = False) -> bool:
|
| 177 |
"""Get boolean from environment variable."""
|
| 178 |
return os.getenv(key, str(default).lower()).lower() in ("true", "1", "yes")
|
|
@@ -186,8 +288,8 @@ def _env_int(key: str, default: int) -> int:
|
|
| 186 |
class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
|
| 187 |
skip_cost_calculation = True
|
| 188 |
|
| 189 |
-
#
|
| 190 |
-
default_rotation_mode: str = "
|
| 191 |
|
| 192 |
# =========================================================================
|
| 193 |
# TIER CONFIGURATION
|
|
@@ -234,32 +336,156 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
|
|
| 234 |
error: Exception, error_body: Optional[str] = None
|
| 235 |
) -> Optional[Dict[str, Any]]:
|
| 236 |
"""
|
| 237 |
-
Parse Gemini CLI quota errors.
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
Args:
|
| 243 |
error: The caught exception
|
| 244 |
error_body: Optional raw response body string
|
| 245 |
|
| 246 |
Returns:
|
| 247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
"""
|
| 249 |
-
|
| 250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
def __init__(self):
|
| 255 |
super().__init__()
|
| 256 |
self.model_definitions = ModelDefinitions()
|
| 257 |
-
|
| 258 |
-
str, str
|
| 259 |
-
] = {} # Cache project ID per credential path
|
| 260 |
-
self.project_tier_cache: Dict[
|
| 261 |
-
str, str
|
| 262 |
-
] = {} # Cache project tier per credential path
|
| 263 |
|
| 264 |
# Gemini 3 configuration from environment
|
| 265 |
memory_ttl = _env_int("GEMINI_CLI_SIGNATURE_CACHE_TTL", 3600)
|
|
@@ -267,7 +493,7 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
|
|
| 267 |
|
| 268 |
# Initialize signature cache for Gemini 3 thoughtSignatures
|
| 269 |
self._signature_cache = ProviderCache(
|
| 270 |
-
|
| 271 |
memory_ttl,
|
| 272 |
disk_ttl,
|
| 273 |
env_prefix="GEMINI_CLI_SIGNATURE",
|
|
@@ -381,7 +607,7 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
|
|
| 381 |
|
| 382 |
# Gemini 3 requires paid tier
|
| 383 |
if model_name.startswith("gemini-3-"):
|
| 384 |
-
return
|
| 385 |
|
| 386 |
return None # All other models have no restrictions
|
| 387 |
|
|
@@ -391,9 +617,48 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
|
|
| 391 |
|
| 392 |
This ensures all credential priorities are known before any API calls,
|
| 393 |
preventing unknown credentials from getting priority 999.
|
|
|
|
|
|
|
|
|
|
| 394 |
"""
|
|
|
|
| 395 |
await self._load_persisted_tiers(credential_paths)
|
| 396 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
async def _load_persisted_tiers(
|
| 398 |
self, credential_paths: List[str]
|
| 399 |
) -> Dict[str, str]:
|
|
@@ -451,6 +716,8 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
|
|
| 451 |
|
| 452 |
return loaded
|
| 453 |
|
|
|
|
|
|
|
| 454 |
# =========================================================================
|
| 455 |
# MODEL UTILITIES
|
| 456 |
# =========================================================================
|
|
@@ -466,520 +733,7 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
|
|
| 466 |
return name[len(self._gemini3_tool_prefix) :]
|
| 467 |
return name
|
| 468 |
|
| 469 |
-
|
| 470 |
-
self, credential_path: str, access_token: str, litellm_params: Dict[str, Any]
|
| 471 |
-
) -> str:
|
| 472 |
-
"""
|
| 473 |
-
Discovers the Google Cloud Project ID, with caching and onboarding for new accounts.
|
| 474 |
-
|
| 475 |
-
This follows the official Gemini CLI discovery flow:
|
| 476 |
-
1. Check in-memory cache
|
| 477 |
-
2. Check configured project_id override (litellm_params or env var)
|
| 478 |
-
3. Check persisted project_id in credential file
|
| 479 |
-
4. Call loadCodeAssist to check if user is already known (has currentTier)
|
| 480 |
-
- If currentTier exists AND cloudaicompanionProject returned: use server's project
|
| 481 |
-
- If currentTier exists but NO cloudaicompanionProject: use configured project_id (paid tier requires this)
|
| 482 |
-
- If no currentTier: user needs onboarding
|
| 483 |
-
5. Onboard user based on tier:
|
| 484 |
-
- FREE tier: pass cloudaicompanionProject=None (server-managed)
|
| 485 |
-
- PAID tier: pass cloudaicompanionProject=configured_project_id
|
| 486 |
-
6. Fallback to GCP Resource Manager project listing
|
| 487 |
-
"""
|
| 488 |
-
lib_logger.debug(
|
| 489 |
-
f"Starting project discovery for credential: {credential_path}"
|
| 490 |
-
)
|
| 491 |
-
|
| 492 |
-
# Check in-memory cache first
|
| 493 |
-
if credential_path in self.project_id_cache:
|
| 494 |
-
cached_project = self.project_id_cache[credential_path]
|
| 495 |
-
lib_logger.debug(f"Using cached project ID: {cached_project}")
|
| 496 |
-
return cached_project
|
| 497 |
-
|
| 498 |
-
# Check for configured project ID override (from litellm_params or env var)
|
| 499 |
-
# This is REQUIRED for paid tier users per the official CLI behavior
|
| 500 |
-
configured_project_id = litellm_params.get("project_id")
|
| 501 |
-
if configured_project_id:
|
| 502 |
-
lib_logger.debug(
|
| 503 |
-
f"Found configured project_id override: {configured_project_id}"
|
| 504 |
-
)
|
| 505 |
-
|
| 506 |
-
# Load credentials from file to check for persisted project_id and tier
|
| 507 |
-
# Skip for env:// paths (environment-based credentials don't persist to files)
|
| 508 |
-
credential_index = self._parse_env_credential_path(credential_path)
|
| 509 |
-
if credential_index is None:
|
| 510 |
-
# Only try to load from file if it's not an env:// path
|
| 511 |
-
try:
|
| 512 |
-
with open(credential_path, "r") as f:
|
| 513 |
-
creds = json.load(f)
|
| 514 |
-
|
| 515 |
-
metadata = creds.get("_proxy_metadata", {})
|
| 516 |
-
persisted_project_id = metadata.get("project_id")
|
| 517 |
-
persisted_tier = metadata.get("tier")
|
| 518 |
-
|
| 519 |
-
if persisted_project_id:
|
| 520 |
-
lib_logger.info(
|
| 521 |
-
f"Loaded persisted project ID from credential file: {persisted_project_id}"
|
| 522 |
-
)
|
| 523 |
-
self.project_id_cache[credential_path] = persisted_project_id
|
| 524 |
-
|
| 525 |
-
# Also load tier if available
|
| 526 |
-
if persisted_tier:
|
| 527 |
-
self.project_tier_cache[credential_path] = persisted_tier
|
| 528 |
-
lib_logger.debug(f"Loaded persisted tier: {persisted_tier}")
|
| 529 |
-
|
| 530 |
-
return persisted_project_id
|
| 531 |
-
except (FileNotFoundError, json.JSONDecodeError, KeyError) as e:
|
| 532 |
-
lib_logger.debug(f"Could not load persisted project ID from file: {e}")
|
| 533 |
-
|
| 534 |
-
lib_logger.debug(
|
| 535 |
-
"No cached or configured project ID found, initiating discovery..."
|
| 536 |
-
)
|
| 537 |
-
headers = {
|
| 538 |
-
"Authorization": f"Bearer {access_token}",
|
| 539 |
-
"Content-Type": "application/json",
|
| 540 |
-
}
|
| 541 |
-
|
| 542 |
-
discovered_project_id = None
|
| 543 |
-
discovered_tier = None
|
| 544 |
-
|
| 545 |
-
async with httpx.AsyncClient() as client:
|
| 546 |
-
# 1. Try discovery endpoint with loadCodeAssist
|
| 547 |
-
lib_logger.debug(
|
| 548 |
-
"Attempting project discovery via Code Assist loadCodeAssist endpoint..."
|
| 549 |
-
)
|
| 550 |
-
try:
|
| 551 |
-
# Build metadata - include duetProject only if we have a configured project
|
| 552 |
-
core_client_metadata = {
|
| 553 |
-
"ideType": "IDE_UNSPECIFIED",
|
| 554 |
-
"platform": "PLATFORM_UNSPECIFIED",
|
| 555 |
-
"pluginType": "GEMINI",
|
| 556 |
-
}
|
| 557 |
-
if configured_project_id:
|
| 558 |
-
core_client_metadata["duetProject"] = configured_project_id
|
| 559 |
-
|
| 560 |
-
# Build load request - pass configured_project_id if available, otherwise None
|
| 561 |
-
load_request = {
|
| 562 |
-
"cloudaicompanionProject": configured_project_id, # Can be None
|
| 563 |
-
"metadata": core_client_metadata,
|
| 564 |
-
}
|
| 565 |
-
|
| 566 |
-
lib_logger.debug(
|
| 567 |
-
f"Sending loadCodeAssist request with cloudaicompanionProject={configured_project_id}"
|
| 568 |
-
)
|
| 569 |
-
response = await client.post(
|
| 570 |
-
f"{CODE_ASSIST_ENDPOINT}:loadCodeAssist",
|
| 571 |
-
headers=headers,
|
| 572 |
-
json=load_request,
|
| 573 |
-
timeout=20,
|
| 574 |
-
)
|
| 575 |
-
response.raise_for_status()
|
| 576 |
-
data = response.json()
|
| 577 |
-
|
| 578 |
-
# Log full response for debugging
|
| 579 |
-
lib_logger.debug(
|
| 580 |
-
f"loadCodeAssist full response keys: {list(data.keys())}"
|
| 581 |
-
)
|
| 582 |
-
|
| 583 |
-
# Extract and log ALL tier information for debugging
|
| 584 |
-
allowed_tiers = data.get("allowedTiers", [])
|
| 585 |
-
current_tier = data.get("currentTier")
|
| 586 |
-
|
| 587 |
-
lib_logger.debug(f"=== Tier Information ===")
|
| 588 |
-
lib_logger.debug(f"currentTier: {current_tier}")
|
| 589 |
-
lib_logger.debug(f"allowedTiers count: {len(allowed_tiers)}")
|
| 590 |
-
for i, tier in enumerate(allowed_tiers):
|
| 591 |
-
tier_id = tier.get("id", "unknown")
|
| 592 |
-
is_default = tier.get("isDefault", False)
|
| 593 |
-
user_defined = tier.get("userDefinedCloudaicompanionProject", False)
|
| 594 |
-
lib_logger.debug(
|
| 595 |
-
f" Tier {i + 1}: id={tier_id}, isDefault={is_default}, userDefinedProject={user_defined}"
|
| 596 |
-
)
|
| 597 |
-
lib_logger.debug(f"========================")
|
| 598 |
-
|
| 599 |
-
# Determine the current tier ID
|
| 600 |
-
current_tier_id = None
|
| 601 |
-
if current_tier:
|
| 602 |
-
current_tier_id = current_tier.get("id")
|
| 603 |
-
lib_logger.debug(f"User has currentTier: {current_tier_id}")
|
| 604 |
-
|
| 605 |
-
# Check if user is already known to server (has currentTier)
|
| 606 |
-
if current_tier_id:
|
| 607 |
-
# User is already onboarded - check for project from server
|
| 608 |
-
server_project = data.get("cloudaicompanionProject")
|
| 609 |
-
|
| 610 |
-
# Check if this tier requires user-defined project (paid tiers)
|
| 611 |
-
requires_user_project = any(
|
| 612 |
-
t.get("id") == current_tier_id
|
| 613 |
-
and t.get("userDefinedCloudaicompanionProject", False)
|
| 614 |
-
for t in allowed_tiers
|
| 615 |
-
)
|
| 616 |
-
is_free_tier = current_tier_id == "free-tier"
|
| 617 |
-
|
| 618 |
-
if server_project:
|
| 619 |
-
# Server returned a project - use it (server wins)
|
| 620 |
-
# This is the normal case for FREE tier users
|
| 621 |
-
project_id = server_project
|
| 622 |
-
lib_logger.debug(f"Server returned project: {project_id}")
|
| 623 |
-
elif configured_project_id:
|
| 624 |
-
# No server project but we have configured one - use it
|
| 625 |
-
# This is the PAID TIER case where server doesn't return a project
|
| 626 |
-
project_id = configured_project_id
|
| 627 |
-
lib_logger.debug(
|
| 628 |
-
f"No server project, using configured: {project_id}"
|
| 629 |
-
)
|
| 630 |
-
elif is_free_tier:
|
| 631 |
-
# Free tier user without server project - this shouldn't happen normally
|
| 632 |
-
# but let's not fail, just proceed to onboarding
|
| 633 |
-
lib_logger.debug(
|
| 634 |
-
"Free tier user with currentTier but no project - will try onboarding"
|
| 635 |
-
)
|
| 636 |
-
project_id = None
|
| 637 |
-
elif requires_user_project:
|
| 638 |
-
# Paid tier requires a project ID to be set
|
| 639 |
-
raise ValueError(
|
| 640 |
-
f"Paid tier '{current_tier_id}' requires setting GEMINI_CLI_PROJECT_ID environment variable. "
|
| 641 |
-
"See https://goo.gle/gemini-cli-auth-docs#workspace-gca"
|
| 642 |
-
)
|
| 643 |
-
else:
|
| 644 |
-
# Unknown tier without project - proceed carefully
|
| 645 |
-
lib_logger.warning(
|
| 646 |
-
f"Tier '{current_tier_id}' has no project and none configured - will try onboarding"
|
| 647 |
-
)
|
| 648 |
-
project_id = None
|
| 649 |
-
|
| 650 |
-
if project_id:
|
| 651 |
-
# Cache tier info
|
| 652 |
-
self.project_tier_cache[credential_path] = current_tier_id
|
| 653 |
-
discovered_tier = current_tier_id
|
| 654 |
-
|
| 655 |
-
# Log appropriately based on tier
|
| 656 |
-
is_paid = current_tier_id and current_tier_id not in [
|
| 657 |
-
"free-tier",
|
| 658 |
-
"legacy-tier",
|
| 659 |
-
"unknown",
|
| 660 |
-
]
|
| 661 |
-
if is_paid:
|
| 662 |
-
lib_logger.info(
|
| 663 |
-
f"Using Gemini paid tier '{current_tier_id}' with project: {project_id}"
|
| 664 |
-
)
|
| 665 |
-
else:
|
| 666 |
-
lib_logger.info(
|
| 667 |
-
f"Discovered Gemini project ID via loadCodeAssist: {project_id}"
|
| 668 |
-
)
|
| 669 |
-
|
| 670 |
-
self.project_id_cache[credential_path] = project_id
|
| 671 |
-
discovered_project_id = project_id
|
| 672 |
-
|
| 673 |
-
# Persist to credential file
|
| 674 |
-
await self._persist_project_metadata(
|
| 675 |
-
credential_path, project_id, discovered_tier
|
| 676 |
-
)
|
| 677 |
-
|
| 678 |
-
return project_id
|
| 679 |
-
|
| 680 |
-
# 2. User needs onboarding - no currentTier
|
| 681 |
-
lib_logger.info(
|
| 682 |
-
"No existing Gemini session found (no currentTier), attempting to onboard user..."
|
| 683 |
-
)
|
| 684 |
-
|
| 685 |
-
# Determine which tier to onboard with
|
| 686 |
-
onboard_tier = None
|
| 687 |
-
for tier in allowed_tiers:
|
| 688 |
-
if tier.get("isDefault"):
|
| 689 |
-
onboard_tier = tier
|
| 690 |
-
break
|
| 691 |
-
|
| 692 |
-
# Fallback to LEGACY tier if no default (requires user project)
|
| 693 |
-
if not onboard_tier and allowed_tiers:
|
| 694 |
-
# Look for legacy-tier as fallback
|
| 695 |
-
for tier in allowed_tiers:
|
| 696 |
-
if tier.get("id") == "legacy-tier":
|
| 697 |
-
onboard_tier = tier
|
| 698 |
-
break
|
| 699 |
-
# If still no tier, use first available
|
| 700 |
-
if not onboard_tier:
|
| 701 |
-
onboard_tier = allowed_tiers[0]
|
| 702 |
-
|
| 703 |
-
if not onboard_tier:
|
| 704 |
-
raise ValueError("No onboarding tiers available from server")
|
| 705 |
-
|
| 706 |
-
tier_id = onboard_tier.get("id", "free-tier")
|
| 707 |
-
requires_user_project = onboard_tier.get(
|
| 708 |
-
"userDefinedCloudaicompanionProject", False
|
| 709 |
-
)
|
| 710 |
-
|
| 711 |
-
lib_logger.debug(
|
| 712 |
-
f"Onboarding with tier: {tier_id}, requiresUserProject: {requires_user_project}"
|
| 713 |
-
)
|
| 714 |
-
|
| 715 |
-
# Build onboard request based on tier type (following official CLI logic)
|
| 716 |
-
# FREE tier: cloudaicompanionProject = None (server-managed)
|
| 717 |
-
# PAID tier: cloudaicompanionProject = configured_project_id (user must provide)
|
| 718 |
-
is_free_tier = tier_id == "free-tier"
|
| 719 |
-
|
| 720 |
-
if is_free_tier:
|
| 721 |
-
# Free tier uses server-managed project
|
| 722 |
-
onboard_request = {
|
| 723 |
-
"tierId": tier_id,
|
| 724 |
-
"cloudaicompanionProject": None, # Server will create/manage
|
| 725 |
-
"metadata": core_client_metadata,
|
| 726 |
-
}
|
| 727 |
-
lib_logger.debug(
|
| 728 |
-
"Free tier onboarding: using server-managed project"
|
| 729 |
-
)
|
| 730 |
-
else:
|
| 731 |
-
# Paid/legacy tier requires user-provided project
|
| 732 |
-
if not configured_project_id and requires_user_project:
|
| 733 |
-
raise ValueError(
|
| 734 |
-
f"Tier '{tier_id}' requires setting GEMINI_CLI_PROJECT_ID environment variable. "
|
| 735 |
-
"See https://goo.gle/gemini-cli-auth-docs#workspace-gca"
|
| 736 |
-
)
|
| 737 |
-
onboard_request = {
|
| 738 |
-
"tierId": tier_id,
|
| 739 |
-
"cloudaicompanionProject": configured_project_id,
|
| 740 |
-
"metadata": {
|
| 741 |
-
**core_client_metadata,
|
| 742 |
-
"duetProject": configured_project_id,
|
| 743 |
-
}
|
| 744 |
-
if configured_project_id
|
| 745 |
-
else core_client_metadata,
|
| 746 |
-
}
|
| 747 |
-
lib_logger.debug(
|
| 748 |
-
f"Paid tier onboarding: using project {configured_project_id}"
|
| 749 |
-
)
|
| 750 |
-
|
| 751 |
-
lib_logger.debug("Initiating onboardUser request...")
|
| 752 |
-
lro_response = await client.post(
|
| 753 |
-
f"{CODE_ASSIST_ENDPOINT}:onboardUser",
|
| 754 |
-
headers=headers,
|
| 755 |
-
json=onboard_request,
|
| 756 |
-
timeout=30,
|
| 757 |
-
)
|
| 758 |
-
lro_response.raise_for_status()
|
| 759 |
-
lro_data = lro_response.json()
|
| 760 |
-
lib_logger.debug(
|
| 761 |
-
f"Initial onboarding response: done={lro_data.get('done')}"
|
| 762 |
-
)
|
| 763 |
-
|
| 764 |
-
for i in range(150): # Poll for up to 5 minutes (150 × 2s)
|
| 765 |
-
if lro_data.get("done"):
|
| 766 |
-
lib_logger.debug(
|
| 767 |
-
f"Onboarding completed after {i} polling attempts"
|
| 768 |
-
)
|
| 769 |
-
break
|
| 770 |
-
await asyncio.sleep(2)
|
| 771 |
-
if (i + 1) % 15 == 0: # Log every 30 seconds
|
| 772 |
-
lib_logger.info(
|
| 773 |
-
f"Still waiting for onboarding completion... ({(i + 1) * 2}s elapsed)"
|
| 774 |
-
)
|
| 775 |
-
lib_logger.debug(
|
| 776 |
-
f"Polling onboarding status... (Attempt {i + 1}/150)"
|
| 777 |
-
)
|
| 778 |
-
lro_response = await client.post(
|
| 779 |
-
f"{CODE_ASSIST_ENDPOINT}:onboardUser",
|
| 780 |
-
headers=headers,
|
| 781 |
-
json=onboard_request,
|
| 782 |
-
timeout=30,
|
| 783 |
-
)
|
| 784 |
-
lro_response.raise_for_status()
|
| 785 |
-
lro_data = lro_response.json()
|
| 786 |
-
|
| 787 |
-
if not lro_data.get("done"):
|
| 788 |
-
lib_logger.error("Onboarding process timed out after 5 minutes")
|
| 789 |
-
raise ValueError(
|
| 790 |
-
"Onboarding process timed out after 5 minutes. Please try again or contact support."
|
| 791 |
-
)
|
| 792 |
-
|
| 793 |
-
# Extract project ID from LRO response
|
| 794 |
-
# Note: onboardUser returns response.cloudaicompanionProject as an object with .id
|
| 795 |
-
lro_response_data = lro_data.get("response", {})
|
| 796 |
-
lro_project_obj = lro_response_data.get("cloudaicompanionProject", {})
|
| 797 |
-
project_id = (
|
| 798 |
-
lro_project_obj.get("id")
|
| 799 |
-
if isinstance(lro_project_obj, dict)
|
| 800 |
-
else None
|
| 801 |
-
)
|
| 802 |
-
|
| 803 |
-
# Fallback to configured project if LRO didn't return one
|
| 804 |
-
if not project_id and configured_project_id:
|
| 805 |
-
project_id = configured_project_id
|
| 806 |
-
lib_logger.debug(
|
| 807 |
-
f"LRO didn't return project, using configured: {project_id}"
|
| 808 |
-
)
|
| 809 |
-
|
| 810 |
-
if not project_id:
|
| 811 |
-
lib_logger.error(
|
| 812 |
-
"Onboarding completed but no project ID in response and none configured"
|
| 813 |
-
)
|
| 814 |
-
raise ValueError(
|
| 815 |
-
"Onboarding completed, but no project ID was returned. "
|
| 816 |
-
"For paid tiers, set GEMINI_CLI_PROJECT_ID environment variable."
|
| 817 |
-
)
|
| 818 |
-
|
| 819 |
-
lib_logger.debug(
|
| 820 |
-
f"Successfully extracted project ID from onboarding response: {project_id}"
|
| 821 |
-
)
|
| 822 |
-
|
| 823 |
-
# Cache tier info
|
| 824 |
-
self.project_tier_cache[credential_path] = tier_id
|
| 825 |
-
discovered_tier = tier_id
|
| 826 |
-
lib_logger.debug(f"Cached tier information: {tier_id}")
|
| 827 |
-
|
| 828 |
-
# Log concise message for paid projects
|
| 829 |
-
is_paid = tier_id and tier_id not in ["free-tier", "legacy-tier"]
|
| 830 |
-
if is_paid:
|
| 831 |
-
lib_logger.info(
|
| 832 |
-
f"Using Gemini paid tier '{tier_id}' with project: {project_id}"
|
| 833 |
-
)
|
| 834 |
-
else:
|
| 835 |
-
lib_logger.info(
|
| 836 |
-
f"Successfully onboarded user and discovered project ID: {project_id}"
|
| 837 |
-
)
|
| 838 |
-
|
| 839 |
-
self.project_id_cache[credential_path] = project_id
|
| 840 |
-
discovered_project_id = project_id
|
| 841 |
-
|
| 842 |
-
# Persist to credential file
|
| 843 |
-
await self._persist_project_metadata(
|
| 844 |
-
credential_path, project_id, discovered_tier
|
| 845 |
-
)
|
| 846 |
-
|
| 847 |
-
return project_id
|
| 848 |
-
|
| 849 |
-
except httpx.HTTPStatusError as e:
|
| 850 |
-
error_body = ""
|
| 851 |
-
try:
|
| 852 |
-
error_body = e.response.text
|
| 853 |
-
except Exception:
|
| 854 |
-
pass
|
| 855 |
-
if e.response.status_code == 403:
|
| 856 |
-
lib_logger.error(
|
| 857 |
-
f"Gemini Code Assist API access denied (403). Response: {error_body}"
|
| 858 |
-
)
|
| 859 |
-
lib_logger.error(
|
| 860 |
-
"Possible causes: 1) cloudaicompanion.googleapis.com API not enabled, 2) Wrong project ID for paid tier, 3) Account lacks permissions"
|
| 861 |
-
)
|
| 862 |
-
elif e.response.status_code == 404:
|
| 863 |
-
lib_logger.warning(
|
| 864 |
-
f"Gemini Code Assist endpoint not found (404). Falling back to project listing."
|
| 865 |
-
)
|
| 866 |
-
elif e.response.status_code == 412:
|
| 867 |
-
# Precondition Failed - often means wrong project for free tier onboarding
|
| 868 |
-
lib_logger.error(
|
| 869 |
-
f"Precondition failed (412): {error_body}. This may mean the project ID is incompatible with the selected tier."
|
| 870 |
-
)
|
| 871 |
-
else:
|
| 872 |
-
lib_logger.warning(
|
| 873 |
-
f"Gemini onboarding/discovery failed with status {e.response.status_code}: {error_body}. Falling back to project listing."
|
| 874 |
-
)
|
| 875 |
-
except httpx.RequestError as e:
|
| 876 |
-
lib_logger.warning(
|
| 877 |
-
f"Gemini onboarding/discovery network error: {e}. Falling back to project listing."
|
| 878 |
-
)
|
| 879 |
-
|
| 880 |
-
# 3. Fallback to listing all available GCP projects (last resort)
|
| 881 |
-
lib_logger.debug(
|
| 882 |
-
"Attempting to discover project via GCP Resource Manager API..."
|
| 883 |
-
)
|
| 884 |
-
try:
|
| 885 |
-
async with httpx.AsyncClient() as client:
|
| 886 |
-
lib_logger.debug(
|
| 887 |
-
"Querying Cloud Resource Manager for available projects..."
|
| 888 |
-
)
|
| 889 |
-
response = await client.get(
|
| 890 |
-
"https://cloudresourcemanager.googleapis.com/v1/projects",
|
| 891 |
-
headers=headers,
|
| 892 |
-
timeout=20,
|
| 893 |
-
)
|
| 894 |
-
response.raise_for_status()
|
| 895 |
-
projects = response.json().get("projects", [])
|
| 896 |
-
lib_logger.debug(f"Found {len(projects)} total projects")
|
| 897 |
-
active_projects = [
|
| 898 |
-
p for p in projects if p.get("lifecycleState") == "ACTIVE"
|
| 899 |
-
]
|
| 900 |
-
lib_logger.debug(f"Found {len(active_projects)} active projects")
|
| 901 |
-
|
| 902 |
-
if not projects:
|
| 903 |
-
lib_logger.error(
|
| 904 |
-
"No GCP projects found for this account. Please create a project in Google Cloud Console."
|
| 905 |
-
)
|
| 906 |
-
elif not active_projects:
|
| 907 |
-
lib_logger.error(
|
| 908 |
-
"No active GCP projects found. Please activate a project in Google Cloud Console."
|
| 909 |
-
)
|
| 910 |
-
else:
|
| 911 |
-
project_id = active_projects[0]["projectId"]
|
| 912 |
-
lib_logger.info(
|
| 913 |
-
f"Discovered Gemini project ID from active projects list: {project_id}"
|
| 914 |
-
)
|
| 915 |
-
lib_logger.debug(
|
| 916 |
-
f"Selected first active project: {project_id} (out of {len(active_projects)} active projects)"
|
| 917 |
-
)
|
| 918 |
-
self.project_id_cache[credential_path] = project_id
|
| 919 |
-
discovered_project_id = project_id
|
| 920 |
-
|
| 921 |
-
# [NEW] Persist to credential file (no tier info from resource manager)
|
| 922 |
-
await self._persist_project_metadata(
|
| 923 |
-
credential_path, project_id, None
|
| 924 |
-
)
|
| 925 |
-
|
| 926 |
-
return project_id
|
| 927 |
-
except httpx.HTTPStatusError as e:
|
| 928 |
-
if e.response.status_code == 403:
|
| 929 |
-
lib_logger.error(
|
| 930 |
-
"Failed to list GCP projects due to a 403 Forbidden error. The Cloud Resource Manager API may not be enabled, or your account lacks the 'resourcemanager.projects.list' permission."
|
| 931 |
-
)
|
| 932 |
-
else:
|
| 933 |
-
lib_logger.error(
|
| 934 |
-
f"Failed to list GCP projects with status {e.response.status_code}: {e}"
|
| 935 |
-
)
|
| 936 |
-
except httpx.RequestError as e:
|
| 937 |
-
lib_logger.error(f"Network error while listing GCP projects: {e}")
|
| 938 |
-
|
| 939 |
-
raise ValueError(
|
| 940 |
-
"Could not auto-discover Gemini project ID. Possible causes:\n"
|
| 941 |
-
" 1. The cloudaicompanion.googleapis.com API is not enabled (enable it in Google Cloud Console)\n"
|
| 942 |
-
" 2. No active GCP projects exist for this account (create one in Google Cloud Console)\n"
|
| 943 |
-
" 3. Account lacks necessary permissions\n"
|
| 944 |
-
"To manually specify a project, set GEMINI_CLI_PROJECT_ID in your .env file."
|
| 945 |
-
)
|
| 946 |
-
|
| 947 |
-
async def _persist_project_metadata(
|
| 948 |
-
self, credential_path: str, project_id: str, tier: Optional[str]
|
| 949 |
-
):
|
| 950 |
-
"""Persists project ID and tier to the credential file for faster future startups."""
|
| 951 |
-
# Skip persistence for env:// paths (environment-based credentials)
|
| 952 |
-
credential_index = self._parse_env_credential_path(credential_path)
|
| 953 |
-
if credential_index is not None:
|
| 954 |
-
lib_logger.debug(
|
| 955 |
-
f"Skipping project metadata persistence for env:// credential path: {credential_path}"
|
| 956 |
-
)
|
| 957 |
-
return
|
| 958 |
-
|
| 959 |
-
try:
|
| 960 |
-
# Load current credentials
|
| 961 |
-
with open(credential_path, "r") as f:
|
| 962 |
-
creds = json.load(f)
|
| 963 |
-
|
| 964 |
-
# Update metadata
|
| 965 |
-
if "_proxy_metadata" not in creds:
|
| 966 |
-
creds["_proxy_metadata"] = {}
|
| 967 |
-
|
| 968 |
-
creds["_proxy_metadata"]["project_id"] = project_id
|
| 969 |
-
if tier:
|
| 970 |
-
creds["_proxy_metadata"]["tier"] = tier
|
| 971 |
-
|
| 972 |
-
# Save back using the existing save method (handles atomic writes and permissions)
|
| 973 |
-
await self._save_credentials(credential_path, creds)
|
| 974 |
-
|
| 975 |
-
lib_logger.debug(
|
| 976 |
-
f"Persisted project_id and tier to credential file: {credential_path}"
|
| 977 |
-
)
|
| 978 |
-
except Exception as e:
|
| 979 |
-
lib_logger.warning(
|
| 980 |
-
f"Failed to persist project metadata to credential file: {e}"
|
| 981 |
-
)
|
| 982 |
-
# Non-fatal - just means slower startup next time
|
| 983 |
|
| 984 |
def _check_mixed_tier_warning(self):
|
| 985 |
"""Check if mixed free/paid tier credentials are loaded and emit warning."""
|
|
@@ -1166,7 +920,7 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
|
|
| 1166 |
func_part["thoughtSignature"] = (
|
| 1167 |
"skip_thought_signature_validator"
|
| 1168 |
)
|
| 1169 |
-
lib_logger.
|
| 1170 |
f"Missing thoughtSignature for first func call {tool_id}, using bypass"
|
| 1171 |
)
|
| 1172 |
# Subsequent parallel calls: no signature field at all
|
|
@@ -1178,23 +932,39 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
|
|
| 1178 |
elif role == "tool":
|
| 1179 |
tool_call_id = msg.get("tool_call_id")
|
| 1180 |
function_name = tool_call_id_to_name.get(tool_call_id)
|
| 1181 |
-
|
| 1182 |
-
|
| 1183 |
-
|
| 1184 |
-
|
| 1185 |
-
|
| 1186 |
-
|
| 1187 |
-
|
| 1188 |
-
|
| 1189 |
-
|
| 1190 |
-
|
| 1191 |
-
|
| 1192 |
-
|
| 1193 |
-
|
| 1194 |
-
|
| 1195 |
-
|
| 1196 |
-
|
|
|
|
| 1197 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1198 |
# Don't add parts here - tool responses are handled via pending_tool_parts
|
| 1199 |
continue
|
| 1200 |
|
|
@@ -1210,6 +980,216 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
|
|
| 1210 |
|
| 1211 |
return system_instruction, gemini_contents
|
| 1212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1213 |
def _handle_reasoning_parameters(
|
| 1214 |
self, payload: Dict[str, Any], model: str
|
| 1215 |
) -> Optional[Dict[str, Any]]:
|
|
@@ -1329,13 +1309,24 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
|
|
| 1329 |
# Get current tool index from accumulator (default 0) and increment
|
| 1330 |
current_tool_idx = accumulator.get("tool_idx", 0) if accumulator else 0
|
| 1331 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1332 |
tool_call = {
|
| 1333 |
"index": current_tool_idx,
|
| 1334 |
"id": tool_call_id,
|
| 1335 |
"type": "function",
|
| 1336 |
"function": {
|
| 1337 |
"name": function_name,
|
| 1338 |
-
"arguments": json.dumps(
|
| 1339 |
},
|
| 1340 |
}
|
| 1341 |
|
|
@@ -1643,13 +1634,32 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
|
|
| 1643 |
schema = self._gemini_cli_transform_schema(
|
| 1644 |
new_function["parameters"]
|
| 1645 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1646 |
new_function["parametersJsonSchema"] = schema
|
| 1647 |
del new_function["parameters"]
|
| 1648 |
elif "parametersJsonSchema" not in new_function:
|
| 1649 |
-
# Set default
|
| 1650 |
new_function["parametersJsonSchema"] = {
|
| 1651 |
"type": "object",
|
| 1652 |
-
"properties": {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1653 |
}
|
| 1654 |
|
| 1655 |
# Gemini 3 specific transformations
|
|
@@ -1889,6 +1899,9 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
|
|
| 1889 |
system_instruction, contents = self._transform_messages(
|
| 1890 |
kwargs.get("messages", []), model_name
|
| 1891 |
)
|
|
|
|
|
|
|
|
|
|
| 1892 |
request_payload = {
|
| 1893 |
"model": model_name,
|
| 1894 |
"project": project_id,
|
|
@@ -1965,7 +1978,7 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
|
|
| 1965 |
headers=final_headers,
|
| 1966 |
json=request_payload,
|
| 1967 |
params={"alt": "sse"},
|
| 1968 |
-
timeout=
|
| 1969 |
) as response:
|
| 1970 |
# Read and log error body before raise_for_status for better debugging
|
| 1971 |
if response.status_code >= 400:
|
|
@@ -2176,6 +2189,8 @@ class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
|
|
| 2176 |
|
| 2177 |
# Transform messages to Gemini format
|
| 2178 |
system_instruction, contents = self._transform_messages(messages)
|
|
|
|
|
|
|
| 2179 |
|
| 2180 |
# Build request payload
|
| 2181 |
request_payload = {
|
|
|
|
| 11 |
from .gemini_auth_base import GeminiAuthBase
|
| 12 |
from .provider_cache import ProviderCache
|
| 13 |
from ..model_definitions import ModelDefinitions
|
| 14 |
+
from ..timeout_config import TimeoutConfig
|
| 15 |
+
from ..utils.paths import get_logs_dir, get_cache_dir
|
| 16 |
import litellm
|
| 17 |
from litellm.exceptions import RateLimitError
|
| 18 |
from ..error_handler import extract_retry_after_from_body
|
|
|
|
| 23 |
|
| 24 |
lib_logger = logging.getLogger("rotator_library")
|
| 25 |
|
| 26 |
+
|
| 27 |
+
def _get_gemini_cli_logs_dir() -> Path:
|
| 28 |
+
"""Get the Gemini CLI logs directory."""
|
| 29 |
+
logs_dir = get_logs_dir() / "gemini_cli_logs"
|
| 30 |
+
logs_dir.mkdir(parents=True, exist_ok=True)
|
| 31 |
+
return logs_dir
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _get_gemini_cli_cache_dir() -> Path:
|
| 35 |
+
"""Get the Gemini CLI cache directory."""
|
| 36 |
+
return get_cache_dir(subdir="gemini_cli")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _get_gemini3_signature_cache_file() -> Path:
|
| 40 |
+
"""Get the Gemini 3 signature cache file path."""
|
| 41 |
+
return _get_gemini_cli_cache_dir() / "gemini3_signatures.json"
|
| 42 |
|
| 43 |
|
| 44 |
class _GeminiCliFileLogger:
|
|
|
|
| 54 |
# Sanitize model name for directory
|
| 55 |
safe_model_name = model_name.replace("/", "_").replace(":", "_")
|
| 56 |
self.log_dir = (
|
| 57 |
+
_get_gemini_cli_logs_dir() / f"{timestamp}_{safe_model_name}_{request_id}"
|
| 58 |
)
|
| 59 |
try:
|
| 60 |
self.log_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 118 |
"gemini-3-pro-preview",
|
| 119 |
]
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
# Gemini 3 tool fix system instruction (prevents hallucination)
|
| 122 |
DEFAULT_GEMINI3_SYSTEM_INSTRUCTION = """<CRITICAL_TOOL_USAGE_INSTRUCTIONS>
|
| 123 |
You are operating in a CUSTOM ENVIRONMENT where tool definitions COMPLETELY DIFFER from your training data.
|
|
|
|
| 183 |
}
|
| 184 |
|
| 185 |
|
| 186 |
+
def _recursively_parse_json_strings(obj: Any) -> Any:
|
| 187 |
+
"""
|
| 188 |
+
Recursively parse JSON strings in nested data structures.
|
| 189 |
+
|
| 190 |
+
Gemini sometimes returns tool arguments with JSON-stringified values:
|
| 191 |
+
{"files": "[{...}]"} instead of {"files": [{...}]}.
|
| 192 |
+
|
| 193 |
+
Additionally handles:
|
| 194 |
+
- Malformed double-encoded JSON (extra trailing '}' or ']')
|
| 195 |
+
- Escaped string content (\n, \t, etc.)
|
| 196 |
+
"""
|
| 197 |
+
if isinstance(obj, dict):
|
| 198 |
+
return {k: _recursively_parse_json_strings(v) for k, v in obj.items()}
|
| 199 |
+
elif isinstance(obj, list):
|
| 200 |
+
return [_recursively_parse_json_strings(item) for item in obj]
|
| 201 |
+
elif isinstance(obj, str):
|
| 202 |
+
stripped = obj.strip()
|
| 203 |
+
|
| 204 |
+
# Check if string contains control character escape sequences that need unescaping
|
| 205 |
+
# This handles cases where diff content has literal \n or \t instead of actual newlines/tabs
|
| 206 |
+
#
|
| 207 |
+
# IMPORTANT: We intentionally do NOT unescape strings containing \" or \\
|
| 208 |
+
# because these are typically intentional escapes in code/config content
|
| 209 |
+
# (e.g., JSON embedded in YAML: BOT_NAMES_JSON: '["mirrobot", ...]')
|
| 210 |
+
# Unescaping these would corrupt the content and cause issues like
|
| 211 |
+
# oldString and newString becoming identical when they should differ.
|
| 212 |
+
has_control_char_escapes = "\\n" in obj or "\\t" in obj
|
| 213 |
+
has_intentional_escapes = '\\"' in obj or "\\\\" in obj
|
| 214 |
+
|
| 215 |
+
if has_control_char_escapes and not has_intentional_escapes:
|
| 216 |
+
try:
|
| 217 |
+
# Use json.loads with quotes to properly unescape the string
|
| 218 |
+
# This converts \n -> newline, \t -> tab
|
| 219 |
+
unescaped = json.loads(f'"{obj}"')
|
| 220 |
+
# Log the fix with a snippet for debugging
|
| 221 |
+
snippet = obj[:80] + "..." if len(obj) > 80 else obj
|
| 222 |
+
lib_logger.debug(
|
| 223 |
+
f"[GeminiCli] Unescaped control chars in string: "
|
| 224 |
+
f"{len(obj) - len(unescaped)} chars changed. Snippet: {snippet!r}"
|
| 225 |
+
)
|
| 226 |
+
return unescaped
|
| 227 |
+
except (json.JSONDecodeError, ValueError):
|
| 228 |
+
# If unescaping fails, continue with original processing
|
| 229 |
+
pass
|
| 230 |
+
|
| 231 |
+
# Check if it looks like JSON (starts with { or [)
|
| 232 |
+
if stripped and stripped[0] in ("{", "["):
|
| 233 |
+
# Try standard parsing first
|
| 234 |
+
if (stripped.startswith("{") and stripped.endswith("}")) or (
|
| 235 |
+
stripped.startswith("[") and stripped.endswith("]")
|
| 236 |
+
):
|
| 237 |
+
try:
|
| 238 |
+
parsed = json.loads(obj)
|
| 239 |
+
return _recursively_parse_json_strings(parsed)
|
| 240 |
+
except (json.JSONDecodeError, ValueError):
|
| 241 |
+
pass
|
| 242 |
+
|
| 243 |
+
# Handle malformed JSON: array that doesn't end with ]
|
| 244 |
+
# e.g., '[{"path": "..."}]}' instead of '[{"path": "..."}]'
|
| 245 |
+
if stripped.startswith("[") and not stripped.endswith("]"):
|
| 246 |
+
try:
|
| 247 |
+
# Find the last ] and truncate there
|
| 248 |
+
last_bracket = stripped.rfind("]")
|
| 249 |
+
if last_bracket > 0:
|
| 250 |
+
cleaned = stripped[: last_bracket + 1]
|
| 251 |
+
parsed = json.loads(cleaned)
|
| 252 |
+
lib_logger.warning(
|
| 253 |
+
f"[GeminiCli] Auto-corrected malformed JSON string: "
|
| 254 |
+
f"truncated {len(stripped) - len(cleaned)} extra chars"
|
| 255 |
+
)
|
| 256 |
+
return _recursively_parse_json_strings(parsed)
|
| 257 |
+
except (json.JSONDecodeError, ValueError):
|
| 258 |
+
pass
|
| 259 |
+
|
| 260 |
+
# Handle malformed JSON: object that doesn't end with }
|
| 261 |
+
if stripped.startswith("{") and not stripped.endswith("}"):
|
| 262 |
+
try:
|
| 263 |
+
# Find the last } and truncate there
|
| 264 |
+
last_brace = stripped.rfind("}")
|
| 265 |
+
if last_brace > 0:
|
| 266 |
+
cleaned = stripped[: last_brace + 1]
|
| 267 |
+
parsed = json.loads(cleaned)
|
| 268 |
+
lib_logger.warning(
|
| 269 |
+
f"[GeminiCli] Auto-corrected malformed JSON string: "
|
| 270 |
+
f"truncated {len(stripped) - len(cleaned)} extra chars"
|
| 271 |
+
)
|
| 272 |
+
return _recursively_parse_json_strings(parsed)
|
| 273 |
+
except (json.JSONDecodeError, ValueError):
|
| 274 |
+
pass
|
| 275 |
+
return obj
|
| 276 |
+
|
| 277 |
+
|
| 278 |
def _env_bool(key: str, default: bool = False) -> bool:
|
| 279 |
"""Get boolean from environment variable."""
|
| 280 |
return os.getenv(key, str(default).lower()).lower() in ("true", "1", "yes")
|
|
|
|
| 288 |
class GeminiCliProvider(GeminiAuthBase, ProviderInterface):
|
| 289 |
skip_cost_calculation = True
|
| 290 |
|
| 291 |
+
# Sequential mode - stick with one credential until it gets a 429, then switch
|
| 292 |
+
default_rotation_mode: str = "sequential"
|
| 293 |
|
| 294 |
# =========================================================================
|
| 295 |
# TIER CONFIGURATION
|
|
|
|
| 336 |
error: Exception, error_body: Optional[str] = None
|
| 337 |
) -> Optional[Dict[str, Any]]:
|
| 338 |
"""
|
| 339 |
+
Parse Gemini CLI rate limit/quota errors.
|
| 340 |
+
|
| 341 |
+
Handles the Gemini CLI error format which embeds reset time in the message:
|
| 342 |
+
"You have exhausted your capacity on this model. Your quota will reset after 2s."
|
| 343 |
+
|
| 344 |
+
Unlike Antigravity which uses structured RetryInfo/quotaResetDelay metadata,
|
| 345 |
+
Gemini CLI embeds the reset time in a human-readable message.
|
| 346 |
+
|
| 347 |
+
Example error format:
|
| 348 |
+
{
|
| 349 |
+
"error": {
|
| 350 |
+
"code": 429,
|
| 351 |
+
"message": "You have exhausted your capacity on this model. Your quota will reset after 2s.",
|
| 352 |
+
"status": "RESOURCE_EXHAUSTED",
|
| 353 |
+
"details": [
|
| 354 |
+
{
|
| 355 |
+
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
|
| 356 |
+
"reason": "RATE_LIMIT_EXCEEDED",
|
| 357 |
+
"domain": "cloudcode-pa.googleapis.com",
|
| 358 |
+
"metadata": { "uiMessage": "true", "model": "gemini-3-pro-preview" }
|
| 359 |
+
}
|
| 360 |
+
]
|
| 361 |
+
}
|
| 362 |
+
}
|
| 363 |
|
| 364 |
Args:
|
| 365 |
error: The caught exception
|
| 366 |
error_body: Optional raw response body string
|
| 367 |
|
| 368 |
Returns:
|
| 369 |
+
None if not a parseable quota error, otherwise:
|
| 370 |
+
{
|
| 371 |
+
"retry_after": int,
|
| 372 |
+
"reason": str | None,
|
| 373 |
+
"reset_timestamp": str | None,
|
| 374 |
+
"quota_reset_timestamp": float | None,
|
| 375 |
+
}
|
| 376 |
"""
|
| 377 |
+
import re as regex_module
|
| 378 |
+
|
| 379 |
+
# Get error body from exception if not provided
|
| 380 |
+
body = error_body
|
| 381 |
+
if not body:
|
| 382 |
+
if hasattr(error, "response") and hasattr(error.response, "text"):
|
| 383 |
+
try:
|
| 384 |
+
body = error.response.text
|
| 385 |
+
except Exception:
|
| 386 |
+
pass
|
| 387 |
+
if not body and hasattr(error, "body"):
|
| 388 |
+
body = str(error.body)
|
| 389 |
+
if not body and hasattr(error, "message"):
|
| 390 |
+
body = str(error.message)
|
| 391 |
+
if not body:
|
| 392 |
+
body = str(error)
|
| 393 |
+
|
| 394 |
+
if not body:
|
| 395 |
+
return None
|
| 396 |
+
|
| 397 |
+
result = {
|
| 398 |
+
"retry_after": None,
|
| 399 |
+
"reason": None,
|
| 400 |
+
"reset_timestamp": None,
|
| 401 |
+
"quota_reset_timestamp": None,
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
# 1. Try to extract retry time from human-readable message
|
| 405 |
+
# Pattern: "Your quota will reset after 2s." or "quota will reset after 156h14m36s"
|
| 406 |
+
retry_after = extract_retry_after_from_body(body)
|
| 407 |
+
if retry_after:
|
| 408 |
+
result["retry_after"] = retry_after
|
| 409 |
+
|
| 410 |
+
# 2. Try to parse JSON to get structured details (reason, any RetryInfo fallback)
|
| 411 |
+
try:
|
| 412 |
+
json_match = regex_module.search(r"\{[\s\S]*\}", body)
|
| 413 |
+
if json_match:
|
| 414 |
+
data = json.loads(json_match.group(0))
|
| 415 |
+
error_obj = data.get("error", data)
|
| 416 |
+
details = error_obj.get("details", [])
|
| 417 |
+
|
| 418 |
+
for detail in details:
|
| 419 |
+
detail_type = detail.get("@type", "")
|
| 420 |
+
|
| 421 |
+
# Extract reason from ErrorInfo
|
| 422 |
+
if "ErrorInfo" in detail_type:
|
| 423 |
+
if not result["reason"]:
|
| 424 |
+
result["reason"] = detail.get("reason")
|
| 425 |
+
# Check metadata for any additional timing info
|
| 426 |
+
metadata = detail.get("metadata", {})
|
| 427 |
+
quota_delay = metadata.get("quotaResetDelay")
|
| 428 |
+
if quota_delay and not result["retry_after"]:
|
| 429 |
+
parsed = GeminiCliProvider._parse_duration(quota_delay)
|
| 430 |
+
if parsed:
|
| 431 |
+
result["retry_after"] = parsed
|
| 432 |
+
|
| 433 |
+
# Check for RetryInfo (fallback, in case format changes)
|
| 434 |
+
if "RetryInfo" in detail_type and not result["retry_after"]:
|
| 435 |
+
retry_delay = detail.get("retryDelay")
|
| 436 |
+
if retry_delay:
|
| 437 |
+
parsed = GeminiCliProvider._parse_duration(retry_delay)
|
| 438 |
+
if parsed:
|
| 439 |
+
result["retry_after"] = parsed
|
| 440 |
+
|
| 441 |
+
except (json.JSONDecodeError, AttributeError, TypeError):
|
| 442 |
+
pass
|
| 443 |
|
| 444 |
+
# Return None if we couldn't extract retry_after
|
| 445 |
+
if not result["retry_after"]:
|
| 446 |
+
return None
|
| 447 |
+
|
| 448 |
+
return result
|
| 449 |
+
|
| 450 |
+
@staticmethod
|
| 451 |
+
def _parse_duration(duration_str: str) -> Optional[int]:
|
| 452 |
+
"""
|
| 453 |
+
Parse duration strings like '2s', '156h14m36.73s', '515092.73s' to seconds.
|
| 454 |
+
|
| 455 |
+
Args:
|
| 456 |
+
duration_str: Duration string to parse
|
| 457 |
+
|
| 458 |
+
Returns:
|
| 459 |
+
Total seconds as integer, or None if parsing fails
|
| 460 |
+
"""
|
| 461 |
+
import re as regex_module
|
| 462 |
+
|
| 463 |
+
if not duration_str:
|
| 464 |
+
return None
|
| 465 |
+
|
| 466 |
+
# Handle pure seconds format: "515092.730699158s" or "2s"
|
| 467 |
+
pure_seconds_match = regex_module.match(r"^([\d.]+)s$", duration_str)
|
| 468 |
+
if pure_seconds_match:
|
| 469 |
+
return int(float(pure_seconds_match.group(1)))
|
| 470 |
+
|
| 471 |
+
# Handle compound format: "143h4m52.730699158s"
|
| 472 |
+
total_seconds = 0
|
| 473 |
+
patterns = [
|
| 474 |
+
(r"(\d+)h", 3600), # hours
|
| 475 |
+
(r"(\d+)m", 60), # minutes
|
| 476 |
+
(r"([\d.]+)s", 1), # seconds
|
| 477 |
+
]
|
| 478 |
+
for pattern, multiplier in patterns:
|
| 479 |
+
match = regex_module.search(pattern, duration_str)
|
| 480 |
+
if match:
|
| 481 |
+
total_seconds += float(match.group(1)) * multiplier
|
| 482 |
+
|
| 483 |
+
return int(total_seconds) if total_seconds > 0 else None
|
| 484 |
|
| 485 |
def __init__(self):
|
| 486 |
super().__init__()
|
| 487 |
self.model_definitions = ModelDefinitions()
|
| 488 |
+
# NOTE: project_id_cache and project_tier_cache are inherited from GeminiAuthBase
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
|
| 490 |
# Gemini 3 configuration from environment
|
| 491 |
memory_ttl = _env_int("GEMINI_CLI_SIGNATURE_CACHE_TTL", 3600)
|
|
|
|
| 493 |
|
| 494 |
# Initialize signature cache for Gemini 3 thoughtSignatures
|
| 495 |
self._signature_cache = ProviderCache(
|
| 496 |
+
_get_gemini3_signature_cache_file(),
|
| 497 |
memory_ttl,
|
| 498 |
disk_ttl,
|
| 499 |
env_prefix="GEMINI_CLI_SIGNATURE",
|
|
|
|
| 607 |
|
| 608 |
# Gemini 3 requires paid tier
|
| 609 |
if model_name.startswith("gemini-3-"):
|
| 610 |
+
return 2 # Only priority 2 (paid) credentials
|
| 611 |
|
| 612 |
return None # All other models have no restrictions
|
| 613 |
|
|
|
|
| 617 |
|
| 618 |
This ensures all credential priorities are known before any API calls,
|
| 619 |
preventing unknown credentials from getting priority 999.
|
| 620 |
+
|
| 621 |
+
For credentials without persisted tier info (new or corrupted), performs
|
| 622 |
+
full discovery to ensure proper prioritization in sequential rotation mode.
|
| 623 |
"""
|
| 624 |
+
# Step 1: Load persisted tiers from files
|
| 625 |
await self._load_persisted_tiers(credential_paths)
|
| 626 |
|
| 627 |
+
# Step 2: Identify credentials still missing tier info
|
| 628 |
+
credentials_needing_discovery = [
|
| 629 |
+
path
|
| 630 |
+
for path in credential_paths
|
| 631 |
+
if path not in self.project_tier_cache
|
| 632 |
+
and self._parse_env_credential_path(path) is None # Skip env:// paths
|
| 633 |
+
]
|
| 634 |
+
|
| 635 |
+
if not credentials_needing_discovery:
|
| 636 |
+
return # All credentials have tier info
|
| 637 |
+
|
| 638 |
+
lib_logger.info(
|
| 639 |
+
f"GeminiCli: Discovering tier info for {len(credentials_needing_discovery)} credential(s)..."
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
# Step 3: Perform discovery for each missing credential (sequential to avoid rate limits)
|
| 643 |
+
for credential_path in credentials_needing_discovery:
|
| 644 |
+
try:
|
| 645 |
+
auth_header = await self.get_auth_header(credential_path)
|
| 646 |
+
access_token = auth_header["Authorization"].split(" ")[1]
|
| 647 |
+
await self._discover_project_id(
|
| 648 |
+
credential_path, access_token, litellm_params={}
|
| 649 |
+
)
|
| 650 |
+
discovered_tier = self.project_tier_cache.get(
|
| 651 |
+
credential_path, "unknown"
|
| 652 |
+
)
|
| 653 |
+
lib_logger.debug(
|
| 654 |
+
f"Discovered tier '{discovered_tier}' for {Path(credential_path).name}"
|
| 655 |
+
)
|
| 656 |
+
except Exception as e:
|
| 657 |
+
lib_logger.warning(
|
| 658 |
+
f"Failed to discover tier for {Path(credential_path).name}: {e}. "
|
| 659 |
+
f"Credential will use default priority."
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
async def _load_persisted_tiers(
|
| 663 |
self, credential_paths: List[str]
|
| 664 |
) -> Dict[str, str]:
|
|
|
|
| 716 |
|
| 717 |
return loaded
|
| 718 |
|
| 719 |
+
# NOTE: _post_auth_discovery() is inherited from GeminiAuthBase
|
| 720 |
+
|
| 721 |
# =========================================================================
|
| 722 |
# MODEL UTILITIES
|
| 723 |
# =========================================================================
|
|
|
|
| 733 |
return name[len(self._gemini3_tool_prefix) :]
|
| 734 |
return name
|
| 735 |
|
| 736 |
+
# NOTE: _discover_project_id() and _persist_project_metadata() are inherited from GeminiAuthBase
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 737 |
|
| 738 |
def _check_mixed_tier_warning(self):
|
| 739 |
"""Check if mixed free/paid tier credentials are loaded and emit warning."""
|
|
|
|
| 920 |
func_part["thoughtSignature"] = (
|
| 921 |
"skip_thought_signature_validator"
|
| 922 |
)
|
| 923 |
+
lib_logger.debug(
|
| 924 |
f"Missing thoughtSignature for first func call {tool_id}, using bypass"
|
| 925 |
)
|
| 926 |
# Subsequent parallel calls: no signature field at all
|
|
|
|
| 932 |
elif role == "tool":
|
| 933 |
tool_call_id = msg.get("tool_call_id")
|
| 934 |
function_name = tool_call_id_to_name.get(tool_call_id)
|
| 935 |
+
|
| 936 |
+
# Log warning if tool_call_id not found in mapping (can happen after context compaction)
|
| 937 |
+
if not function_name:
|
| 938 |
+
lib_logger.warning(
|
| 939 |
+
f"[ID Mismatch] Tool response has ID '{tool_call_id}' which was not found in tool_id_to_name map. "
|
| 940 |
+
f"Available IDs: {list(tool_call_id_to_name.keys())}. Using 'unknown_function' as fallback."
|
| 941 |
+
)
|
| 942 |
+
function_name = "unknown_function"
|
| 943 |
+
|
| 944 |
+
# Add prefix for Gemini 3
|
| 945 |
+
if is_gemini_3 and self._enable_gemini3_tool_fix:
|
| 946 |
+
function_name = f"{self._gemini3_tool_prefix}{function_name}"
|
| 947 |
+
|
| 948 |
+
# Try to parse content as JSON first, fall back to string
|
| 949 |
+
try:
|
| 950 |
+
parsed_content = (
|
| 951 |
+
json.loads(content) if isinstance(content, str) else content
|
| 952 |
)
|
| 953 |
+
except (json.JSONDecodeError, TypeError):
|
| 954 |
+
parsed_content = content
|
| 955 |
+
|
| 956 |
+
# Wrap the tool response in a 'result' object
|
| 957 |
+
response_content = {"result": parsed_content}
|
| 958 |
+
# Accumulate tool responses - they'll be combined into one user message
|
| 959 |
+
pending_tool_parts.append(
|
| 960 |
+
{
|
| 961 |
+
"functionResponse": {
|
| 962 |
+
"name": function_name,
|
| 963 |
+
"response": response_content,
|
| 964 |
+
"id": tool_call_id,
|
| 965 |
+
}
|
| 966 |
+
}
|
| 967 |
+
)
|
| 968 |
# Don't add parts here - tool responses are handled via pending_tool_parts
|
| 969 |
continue
|
| 970 |
|
|
|
|
| 980 |
|
| 981 |
return system_instruction, gemini_contents
|
| 982 |
|
| 983 |
+
def _fix_tool_response_grouping(
|
| 984 |
+
self, contents: List[Dict[str, Any]]
|
| 985 |
+
) -> List[Dict[str, Any]]:
|
| 986 |
+
"""
|
| 987 |
+
Group function calls with their responses for Gemini CLI compatibility.
|
| 988 |
+
|
| 989 |
+
Converts linear format (call, response, call, response)
|
| 990 |
+
to grouped format (model with calls, user with all responses).
|
| 991 |
+
|
| 992 |
+
IMPORTANT: Preserves ID-based pairing to prevent mismatches.
|
| 993 |
+
When IDs don't match, attempts recovery by:
|
| 994 |
+
1. Matching by function name first
|
| 995 |
+
2. Matching by order if names don't match
|
| 996 |
+
3. Inserting placeholder responses if responses are missing
|
| 997 |
+
4. Inserting responses at the CORRECT position (after their corresponding call)
|
| 998 |
+
"""
|
| 999 |
+
new_contents = []
|
| 1000 |
+
# Each pending group tracks:
|
| 1001 |
+
# - ids: expected response IDs
|
| 1002 |
+
# - func_names: expected function names (for orphan matching)
|
| 1003 |
+
# - insert_after_idx: position in new_contents where model message was added
|
| 1004 |
+
pending_groups = []
|
| 1005 |
+
collected_responses = {} # Dict mapping ID -> response_part
|
| 1006 |
+
|
| 1007 |
+
for content in contents:
|
| 1008 |
+
role = content.get("role")
|
| 1009 |
+
parts = content.get("parts", [])
|
| 1010 |
+
|
| 1011 |
+
response_parts = [p for p in parts if "functionResponse" in p]
|
| 1012 |
+
|
| 1013 |
+
if response_parts:
|
| 1014 |
+
# Collect responses by ID (ignore duplicates - keep first occurrence)
|
| 1015 |
+
for resp in response_parts:
|
| 1016 |
+
resp_id = resp.get("functionResponse", {}).get("id", "")
|
| 1017 |
+
if resp_id:
|
| 1018 |
+
if resp_id in collected_responses:
|
| 1019 |
+
lib_logger.warning(
|
| 1020 |
+
f"[Grouping] Duplicate response ID detected: {resp_id}. "
|
| 1021 |
+
f"Ignoring duplicate - this may indicate malformed conversation history."
|
| 1022 |
+
)
|
| 1023 |
+
continue
|
| 1024 |
+
collected_responses[resp_id] = resp
|
| 1025 |
+
|
| 1026 |
+
# Try to satisfy pending groups (newest first)
|
| 1027 |
+
for i in range(len(pending_groups) - 1, -1, -1):
|
| 1028 |
+
group = pending_groups[i]
|
| 1029 |
+
group_ids = group["ids"]
|
| 1030 |
+
|
| 1031 |
+
# Check if we have ALL responses for this group
|
| 1032 |
+
if all(gid in collected_responses for gid in group_ids):
|
| 1033 |
+
# Extract responses in the same order as the function calls
|
| 1034 |
+
group_responses = [
|
| 1035 |
+
collected_responses.pop(gid) for gid in group_ids
|
| 1036 |
+
]
|
| 1037 |
+
new_contents.append({"parts": group_responses, "role": "user"})
|
| 1038 |
+
pending_groups.pop(i)
|
| 1039 |
+
break
|
| 1040 |
+
continue
|
| 1041 |
+
|
| 1042 |
+
if role == "model":
|
| 1043 |
+
func_calls = [p for p in parts if "functionCall" in p]
|
| 1044 |
+
new_contents.append(content)
|
| 1045 |
+
if func_calls:
|
| 1046 |
+
call_ids = [
|
| 1047 |
+
fc.get("functionCall", {}).get("id", "") for fc in func_calls
|
| 1048 |
+
]
|
| 1049 |
+
call_ids = [cid for cid in call_ids if cid] # Filter empty IDs
|
| 1050 |
+
|
| 1051 |
+
# Also extract function names for orphan matching
|
| 1052 |
+
func_names = [
|
| 1053 |
+
fc.get("functionCall", {}).get("name", "") for fc in func_calls
|
| 1054 |
+
]
|
| 1055 |
+
|
| 1056 |
+
if call_ids:
|
| 1057 |
+
pending_groups.append(
|
| 1058 |
+
{
|
| 1059 |
+
"ids": call_ids,
|
| 1060 |
+
"func_names": func_names,
|
| 1061 |
+
"insert_after_idx": len(new_contents) - 1,
|
| 1062 |
+
}
|
| 1063 |
+
)
|
| 1064 |
+
else:
|
| 1065 |
+
new_contents.append(content)
|
| 1066 |
+
|
| 1067 |
+
# Handle remaining groups (shouldn't happen in well-formed conversations)
|
| 1068 |
+
# Attempt recovery by matching orphans to unsatisfied calls
|
| 1069 |
+
# Process in REVERSE order of insert_after_idx so insertions don't shift indices
|
| 1070 |
+
pending_groups.sort(key=lambda g: g["insert_after_idx"], reverse=True)
|
| 1071 |
+
|
| 1072 |
+
for group in pending_groups:
|
| 1073 |
+
group_ids = group["ids"]
|
| 1074 |
+
group_func_names = group.get("func_names", [])
|
| 1075 |
+
insert_idx = group["insert_after_idx"] + 1
|
| 1076 |
+
group_responses = []
|
| 1077 |
+
|
| 1078 |
+
lib_logger.debug(
|
| 1079 |
+
f"[Grouping Recovery] Processing unsatisfied group: "
|
| 1080 |
+
f"ids={group_ids}, names={group_func_names}, insert_at={insert_idx}"
|
| 1081 |
+
)
|
| 1082 |
+
|
| 1083 |
+
for i, expected_id in enumerate(group_ids):
|
| 1084 |
+
expected_name = group_func_names[i] if i < len(group_func_names) else ""
|
| 1085 |
+
|
| 1086 |
+
if expected_id in collected_responses:
|
| 1087 |
+
# Direct ID match
|
| 1088 |
+
group_responses.append(collected_responses.pop(expected_id))
|
| 1089 |
+
lib_logger.debug(
|
| 1090 |
+
f"[Grouping Recovery] Direct ID match for '{expected_id}'"
|
| 1091 |
+
)
|
| 1092 |
+
elif collected_responses:
|
| 1093 |
+
# Try to find orphan with matching function name first
|
| 1094 |
+
matched_orphan_id = None
|
| 1095 |
+
|
| 1096 |
+
# First pass: match by function name
|
| 1097 |
+
for orphan_id, orphan_resp in collected_responses.items():
|
| 1098 |
+
orphan_name = orphan_resp.get("functionResponse", {}).get(
|
| 1099 |
+
"name", ""
|
| 1100 |
+
)
|
| 1101 |
+
# Match if names are equal
|
| 1102 |
+
if orphan_name == expected_name:
|
| 1103 |
+
matched_orphan_id = orphan_id
|
| 1104 |
+
lib_logger.debug(
|
| 1105 |
+
f"[Grouping Recovery] Matched orphan '{orphan_id}' by name '{orphan_name}'"
|
| 1106 |
+
)
|
| 1107 |
+
break
|
| 1108 |
+
|
| 1109 |
+
# Second pass: if no name match, try "unknown_function" orphans
|
| 1110 |
+
if not matched_orphan_id:
|
| 1111 |
+
for orphan_id, orphan_resp in collected_responses.items():
|
| 1112 |
+
orphan_name = orphan_resp.get("functionResponse", {}).get(
|
| 1113 |
+
"name", ""
|
| 1114 |
+
)
|
| 1115 |
+
if orphan_name == "unknown_function":
|
| 1116 |
+
matched_orphan_id = orphan_id
|
| 1117 |
+
lib_logger.debug(
|
| 1118 |
+
f"[Grouping Recovery] Matched unknown_function orphan '{orphan_id}' "
|
| 1119 |
+
f"to expected '{expected_name}'"
|
| 1120 |
+
)
|
| 1121 |
+
break
|
| 1122 |
+
|
| 1123 |
+
# Third pass: if still no match, take first available (order-based)
|
| 1124 |
+
if not matched_orphan_id:
|
| 1125 |
+
matched_orphan_id = next(iter(collected_responses))
|
| 1126 |
+
lib_logger.debug(
|
| 1127 |
+
f"[Grouping Recovery] No name match, using first available orphan '{matched_orphan_id}'"
|
| 1128 |
+
)
|
| 1129 |
+
|
| 1130 |
+
if matched_orphan_id:
|
| 1131 |
+
orphan_resp = collected_responses.pop(matched_orphan_id)
|
| 1132 |
+
|
| 1133 |
+
# Fix the ID in the response to match the call
|
| 1134 |
+
old_id = orphan_resp["functionResponse"].get("id", "")
|
| 1135 |
+
orphan_resp["functionResponse"]["id"] = expected_id
|
| 1136 |
+
|
| 1137 |
+
# Fix the name if it was "unknown_function"
|
| 1138 |
+
if (
|
| 1139 |
+
orphan_resp["functionResponse"].get("name")
|
| 1140 |
+
== "unknown_function"
|
| 1141 |
+
and expected_name
|
| 1142 |
+
):
|
| 1143 |
+
orphan_resp["functionResponse"]["name"] = expected_name
|
| 1144 |
+
lib_logger.info(
|
| 1145 |
+
f"[Grouping Recovery] Fixed function name from 'unknown_function' to '{expected_name}'"
|
| 1146 |
+
)
|
| 1147 |
+
|
| 1148 |
+
lib_logger.warning(
|
| 1149 |
+
f"[Grouping] Auto-repaired ID mismatch: mapped response '{old_id}' "
|
| 1150 |
+
f"to call '{expected_id}' (function: {expected_name})"
|
| 1151 |
+
)
|
| 1152 |
+
group_responses.append(orphan_resp)
|
| 1153 |
+
else:
|
| 1154 |
+
# No responses available - create placeholder
|
| 1155 |
+
placeholder_resp = {
|
| 1156 |
+
"functionResponse": {
|
| 1157 |
+
"name": expected_name or "unknown_function",
|
| 1158 |
+
"response": {
|
| 1159 |
+
"result": {
|
| 1160 |
+
"error": "Tool response was lost during context processing. "
|
| 1161 |
+
"This is a recovered placeholder.",
|
| 1162 |
+
"recovered": True,
|
| 1163 |
+
}
|
| 1164 |
+
},
|
| 1165 |
+
"id": expected_id,
|
| 1166 |
+
}
|
| 1167 |
+
}
|
| 1168 |
+
lib_logger.warning(
|
| 1169 |
+
f"[Grouping Recovery] Created placeholder response for missing tool: "
|
| 1170 |
+
f"id='{expected_id}', name='{expected_name}'"
|
| 1171 |
+
)
|
| 1172 |
+
group_responses.append(placeholder_resp)
|
| 1173 |
+
|
| 1174 |
+
if group_responses:
|
| 1175 |
+
# Insert at the correct position (right after the model message with the calls)
|
| 1176 |
+
new_contents.insert(
|
| 1177 |
+
insert_idx, {"parts": group_responses, "role": "user"}
|
| 1178 |
+
)
|
| 1179 |
+
lib_logger.info(
|
| 1180 |
+
f"[Grouping Recovery] Inserted {len(group_responses)} responses at position {insert_idx} "
|
| 1181 |
+
f"(expected {len(group_ids)})"
|
| 1182 |
+
)
|
| 1183 |
+
|
| 1184 |
+
# Warn about unmatched responses
|
| 1185 |
+
if collected_responses:
|
| 1186 |
+
lib_logger.warning(
|
| 1187 |
+
f"[Grouping] {len(collected_responses)} unmatched responses remaining: "
|
| 1188 |
+
f"ids={list(collected_responses.keys())}"
|
| 1189 |
+
)
|
| 1190 |
+
|
| 1191 |
+
return new_contents
|
| 1192 |
+
|
| 1193 |
def _handle_reasoning_parameters(
|
| 1194 |
self, payload: Dict[str, Any], model: str
|
| 1195 |
) -> Optional[Dict[str, Any]]:
|
|
|
|
| 1309 |
# Get current tool index from accumulator (default 0) and increment
|
| 1310 |
current_tool_idx = accumulator.get("tool_idx", 0) if accumulator else 0
|
| 1311 |
|
| 1312 |
+
# Get args, recursively parse any JSON strings, and strip _confirm if sole param
|
| 1313 |
+
raw_args = function_call.get("args", {})
|
| 1314 |
+
tool_args = _recursively_parse_json_strings(raw_args)
|
| 1315 |
+
|
| 1316 |
+
# Strip _confirm ONLY if it's the sole parameter
|
| 1317 |
+
# This ensures we only strip our injection, not legitimate user params
|
| 1318 |
+
if isinstance(tool_args, dict) and "_confirm" in tool_args:
|
| 1319 |
+
if len(tool_args) == 1:
|
| 1320 |
+
# _confirm is the only param - this was our injection
|
| 1321 |
+
tool_args.pop("_confirm")
|
| 1322 |
+
|
| 1323 |
tool_call = {
|
| 1324 |
"index": current_tool_idx,
|
| 1325 |
"id": tool_call_id,
|
| 1326 |
"type": "function",
|
| 1327 |
"function": {
|
| 1328 |
"name": function_name,
|
| 1329 |
+
"arguments": json.dumps(tool_args),
|
| 1330 |
},
|
| 1331 |
}
|
| 1332 |
|
|
|
|
| 1634 |
schema = self._gemini_cli_transform_schema(
|
| 1635 |
new_function["parameters"]
|
| 1636 |
)
|
| 1637 |
+
# Workaround: Gemini fails to emit functionCall for tools
|
| 1638 |
+
# with empty properties {}. Inject a required confirmation param.
|
| 1639 |
+
# Using a required parameter forces the model to commit to
|
| 1640 |
+
# the tool call rather than just thinking about it.
|
| 1641 |
+
props = schema.get("properties", {})
|
| 1642 |
+
if not props:
|
| 1643 |
+
schema["properties"] = {
|
| 1644 |
+
"_confirm": {
|
| 1645 |
+
"type": "string",
|
| 1646 |
+
"description": "Enter 'yes' to proceed",
|
| 1647 |
+
}
|
| 1648 |
+
}
|
| 1649 |
+
schema["required"] = ["_confirm"]
|
| 1650 |
new_function["parametersJsonSchema"] = schema
|
| 1651 |
del new_function["parameters"]
|
| 1652 |
elif "parametersJsonSchema" not in new_function:
|
| 1653 |
+
# Set default schema with required confirm param if neither exists
|
| 1654 |
new_function["parametersJsonSchema"] = {
|
| 1655 |
"type": "object",
|
| 1656 |
+
"properties": {
|
| 1657 |
+
"_confirm": {
|
| 1658 |
+
"type": "string",
|
| 1659 |
+
"description": "Enter 'yes' to proceed",
|
| 1660 |
+
}
|
| 1661 |
+
},
|
| 1662 |
+
"required": ["_confirm"],
|
| 1663 |
}
|
| 1664 |
|
| 1665 |
# Gemini 3 specific transformations
|
|
|
|
| 1899 |
system_instruction, contents = self._transform_messages(
|
| 1900 |
kwargs.get("messages", []), model_name
|
| 1901 |
)
|
| 1902 |
+
# Fix tool response grouping (handles ID mismatches, missing responses)
|
| 1903 |
+
contents = self._fix_tool_response_grouping(contents)
|
| 1904 |
+
|
| 1905 |
request_payload = {
|
| 1906 |
"model": model_name,
|
| 1907 |
"project": project_id,
|
|
|
|
| 1978 |
headers=final_headers,
|
| 1979 |
json=request_payload,
|
| 1980 |
params={"alt": "sse"},
|
| 1981 |
+
timeout=TimeoutConfig.streaming(),
|
| 1982 |
) as response:
|
| 1983 |
# Read and log error body before raise_for_status for better debugging
|
| 1984 |
if response.status_code >= 400:
|
|
|
|
| 2189 |
|
| 2190 |
# Transform messages to Gemini format
|
| 2191 |
system_instruction, contents = self._transform_messages(messages)
|
| 2192 |
+
# Fix tool response grouping (handles ID mismatches, missing responses)
|
| 2193 |
+
contents = self._fix_tool_response_grouping(contents)
|
| 2194 |
|
| 2195 |
# Build request payload
|
| 2196 |
request_payload = {
|
src/rotator_library/providers/google_oauth_base.py
CHANGED
|
@@ -1,16 +1,17 @@
|
|
| 1 |
# src/rotator_library/providers/google_oauth_base.py
|
| 2 |
|
| 3 |
import os
|
|
|
|
| 4 |
import webbrowser
|
| 5 |
-
from
|
|
|
|
| 6 |
import json
|
| 7 |
import time
|
| 8 |
import asyncio
|
| 9 |
import logging
|
| 10 |
from pathlib import Path
|
| 11 |
from typing import Dict, Any
|
| 12 |
-
import
|
| 13 |
-
import shutil
|
| 14 |
|
| 15 |
import httpx
|
| 16 |
from rich.console import Console
|
|
@@ -20,12 +21,31 @@ from rich.markup import escape as rich_escape
|
|
| 20 |
|
| 21 |
from ..utils.headless_detection import is_headless_environment
|
| 22 |
from ..utils.reauth_coordinator import get_reauth_coordinator
|
|
|
|
| 23 |
|
| 24 |
lib_logger = logging.getLogger("rotator_library")
|
| 25 |
|
| 26 |
console = Console()
|
| 27 |
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
class GoogleOAuthBase:
|
| 30 |
"""
|
| 31 |
Base class for Google OAuth2 authentication providers.
|
|
@@ -55,6 +75,25 @@ class GoogleOAuthBase:
|
|
| 55 |
CALLBACK_PATH: str = "/oauth2callback"
|
| 56 |
REFRESH_EXPIRY_BUFFER_SECONDS: int = 30 * 60 # 30 minutes
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
def __init__(self):
|
| 59 |
# Validate that subclass has set required attributes
|
| 60 |
if self.CLIENT_ID is None:
|
|
@@ -83,19 +122,36 @@ class GoogleOAuthBase:
|
|
| 83 |
str, float
|
| 84 |
] = {} # Track backoff timers (Unix timestamp)
|
| 85 |
|
| 86 |
-
# [QUEUE SYSTEM] Sequential refresh processing
|
|
|
|
| 87 |
self._refresh_queue: asyncio.Queue = asyncio.Queue()
|
| 88 |
-
self.
|
| 89 |
-
|
| 90 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
self._unavailable_credentials: Dict[
|
| 92 |
str, float
|
| 93 |
] = {} # Maps credential path -> timestamp when marked unavailable
|
| 94 |
-
|
|
|
|
| 95 |
self._queue_tracking_lock = asyncio.Lock() # Protects queue sets
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
def _parse_env_credential_path(self, path: str) -> Optional[str]:
|
| 101 |
"""
|
|
@@ -228,17 +284,7 @@ class GoogleOAuthBase:
|
|
| 228 |
f"Environment variables for {self.ENV_PREFIX} credential index {credential_index} not found"
|
| 229 |
)
|
| 230 |
|
| 231 |
-
#
|
| 232 |
-
env_creds = self._load_from_env()
|
| 233 |
-
if env_creds:
|
| 234 |
-
lib_logger.info(
|
| 235 |
-
f"Using {self.ENV_PREFIX} credentials from environment variables"
|
| 236 |
-
)
|
| 237 |
-
# Cache env-based credentials using the path as key
|
| 238 |
-
self._credentials_cache[path] = env_creds
|
| 239 |
-
return env_creds
|
| 240 |
-
|
| 241 |
-
# Fall back to file-based loading
|
| 242 |
try:
|
| 243 |
lib_logger.debug(
|
| 244 |
f"Loading {self.ENV_PREFIX} credentials from file: {path}"
|
|
@@ -251,6 +297,15 @@ class GoogleOAuthBase:
|
|
| 251 |
self._credentials_cache[path] = creds
|
| 252 |
return creds
|
| 253 |
except FileNotFoundError:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
raise IOError(
|
| 255 |
f"{self.ENV_PREFIX} OAuth credential file not found at '{path}'"
|
| 256 |
)
|
|
@@ -258,70 +313,29 @@ class GoogleOAuthBase:
|
|
| 258 |
raise IOError(
|
| 259 |
f"Failed to load {self.ENV_PREFIX} OAuth credentials from '{path}': {e}"
|
| 260 |
)
|
| 261 |
-
except Exception as e:
|
| 262 |
-
raise IOError(
|
| 263 |
-
f"Failed to load {self.ENV_PREFIX} OAuth credentials from '{path}': {e}"
|
| 264 |
-
)
|
| 265 |
|
| 266 |
async def _save_credentials(self, path: str, creds: Dict[str, Any]):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
# Don't save to file if credentials were loaded from environment
|
| 268 |
if creds.get("_proxy_metadata", {}).get("loaded_from_env"):
|
| 269 |
lib_logger.debug("Credentials loaded from env, skipping file save")
|
| 270 |
-
# Still update cache for in-memory consistency
|
| 271 |
-
self._credentials_cache[path] = creds
|
| 272 |
return
|
| 273 |
|
| 274 |
-
#
|
| 275 |
-
#
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
tmp_fd = None
|
| 280 |
-
tmp_path = None
|
| 281 |
-
try:
|
| 282 |
-
# Create temp file in same directory as target (ensures same filesystem)
|
| 283 |
-
tmp_fd, tmp_path = tempfile.mkstemp(
|
| 284 |
-
dir=parent_dir, prefix=".tmp_", suffix=".json", text=True
|
| 285 |
-
)
|
| 286 |
-
|
| 287 |
-
# Write JSON to temp file
|
| 288 |
-
with os.fdopen(tmp_fd, "w") as f:
|
| 289 |
-
json.dump(creds, f, indent=2)
|
| 290 |
-
tmp_fd = None # fdopen closes the fd
|
| 291 |
-
|
| 292 |
-
# Set secure permissions (0600 = owner read/write only)
|
| 293 |
-
try:
|
| 294 |
-
os.chmod(tmp_path, 0o600)
|
| 295 |
-
except (OSError, AttributeError):
|
| 296 |
-
# Windows may not support chmod, ignore
|
| 297 |
-
pass
|
| 298 |
-
|
| 299 |
-
# Atomic move (overwrites target if it exists)
|
| 300 |
-
shutil.move(tmp_path, path)
|
| 301 |
-
tmp_path = None # Successfully moved
|
| 302 |
-
|
| 303 |
-
# Update cache AFTER successful file write (prevents cache/file inconsistency)
|
| 304 |
-
self._credentials_cache[path] = creds
|
| 305 |
lib_logger.debug(
|
| 306 |
-
f"Saved updated {self.ENV_PREFIX} OAuth credentials to '{path}'
|
| 307 |
)
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
f"Failed to save updated {self.ENV_PREFIX} OAuth credentials to '{path}': {e}"
|
| 312 |
)
|
| 313 |
-
# Clean up temp file if it still exists
|
| 314 |
-
if tmp_fd is not None:
|
| 315 |
-
try:
|
| 316 |
-
os.close(tmp_fd)
|
| 317 |
-
except:
|
| 318 |
-
pass
|
| 319 |
-
if tmp_path and os.path.exists(tmp_path):
|
| 320 |
-
try:
|
| 321 |
-
os.unlink(tmp_path)
|
| 322 |
-
except:
|
| 323 |
-
pass
|
| 324 |
-
raise
|
| 325 |
|
| 326 |
def _is_token_expired(self, creds: Dict[str, Any]) -> bool:
|
| 327 |
expiry = creds.get("token_expiry") # gcloud format
|
|
@@ -518,7 +532,7 @@ class GoogleOAuthBase:
|
|
| 518 |
"""Proactively refresh a credential by queueing it for refresh."""
|
| 519 |
creds = await self._load_credentials(credential_path)
|
| 520 |
if self._is_token_expired(creds):
|
| 521 |
-
#
|
| 522 |
await self._queue_refresh(credential_path, force=False, needs_reauth=False)
|
| 523 |
|
| 524 |
async def _get_lock(self, path: str) -> asyncio.Lock:
|
|
@@ -529,34 +543,69 @@ class GoogleOAuthBase:
|
|
| 529 |
self._refresh_locks[path] = asyncio.Lock()
|
| 530 |
return self._refresh_locks[path]
|
| 531 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
def is_credential_available(self, path: str) -> bool:
|
| 533 |
-
"""Check if a credential is available for rotation
|
|
|
|
|
|
|
|
|
|
|
|
|
| 534 |
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 538 |
"""
|
| 539 |
-
|
| 540 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 541 |
|
| 542 |
-
#
|
| 543 |
-
|
| 544 |
-
if
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
lib_logger.
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
|
|
|
| 552 |
)
|
| 553 |
-
|
| 554 |
-
# However, pop from dict is thread-safe for single operations.
|
| 555 |
-
# The _queue_tracking_lock protects concurrent modifications in async context.
|
| 556 |
-
self._unavailable_credentials.pop(path, None)
|
| 557 |
-
return True
|
| 558 |
|
| 559 |
-
return
|
| 560 |
|
| 561 |
async def _ensure_queue_processor_running(self):
|
| 562 |
"""Lazily starts the queue processor if not already running."""
|
|
@@ -565,15 +614,27 @@ class GoogleOAuthBase:
|
|
| 565 |
self._process_refresh_queue()
|
| 566 |
)
|
| 567 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 568 |
async def _queue_refresh(
|
| 569 |
self, path: str, force: bool = False, needs_reauth: bool = False
|
| 570 |
):
|
| 571 |
-
"""Add a credential to the refresh queue if not already queued.
|
| 572 |
|
| 573 |
Args:
|
| 574 |
path: Credential file path
|
| 575 |
force: Force refresh even if not expired
|
| 576 |
-
needs_reauth: True if full re-authentication needed (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
"""
|
| 578 |
# IMPORTANT: Only check backoff for simple automated refreshes
|
| 579 |
# Re-authentication (interactive OAuth) should BYPASS backoff since it needs user input
|
|
@@ -583,108 +644,226 @@ class GoogleOAuthBase:
|
|
| 583 |
backoff_until = self._next_refresh_after[path]
|
| 584 |
if now < backoff_until:
|
| 585 |
# Credential is in backoff for automated refresh, do not queue
|
| 586 |
-
remaining = int(backoff_until - now)
|
| 587 |
-
lib_logger.debug(
|
| 588 |
-
|
| 589 |
-
)
|
| 590 |
return
|
| 591 |
|
| 592 |
async with self._queue_tracking_lock:
|
| 593 |
if path not in self._queued_credentials:
|
| 594 |
self._queued_credentials.add(path)
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
|
| 604 |
async def _process_refresh_queue(self):
|
| 605 |
-
"""Background worker that processes refresh requests sequentially.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 606 |
while True:
|
| 607 |
path = None
|
| 608 |
try:
|
| 609 |
# Wait for an item with timeout to allow graceful shutdown
|
| 610 |
try:
|
| 611 |
-
path, force
|
| 612 |
self._refresh_queue.get(), timeout=60.0
|
| 613 |
)
|
| 614 |
except asyncio.TimeoutError:
|
| 615 |
-
#
|
| 616 |
-
# If we're idle for 60s, no refreshes are in progress
|
| 617 |
async with self._queue_tracking_lock:
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
lib_logger.warning(
|
| 621 |
-
f"Queue processor idle timeout. Cleaning {stale_count} "
|
| 622 |
-
f"stale unavailable credentials: {list(self._unavailable_credentials.keys())}"
|
| 623 |
-
)
|
| 624 |
-
self._unavailable_credentials.clear()
|
| 625 |
self._queue_processor_task = None
|
|
|
|
| 626 |
return
|
| 627 |
|
| 628 |
try:
|
| 629 |
-
#
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 642 |
|
| 643 |
-
#
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 654 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 655 |
|
| 656 |
finally:
|
| 657 |
-
#
|
| 658 |
-
# This ensures cleanup happens in ALL exit paths (success, exception, etc.)
|
| 659 |
async with self._queue_tracking_lock:
|
| 660 |
self._queued_credentials.discard(path)
|
| 661 |
-
# [FIX PR#34] Always clean up unavailable credentials in finally block
|
| 662 |
self._unavailable_credentials.pop(path, None)
|
| 663 |
-
lib_logger.debug(
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
)
|
| 667 |
-
self.
|
|
|
|
| 668 |
except asyncio.CancelledError:
|
| 669 |
-
#
|
| 670 |
if path:
|
| 671 |
async with self._queue_tracking_lock:
|
|
|
|
| 672 |
self._unavailable_credentials.pop(path, None)
|
| 673 |
-
|
| 674 |
-
f"CancelledError cleanup for '{Path(path).name}'. "
|
| 675 |
-
f"Remaining unavailable: {len(self._unavailable_credentials)}"
|
| 676 |
-
)
|
| 677 |
break
|
| 678 |
except Exception as e:
|
| 679 |
-
lib_logger.error(f"Error in queue processor: {e}")
|
| 680 |
-
# Even on error, mark as available (backoff will prevent immediate retry)
|
| 681 |
if path:
|
| 682 |
async with self._queue_tracking_lock:
|
|
|
|
| 683 |
self._unavailable_credentials.pop(path, None)
|
| 684 |
-
lib_logger.debug(
|
| 685 |
-
f"Error cleanup for '{Path(path).name}': {e}. "
|
| 686 |
-
f"Remaining unavailable: {len(self._unavailable_credentials)}"
|
| 687 |
-
)
|
| 688 |
|
| 689 |
async def _perform_interactive_oauth(
|
| 690 |
self, path: str, creds: Dict[str, Any], display_name: str
|
|
@@ -744,14 +923,14 @@ class GoogleOAuthBase:
|
|
| 744 |
|
| 745 |
try:
|
| 746 |
server = await asyncio.start_server(
|
| 747 |
-
handle_callback, "127.0.0.1", self.
|
| 748 |
)
|
| 749 |
from urllib.parse import urlencode
|
| 750 |
|
| 751 |
auth_url = "https://accounts.google.com/o/oauth2/v2/auth?" + urlencode(
|
| 752 |
{
|
| 753 |
"client_id": self.CLIENT_ID,
|
| 754 |
-
"redirect_uri": f"http://localhost:{self.
|
| 755 |
"scope": " ".join(self.OAUTH_SCOPES),
|
| 756 |
"access_type": "offline",
|
| 757 |
"response_type": "code",
|
|
@@ -826,7 +1005,7 @@ class GoogleOAuthBase:
|
|
| 826 |
"code": auth_code.strip(),
|
| 827 |
"client_id": self.CLIENT_ID,
|
| 828 |
"client_secret": self.CLIENT_SECRET,
|
| 829 |
-
"redirect_uri": f"http://localhost:{self.
|
| 830 |
"grant_type": "authorization_code",
|
| 831 |
},
|
| 832 |
)
|
|
@@ -864,6 +1043,18 @@ class GoogleOAuthBase:
|
|
| 864 |
lib_logger.info(
|
| 865 |
f"{self.ENV_PREFIX} OAuth initialized successfully for '{display_name}'."
|
| 866 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 867 |
return new_creds
|
| 868 |
|
| 869 |
async def initialize_token(
|
|
@@ -940,10 +1131,51 @@ class GoogleOAuthBase:
|
|
| 940 |
)
|
| 941 |
|
| 942 |
async def get_auth_header(self, credential_path: str) -> Dict[str, str]:
|
| 943 |
-
|
| 944 |
-
|
| 945 |
-
creds = await self.
|
| 946 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 947 |
|
| 948 |
async def get_user_info(
|
| 949 |
self, creds_or_path: Union[Dict[str, Any], str]
|
|
@@ -976,3 +1208,372 @@ class GoogleOAuthBase:
|
|
| 976 |
if path:
|
| 977 |
await self._save_credentials(path, creds)
|
| 978 |
return {"email": user_info.get("email")}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# src/rotator_library/providers/google_oauth_base.py
|
| 2 |
|
| 3 |
import os
|
| 4 |
+
import re
|
| 5 |
import webbrowser
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
from typing import Union, Optional, List
|
| 8 |
import json
|
| 9 |
import time
|
| 10 |
import asyncio
|
| 11 |
import logging
|
| 12 |
from pathlib import Path
|
| 13 |
from typing import Dict, Any
|
| 14 |
+
from glob import glob
|
|
|
|
| 15 |
|
| 16 |
import httpx
|
| 17 |
from rich.console import Console
|
|
|
|
| 21 |
|
| 22 |
from ..utils.headless_detection import is_headless_environment
|
| 23 |
from ..utils.reauth_coordinator import get_reauth_coordinator
|
| 24 |
+
from ..utils.resilient_io import safe_write_json
|
| 25 |
|
| 26 |
lib_logger = logging.getLogger("rotator_library")
|
| 27 |
|
| 28 |
console = Console()
|
| 29 |
|
| 30 |
|
| 31 |
+
@dataclass
|
| 32 |
+
class CredentialSetupResult:
|
| 33 |
+
"""
|
| 34 |
+
Standardized result structure for credential setup operations.
|
| 35 |
+
|
| 36 |
+
Used by all auth classes to return consistent setup results to the credential tool.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
success: bool
|
| 40 |
+
file_path: Optional[str] = None
|
| 41 |
+
email: Optional[str] = None
|
| 42 |
+
tier: Optional[str] = None
|
| 43 |
+
project_id: Optional[str] = None
|
| 44 |
+
is_update: bool = False
|
| 45 |
+
error: Optional[str] = None
|
| 46 |
+
credentials: Optional[Dict[str, Any]] = field(default=None, repr=False)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
class GoogleOAuthBase:
|
| 50 |
"""
|
| 51 |
Base class for Google OAuth2 authentication providers.
|
|
|
|
| 75 |
CALLBACK_PATH: str = "/oauth2callback"
|
| 76 |
REFRESH_EXPIRY_BUFFER_SECONDS: int = 30 * 60 # 30 minutes
|
| 77 |
|
| 78 |
+
@property
|
| 79 |
+
def callback_port(self) -> int:
|
| 80 |
+
"""
|
| 81 |
+
Get the OAuth callback port, checking environment variable first.
|
| 82 |
+
|
| 83 |
+
Reads from {ENV_PREFIX}_OAUTH_PORT environment variable, falling back
|
| 84 |
+
to the class's CALLBACK_PORT default if not set.
|
| 85 |
+
"""
|
| 86 |
+
env_var = f"{self.ENV_PREFIX}_OAUTH_PORT"
|
| 87 |
+
env_value = os.getenv(env_var)
|
| 88 |
+
if env_value:
|
| 89 |
+
try:
|
| 90 |
+
return int(env_value)
|
| 91 |
+
except ValueError:
|
| 92 |
+
lib_logger.warning(
|
| 93 |
+
f"Invalid {env_var} value: {env_value}, using default {self.CALLBACK_PORT}"
|
| 94 |
+
)
|
| 95 |
+
return self.CALLBACK_PORT
|
| 96 |
+
|
| 97 |
def __init__(self):
|
| 98 |
# Validate that subclass has set required attributes
|
| 99 |
if self.CLIENT_ID is None:
|
|
|
|
| 122 |
str, float
|
| 123 |
] = {} # Track backoff timers (Unix timestamp)
|
| 124 |
|
| 125 |
+
# [QUEUE SYSTEM] Sequential refresh processing with two separate queues
|
| 126 |
+
# Normal refresh queue: for proactive token refresh (old token still valid)
|
| 127 |
self._refresh_queue: asyncio.Queue = asyncio.Queue()
|
| 128 |
+
self._queue_processor_task: Optional[asyncio.Task] = None
|
| 129 |
+
|
| 130 |
+
# Re-auth queue: for invalid refresh tokens (requires user interaction)
|
| 131 |
+
self._reauth_queue: asyncio.Queue = asyncio.Queue()
|
| 132 |
+
self._reauth_processor_task: Optional[asyncio.Task] = None
|
| 133 |
+
|
| 134 |
+
# Tracking sets/dicts
|
| 135 |
+
self._queued_credentials: set = set() # Track credentials in either queue
|
| 136 |
+
# Only credentials in re-auth queue are marked unavailable (not normal refresh)
|
| 137 |
+
# TTL cleanup is defense-in-depth for edge cases where re-auth processor crashes
|
| 138 |
self._unavailable_credentials: Dict[
|
| 139 |
str, float
|
| 140 |
] = {} # Maps credential path -> timestamp when marked unavailable
|
| 141 |
+
# TTL should exceed reauth timeout (300s) to avoid premature cleanup
|
| 142 |
+
self._unavailable_ttl_seconds: int = 360 # 6 minutes TTL for stale entries
|
| 143 |
self._queue_tracking_lock = asyncio.Lock() # Protects queue sets
|
| 144 |
+
|
| 145 |
+
# Retry tracking for normal refresh queue
|
| 146 |
+
self._queue_retry_count: Dict[
|
| 147 |
+
str, int
|
| 148 |
+
] = {} # Track retry attempts per credential
|
| 149 |
+
|
| 150 |
+
# Configuration constants
|
| 151 |
+
self._refresh_timeout_seconds: int = 15 # Max time for single refresh
|
| 152 |
+
self._refresh_interval_seconds: int = 30 # Delay between queue items
|
| 153 |
+
self._refresh_max_retries: int = 3 # Attempts before kicked out
|
| 154 |
+
self._reauth_timeout_seconds: int = 300 # Time for user to complete OAuth
|
| 155 |
|
| 156 |
def _parse_env_credential_path(self, path: str) -> Optional[str]:
|
| 157 |
"""
|
|
|
|
| 284 |
f"Environment variables for {self.ENV_PREFIX} credential index {credential_index} not found"
|
| 285 |
)
|
| 286 |
|
| 287 |
+
# Try file-based loading first (preferred for explicit file paths)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
try:
|
| 289 |
lib_logger.debug(
|
| 290 |
f"Loading {self.ENV_PREFIX} credentials from file: {path}"
|
|
|
|
| 297 |
self._credentials_cache[path] = creds
|
| 298 |
return creds
|
| 299 |
except FileNotFoundError:
|
| 300 |
+
# File not found - fall back to legacy env vars for backwards compatibility
|
| 301 |
+
# This handles the case where only env vars are set and file paths are placeholders
|
| 302 |
+
env_creds = self._load_from_env()
|
| 303 |
+
if env_creds:
|
| 304 |
+
lib_logger.info(
|
| 305 |
+
f"File '{path}' not found, using {self.ENV_PREFIX} credentials from environment variables"
|
| 306 |
+
)
|
| 307 |
+
self._credentials_cache[path] = env_creds
|
| 308 |
+
return env_creds
|
| 309 |
raise IOError(
|
| 310 |
f"{self.ENV_PREFIX} OAuth credential file not found at '{path}'"
|
| 311 |
)
|
|
|
|
| 313 |
raise IOError(
|
| 314 |
f"Failed to load {self.ENV_PREFIX} OAuth credentials from '{path}': {e}"
|
| 315 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
|
| 317 |
async def _save_credentials(self, path: str, creds: Dict[str, Any]):
|
| 318 |
+
"""Save credentials with in-memory fallback if disk unavailable."""
|
| 319 |
+
# Always update cache first (memory is reliable)
|
| 320 |
+
self._credentials_cache[path] = creds
|
| 321 |
+
|
| 322 |
# Don't save to file if credentials were loaded from environment
|
| 323 |
if creds.get("_proxy_metadata", {}).get("loaded_from_env"):
|
| 324 |
lib_logger.debug("Credentials loaded from env, skipping file save")
|
|
|
|
|
|
|
| 325 |
return
|
| 326 |
|
| 327 |
+
# Attempt disk write - if it fails, we still have the cache
|
| 328 |
+
# buffer_on_failure ensures data is retried periodically and saved on shutdown
|
| 329 |
+
if safe_write_json(
|
| 330 |
+
path, creds, lib_logger, secure_permissions=True, buffer_on_failure=True
|
| 331 |
+
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
lib_logger.debug(
|
| 333 |
+
f"Saved updated {self.ENV_PREFIX} OAuth credentials to '{path}'."
|
| 334 |
)
|
| 335 |
+
else:
|
| 336 |
+
lib_logger.warning(
|
| 337 |
+
f"Credentials for {self.ENV_PREFIX} cached in memory only (buffered for retry)."
|
|
|
|
| 338 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
|
| 340 |
def _is_token_expired(self, creds: Dict[str, Any]) -> bool:
|
| 341 |
expiry = creds.get("token_expiry") # gcloud format
|
|
|
|
| 532 |
"""Proactively refresh a credential by queueing it for refresh."""
|
| 533 |
creds = await self._load_credentials(credential_path)
|
| 534 |
if self._is_token_expired(creds):
|
| 535 |
+
# lib_logger.info(f"Proactive refresh triggered for '{Path(credential_path).name}'")
|
| 536 |
await self._queue_refresh(credential_path, force=False, needs_reauth=False)
|
| 537 |
|
| 538 |
async def _get_lock(self, path: str) -> asyncio.Lock:
|
|
|
|
| 543 |
self._refresh_locks[path] = asyncio.Lock()
|
| 544 |
return self._refresh_locks[path]
|
| 545 |
|
| 546 |
+
def _is_token_truly_expired(self, creds: Dict[str, Any]) -> bool:
|
| 547 |
+
"""Check if token is TRULY expired (past actual expiry, not just threshold).
|
| 548 |
+
|
| 549 |
+
This is different from _is_token_expired() which uses a buffer for proactive refresh.
|
| 550 |
+
This method checks if the token is actually unusable.
|
| 551 |
+
"""
|
| 552 |
+
expiry = creds.get("token_expiry") # gcloud format
|
| 553 |
+
if not expiry: # gemini-cli format
|
| 554 |
+
expiry_timestamp = creds.get("expiry_date", 0) / 1000
|
| 555 |
+
else:
|
| 556 |
+
expiry_timestamp = time.mktime(time.strptime(expiry, "%Y-%m-%dT%H:%M:%SZ"))
|
| 557 |
+
return expiry_timestamp < time.time()
|
| 558 |
+
|
| 559 |
def is_credential_available(self, path: str) -> bool:
|
| 560 |
+
"""Check if a credential is available for rotation.
|
| 561 |
+
|
| 562 |
+
Credentials are unavailable if:
|
| 563 |
+
1. In re-auth queue (token is truly broken, requires user interaction)
|
| 564 |
+
2. Token is TRULY expired (past actual expiry, not just threshold)
|
| 565 |
|
| 566 |
+
Note: Credentials in normal refresh queue are still available because
|
| 567 |
+
the old token is valid until actual expiry.
|
| 568 |
+
|
| 569 |
+
TTL cleanup (defense-in-depth): If a credential has been in the re-auth
|
| 570 |
+
queue longer than _unavailable_ttl_seconds without being processed, it's
|
| 571 |
+
cleaned up. This should only happen if the re-auth processor crashes or
|
| 572 |
+
is cancelled without proper cleanup.
|
| 573 |
"""
|
| 574 |
+
# Check if in re-auth queue (truly unavailable)
|
| 575 |
+
if path in self._unavailable_credentials:
|
| 576 |
+
marked_time = self._unavailable_credentials.get(path)
|
| 577 |
+
if marked_time is not None:
|
| 578 |
+
now = time.time()
|
| 579 |
+
if now - marked_time > self._unavailable_ttl_seconds:
|
| 580 |
+
# Entry is stale - clean it up and return available
|
| 581 |
+
# This is a defense-in-depth for edge cases where re-auth
|
| 582 |
+
# processor crashed or was cancelled without cleanup
|
| 583 |
+
lib_logger.warning(
|
| 584 |
+
f"Credential '{Path(path).name}' stuck in re-auth queue for "
|
| 585 |
+
f"{int(now - marked_time)}s (TTL: {self._unavailable_ttl_seconds}s). "
|
| 586 |
+
f"Re-auth processor may have crashed. Auto-cleaning stale entry."
|
| 587 |
+
)
|
| 588 |
+
# Clean up both tracking structures for consistency
|
| 589 |
+
self._unavailable_credentials.pop(path, None)
|
| 590 |
+
self._queued_credentials.discard(path)
|
| 591 |
+
else:
|
| 592 |
+
return False # Still in re-auth, not available
|
| 593 |
|
| 594 |
+
# Check if token is TRULY expired (not just threshold-expired)
|
| 595 |
+
creds = self._credentials_cache.get(path)
|
| 596 |
+
if creds and self._is_token_truly_expired(creds):
|
| 597 |
+
# Token is actually expired - should not be used
|
| 598 |
+
# Queue for refresh if not already queued
|
| 599 |
+
if path not in self._queued_credentials:
|
| 600 |
+
# lib_logger.debug(
|
| 601 |
+
# f"Credential '{Path(path).name}' is truly expired, queueing for refresh"
|
| 602 |
+
# )
|
| 603 |
+
asyncio.create_task(
|
| 604 |
+
self._queue_refresh(path, force=True, needs_reauth=False)
|
| 605 |
)
|
| 606 |
+
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
| 607 |
|
| 608 |
+
return True
|
| 609 |
|
| 610 |
async def _ensure_queue_processor_running(self):
|
| 611 |
"""Lazily starts the queue processor if not already running."""
|
|
|
|
| 614 |
self._process_refresh_queue()
|
| 615 |
)
|
| 616 |
|
| 617 |
+
async def _ensure_reauth_processor_running(self):
|
| 618 |
+
"""Lazily starts the re-auth queue processor if not already running."""
|
| 619 |
+
if self._reauth_processor_task is None or self._reauth_processor_task.done():
|
| 620 |
+
self._reauth_processor_task = asyncio.create_task(
|
| 621 |
+
self._process_reauth_queue()
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
async def _queue_refresh(
|
| 625 |
self, path: str, force: bool = False, needs_reauth: bool = False
|
| 626 |
):
|
| 627 |
+
"""Add a credential to the appropriate refresh queue if not already queued.
|
| 628 |
|
| 629 |
Args:
|
| 630 |
path: Credential file path
|
| 631 |
force: Force refresh even if not expired
|
| 632 |
+
needs_reauth: True if full re-authentication needed (routes to re-auth queue)
|
| 633 |
+
|
| 634 |
+
Queue routing:
|
| 635 |
+
- needs_reauth=True: Goes to re-auth queue, marks as unavailable
|
| 636 |
+
- needs_reauth=False: Goes to normal refresh queue, does NOT mark unavailable
|
| 637 |
+
(old token is still valid until actual expiry)
|
| 638 |
"""
|
| 639 |
# IMPORTANT: Only check backoff for simple automated refreshes
|
| 640 |
# Re-authentication (interactive OAuth) should BYPASS backoff since it needs user input
|
|
|
|
| 644 |
backoff_until = self._next_refresh_after[path]
|
| 645 |
if now < backoff_until:
|
| 646 |
# Credential is in backoff for automated refresh, do not queue
|
| 647 |
+
# remaining = int(backoff_until - now)
|
| 648 |
+
# lib_logger.debug(
|
| 649 |
+
# f"Skipping automated refresh for '{Path(path).name}' (in backoff for {remaining}s)"
|
| 650 |
+
# )
|
| 651 |
return
|
| 652 |
|
| 653 |
async with self._queue_tracking_lock:
|
| 654 |
if path not in self._queued_credentials:
|
| 655 |
self._queued_credentials.add(path)
|
| 656 |
+
|
| 657 |
+
if needs_reauth:
|
| 658 |
+
# Re-auth queue: mark as unavailable (token is truly broken)
|
| 659 |
+
self._unavailable_credentials[path] = time.time()
|
| 660 |
+
# lib_logger.debug(
|
| 661 |
+
# f"Queued '{Path(path).name}' for RE-AUTH (marked unavailable). "
|
| 662 |
+
# f"Total unavailable: {len(self._unavailable_credentials)}"
|
| 663 |
+
# )
|
| 664 |
+
await self._reauth_queue.put(path)
|
| 665 |
+
await self._ensure_reauth_processor_running()
|
| 666 |
+
else:
|
| 667 |
+
# Normal refresh queue: do NOT mark unavailable (old token still valid)
|
| 668 |
+
# lib_logger.debug(
|
| 669 |
+
# f"Queued '{Path(path).name}' for refresh (still available). "
|
| 670 |
+
# f"Queue size: {self._refresh_queue.qsize() + 1}"
|
| 671 |
+
# )
|
| 672 |
+
await self._refresh_queue.put((path, force))
|
| 673 |
+
await self._ensure_queue_processor_running()
|
| 674 |
|
| 675 |
async def _process_refresh_queue(self):
|
| 676 |
+
"""Background worker that processes normal refresh requests sequentially.
|
| 677 |
+
|
| 678 |
+
Key behaviors:
|
| 679 |
+
- 15s timeout per refresh operation
|
| 680 |
+
- 30s delay between processing credentials (prevents thundering herd)
|
| 681 |
+
- On failure: back of queue, max 3 retries before kicked
|
| 682 |
+
- If 401/403 detected: routes to re-auth queue
|
| 683 |
+
- Does NOT mark credentials unavailable (old token still valid)
|
| 684 |
+
"""
|
| 685 |
+
# lib_logger.info("Refresh queue processor started")
|
| 686 |
while True:
|
| 687 |
path = None
|
| 688 |
try:
|
| 689 |
# Wait for an item with timeout to allow graceful shutdown
|
| 690 |
try:
|
| 691 |
+
path, force = await asyncio.wait_for(
|
| 692 |
self._refresh_queue.get(), timeout=60.0
|
| 693 |
)
|
| 694 |
except asyncio.TimeoutError:
|
| 695 |
+
# Queue is empty and idle for 60s - clean up and exit
|
|
|
|
| 696 |
async with self._queue_tracking_lock:
|
| 697 |
+
# Clear any stale retry counts
|
| 698 |
+
self._queue_retry_count.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 699 |
self._queue_processor_task = None
|
| 700 |
+
# lib_logger.debug("Refresh queue processor idle, shutting down")
|
| 701 |
return
|
| 702 |
|
| 703 |
try:
|
| 704 |
+
# Quick check if still expired (optimization to avoid unnecessary refresh)
|
| 705 |
+
creds = self._credentials_cache.get(path)
|
| 706 |
+
if creds and not self._is_token_expired(creds):
|
| 707 |
+
# No longer expired, skip refresh
|
| 708 |
+
# lib_logger.debug(
|
| 709 |
+
# f"Credential '{Path(path).name}' no longer expired, skipping refresh"
|
| 710 |
+
# )
|
| 711 |
+
# Clear retry count on skip (not a failure)
|
| 712 |
+
self._queue_retry_count.pop(path, None)
|
| 713 |
+
continue
|
| 714 |
+
|
| 715 |
+
# Perform refresh with timeout
|
| 716 |
+
if not creds:
|
| 717 |
+
creds = await self._load_credentials(path)
|
| 718 |
+
|
| 719 |
+
try:
|
| 720 |
+
async with asyncio.timeout(self._refresh_timeout_seconds):
|
| 721 |
+
await self._refresh_token(path, creds, force=force)
|
| 722 |
|
| 723 |
+
# SUCCESS: Clear retry count
|
| 724 |
+
self._queue_retry_count.pop(path, None)
|
| 725 |
+
# lib_logger.info(f"Refresh SUCCESS for '{Path(path).name}'")
|
| 726 |
+
|
| 727 |
+
except asyncio.TimeoutError:
|
| 728 |
+
lib_logger.warning(
|
| 729 |
+
f"Refresh timeout ({self._refresh_timeout_seconds}s) for '{Path(path).name}'"
|
| 730 |
+
)
|
| 731 |
+
await self._handle_refresh_failure(path, force, "timeout")
|
| 732 |
+
|
| 733 |
+
except httpx.HTTPStatusError as e:
|
| 734 |
+
status_code = e.response.status_code
|
| 735 |
+
if status_code in (401, 403):
|
| 736 |
+
# Invalid refresh token - route to re-auth queue
|
| 737 |
+
lib_logger.warning(
|
| 738 |
+
f"Refresh token invalid for '{Path(path).name}' (HTTP {status_code}). "
|
| 739 |
+
f"Routing to re-auth queue."
|
| 740 |
)
|
| 741 |
+
self._queue_retry_count.pop(path, None) # Clear retry count
|
| 742 |
+
async with self._queue_tracking_lock:
|
| 743 |
+
self._queued_credentials.discard(
|
| 744 |
+
path
|
| 745 |
+
) # Remove from queued
|
| 746 |
+
await self._queue_refresh(
|
| 747 |
+
path, force=True, needs_reauth=True
|
| 748 |
+
)
|
| 749 |
+
else:
|
| 750 |
+
await self._handle_refresh_failure(
|
| 751 |
+
path, force, f"HTTP {status_code}"
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
except Exception as e:
|
| 755 |
+
await self._handle_refresh_failure(path, force, str(e))
|
| 756 |
+
|
| 757 |
+
finally:
|
| 758 |
+
# Remove from queued set (unless re-queued by failure handler)
|
| 759 |
+
async with self._queue_tracking_lock:
|
| 760 |
+
# Only discard if not re-queued (check if still in queue set from retry)
|
| 761 |
+
if (
|
| 762 |
+
path in self._queued_credentials
|
| 763 |
+
and self._queue_retry_count.get(path, 0) == 0
|
| 764 |
+
):
|
| 765 |
+
self._queued_credentials.discard(path)
|
| 766 |
+
self._refresh_queue.task_done()
|
| 767 |
+
|
| 768 |
+
# Wait between credentials to spread load
|
| 769 |
+
await asyncio.sleep(self._refresh_interval_seconds)
|
| 770 |
+
|
| 771 |
+
except asyncio.CancelledError:
|
| 772 |
+
# lib_logger.debug("Refresh queue processor cancelled")
|
| 773 |
+
break
|
| 774 |
+
except Exception as e:
|
| 775 |
+
lib_logger.error(f"Error in refresh queue processor: {e}")
|
| 776 |
+
if path:
|
| 777 |
+
async with self._queue_tracking_lock:
|
| 778 |
+
self._queued_credentials.discard(path)
|
| 779 |
+
|
| 780 |
+
async def _handle_refresh_failure(self, path: str, force: bool, error: str):
|
| 781 |
+
"""Handle a refresh failure with back-of-line retry logic.
|
| 782 |
+
|
| 783 |
+
- Increments retry count
|
| 784 |
+
- If under max retries: re-adds to END of queue
|
| 785 |
+
- If at max retries: kicks credential out (retried next BackgroundRefresher cycle)
|
| 786 |
+
"""
|
| 787 |
+
retry_count = self._queue_retry_count.get(path, 0) + 1
|
| 788 |
+
self._queue_retry_count[path] = retry_count
|
| 789 |
+
|
| 790 |
+
if retry_count >= self._refresh_max_retries:
|
| 791 |
+
# Kicked out until next BackgroundRefresher cycle
|
| 792 |
+
lib_logger.error(
|
| 793 |
+
f"Max retries ({self._refresh_max_retries}) reached for '{Path(path).name}' "
|
| 794 |
+
f"(last error: {error}). Will retry next refresh cycle."
|
| 795 |
+
)
|
| 796 |
+
self._queue_retry_count.pop(path, None)
|
| 797 |
+
async with self._queue_tracking_lock:
|
| 798 |
+
self._queued_credentials.discard(path)
|
| 799 |
+
return
|
| 800 |
+
|
| 801 |
+
# Re-add to END of queue for retry
|
| 802 |
+
lib_logger.warning(
|
| 803 |
+
f"Refresh failed for '{Path(path).name}' ({error}). "
|
| 804 |
+
f"Retry {retry_count}/{self._refresh_max_retries}, back of queue."
|
| 805 |
+
)
|
| 806 |
+
# Keep in queued_credentials set, add back to queue
|
| 807 |
+
await self._refresh_queue.put((path, force))
|
| 808 |
+
|
| 809 |
+
async def _process_reauth_queue(self):
|
| 810 |
+
"""Background worker that processes re-auth requests.
|
| 811 |
+
|
| 812 |
+
Key behaviors:
|
| 813 |
+
- Credentials ARE marked unavailable (token is truly broken)
|
| 814 |
+
- Uses ReauthCoordinator for interactive OAuth
|
| 815 |
+
- No automatic retry (requires user action)
|
| 816 |
+
- Cleans up unavailable status when done
|
| 817 |
+
"""
|
| 818 |
+
# lib_logger.info("Re-auth queue processor started")
|
| 819 |
+
while True:
|
| 820 |
+
path = None
|
| 821 |
+
try:
|
| 822 |
+
# Wait for an item with timeout to allow graceful shutdown
|
| 823 |
+
try:
|
| 824 |
+
path = await asyncio.wait_for(
|
| 825 |
+
self._reauth_queue.get(), timeout=60.0
|
| 826 |
+
)
|
| 827 |
+
except asyncio.TimeoutError:
|
| 828 |
+
# Queue is empty and idle for 60s - exit
|
| 829 |
+
self._reauth_processor_task = None
|
| 830 |
+
# lib_logger.debug("Re-auth queue processor idle, shutting down")
|
| 831 |
+
return
|
| 832 |
+
|
| 833 |
+
try:
|
| 834 |
+
lib_logger.info(f"Starting re-auth for '{Path(path).name}'...")
|
| 835 |
+
await self.initialize_token(path)
|
| 836 |
+
lib_logger.info(f"Re-auth SUCCESS for '{Path(path).name}'")
|
| 837 |
+
|
| 838 |
+
except Exception as e:
|
| 839 |
+
lib_logger.error(f"Re-auth FAILED for '{Path(path).name}': {e}")
|
| 840 |
+
# No automatic retry for re-auth (requires user action)
|
| 841 |
|
| 842 |
finally:
|
| 843 |
+
# Always clean up
|
|
|
|
| 844 |
async with self._queue_tracking_lock:
|
| 845 |
self._queued_credentials.discard(path)
|
|
|
|
| 846 |
self._unavailable_credentials.pop(path, None)
|
| 847 |
+
# lib_logger.debug(
|
| 848 |
+
# f"Re-auth cleanup for '{Path(path).name}'. "
|
| 849 |
+
# f"Remaining unavailable: {len(self._unavailable_credentials)}"
|
| 850 |
+
# )
|
| 851 |
+
self._reauth_queue.task_done()
|
| 852 |
+
|
| 853 |
except asyncio.CancelledError:
|
| 854 |
+
# Clean up current credential before breaking
|
| 855 |
if path:
|
| 856 |
async with self._queue_tracking_lock:
|
| 857 |
+
self._queued_credentials.discard(path)
|
| 858 |
self._unavailable_credentials.pop(path, None)
|
| 859 |
+
# lib_logger.debug("Re-auth queue processor cancelled")
|
|
|
|
|
|
|
|
|
|
| 860 |
break
|
| 861 |
except Exception as e:
|
| 862 |
+
lib_logger.error(f"Error in re-auth queue processor: {e}")
|
|
|
|
| 863 |
if path:
|
| 864 |
async with self._queue_tracking_lock:
|
| 865 |
+
self._queued_credentials.discard(path)
|
| 866 |
self._unavailable_credentials.pop(path, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 867 |
|
| 868 |
async def _perform_interactive_oauth(
|
| 869 |
self, path: str, creds: Dict[str, Any], display_name: str
|
|
|
|
| 923 |
|
| 924 |
try:
|
| 925 |
server = await asyncio.start_server(
|
| 926 |
+
handle_callback, "127.0.0.1", self.callback_port
|
| 927 |
)
|
| 928 |
from urllib.parse import urlencode
|
| 929 |
|
| 930 |
auth_url = "https://accounts.google.com/o/oauth2/v2/auth?" + urlencode(
|
| 931 |
{
|
| 932 |
"client_id": self.CLIENT_ID,
|
| 933 |
+
"redirect_uri": f"http://localhost:{self.callback_port}{self.CALLBACK_PATH}",
|
| 934 |
"scope": " ".join(self.OAUTH_SCOPES),
|
| 935 |
"access_type": "offline",
|
| 936 |
"response_type": "code",
|
|
|
|
| 1005 |
"code": auth_code.strip(),
|
| 1006 |
"client_id": self.CLIENT_ID,
|
| 1007 |
"client_secret": self.CLIENT_SECRET,
|
| 1008 |
+
"redirect_uri": f"http://localhost:{self.callback_port}{self.CALLBACK_PATH}",
|
| 1009 |
"grant_type": "authorization_code",
|
| 1010 |
},
|
| 1011 |
)
|
|
|
|
| 1043 |
lib_logger.info(
|
| 1044 |
f"{self.ENV_PREFIX} OAuth initialized successfully for '{display_name}'."
|
| 1045 |
)
|
| 1046 |
+
|
| 1047 |
+
# Perform post-auth discovery (tier, project, etc.) while we have a fresh token
|
| 1048 |
+
if path:
|
| 1049 |
+
try:
|
| 1050 |
+
await self._post_auth_discovery(path, new_creds["access_token"])
|
| 1051 |
+
except Exception as e:
|
| 1052 |
+
# Don't fail auth if discovery fails - it can be retried on first request
|
| 1053 |
+
lib_logger.warning(
|
| 1054 |
+
f"Post-auth discovery failed for '{display_name}': {e}. "
|
| 1055 |
+
"Tier/project will be discovered on first request."
|
| 1056 |
+
)
|
| 1057 |
+
|
| 1058 |
return new_creds
|
| 1059 |
|
| 1060 |
async def initialize_token(
|
|
|
|
| 1131 |
)
|
| 1132 |
|
| 1133 |
async def get_auth_header(self, credential_path: str) -> Dict[str, str]:
|
| 1134 |
+
"""Get auth header with graceful degradation if refresh fails."""
|
| 1135 |
+
try:
|
| 1136 |
+
creds = await self._load_credentials(credential_path)
|
| 1137 |
+
if self._is_token_expired(creds):
|
| 1138 |
+
try:
|
| 1139 |
+
creds = await self._refresh_token(credential_path, creds)
|
| 1140 |
+
except Exception as e:
|
| 1141 |
+
# Check if we have a cached token that might still work
|
| 1142 |
+
cached = self._credentials_cache.get(credential_path)
|
| 1143 |
+
if cached and cached.get("access_token"):
|
| 1144 |
+
lib_logger.warning(
|
| 1145 |
+
f"Token refresh failed for {Path(credential_path).name}: {e}. "
|
| 1146 |
+
"Using cached token (may be expired)."
|
| 1147 |
+
)
|
| 1148 |
+
creds = cached
|
| 1149 |
+
else:
|
| 1150 |
+
raise
|
| 1151 |
+
return {"Authorization": f"Bearer {creds['access_token']}"}
|
| 1152 |
+
except Exception as e:
|
| 1153 |
+
# Check if any cached credential exists as last resort
|
| 1154 |
+
cached = self._credentials_cache.get(credential_path)
|
| 1155 |
+
if cached and cached.get("access_token"):
|
| 1156 |
+
lib_logger.error(
|
| 1157 |
+
f"Credential load failed for {credential_path}: {e}. "
|
| 1158 |
+
"Using stale cached token as last resort."
|
| 1159 |
+
)
|
| 1160 |
+
return {"Authorization": f"Bearer {cached['access_token']}"}
|
| 1161 |
+
raise
|
| 1162 |
+
|
| 1163 |
+
async def _post_auth_discovery(
|
| 1164 |
+
self, credential_path: str, access_token: str
|
| 1165 |
+
) -> None:
|
| 1166 |
+
"""
|
| 1167 |
+
Hook for subclasses to perform post-authentication discovery.
|
| 1168 |
+
|
| 1169 |
+
Called after successful OAuth authentication (both initial and re-auth).
|
| 1170 |
+
Subclasses can override this to discover and cache tier/project information
|
| 1171 |
+
during the authentication flow rather than waiting for the first API request.
|
| 1172 |
+
|
| 1173 |
+
Args:
|
| 1174 |
+
credential_path: Path to the credential file
|
| 1175 |
+
access_token: The newly obtained access token
|
| 1176 |
+
"""
|
| 1177 |
+
# Default implementation does nothing - subclasses can override
|
| 1178 |
+
pass
|
| 1179 |
|
| 1180 |
async def get_user_info(
|
| 1181 |
self, creds_or_path: Union[Dict[str, Any], str]
|
|
|
|
| 1208 |
if path:
|
| 1209 |
await self._save_credentials(path, creds)
|
| 1210 |
return {"email": user_info.get("email")}
|
| 1211 |
+
|
| 1212 |
+
# =========================================================================
|
| 1213 |
+
# CREDENTIAL MANAGEMENT METHODS
|
| 1214 |
+
# =========================================================================
|
| 1215 |
+
|
| 1216 |
+
def _get_provider_file_prefix(self) -> str:
|
| 1217 |
+
"""
|
| 1218 |
+
Get the file name prefix for this provider's credential files.
|
| 1219 |
+
|
| 1220 |
+
Override in subclasses if the prefix differs from ENV_PREFIX.
|
| 1221 |
+
Default: lowercase ENV_PREFIX with underscores (e.g., "gemini_cli")
|
| 1222 |
+
"""
|
| 1223 |
+
return self.ENV_PREFIX.lower()
|
| 1224 |
+
|
| 1225 |
+
def _get_oauth_base_dir(self) -> Path:
|
| 1226 |
+
"""
|
| 1227 |
+
Get the base directory for OAuth credential files.
|
| 1228 |
+
|
| 1229 |
+
Can be overridden to customize credential storage location.
|
| 1230 |
+
"""
|
| 1231 |
+
return Path.cwd() / "oauth_creds"
|
| 1232 |
+
|
| 1233 |
+
def _find_existing_credential_by_email(
|
| 1234 |
+
self, email: str, base_dir: Optional[Path] = None
|
| 1235 |
+
) -> Optional[Path]:
|
| 1236 |
+
"""
|
| 1237 |
+
Find an existing credential file for the given email.
|
| 1238 |
+
|
| 1239 |
+
Args:
|
| 1240 |
+
email: Email address to search for
|
| 1241 |
+
base_dir: Directory to search in (defaults to oauth_creds)
|
| 1242 |
+
|
| 1243 |
+
Returns:
|
| 1244 |
+
Path to existing credential file, or None if not found
|
| 1245 |
+
"""
|
| 1246 |
+
if base_dir is None:
|
| 1247 |
+
base_dir = self._get_oauth_base_dir()
|
| 1248 |
+
|
| 1249 |
+
prefix = self._get_provider_file_prefix()
|
| 1250 |
+
pattern = str(base_dir / f"{prefix}_oauth_*.json")
|
| 1251 |
+
|
| 1252 |
+
for cred_file in glob(pattern):
|
| 1253 |
+
try:
|
| 1254 |
+
with open(cred_file, "r") as f:
|
| 1255 |
+
creds = json.load(f)
|
| 1256 |
+
existing_email = creds.get("_proxy_metadata", {}).get("email")
|
| 1257 |
+
if existing_email == email:
|
| 1258 |
+
return Path(cred_file)
|
| 1259 |
+
except (json.JSONDecodeError, IOError) as e:
|
| 1260 |
+
lib_logger.debug(f"Could not read credential file {cred_file}: {e}")
|
| 1261 |
+
continue
|
| 1262 |
+
|
| 1263 |
+
return None
|
| 1264 |
+
|
| 1265 |
+
def _get_next_credential_number(self, base_dir: Optional[Path] = None) -> int:
|
| 1266 |
+
"""
|
| 1267 |
+
Get the next available credential number for new credential files.
|
| 1268 |
+
|
| 1269 |
+
Args:
|
| 1270 |
+
base_dir: Directory to scan (defaults to oauth_creds)
|
| 1271 |
+
|
| 1272 |
+
Returns:
|
| 1273 |
+
Next available credential number (1, 2, 3, etc.)
|
| 1274 |
+
"""
|
| 1275 |
+
if base_dir is None:
|
| 1276 |
+
base_dir = self._get_oauth_base_dir()
|
| 1277 |
+
|
| 1278 |
+
prefix = self._get_provider_file_prefix()
|
| 1279 |
+
pattern = str(base_dir / f"{prefix}_oauth_*.json")
|
| 1280 |
+
|
| 1281 |
+
existing_numbers = []
|
| 1282 |
+
for cred_file in glob(pattern):
|
| 1283 |
+
match = re.search(r"_oauth_(\d+)\.json$", cred_file)
|
| 1284 |
+
if match:
|
| 1285 |
+
existing_numbers.append(int(match.group(1)))
|
| 1286 |
+
|
| 1287 |
+
if not existing_numbers:
|
| 1288 |
+
return 1
|
| 1289 |
+
return max(existing_numbers) + 1
|
| 1290 |
+
|
| 1291 |
+
def _build_credential_path(
|
| 1292 |
+
self, base_dir: Optional[Path] = None, number: Optional[int] = None
|
| 1293 |
+
) -> Path:
|
| 1294 |
+
"""
|
| 1295 |
+
Build a path for a new credential file.
|
| 1296 |
+
|
| 1297 |
+
Args:
|
| 1298 |
+
base_dir: Directory for credential files (defaults to oauth_creds)
|
| 1299 |
+
number: Credential number (auto-determined if None)
|
| 1300 |
+
|
| 1301 |
+
Returns:
|
| 1302 |
+
Path for the new credential file
|
| 1303 |
+
"""
|
| 1304 |
+
if base_dir is None:
|
| 1305 |
+
base_dir = self._get_oauth_base_dir()
|
| 1306 |
+
|
| 1307 |
+
if number is None:
|
| 1308 |
+
number = self._get_next_credential_number(base_dir)
|
| 1309 |
+
|
| 1310 |
+
prefix = self._get_provider_file_prefix()
|
| 1311 |
+
filename = f"{prefix}_oauth_{number}.json"
|
| 1312 |
+
return base_dir / filename
|
| 1313 |
+
|
| 1314 |
+
async def setup_credential(
|
| 1315 |
+
self, base_dir: Optional[Path] = None
|
| 1316 |
+
) -> CredentialSetupResult:
|
| 1317 |
+
"""
|
| 1318 |
+
Complete credential setup flow: OAuth -> save -> discovery.
|
| 1319 |
+
|
| 1320 |
+
This is the main entry point for setting up new credentials.
|
| 1321 |
+
Handles the entire lifecycle:
|
| 1322 |
+
1. Perform OAuth authentication
|
| 1323 |
+
2. Get user info (email) for deduplication
|
| 1324 |
+
3. Find existing credential or create new file path
|
| 1325 |
+
4. Save credentials to file
|
| 1326 |
+
5. Perform post-auth discovery (tier/project for Google OAuth)
|
| 1327 |
+
|
| 1328 |
+
Args:
|
| 1329 |
+
base_dir: Directory for credential files (defaults to oauth_creds)
|
| 1330 |
+
|
| 1331 |
+
Returns:
|
| 1332 |
+
CredentialSetupResult with status and details
|
| 1333 |
+
"""
|
| 1334 |
+
if base_dir is None:
|
| 1335 |
+
base_dir = self._get_oauth_base_dir()
|
| 1336 |
+
|
| 1337 |
+
# Ensure directory exists
|
| 1338 |
+
base_dir.mkdir(exist_ok=True)
|
| 1339 |
+
|
| 1340 |
+
try:
|
| 1341 |
+
# Step 1: Perform OAuth authentication (returns credentials dict)
|
| 1342 |
+
temp_creds = {
|
| 1343 |
+
"_proxy_metadata": {"display_name": f"new {self.ENV_PREFIX} credential"}
|
| 1344 |
+
}
|
| 1345 |
+
new_creds = await self.initialize_token(temp_creds)
|
| 1346 |
+
|
| 1347 |
+
# Step 2: Get user info for deduplication
|
| 1348 |
+
user_info = await self.get_user_info(new_creds)
|
| 1349 |
+
email = user_info.get("email")
|
| 1350 |
+
|
| 1351 |
+
if not email:
|
| 1352 |
+
return CredentialSetupResult(
|
| 1353 |
+
success=False, error="Could not retrieve email from OAuth response"
|
| 1354 |
+
)
|
| 1355 |
+
|
| 1356 |
+
# Step 3: Check for existing credential with same email
|
| 1357 |
+
existing_path = self._find_existing_credential_by_email(email, base_dir)
|
| 1358 |
+
is_update = existing_path is not None
|
| 1359 |
+
|
| 1360 |
+
if is_update:
|
| 1361 |
+
file_path = existing_path
|
| 1362 |
+
lib_logger.info(
|
| 1363 |
+
f"Found existing credential for {email}, updating {file_path.name}"
|
| 1364 |
+
)
|
| 1365 |
+
else:
|
| 1366 |
+
file_path = self._build_credential_path(base_dir)
|
| 1367 |
+
lib_logger.info(
|
| 1368 |
+
f"Creating new credential for {email} at {file_path.name}"
|
| 1369 |
+
)
|
| 1370 |
+
|
| 1371 |
+
# Step 4: Save credentials to file
|
| 1372 |
+
await self._save_credentials(str(file_path), new_creds)
|
| 1373 |
+
|
| 1374 |
+
# Step 5: Perform post-auth discovery (tier, project_id)
|
| 1375 |
+
# This is already called in _perform_interactive_oauth, but we call it again
|
| 1376 |
+
# in case credentials were loaded from existing token
|
| 1377 |
+
tier = None
|
| 1378 |
+
project_id = None
|
| 1379 |
+
try:
|
| 1380 |
+
await self._post_auth_discovery(
|
| 1381 |
+
str(file_path), new_creds["access_token"]
|
| 1382 |
+
)
|
| 1383 |
+
# Reload credentials to get discovered metadata
|
| 1384 |
+
with open(file_path, "r") as f:
|
| 1385 |
+
updated_creds = json.load(f)
|
| 1386 |
+
tier = updated_creds.get("_proxy_metadata", {}).get("tier")
|
| 1387 |
+
project_id = updated_creds.get("_proxy_metadata", {}).get("project_id")
|
| 1388 |
+
new_creds = updated_creds
|
| 1389 |
+
except Exception as e:
|
| 1390 |
+
lib_logger.warning(
|
| 1391 |
+
f"Post-auth discovery failed: {e}. Tier/project will be discovered on first request."
|
| 1392 |
+
)
|
| 1393 |
+
|
| 1394 |
+
return CredentialSetupResult(
|
| 1395 |
+
success=True,
|
| 1396 |
+
file_path=str(file_path),
|
| 1397 |
+
email=email,
|
| 1398 |
+
tier=tier,
|
| 1399 |
+
project_id=project_id,
|
| 1400 |
+
is_update=is_update,
|
| 1401 |
+
credentials=new_creds,
|
| 1402 |
+
)
|
| 1403 |
+
|
| 1404 |
+
except Exception as e:
|
| 1405 |
+
lib_logger.error(f"Credential setup failed: {e}")
|
| 1406 |
+
return CredentialSetupResult(success=False, error=str(e))
|
| 1407 |
+
|
| 1408 |
+
def build_env_lines(self, creds: Dict[str, Any], cred_number: int) -> List[str]:
|
| 1409 |
+
"""
|
| 1410 |
+
Generate .env file lines for a credential.
|
| 1411 |
+
|
| 1412 |
+
Subclasses should override to include provider-specific fields
|
| 1413 |
+
(e.g., tier, project_id for Google OAuth providers).
|
| 1414 |
+
|
| 1415 |
+
Args:
|
| 1416 |
+
creds: Credential dictionary loaded from JSON
|
| 1417 |
+
cred_number: Credential number (1, 2, 3, etc.)
|
| 1418 |
+
|
| 1419 |
+
Returns:
|
| 1420 |
+
List of .env file lines
|
| 1421 |
+
"""
|
| 1422 |
+
email = creds.get("_proxy_metadata", {}).get("email", "unknown")
|
| 1423 |
+
prefix = f"{self.ENV_PREFIX}_{cred_number}"
|
| 1424 |
+
|
| 1425 |
+
lines = [
|
| 1426 |
+
f"# {self.ENV_PREFIX} Credential #{cred_number} for: {email}",
|
| 1427 |
+
f"# Exported from: {self._get_provider_file_prefix()}_oauth_{cred_number}.json",
|
| 1428 |
+
f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
|
| 1429 |
+
"#",
|
| 1430 |
+
"# To combine multiple credentials into one .env file, copy these lines",
|
| 1431 |
+
"# and ensure each credential has a unique number (1, 2, 3, etc.)",
|
| 1432 |
+
"",
|
| 1433 |
+
f"{prefix}_ACCESS_TOKEN={creds.get('access_token', '')}",
|
| 1434 |
+
f"{prefix}_REFRESH_TOKEN={creds.get('refresh_token', '')}",
|
| 1435 |
+
f"{prefix}_SCOPE={creds.get('scope', '')}",
|
| 1436 |
+
f"{prefix}_TOKEN_TYPE={creds.get('token_type', 'Bearer')}",
|
| 1437 |
+
f"{prefix}_ID_TOKEN={creds.get('id_token', '')}",
|
| 1438 |
+
f"{prefix}_EXPIRY_DATE={creds.get('expiry_date', 0)}",
|
| 1439 |
+
f"{prefix}_CLIENT_ID={creds.get('client_id', '')}",
|
| 1440 |
+
f"{prefix}_CLIENT_SECRET={creds.get('client_secret', '')}",
|
| 1441 |
+
f"{prefix}_TOKEN_URI={creds.get('token_uri', 'https://oauth2.googleapis.com/token')}",
|
| 1442 |
+
f"{prefix}_UNIVERSE_DOMAIN={creds.get('universe_domain', 'googleapis.com')}",
|
| 1443 |
+
f"{prefix}_EMAIL={email}",
|
| 1444 |
+
]
|
| 1445 |
+
|
| 1446 |
+
return lines
|
| 1447 |
+
|
| 1448 |
+
def export_credential_to_env(
|
| 1449 |
+
self, credential_path: str, output_dir: Optional[Path] = None
|
| 1450 |
+
) -> Optional[str]:
|
| 1451 |
+
"""
|
| 1452 |
+
Export a credential file to .env format.
|
| 1453 |
+
|
| 1454 |
+
Args:
|
| 1455 |
+
credential_path: Path to the credential JSON file
|
| 1456 |
+
output_dir: Directory for output .env file (defaults to same as credential)
|
| 1457 |
+
|
| 1458 |
+
Returns:
|
| 1459 |
+
Path to the exported .env file, or None on error
|
| 1460 |
+
"""
|
| 1461 |
+
try:
|
| 1462 |
+
cred_path = Path(credential_path)
|
| 1463 |
+
|
| 1464 |
+
# Load credential
|
| 1465 |
+
with open(cred_path, "r") as f:
|
| 1466 |
+
creds = json.load(f)
|
| 1467 |
+
|
| 1468 |
+
# Extract metadata
|
| 1469 |
+
email = creds.get("_proxy_metadata", {}).get("email", "unknown")
|
| 1470 |
+
|
| 1471 |
+
# Get credential number from filename
|
| 1472 |
+
match = re.search(r"_oauth_(\d+)\.json$", cred_path.name)
|
| 1473 |
+
cred_number = int(match.group(1)) if match else 1
|
| 1474 |
+
|
| 1475 |
+
# Build output path
|
| 1476 |
+
if output_dir is None:
|
| 1477 |
+
output_dir = cred_path.parent
|
| 1478 |
+
|
| 1479 |
+
safe_email = email.replace("@", "_at_").replace(".", "_")
|
| 1480 |
+
prefix = self._get_provider_file_prefix()
|
| 1481 |
+
env_filename = f"{prefix}_{cred_number}_{safe_email}.env"
|
| 1482 |
+
env_path = output_dir / env_filename
|
| 1483 |
+
|
| 1484 |
+
# Build and write content
|
| 1485 |
+
env_lines = self.build_env_lines(creds, cred_number)
|
| 1486 |
+
with open(env_path, "w") as f:
|
| 1487 |
+
f.write("\n".join(env_lines))
|
| 1488 |
+
|
| 1489 |
+
lib_logger.info(f"Exported credential to {env_path}")
|
| 1490 |
+
return str(env_path)
|
| 1491 |
+
|
| 1492 |
+
except Exception as e:
|
| 1493 |
+
lib_logger.error(f"Failed to export credential: {e}")
|
| 1494 |
+
return None
|
| 1495 |
+
|
| 1496 |
+
def list_credentials(self, base_dir: Optional[Path] = None) -> List[Dict[str, Any]]:
|
| 1497 |
+
"""
|
| 1498 |
+
List all credential files for this provider.
|
| 1499 |
+
|
| 1500 |
+
Args:
|
| 1501 |
+
base_dir: Directory to search (defaults to oauth_creds)
|
| 1502 |
+
|
| 1503 |
+
Returns:
|
| 1504 |
+
List of dicts with credential info:
|
| 1505 |
+
- file_path: Path to credential file
|
| 1506 |
+
- email: User email
|
| 1507 |
+
- tier: Tier info (if available)
|
| 1508 |
+
- project_id: Project ID (if available)
|
| 1509 |
+
- number: Credential number
|
| 1510 |
+
"""
|
| 1511 |
+
if base_dir is None:
|
| 1512 |
+
base_dir = self._get_oauth_base_dir()
|
| 1513 |
+
|
| 1514 |
+
prefix = self._get_provider_file_prefix()
|
| 1515 |
+
pattern = str(base_dir / f"{prefix}_oauth_*.json")
|
| 1516 |
+
|
| 1517 |
+
credentials = []
|
| 1518 |
+
for cred_file in sorted(glob(pattern)):
|
| 1519 |
+
try:
|
| 1520 |
+
with open(cred_file, "r") as f:
|
| 1521 |
+
creds = json.load(f)
|
| 1522 |
+
|
| 1523 |
+
metadata = creds.get("_proxy_metadata", {})
|
| 1524 |
+
|
| 1525 |
+
# Extract number from filename
|
| 1526 |
+
match = re.search(r"_oauth_(\d+)\.json$", cred_file)
|
| 1527 |
+
number = int(match.group(1)) if match else 0
|
| 1528 |
+
|
| 1529 |
+
credentials.append(
|
| 1530 |
+
{
|
| 1531 |
+
"file_path": cred_file,
|
| 1532 |
+
"email": metadata.get("email", "unknown"),
|
| 1533 |
+
"tier": metadata.get("tier"),
|
| 1534 |
+
"project_id": metadata.get("project_id"),
|
| 1535 |
+
"number": number,
|
| 1536 |
+
}
|
| 1537 |
+
)
|
| 1538 |
+
except Exception as e:
|
| 1539 |
+
lib_logger.debug(f"Could not read credential file {cred_file}: {e}")
|
| 1540 |
+
continue
|
| 1541 |
+
|
| 1542 |
+
return credentials
|
| 1543 |
+
|
| 1544 |
+
def delete_credential(self, credential_path: str) -> bool:
|
| 1545 |
+
"""
|
| 1546 |
+
Delete a credential file.
|
| 1547 |
+
|
| 1548 |
+
Args:
|
| 1549 |
+
credential_path: Path to the credential file
|
| 1550 |
+
|
| 1551 |
+
Returns:
|
| 1552 |
+
True if deleted successfully, False otherwise
|
| 1553 |
+
"""
|
| 1554 |
+
try:
|
| 1555 |
+
cred_path = Path(credential_path)
|
| 1556 |
+
|
| 1557 |
+
# Validate that it's one of our credential files
|
| 1558 |
+
prefix = self._get_provider_file_prefix()
|
| 1559 |
+
if not cred_path.name.startswith(f"{prefix}_oauth_"):
|
| 1560 |
+
lib_logger.error(
|
| 1561 |
+
f"File {cred_path.name} does not appear to be a {self.ENV_PREFIX} credential"
|
| 1562 |
+
)
|
| 1563 |
+
return False
|
| 1564 |
+
|
| 1565 |
+
if not cred_path.exists():
|
| 1566 |
+
lib_logger.warning(f"Credential file does not exist: {credential_path}")
|
| 1567 |
+
return False
|
| 1568 |
+
|
| 1569 |
+
# Remove from cache if present
|
| 1570 |
+
self._credentials_cache.pop(credential_path, None)
|
| 1571 |
+
|
| 1572 |
+
# Delete the file
|
| 1573 |
+
cred_path.unlink()
|
| 1574 |
+
lib_logger.info(f"Deleted credential file: {credential_path}")
|
| 1575 |
+
return True
|
| 1576 |
+
|
| 1577 |
+
except Exception as e:
|
| 1578 |
+
lib_logger.error(f"Failed to delete credential: {e}")
|
| 1579 |
+
return False
|
src/rotator_library/providers/iflow_auth_base.py
CHANGED
|
@@ -9,11 +9,12 @@ import logging
|
|
| 9 |
import webbrowser
|
| 10 |
import socket
|
| 11 |
import os
|
|
|
|
|
|
|
| 12 |
from pathlib import Path
|
| 13 |
-
from
|
|
|
|
| 14 |
from urllib.parse import urlencode, parse_qs, urlparse
|
| 15 |
-
import tempfile
|
| 16 |
-
import shutil
|
| 17 |
|
| 18 |
import httpx
|
| 19 |
from aiohttp import web
|
|
@@ -24,6 +25,7 @@ from rich.text import Text
|
|
| 24 |
from rich.markup import escape as rich_escape
|
| 25 |
from ..utils.headless_detection import is_headless_environment
|
| 26 |
from ..utils.reauth_coordinator import get_reauth_coordinator
|
|
|
|
| 27 |
|
| 28 |
lib_logger = logging.getLogger("rotator_library")
|
| 29 |
|
|
@@ -40,6 +42,39 @@ IFLOW_CLIENT_SECRET = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW"
|
|
| 40 |
# Local callback server port
|
| 41 |
CALLBACK_PORT = 11451
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
# Refresh tokens 24 hours before expiry
|
| 44 |
REFRESH_EXPIRY_BUFFER_SECONDS = 24 * 60 * 60
|
| 45 |
|
|
@@ -171,19 +206,36 @@ class IFlowAuthBase:
|
|
| 171 |
str, float
|
| 172 |
] = {} # Track backoff timers (Unix timestamp)
|
| 173 |
|
| 174 |
-
# [QUEUE SYSTEM] Sequential refresh processing
|
|
|
|
| 175 |
self._refresh_queue: asyncio.Queue = asyncio.Queue()
|
| 176 |
-
self.
|
| 177 |
-
|
| 178 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
self._unavailable_credentials: Dict[
|
| 180 |
str, float
|
| 181 |
] = {} # Maps credential path -> timestamp when marked unavailable
|
| 182 |
-
|
|
|
|
| 183 |
self._queue_tracking_lock = asyncio.Lock() # Protects queue sets
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
def _parse_env_credential_path(self, path: str) -> Optional[str]:
|
| 189 |
"""
|
|
@@ -305,76 +357,40 @@ class IFlowAuthBase:
|
|
| 305 |
f"Environment variables for iFlow credential index {credential_index} not found"
|
| 306 |
)
|
| 307 |
|
| 308 |
-
#
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
|
| 318 |
async def _save_credentials(self, path: str, creds: Dict[str, Any]):
|
| 319 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 320 |
# Don't save to file if credentials were loaded from environment
|
| 321 |
if creds.get("_proxy_metadata", {}).get("loaded_from_env"):
|
| 322 |
lib_logger.debug("Credentials loaded from env, skipping file save")
|
| 323 |
-
# Still update cache for in-memory consistency
|
| 324 |
-
self._credentials_cache[path] = creds
|
| 325 |
return
|
| 326 |
|
| 327 |
-
#
|
| 328 |
-
#
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
tmp_fd, tmp_path = tempfile.mkstemp(
|
| 337 |
-
dir=parent_dir, prefix=".tmp_", suffix=".json", text=True
|
| 338 |
-
)
|
| 339 |
-
|
| 340 |
-
# Write JSON to temp file
|
| 341 |
-
with os.fdopen(tmp_fd, "w") as f:
|
| 342 |
-
json.dump(creds, f, indent=2)
|
| 343 |
-
tmp_fd = None # fdopen closes the fd
|
| 344 |
-
|
| 345 |
-
# Set secure permissions (0600 = owner read/write only)
|
| 346 |
-
try:
|
| 347 |
-
os.chmod(tmp_path, 0o600)
|
| 348 |
-
except (OSError, AttributeError):
|
| 349 |
-
# Windows may not support chmod, ignore
|
| 350 |
-
pass
|
| 351 |
-
|
| 352 |
-
# Atomic move (overwrites target if it exists)
|
| 353 |
-
shutil.move(tmp_path, path)
|
| 354 |
-
tmp_path = None # Successfully moved
|
| 355 |
-
|
| 356 |
-
# Update cache AFTER successful file write
|
| 357 |
-
self._credentials_cache[path] = creds
|
| 358 |
-
lib_logger.debug(
|
| 359 |
-
f"Saved updated iFlow OAuth credentials to '{path}' (atomic write)."
|
| 360 |
-
)
|
| 361 |
-
|
| 362 |
-
except Exception as e:
|
| 363 |
-
lib_logger.error(
|
| 364 |
-
f"Failed to save updated iFlow OAuth credentials to '{path}': {e}"
|
| 365 |
)
|
| 366 |
-
# Clean up temp file if it still exists
|
| 367 |
-
if tmp_fd is not None:
|
| 368 |
-
try:
|
| 369 |
-
os.close(tmp_fd)
|
| 370 |
-
except:
|
| 371 |
-
pass
|
| 372 |
-
if tmp_path and os.path.exists(tmp_path):
|
| 373 |
-
try:
|
| 374 |
-
os.unlink(tmp_path)
|
| 375 |
-
except:
|
| 376 |
-
pass
|
| 377 |
-
raise
|
| 378 |
|
| 379 |
def _is_token_expired(self, creds: Dict[str, Any]) -> bool:
|
| 380 |
"""Checks if the token is expired (with buffer for proactive refresh)."""
|
|
@@ -399,6 +415,29 @@ class IFlowAuthBase:
|
|
| 399 |
|
| 400 |
return expiry_timestamp < time.time() + REFRESH_EXPIRY_BUFFER_SECONDS
|
| 401 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
async def _fetch_user_info(self, access_token: str) -> Dict[str, Any]:
|
| 403 |
"""
|
| 404 |
Fetches user info (including API key) from iFlow API.
|
|
@@ -553,6 +592,26 @@ class IFlowAuthBase:
|
|
| 553 |
)
|
| 554 |
response.raise_for_status()
|
| 555 |
new_token_data = response.json()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 556 |
break # Success
|
| 557 |
|
| 558 |
except httpx.HTTPStatusError as e:
|
|
@@ -654,6 +713,16 @@ class IFlowAuthBase:
|
|
| 654 |
# Update tokens
|
| 655 |
access_token = new_token_data.get("access_token")
|
| 656 |
if not access_token:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 657 |
raise ValueError("Missing access_token in refresh response")
|
| 658 |
|
| 659 |
creds_from_file["access_token"] = access_token
|
|
@@ -749,7 +818,7 @@ class IFlowAuthBase:
|
|
| 749 |
Proactively refreshes tokens if they're close to expiry.
|
| 750 |
Only applies to OAuth credentials (file paths or env:// paths). Direct API keys are skipped.
|
| 751 |
"""
|
| 752 |
-
lib_logger.debug(f"proactively_refresh called for: {credential_identifier}")
|
| 753 |
|
| 754 |
# Try to load credentials - this will fail for direct API keys
|
| 755 |
# and succeed for OAuth credentials (file paths or env:// paths)
|
|
@@ -757,21 +826,21 @@ class IFlowAuthBase:
|
|
| 757 |
creds = await self._load_credentials(credential_identifier)
|
| 758 |
except IOError as e:
|
| 759 |
# Not a valid credential path (likely a direct API key string)
|
| 760 |
-
lib_logger.debug(
|
| 761 |
-
|
| 762 |
-
)
|
| 763 |
return
|
| 764 |
|
| 765 |
is_expired = self._is_token_expired(creds)
|
| 766 |
-
lib_logger.debug(
|
| 767 |
-
|
| 768 |
-
)
|
| 769 |
|
| 770 |
if is_expired:
|
| 771 |
-
lib_logger.debug(
|
| 772 |
-
|
| 773 |
-
)
|
| 774 |
-
#
|
| 775 |
await self._queue_refresh(
|
| 776 |
credential_identifier, force=False, needs_reauth=False
|
| 777 |
)
|
|
@@ -785,30 +854,55 @@ class IFlowAuthBase:
|
|
| 785 |
return self._refresh_locks[path]
|
| 786 |
|
| 787 |
def is_credential_available(self, path: str) -> bool:
|
| 788 |
-
"""Check if a credential is available for rotation
|
| 789 |
|
| 790 |
-
|
| 791 |
-
|
| 792 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 793 |
"""
|
| 794 |
-
|
| 795 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 796 |
|
| 797 |
-
#
|
| 798 |
-
|
| 799 |
-
if
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
lib_logger.
|
| 804 |
-
|
| 805 |
-
|
| 806 |
-
|
|
|
|
| 807 |
)
|
| 808 |
-
|
| 809 |
-
return True
|
| 810 |
|
| 811 |
-
return
|
| 812 |
|
| 813 |
async def _ensure_queue_processor_running(self):
|
| 814 |
"""Lazily starts the queue processor if not already running."""
|
|
@@ -817,15 +911,27 @@ class IFlowAuthBase:
|
|
| 817 |
self._process_refresh_queue()
|
| 818 |
)
|
| 819 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 820 |
async def _queue_refresh(
|
| 821 |
self, path: str, force: bool = False, needs_reauth: bool = False
|
| 822 |
):
|
| 823 |
-
"""Add a credential to the refresh queue if not already queued.
|
| 824 |
|
| 825 |
Args:
|
| 826 |
path: Credential file path
|
| 827 |
force: Force refresh even if not expired
|
| 828 |
-
needs_reauth: True if full re-authentication needed (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 829 |
"""
|
| 830 |
# IMPORTANT: Only check backoff for simple automated refreshes
|
| 831 |
# Re-authentication (interactive OAuth) should BYPASS backoff since it needs user input
|
|
@@ -835,114 +941,223 @@ class IFlowAuthBase:
|
|
| 835 |
backoff_until = self._next_refresh_after[path]
|
| 836 |
if now < backoff_until:
|
| 837 |
# Credential is in backoff for automated refresh, do not queue
|
| 838 |
-
remaining = int(backoff_until - now)
|
| 839 |
-
lib_logger.debug(
|
| 840 |
-
|
| 841 |
-
)
|
| 842 |
return
|
| 843 |
|
| 844 |
async with self._queue_tracking_lock:
|
| 845 |
if path not in self._queued_credentials:
|
| 846 |
self._queued_credentials.add(path)
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 855 |
|
| 856 |
async def _process_refresh_queue(self):
|
| 857 |
-
"""Background worker that processes refresh requests sequentially.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 858 |
while True:
|
| 859 |
path = None
|
| 860 |
try:
|
| 861 |
# Wait for an item with timeout to allow graceful shutdown
|
| 862 |
try:
|
| 863 |
-
path, force
|
| 864 |
self._refresh_queue.get(), timeout=60.0
|
| 865 |
)
|
| 866 |
except asyncio.TimeoutError:
|
| 867 |
-
#
|
| 868 |
-
# If we're idle for 60s, no refreshes are in progress
|
| 869 |
async with self._queue_tracking_lock:
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
lib_logger.warning(
|
| 873 |
-
f"Queue processor idle timeout. Cleaning {stale_count} "
|
| 874 |
-
f"stale unavailable credentials: {list(self._unavailable_credentials.keys())}"
|
| 875 |
-
)
|
| 876 |
-
self._unavailable_credentials.clear()
|
| 877 |
-
# [FIX BUG#6] Also clear queued credentials to prevent stuck state
|
| 878 |
-
if self._queued_credentials:
|
| 879 |
-
lib_logger.debug(
|
| 880 |
-
f"Clearing {len(self._queued_credentials)} queued credentials on timeout"
|
| 881 |
-
)
|
| 882 |
-
self._queued_credentials.clear()
|
| 883 |
self._queue_processor_task = None
|
|
|
|
| 884 |
return
|
| 885 |
|
| 886 |
try:
|
| 887 |
-
#
|
| 888 |
-
|
| 889 |
-
|
| 890 |
-
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
|
| 894 |
-
|
| 895 |
-
|
| 896 |
-
|
| 897 |
-
|
| 898 |
-
|
| 899 |
-
|
|
|
|
|
|
|
| 900 |
|
| 901 |
-
#
|
| 902 |
-
|
| 903 |
-
|
| 904 |
-
await self._refresh_token(path, force=force)
|
| 905 |
|
| 906 |
-
|
| 907 |
-
|
| 908 |
-
self.
|
| 909 |
-
|
| 910 |
-
|
| 911 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 912 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 913 |
|
| 914 |
finally:
|
| 915 |
-
#
|
| 916 |
-
# This ensures cleanup happens in ALL exit paths (success, exception, etc.)
|
| 917 |
async with self._queue_tracking_lock:
|
| 918 |
self._queued_credentials.discard(path)
|
| 919 |
-
# [FIX PR#34] Always clean up unavailable credentials in finally block
|
| 920 |
self._unavailable_credentials.pop(path, None)
|
| 921 |
-
lib_logger.debug(
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
)
|
| 925 |
-
self.
|
|
|
|
| 926 |
except asyncio.CancelledError:
|
| 927 |
-
#
|
| 928 |
if path:
|
| 929 |
async with self._queue_tracking_lock:
|
|
|
|
| 930 |
self._unavailable_credentials.pop(path, None)
|
| 931 |
-
|
| 932 |
-
f"CancelledError cleanup for '{Path(path).name}'. "
|
| 933 |
-
f"Remaining unavailable: {len(self._unavailable_credentials)}"
|
| 934 |
-
)
|
| 935 |
break
|
| 936 |
except Exception as e:
|
| 937 |
-
lib_logger.error(f"Error in queue processor: {e}")
|
| 938 |
-
# Even on error, mark as available (backoff will prevent immediate retry)
|
| 939 |
if path:
|
| 940 |
async with self._queue_tracking_lock:
|
|
|
|
| 941 |
self._unavailable_credentials.pop(path, None)
|
| 942 |
-
lib_logger.debug(
|
| 943 |
-
f"Error cleanup for '{Path(path).name}': {e}. "
|
| 944 |
-
f"Remaining unavailable: {len(self._unavailable_credentials)}"
|
| 945 |
-
)
|
| 946 |
|
| 947 |
async def _perform_interactive_oauth(
|
| 948 |
self, path: str, creds: Dict[str, Any], display_name: str
|
|
@@ -968,7 +1183,8 @@ class IFlowAuthBase:
|
|
| 968 |
state = secrets.token_urlsafe(32)
|
| 969 |
|
| 970 |
# Build authorization URL
|
| 971 |
-
|
|
|
|
| 972 |
auth_params = {
|
| 973 |
"loginMethod": "phone",
|
| 974 |
"type": "phone",
|
|
@@ -979,7 +1195,7 @@ class IFlowAuthBase:
|
|
| 979 |
auth_url = f"{IFLOW_OAUTH_AUTHORIZE_ENDPOINT}?{urlencode(auth_params)}"
|
| 980 |
|
| 981 |
# Start OAuth callback server
|
| 982 |
-
callback_server = OAuthCallbackServer(port=
|
| 983 |
try:
|
| 984 |
await callback_server.start(expected_state=state)
|
| 985 |
|
|
@@ -1182,3 +1398,261 @@ class IFlowAuthBase:
|
|
| 1182 |
except Exception as e:
|
| 1183 |
lib_logger.error(f"Failed to get iFlow user info from credentials: {e}")
|
| 1184 |
return {"email": None}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
import webbrowser
|
| 10 |
import socket
|
| 11 |
import os
|
| 12 |
+
import re
|
| 13 |
+
from dataclasses import dataclass, field
|
| 14 |
from pathlib import Path
|
| 15 |
+
from glob import glob
|
| 16 |
+
from typing import Dict, Any, Tuple, Union, Optional, List
|
| 17 |
from urllib.parse import urlencode, parse_qs, urlparse
|
|
|
|
|
|
|
| 18 |
|
| 19 |
import httpx
|
| 20 |
from aiohttp import web
|
|
|
|
| 25 |
from rich.markup import escape as rich_escape
|
| 26 |
from ..utils.headless_detection import is_headless_environment
|
| 27 |
from ..utils.reauth_coordinator import get_reauth_coordinator
|
| 28 |
+
from ..utils.resilient_io import safe_write_json
|
| 29 |
|
| 30 |
lib_logger = logging.getLogger("rotator_library")
|
| 31 |
|
|
|
|
| 42 |
# Local callback server port
|
| 43 |
CALLBACK_PORT = 11451
|
| 44 |
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
class IFlowCredentialSetupResult:
|
| 48 |
+
"""
|
| 49 |
+
Standardized result structure for iFlow credential setup operations.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
success: bool
|
| 53 |
+
file_path: Optional[str] = None
|
| 54 |
+
email: Optional[str] = None
|
| 55 |
+
is_update: bool = False
|
| 56 |
+
error: Optional[str] = None
|
| 57 |
+
credentials: Optional[Dict[str, Any]] = field(default=None, repr=False)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def get_callback_port() -> int:
|
| 61 |
+
"""
|
| 62 |
+
Get the OAuth callback port, checking environment variable first.
|
| 63 |
+
|
| 64 |
+
Reads from IFLOW_OAUTH_PORT environment variable, falling back
|
| 65 |
+
to the default CALLBACK_PORT if not set.
|
| 66 |
+
"""
|
| 67 |
+
env_value = os.getenv("IFLOW_OAUTH_PORT")
|
| 68 |
+
if env_value:
|
| 69 |
+
try:
|
| 70 |
+
return int(env_value)
|
| 71 |
+
except ValueError:
|
| 72 |
+
logging.getLogger("rotator_library").warning(
|
| 73 |
+
f"Invalid IFLOW_OAUTH_PORT value: {env_value}, using default {CALLBACK_PORT}"
|
| 74 |
+
)
|
| 75 |
+
return CALLBACK_PORT
|
| 76 |
+
|
| 77 |
+
|
| 78 |
# Refresh tokens 24 hours before expiry
|
| 79 |
REFRESH_EXPIRY_BUFFER_SECONDS = 24 * 60 * 60
|
| 80 |
|
|
|
|
| 206 |
str, float
|
| 207 |
] = {} # Track backoff timers (Unix timestamp)
|
| 208 |
|
| 209 |
+
# [QUEUE SYSTEM] Sequential refresh processing with two separate queues
|
| 210 |
+
# Normal refresh queue: for proactive token refresh (old token still valid)
|
| 211 |
self._refresh_queue: asyncio.Queue = asyncio.Queue()
|
| 212 |
+
self._queue_processor_task: Optional[asyncio.Task] = None
|
| 213 |
+
|
| 214 |
+
# Re-auth queue: for invalid refresh tokens (requires user interaction)
|
| 215 |
+
self._reauth_queue: asyncio.Queue = asyncio.Queue()
|
| 216 |
+
self._reauth_processor_task: Optional[asyncio.Task] = None
|
| 217 |
+
|
| 218 |
+
# Tracking sets/dicts
|
| 219 |
+
self._queued_credentials: set = set() # Track credentials in either queue
|
| 220 |
+
# Only credentials in re-auth queue are marked unavailable (not normal refresh)
|
| 221 |
+
# TTL cleanup is defense-in-depth for edge cases where re-auth processor crashes
|
| 222 |
self._unavailable_credentials: Dict[
|
| 223 |
str, float
|
| 224 |
] = {} # Maps credential path -> timestamp when marked unavailable
|
| 225 |
+
# TTL should exceed reauth timeout (300s) to avoid premature cleanup
|
| 226 |
+
self._unavailable_ttl_seconds: int = 360 # 6 minutes TTL for stale entries
|
| 227 |
self._queue_tracking_lock = asyncio.Lock() # Protects queue sets
|
| 228 |
+
|
| 229 |
+
# Retry tracking for normal refresh queue
|
| 230 |
+
self._queue_retry_count: Dict[
|
| 231 |
+
str, int
|
| 232 |
+
] = {} # Track retry attempts per credential
|
| 233 |
+
|
| 234 |
+
# Configuration constants
|
| 235 |
+
self._refresh_timeout_seconds: int = 15 # Max time for single refresh
|
| 236 |
+
self._refresh_interval_seconds: int = 30 # Delay between queue items
|
| 237 |
+
self._refresh_max_retries: int = 3 # Attempts before kicked out
|
| 238 |
+
self._reauth_timeout_seconds: int = 300 # Time for user to complete OAuth
|
| 239 |
|
| 240 |
def _parse_env_credential_path(self, path: str) -> Optional[str]:
|
| 241 |
"""
|
|
|
|
| 357 |
f"Environment variables for iFlow credential index {credential_index} not found"
|
| 358 |
)
|
| 359 |
|
| 360 |
+
# Try file-based loading first (preferred for explicit file paths)
|
| 361 |
+
try:
|
| 362 |
+
return await self._read_creds_from_file(path)
|
| 363 |
+
except IOError:
|
| 364 |
+
# File not found - fall back to legacy env vars for backwards compatibility
|
| 365 |
+
env_creds = self._load_from_env()
|
| 366 |
+
if env_creds:
|
| 367 |
+
lib_logger.info(
|
| 368 |
+
f"File '{path}' not found, using iFlow credentials from environment variables"
|
| 369 |
+
)
|
| 370 |
+
self._credentials_cache[path] = env_creds
|
| 371 |
+
return env_creds
|
| 372 |
+
raise # Re-raise the original file not found error
|
| 373 |
|
| 374 |
async def _save_credentials(self, path: str, creds: Dict[str, Any]):
|
| 375 |
+
"""Save credentials with in-memory fallback if disk unavailable."""
|
| 376 |
+
# Always update cache first (memory is reliable)
|
| 377 |
+
self._credentials_cache[path] = creds
|
| 378 |
+
|
| 379 |
# Don't save to file if credentials were loaded from environment
|
| 380 |
if creds.get("_proxy_metadata", {}).get("loaded_from_env"):
|
| 381 |
lib_logger.debug("Credentials loaded from env, skipping file save")
|
|
|
|
|
|
|
| 382 |
return
|
| 383 |
|
| 384 |
+
# Attempt disk write - if it fails, we still have the cache
|
| 385 |
+
# buffer_on_failure ensures data is retried periodically and saved on shutdown
|
| 386 |
+
if safe_write_json(
|
| 387 |
+
path, creds, lib_logger, secure_permissions=True, buffer_on_failure=True
|
| 388 |
+
):
|
| 389 |
+
lib_logger.debug(f"Saved updated iFlow OAuth credentials to '{path}'.")
|
| 390 |
+
else:
|
| 391 |
+
lib_logger.warning(
|
| 392 |
+
"iFlow credentials cached in memory only (buffered for retry)."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
|
| 395 |
def _is_token_expired(self, creds: Dict[str, Any]) -> bool:
|
| 396 |
"""Checks if the token is expired (with buffer for proactive refresh)."""
|
|
|
|
| 415 |
|
| 416 |
return expiry_timestamp < time.time() + REFRESH_EXPIRY_BUFFER_SECONDS
|
| 417 |
|
| 418 |
+
def _is_token_truly_expired(self, creds: Dict[str, Any]) -> bool:
|
| 419 |
+
"""Check if token is TRULY expired (past actual expiry, not just threshold).
|
| 420 |
+
|
| 421 |
+
This is different from _is_token_expired() which uses a buffer for proactive refresh.
|
| 422 |
+
This method checks if the token is actually unusable.
|
| 423 |
+
"""
|
| 424 |
+
expiry_str = creds.get("expiry_date")
|
| 425 |
+
if not expiry_str:
|
| 426 |
+
return True
|
| 427 |
+
|
| 428 |
+
try:
|
| 429 |
+
from datetime import datetime
|
| 430 |
+
|
| 431 |
+
expiry_dt = datetime.fromisoformat(expiry_str.replace("Z", "+00:00"))
|
| 432 |
+
expiry_timestamp = expiry_dt.timestamp()
|
| 433 |
+
except (ValueError, AttributeError):
|
| 434 |
+
try:
|
| 435 |
+
expiry_timestamp = float(expiry_str)
|
| 436 |
+
except (ValueError, TypeError):
|
| 437 |
+
return True
|
| 438 |
+
|
| 439 |
+
return expiry_timestamp < time.time()
|
| 440 |
+
|
| 441 |
async def _fetch_user_info(self, access_token: str) -> Dict[str, Any]:
|
| 442 |
"""
|
| 443 |
Fetches user info (including API key) from iFlow API.
|
|
|
|
| 592 |
)
|
| 593 |
response.raise_for_status()
|
| 594 |
new_token_data = response.json()
|
| 595 |
+
|
| 596 |
+
# [FIX] Handle wrapped response format: {success: bool, data: {...}}
|
| 597 |
+
# iFlow API may return tokens nested inside a 'data' key
|
| 598 |
+
if (
|
| 599 |
+
isinstance(new_token_data, dict)
|
| 600 |
+
and "data" in new_token_data
|
| 601 |
+
):
|
| 602 |
+
lib_logger.debug(
|
| 603 |
+
f"iFlow refresh response wrapped in 'data' key, extracting..."
|
| 604 |
+
)
|
| 605 |
+
# Check for error in wrapped response
|
| 606 |
+
if not new_token_data.get("success", True):
|
| 607 |
+
error_msg = new_token_data.get(
|
| 608 |
+
"message", "Unknown error"
|
| 609 |
+
)
|
| 610 |
+
raise ValueError(
|
| 611 |
+
f"iFlow token refresh failed: {error_msg}"
|
| 612 |
+
)
|
| 613 |
+
new_token_data = new_token_data.get("data", {})
|
| 614 |
+
|
| 615 |
break # Success
|
| 616 |
|
| 617 |
except httpx.HTTPStatusError as e:
|
|
|
|
| 713 |
# Update tokens
|
| 714 |
access_token = new_token_data.get("access_token")
|
| 715 |
if not access_token:
|
| 716 |
+
# Log response keys for debugging
|
| 717 |
+
response_keys = (
|
| 718 |
+
list(new_token_data.keys())
|
| 719 |
+
if isinstance(new_token_data, dict)
|
| 720 |
+
else type(new_token_data).__name__
|
| 721 |
+
)
|
| 722 |
+
lib_logger.error(
|
| 723 |
+
f"Missing access_token in refresh response for '{Path(path).name}'. "
|
| 724 |
+
f"Response keys: {response_keys}"
|
| 725 |
+
)
|
| 726 |
raise ValueError("Missing access_token in refresh response")
|
| 727 |
|
| 728 |
creds_from_file["access_token"] = access_token
|
|
|
|
| 818 |
Proactively refreshes tokens if they're close to expiry.
|
| 819 |
Only applies to OAuth credentials (file paths or env:// paths). Direct API keys are skipped.
|
| 820 |
"""
|
| 821 |
+
# lib_logger.debug(f"proactively_refresh called for: {credential_identifier}")
|
| 822 |
|
| 823 |
# Try to load credentials - this will fail for direct API keys
|
| 824 |
# and succeed for OAuth credentials (file paths or env:// paths)
|
|
|
|
| 826 |
creds = await self._load_credentials(credential_identifier)
|
| 827 |
except IOError as e:
|
| 828 |
# Not a valid credential path (likely a direct API key string)
|
| 829 |
+
# lib_logger.debug(
|
| 830 |
+
# f"Skipping refresh for '{credential_identifier}' - not an OAuth credential: {e}"
|
| 831 |
+
# )
|
| 832 |
return
|
| 833 |
|
| 834 |
is_expired = self._is_token_expired(creds)
|
| 835 |
+
# lib_logger.debug(
|
| 836 |
+
# f"Token expired check for '{Path(credential_identifier).name}': {is_expired}"
|
| 837 |
+
# )
|
| 838 |
|
| 839 |
if is_expired:
|
| 840 |
+
# lib_logger.debug(
|
| 841 |
+
# f"Queueing refresh for '{Path(credential_identifier).name}'"
|
| 842 |
+
# )
|
| 843 |
+
# lib_logger.info(f"Proactive refresh triggered for '{Path(credential_identifier).name}'")
|
| 844 |
await self._queue_refresh(
|
| 845 |
credential_identifier, force=False, needs_reauth=False
|
| 846 |
)
|
|
|
|
| 854 |
return self._refresh_locks[path]
|
| 855 |
|
| 856 |
def is_credential_available(self, path: str) -> bool:
|
| 857 |
+
"""Check if a credential is available for rotation.
|
| 858 |
|
| 859 |
+
Credentials are unavailable if:
|
| 860 |
+
1. In re-auth queue (token is truly broken, requires user interaction)
|
| 861 |
+
2. Token is TRULY expired (past actual expiry, not just threshold)
|
| 862 |
+
|
| 863 |
+
Note: Credentials in normal refresh queue are still available because
|
| 864 |
+
the old token is valid until actual expiry.
|
| 865 |
+
|
| 866 |
+
TTL cleanup (defense-in-depth): If a credential has been in the re-auth
|
| 867 |
+
queue longer than _unavailable_ttl_seconds without being processed, it's
|
| 868 |
+
cleaned up. This should only happen if the re-auth processor crashes or
|
| 869 |
+
is cancelled without proper cleanup.
|
| 870 |
"""
|
| 871 |
+
# Check if in re-auth queue (truly unavailable)
|
| 872 |
+
if path in self._unavailable_credentials:
|
| 873 |
+
marked_time = self._unavailable_credentials.get(path)
|
| 874 |
+
if marked_time is not None:
|
| 875 |
+
now = time.time()
|
| 876 |
+
if now - marked_time > self._unavailable_ttl_seconds:
|
| 877 |
+
# Entry is stale - clean it up and return available
|
| 878 |
+
# This is a defense-in-depth for edge cases where re-auth
|
| 879 |
+
# processor crashed or was cancelled without cleanup
|
| 880 |
+
lib_logger.warning(
|
| 881 |
+
f"Credential '{Path(path).name}' stuck in re-auth queue for "
|
| 882 |
+
f"{int(now - marked_time)}s (TTL: {self._unavailable_ttl_seconds}s). "
|
| 883 |
+
f"Re-auth processor may have crashed. Auto-cleaning stale entry."
|
| 884 |
+
)
|
| 885 |
+
# Clean up both tracking structures for consistency
|
| 886 |
+
self._unavailable_credentials.pop(path, None)
|
| 887 |
+
self._queued_credentials.discard(path)
|
| 888 |
+
else:
|
| 889 |
+
return False # Still in re-auth, not available
|
| 890 |
|
| 891 |
+
# Check if token is TRULY expired (not just threshold-expired)
|
| 892 |
+
creds = self._credentials_cache.get(path)
|
| 893 |
+
if creds and self._is_token_truly_expired(creds):
|
| 894 |
+
# Token is actually expired - should not be used
|
| 895 |
+
# Queue for refresh if not already queued
|
| 896 |
+
if path not in self._queued_credentials:
|
| 897 |
+
# lib_logger.debug(
|
| 898 |
+
# f"Credential '{Path(path).name}' is truly expired, queueing for refresh"
|
| 899 |
+
# )
|
| 900 |
+
asyncio.create_task(
|
| 901 |
+
self._queue_refresh(path, force=True, needs_reauth=False)
|
| 902 |
)
|
| 903 |
+
return False
|
|
|
|
| 904 |
|
| 905 |
+
return True
|
| 906 |
|
| 907 |
async def _ensure_queue_processor_running(self):
|
| 908 |
"""Lazily starts the queue processor if not already running."""
|
|
|
|
| 911 |
self._process_refresh_queue()
|
| 912 |
)
|
| 913 |
|
| 914 |
+
async def _ensure_reauth_processor_running(self):
|
| 915 |
+
"""Lazily starts the re-auth queue processor if not already running."""
|
| 916 |
+
if self._reauth_processor_task is None or self._reauth_processor_task.done():
|
| 917 |
+
self._reauth_processor_task = asyncio.create_task(
|
| 918 |
+
self._process_reauth_queue()
|
| 919 |
+
)
|
| 920 |
+
|
| 921 |
async def _queue_refresh(
|
| 922 |
self, path: str, force: bool = False, needs_reauth: bool = False
|
| 923 |
):
|
| 924 |
+
"""Add a credential to the appropriate refresh queue if not already queued.
|
| 925 |
|
| 926 |
Args:
|
| 927 |
path: Credential file path
|
| 928 |
force: Force refresh even if not expired
|
| 929 |
+
needs_reauth: True if full re-authentication needed (routes to re-auth queue)
|
| 930 |
+
|
| 931 |
+
Queue routing:
|
| 932 |
+
- needs_reauth=True: Goes to re-auth queue, marks as unavailable
|
| 933 |
+
- needs_reauth=False: Goes to normal refresh queue, does NOT mark unavailable
|
| 934 |
+
(old token is still valid until actual expiry)
|
| 935 |
"""
|
| 936 |
# IMPORTANT: Only check backoff for simple automated refreshes
|
| 937 |
# Re-authentication (interactive OAuth) should BYPASS backoff since it needs user input
|
|
|
|
| 941 |
backoff_until = self._next_refresh_after[path]
|
| 942 |
if now < backoff_until:
|
| 943 |
# Credential is in backoff for automated refresh, do not queue
|
| 944 |
+
# remaining = int(backoff_until - now)
|
| 945 |
+
# lib_logger.debug(
|
| 946 |
+
# f"Skipping automated refresh for '{Path(path).name}' (in backoff for {remaining}s)"
|
| 947 |
+
# )
|
| 948 |
return
|
| 949 |
|
| 950 |
async with self._queue_tracking_lock:
|
| 951 |
if path not in self._queued_credentials:
|
| 952 |
self._queued_credentials.add(path)
|
| 953 |
+
|
| 954 |
+
if needs_reauth:
|
| 955 |
+
# Re-auth queue: mark as unavailable (token is truly broken)
|
| 956 |
+
self._unavailable_credentials[path] = time.time()
|
| 957 |
+
# lib_logger.debug(
|
| 958 |
+
# f"Queued '{Path(path).name}' for RE-AUTH (marked unavailable). "
|
| 959 |
+
# f"Total unavailable: {len(self._unavailable_credentials)}"
|
| 960 |
+
# )
|
| 961 |
+
await self._reauth_queue.put(path)
|
| 962 |
+
await self._ensure_reauth_processor_running()
|
| 963 |
+
else:
|
| 964 |
+
# Normal refresh queue: do NOT mark unavailable (old token still valid)
|
| 965 |
+
# lib_logger.debug(
|
| 966 |
+
# f"Queued '{Path(path).name}' for refresh (still available). "
|
| 967 |
+
# f"Queue size: {self._refresh_queue.qsize() + 1}"
|
| 968 |
+
# )
|
| 969 |
+
await self._refresh_queue.put((path, force))
|
| 970 |
+
await self._ensure_queue_processor_running()
|
| 971 |
|
| 972 |
async def _process_refresh_queue(self):
|
| 973 |
+
"""Background worker that processes normal refresh requests sequentially.
|
| 974 |
+
|
| 975 |
+
Key behaviors:
|
| 976 |
+
- 15s timeout per refresh operation
|
| 977 |
+
- 30s delay between processing credentials (prevents thundering herd)
|
| 978 |
+
- On failure: back of queue, max 3 retries before kicked
|
| 979 |
+
- If 401/403 detected: routes to re-auth queue
|
| 980 |
+
- Does NOT mark credentials unavailable (old token still valid)
|
| 981 |
+
"""
|
| 982 |
+
# lib_logger.info("Refresh queue processor started")
|
| 983 |
while True:
|
| 984 |
path = None
|
| 985 |
try:
|
| 986 |
# Wait for an item with timeout to allow graceful shutdown
|
| 987 |
try:
|
| 988 |
+
path, force = await asyncio.wait_for(
|
| 989 |
self._refresh_queue.get(), timeout=60.0
|
| 990 |
)
|
| 991 |
except asyncio.TimeoutError:
|
| 992 |
+
# Queue is empty and idle for 60s - clean up and exit
|
|
|
|
| 993 |
async with self._queue_tracking_lock:
|
| 994 |
+
# Clear any stale retry counts
|
| 995 |
+
self._queue_retry_count.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 996 |
self._queue_processor_task = None
|
| 997 |
+
# lib_logger.debug("Refresh queue processor idle, shutting down")
|
| 998 |
return
|
| 999 |
|
| 1000 |
try:
|
| 1001 |
+
# Quick check if still expired (optimization to avoid unnecessary refresh)
|
| 1002 |
+
creds = self._credentials_cache.get(path)
|
| 1003 |
+
if creds and not self._is_token_expired(creds):
|
| 1004 |
+
# No longer expired, skip refresh
|
| 1005 |
+
# lib_logger.debug(
|
| 1006 |
+
# f"Credential '{Path(path).name}' no longer expired, skipping refresh"
|
| 1007 |
+
# )
|
| 1008 |
+
# Clear retry count on skip (not a failure)
|
| 1009 |
+
self._queue_retry_count.pop(path, None)
|
| 1010 |
+
continue
|
| 1011 |
+
|
| 1012 |
+
# Perform refresh with timeout
|
| 1013 |
+
try:
|
| 1014 |
+
async with asyncio.timeout(self._refresh_timeout_seconds):
|
| 1015 |
+
await self._refresh_token(path, force=force)
|
| 1016 |
|
| 1017 |
+
# SUCCESS: Clear retry count
|
| 1018 |
+
self._queue_retry_count.pop(path, None)
|
| 1019 |
+
# lib_logger.info(f"Refresh SUCCESS for '{Path(path).name}'")
|
|
|
|
| 1020 |
|
| 1021 |
+
except asyncio.TimeoutError:
|
| 1022 |
+
lib_logger.warning(
|
| 1023 |
+
f"Refresh timeout ({self._refresh_timeout_seconds}s) for '{Path(path).name}'"
|
| 1024 |
+
)
|
| 1025 |
+
await self._handle_refresh_failure(path, force, "timeout")
|
| 1026 |
+
|
| 1027 |
+
except httpx.HTTPStatusError as e:
|
| 1028 |
+
status_code = e.response.status_code
|
| 1029 |
+
if status_code in (401, 403):
|
| 1030 |
+
# Invalid refresh token - route to re-auth queue
|
| 1031 |
+
lib_logger.warning(
|
| 1032 |
+
f"Refresh token invalid for '{Path(path).name}' (HTTP {status_code}). "
|
| 1033 |
+
f"Routing to re-auth queue."
|
| 1034 |
)
|
| 1035 |
+
self._queue_retry_count.pop(path, None) # Clear retry count
|
| 1036 |
+
async with self._queue_tracking_lock:
|
| 1037 |
+
self._queued_credentials.discard(
|
| 1038 |
+
path
|
| 1039 |
+
) # Remove from queued
|
| 1040 |
+
await self._queue_refresh(
|
| 1041 |
+
path, force=True, needs_reauth=True
|
| 1042 |
+
)
|
| 1043 |
+
else:
|
| 1044 |
+
await self._handle_refresh_failure(
|
| 1045 |
+
path, force, f"HTTP {status_code}"
|
| 1046 |
+
)
|
| 1047 |
+
|
| 1048 |
+
except Exception as e:
|
| 1049 |
+
await self._handle_refresh_failure(path, force, str(e))
|
| 1050 |
+
|
| 1051 |
+
finally:
|
| 1052 |
+
# Remove from queued set (unless re-queued by failure handler)
|
| 1053 |
+
async with self._queue_tracking_lock:
|
| 1054 |
+
# Only discard if not re-queued (check if still in queue set from retry)
|
| 1055 |
+
if (
|
| 1056 |
+
path in self._queued_credentials
|
| 1057 |
+
and self._queue_retry_count.get(path, 0) == 0
|
| 1058 |
+
):
|
| 1059 |
+
self._queued_credentials.discard(path)
|
| 1060 |
+
self._refresh_queue.task_done()
|
| 1061 |
+
|
| 1062 |
+
# Wait between credentials to spread load
|
| 1063 |
+
await asyncio.sleep(self._refresh_interval_seconds)
|
| 1064 |
+
|
| 1065 |
+
except asyncio.CancelledError:
|
| 1066 |
+
# lib_logger.debug("Refresh queue processor cancelled")
|
| 1067 |
+
break
|
| 1068 |
+
except Exception as e:
|
| 1069 |
+
lib_logger.error(f"Error in refresh queue processor: {e}")
|
| 1070 |
+
if path:
|
| 1071 |
+
async with self._queue_tracking_lock:
|
| 1072 |
+
self._queued_credentials.discard(path)
|
| 1073 |
+
|
| 1074 |
+
async def _handle_refresh_failure(self, path: str, force: bool, error: str):
|
| 1075 |
+
"""Handle a refresh failure with back-of-line retry logic.
|
| 1076 |
+
|
| 1077 |
+
- Increments retry count
|
| 1078 |
+
- If under max retries: re-adds to END of queue
|
| 1079 |
+
- If at max retries: kicks credential out (retried next BackgroundRefresher cycle)
|
| 1080 |
+
"""
|
| 1081 |
+
retry_count = self._queue_retry_count.get(path, 0) + 1
|
| 1082 |
+
self._queue_retry_count[path] = retry_count
|
| 1083 |
+
|
| 1084 |
+
if retry_count >= self._refresh_max_retries:
|
| 1085 |
+
# Kicked out until next BackgroundRefresher cycle
|
| 1086 |
+
lib_logger.error(
|
| 1087 |
+
f"Max retries ({self._refresh_max_retries}) reached for '{Path(path).name}' "
|
| 1088 |
+
f"(last error: {error}). Will retry next refresh cycle."
|
| 1089 |
+
)
|
| 1090 |
+
self._queue_retry_count.pop(path, None)
|
| 1091 |
+
async with self._queue_tracking_lock:
|
| 1092 |
+
self._queued_credentials.discard(path)
|
| 1093 |
+
return
|
| 1094 |
+
|
| 1095 |
+
# Re-add to END of queue for retry
|
| 1096 |
+
lib_logger.warning(
|
| 1097 |
+
f"Refresh failed for '{Path(path).name}' ({error}). "
|
| 1098 |
+
f"Retry {retry_count}/{self._refresh_max_retries}, back of queue."
|
| 1099 |
+
)
|
| 1100 |
+
# Keep in queued_credentials set, add back to queue
|
| 1101 |
+
await self._refresh_queue.put((path, force))
|
| 1102 |
+
|
| 1103 |
+
async def _process_reauth_queue(self):
|
| 1104 |
+
"""Background worker that processes re-auth requests.
|
| 1105 |
+
|
| 1106 |
+
Key behaviors:
|
| 1107 |
+
- Credentials ARE marked unavailable (token is truly broken)
|
| 1108 |
+
- Uses ReauthCoordinator for interactive OAuth
|
| 1109 |
+
- No automatic retry (requires user action)
|
| 1110 |
+
- Cleans up unavailable status when done
|
| 1111 |
+
"""
|
| 1112 |
+
# lib_logger.info("Re-auth queue processor started")
|
| 1113 |
+
while True:
|
| 1114 |
+
path = None
|
| 1115 |
+
try:
|
| 1116 |
+
# Wait for an item with timeout to allow graceful shutdown
|
| 1117 |
+
try:
|
| 1118 |
+
path = await asyncio.wait_for(
|
| 1119 |
+
self._reauth_queue.get(), timeout=60.0
|
| 1120 |
+
)
|
| 1121 |
+
except asyncio.TimeoutError:
|
| 1122 |
+
# Queue is empty and idle for 60s - exit
|
| 1123 |
+
self._reauth_processor_task = None
|
| 1124 |
+
# lib_logger.debug("Re-auth queue processor idle, shutting down")
|
| 1125 |
+
return
|
| 1126 |
+
|
| 1127 |
+
try:
|
| 1128 |
+
lib_logger.info(f"Starting re-auth for '{Path(path).name}'...")
|
| 1129 |
+
await self.initialize_token(path)
|
| 1130 |
+
lib_logger.info(f"Re-auth SUCCESS for '{Path(path).name}'")
|
| 1131 |
+
|
| 1132 |
+
except Exception as e:
|
| 1133 |
+
lib_logger.error(f"Re-auth FAILED for '{Path(path).name}': {e}")
|
| 1134 |
+
# No automatic retry for re-auth (requires user action)
|
| 1135 |
|
| 1136 |
finally:
|
| 1137 |
+
# Always clean up
|
|
|
|
| 1138 |
async with self._queue_tracking_lock:
|
| 1139 |
self._queued_credentials.discard(path)
|
|
|
|
| 1140 |
self._unavailable_credentials.pop(path, None)
|
| 1141 |
+
# lib_logger.debug(
|
| 1142 |
+
# f"Re-auth cleanup for '{Path(path).name}'. "
|
| 1143 |
+
# f"Remaining unavailable: {len(self._unavailable_credentials)}"
|
| 1144 |
+
# )
|
| 1145 |
+
self._reauth_queue.task_done()
|
| 1146 |
+
|
| 1147 |
except asyncio.CancelledError:
|
| 1148 |
+
# Clean up current credential before breaking
|
| 1149 |
if path:
|
| 1150 |
async with self._queue_tracking_lock:
|
| 1151 |
+
self._queued_credentials.discard(path)
|
| 1152 |
self._unavailable_credentials.pop(path, None)
|
| 1153 |
+
# lib_logger.debug("Re-auth queue processor cancelled")
|
|
|
|
|
|
|
|
|
|
| 1154 |
break
|
| 1155 |
except Exception as e:
|
| 1156 |
+
lib_logger.error(f"Error in re-auth queue processor: {e}")
|
|
|
|
| 1157 |
if path:
|
| 1158 |
async with self._queue_tracking_lock:
|
| 1159 |
+
self._queued_credentials.discard(path)
|
| 1160 |
self._unavailable_credentials.pop(path, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1161 |
|
| 1162 |
async def _perform_interactive_oauth(
|
| 1163 |
self, path: str, creds: Dict[str, Any], display_name: str
|
|
|
|
| 1183 |
state = secrets.token_urlsafe(32)
|
| 1184 |
|
| 1185 |
# Build authorization URL
|
| 1186 |
+
callback_port = get_callback_port()
|
| 1187 |
+
redirect_uri = f"http://localhost:{callback_port}/oauth2callback"
|
| 1188 |
auth_params = {
|
| 1189 |
"loginMethod": "phone",
|
| 1190 |
"type": "phone",
|
|
|
|
| 1195 |
auth_url = f"{IFLOW_OAUTH_AUTHORIZE_ENDPOINT}?{urlencode(auth_params)}"
|
| 1196 |
|
| 1197 |
# Start OAuth callback server
|
| 1198 |
+
callback_server = OAuthCallbackServer(port=callback_port)
|
| 1199 |
try:
|
| 1200 |
await callback_server.start(expected_state=state)
|
| 1201 |
|
|
|
|
| 1398 |
except Exception as e:
|
| 1399 |
lib_logger.error(f"Failed to get iFlow user info from credentials: {e}")
|
| 1400 |
return {"email": None}
|
| 1401 |
+
|
| 1402 |
+
# =========================================================================
|
| 1403 |
+
# CREDENTIAL MANAGEMENT METHODS
|
| 1404 |
+
# =========================================================================
|
| 1405 |
+
|
| 1406 |
+
def _get_provider_file_prefix(self) -> str:
|
| 1407 |
+
"""Return the file prefix for iFlow credentials."""
|
| 1408 |
+
return "iflow"
|
| 1409 |
+
|
| 1410 |
+
def _get_oauth_base_dir(self) -> Path:
|
| 1411 |
+
"""Get the base directory for OAuth credential files."""
|
| 1412 |
+
return Path.cwd() / "oauth_creds"
|
| 1413 |
+
|
| 1414 |
+
def _find_existing_credential_by_email(
|
| 1415 |
+
self, email: str, base_dir: Optional[Path] = None
|
| 1416 |
+
) -> Optional[Path]:
|
| 1417 |
+
"""Find an existing credential file for the given email."""
|
| 1418 |
+
if base_dir is None:
|
| 1419 |
+
base_dir = self._get_oauth_base_dir()
|
| 1420 |
+
|
| 1421 |
+
prefix = self._get_provider_file_prefix()
|
| 1422 |
+
pattern = str(base_dir / f"{prefix}_oauth_*.json")
|
| 1423 |
+
|
| 1424 |
+
for cred_file in glob(pattern):
|
| 1425 |
+
try:
|
| 1426 |
+
with open(cred_file, "r") as f:
|
| 1427 |
+
creds = json.load(f)
|
| 1428 |
+
existing_email = creds.get("email") or creds.get(
|
| 1429 |
+
"_proxy_metadata", {}
|
| 1430 |
+
).get("email")
|
| 1431 |
+
if existing_email == email:
|
| 1432 |
+
return Path(cred_file)
|
| 1433 |
+
except (json.JSONDecodeError, IOError) as e:
|
| 1434 |
+
lib_logger.debug(f"Could not read credential file {cred_file}: {e}")
|
| 1435 |
+
continue
|
| 1436 |
+
|
| 1437 |
+
return None
|
| 1438 |
+
|
| 1439 |
+
def _get_next_credential_number(self, base_dir: Optional[Path] = None) -> int:
|
| 1440 |
+
"""Get the next available credential number."""
|
| 1441 |
+
if base_dir is None:
|
| 1442 |
+
base_dir = self._get_oauth_base_dir()
|
| 1443 |
+
|
| 1444 |
+
prefix = self._get_provider_file_prefix()
|
| 1445 |
+
pattern = str(base_dir / f"{prefix}_oauth_*.json")
|
| 1446 |
+
|
| 1447 |
+
existing_numbers = []
|
| 1448 |
+
for cred_file in glob(pattern):
|
| 1449 |
+
match = re.search(r"_oauth_(\d+)\.json$", cred_file)
|
| 1450 |
+
if match:
|
| 1451 |
+
existing_numbers.append(int(match.group(1)))
|
| 1452 |
+
|
| 1453 |
+
if not existing_numbers:
|
| 1454 |
+
return 1
|
| 1455 |
+
return max(existing_numbers) + 1
|
| 1456 |
+
|
| 1457 |
+
def _build_credential_path(
|
| 1458 |
+
self, base_dir: Optional[Path] = None, number: Optional[int] = None
|
| 1459 |
+
) -> Path:
|
| 1460 |
+
"""Build a path for a new credential file."""
|
| 1461 |
+
if base_dir is None:
|
| 1462 |
+
base_dir = self._get_oauth_base_dir()
|
| 1463 |
+
|
| 1464 |
+
if number is None:
|
| 1465 |
+
number = self._get_next_credential_number(base_dir)
|
| 1466 |
+
|
| 1467 |
+
prefix = self._get_provider_file_prefix()
|
| 1468 |
+
filename = f"{prefix}_oauth_{number}.json"
|
| 1469 |
+
return base_dir / filename
|
| 1470 |
+
|
| 1471 |
+
async def setup_credential(
|
| 1472 |
+
self, base_dir: Optional[Path] = None
|
| 1473 |
+
) -> IFlowCredentialSetupResult:
|
| 1474 |
+
"""
|
| 1475 |
+
Complete credential setup flow: OAuth -> save.
|
| 1476 |
+
|
| 1477 |
+
This is the main entry point for setting up new credentials.
|
| 1478 |
+
"""
|
| 1479 |
+
if base_dir is None:
|
| 1480 |
+
base_dir = self._get_oauth_base_dir()
|
| 1481 |
+
|
| 1482 |
+
# Ensure directory exists
|
| 1483 |
+
base_dir.mkdir(exist_ok=True)
|
| 1484 |
+
|
| 1485 |
+
try:
|
| 1486 |
+
# Step 1: Perform OAuth authentication
|
| 1487 |
+
temp_creds = {"_proxy_metadata": {"display_name": "new iFlow credential"}}
|
| 1488 |
+
new_creds = await self.initialize_token(temp_creds)
|
| 1489 |
+
|
| 1490 |
+
# Step 2: Get user info for deduplication
|
| 1491 |
+
email = new_creds.get("email") or new_creds.get("_proxy_metadata", {}).get(
|
| 1492 |
+
"email"
|
| 1493 |
+
)
|
| 1494 |
+
|
| 1495 |
+
if not email:
|
| 1496 |
+
return IFlowCredentialSetupResult(
|
| 1497 |
+
success=False, error="Could not retrieve email from OAuth response"
|
| 1498 |
+
)
|
| 1499 |
+
|
| 1500 |
+
# Step 3: Check for existing credential with same email
|
| 1501 |
+
existing_path = self._find_existing_credential_by_email(email, base_dir)
|
| 1502 |
+
is_update = existing_path is not None
|
| 1503 |
+
|
| 1504 |
+
if is_update:
|
| 1505 |
+
file_path = existing_path
|
| 1506 |
+
lib_logger.info(
|
| 1507 |
+
f"Found existing credential for {email}, updating {file_path.name}"
|
| 1508 |
+
)
|
| 1509 |
+
else:
|
| 1510 |
+
file_path = self._build_credential_path(base_dir)
|
| 1511 |
+
lib_logger.info(
|
| 1512 |
+
f"Creating new credential for {email} at {file_path.name}"
|
| 1513 |
+
)
|
| 1514 |
+
|
| 1515 |
+
# Step 4: Save credentials to file
|
| 1516 |
+
await self._save_credentials(str(file_path), new_creds)
|
| 1517 |
+
|
| 1518 |
+
return IFlowCredentialSetupResult(
|
| 1519 |
+
success=True,
|
| 1520 |
+
file_path=str(file_path),
|
| 1521 |
+
email=email,
|
| 1522 |
+
is_update=is_update,
|
| 1523 |
+
credentials=new_creds,
|
| 1524 |
+
)
|
| 1525 |
+
|
| 1526 |
+
except Exception as e:
|
| 1527 |
+
lib_logger.error(f"Credential setup failed: {e}")
|
| 1528 |
+
return IFlowCredentialSetupResult(success=False, error=str(e))
|
| 1529 |
+
|
| 1530 |
+
def build_env_lines(self, creds: Dict[str, Any], cred_number: int) -> List[str]:
|
| 1531 |
+
"""Generate .env file lines for an iFlow credential."""
|
| 1532 |
+
email = creds.get("email") or creds.get("_proxy_metadata", {}).get(
|
| 1533 |
+
"email", "unknown"
|
| 1534 |
+
)
|
| 1535 |
+
prefix = f"IFLOW_{cred_number}"
|
| 1536 |
+
|
| 1537 |
+
lines = [
|
| 1538 |
+
f"# IFLOW Credential #{cred_number} for: {email}",
|
| 1539 |
+
f"# Exported from: iflow_oauth_{cred_number}.json",
|
| 1540 |
+
f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
|
| 1541 |
+
"#",
|
| 1542 |
+
"# To combine multiple credentials into one .env file, copy these lines",
|
| 1543 |
+
"# and ensure each credential has a unique number (1, 2, 3, etc.)",
|
| 1544 |
+
"",
|
| 1545 |
+
f"{prefix}_ACCESS_TOKEN={creds.get('access_token', '')}",
|
| 1546 |
+
f"{prefix}_REFRESH_TOKEN={creds.get('refresh_token', '')}",
|
| 1547 |
+
f"{prefix}_API_KEY={creds.get('api_key', '')}",
|
| 1548 |
+
f"{prefix}_EXPIRY_DATE={creds.get('expiry_date', '')}",
|
| 1549 |
+
f"{prefix}_EMAIL={email}",
|
| 1550 |
+
f"{prefix}_TOKEN_TYPE={creds.get('token_type', 'Bearer')}",
|
| 1551 |
+
f"{prefix}_SCOPE={creds.get('scope', 'read write')}",
|
| 1552 |
+
]
|
| 1553 |
+
|
| 1554 |
+
return lines
|
| 1555 |
+
|
| 1556 |
+
def export_credential_to_env(
|
| 1557 |
+
self, credential_path: str, output_dir: Optional[Path] = None
|
| 1558 |
+
) -> Optional[str]:
|
| 1559 |
+
"""Export a credential file to .env format."""
|
| 1560 |
+
try:
|
| 1561 |
+
cred_path = Path(credential_path)
|
| 1562 |
+
|
| 1563 |
+
# Load credential
|
| 1564 |
+
with open(cred_path, "r") as f:
|
| 1565 |
+
creds = json.load(f)
|
| 1566 |
+
|
| 1567 |
+
# Extract metadata
|
| 1568 |
+
email = creds.get("email") or creds.get("_proxy_metadata", {}).get(
|
| 1569 |
+
"email", "unknown"
|
| 1570 |
+
)
|
| 1571 |
+
|
| 1572 |
+
# Get credential number from filename
|
| 1573 |
+
match = re.search(r"_oauth_(\d+)\.json$", cred_path.name)
|
| 1574 |
+
cred_number = int(match.group(1)) if match else 1
|
| 1575 |
+
|
| 1576 |
+
# Build output path
|
| 1577 |
+
if output_dir is None:
|
| 1578 |
+
output_dir = cred_path.parent
|
| 1579 |
+
|
| 1580 |
+
safe_email = email.replace("@", "_at_").replace(".", "_")
|
| 1581 |
+
env_filename = f"iflow_{cred_number}_{safe_email}.env"
|
| 1582 |
+
env_path = output_dir / env_filename
|
| 1583 |
+
|
| 1584 |
+
# Build and write content
|
| 1585 |
+
env_lines = self.build_env_lines(creds, cred_number)
|
| 1586 |
+
with open(env_path, "w") as f:
|
| 1587 |
+
f.write("\n".join(env_lines))
|
| 1588 |
+
|
| 1589 |
+
lib_logger.info(f"Exported credential to {env_path}")
|
| 1590 |
+
return str(env_path)
|
| 1591 |
+
|
| 1592 |
+
except Exception as e:
|
| 1593 |
+
lib_logger.error(f"Failed to export credential: {e}")
|
| 1594 |
+
return None
|
| 1595 |
+
|
| 1596 |
+
def list_credentials(self, base_dir: Optional[Path] = None) -> List[Dict[str, Any]]:
|
| 1597 |
+
"""List all iFlow credential files."""
|
| 1598 |
+
if base_dir is None:
|
| 1599 |
+
base_dir = self._get_oauth_base_dir()
|
| 1600 |
+
|
| 1601 |
+
prefix = self._get_provider_file_prefix()
|
| 1602 |
+
pattern = str(base_dir / f"{prefix}_oauth_*.json")
|
| 1603 |
+
|
| 1604 |
+
credentials = []
|
| 1605 |
+
for cred_file in sorted(glob(pattern)):
|
| 1606 |
+
try:
|
| 1607 |
+
with open(cred_file, "r") as f:
|
| 1608 |
+
creds = json.load(f)
|
| 1609 |
+
|
| 1610 |
+
email = creds.get("email") or creds.get("_proxy_metadata", {}).get(
|
| 1611 |
+
"email", "unknown"
|
| 1612 |
+
)
|
| 1613 |
+
|
| 1614 |
+
# Extract number from filename
|
| 1615 |
+
match = re.search(r"_oauth_(\d+)\.json$", cred_file)
|
| 1616 |
+
number = int(match.group(1)) if match else 0
|
| 1617 |
+
|
| 1618 |
+
credentials.append(
|
| 1619 |
+
{
|
| 1620 |
+
"file_path": cred_file,
|
| 1621 |
+
"email": email,
|
| 1622 |
+
"number": number,
|
| 1623 |
+
}
|
| 1624 |
+
)
|
| 1625 |
+
except Exception as e:
|
| 1626 |
+
lib_logger.debug(f"Could not read credential file {cred_file}: {e}")
|
| 1627 |
+
continue
|
| 1628 |
+
|
| 1629 |
+
return credentials
|
| 1630 |
+
|
| 1631 |
+
def delete_credential(self, credential_path: str) -> bool:
|
| 1632 |
+
"""Delete a credential file."""
|
| 1633 |
+
try:
|
| 1634 |
+
cred_path = Path(credential_path)
|
| 1635 |
+
|
| 1636 |
+
# Validate that it's one of our credential files
|
| 1637 |
+
prefix = self._get_provider_file_prefix()
|
| 1638 |
+
if not cred_path.name.startswith(f"{prefix}_oauth_"):
|
| 1639 |
+
lib_logger.error(
|
| 1640 |
+
f"File {cred_path.name} does not appear to be an iFlow credential"
|
| 1641 |
+
)
|
| 1642 |
+
return False
|
| 1643 |
+
|
| 1644 |
+
if not cred_path.exists():
|
| 1645 |
+
lib_logger.warning(f"Credential file does not exist: {credential_path}")
|
| 1646 |
+
return False
|
| 1647 |
+
|
| 1648 |
+
# Remove from cache if present
|
| 1649 |
+
self._credentials_cache.pop(credential_path, None)
|
| 1650 |
+
|
| 1651 |
+
# Delete the file
|
| 1652 |
+
cred_path.unlink()
|
| 1653 |
+
lib_logger.info(f"Deleted credential file: {credential_path}")
|
| 1654 |
+
return True
|
| 1655 |
+
|
| 1656 |
+
except Exception as e:
|
| 1657 |
+
lib_logger.error(f"Failed to delete credential: {e}")
|
| 1658 |
+
return False
|
src/rotator_library/providers/iflow_provider.py
CHANGED
|
@@ -10,19 +10,27 @@ from typing import Union, AsyncGenerator, List, Dict, Any
|
|
| 10 |
from .provider_interface import ProviderInterface
|
| 11 |
from .iflow_auth_base import IFlowAuthBase
|
| 12 |
from ..model_definitions import ModelDefinitions
|
|
|
|
|
|
|
| 13 |
import litellm
|
| 14 |
from litellm.exceptions import RateLimitError, AuthenticationError
|
| 15 |
from pathlib import Path
|
| 16 |
import uuid
|
| 17 |
from datetime import datetime
|
| 18 |
|
| 19 |
-
lib_logger = logging.getLogger(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
LOGS_DIR = Path(__file__).resolve().parent.parent.parent.parent / "logs"
|
| 22 |
-
IFLOW_LOGS_DIR = LOGS_DIR / "iflow_logs"
|
| 23 |
|
| 24 |
class _IFlowFileLogger:
|
| 25 |
"""A simple file logger for a single iFlow transaction."""
|
|
|
|
| 26 |
def __init__(self, model_name: str, enabled: bool = True):
|
| 27 |
self.enabled = enabled
|
| 28 |
if not self.enabled:
|
|
@@ -31,8 +39,10 @@ class _IFlowFileLogger:
|
|
| 31 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 32 |
request_id = str(uuid.uuid4())
|
| 33 |
# Sanitize model name for directory
|
| 34 |
-
safe_model_name = model_name.replace(
|
| 35 |
-
self.log_dir =
|
|
|
|
|
|
|
| 36 |
try:
|
| 37 |
self.log_dir.mkdir(parents=True, exist_ok=True)
|
| 38 |
except Exception as e:
|
|
@@ -41,16 +51,20 @@ class _IFlowFileLogger:
|
|
| 41 |
|
| 42 |
def log_request(self, payload: Dict[str, Any]):
|
| 43 |
"""Logs the request payload sent to iFlow."""
|
| 44 |
-
if not self.enabled:
|
|
|
|
| 45 |
try:
|
| 46 |
-
with open(
|
|
|
|
|
|
|
| 47 |
json.dump(payload, f, indent=2, ensure_ascii=False)
|
| 48 |
except Exception as e:
|
| 49 |
lib_logger.error(f"_IFlowFileLogger: Failed to write request: {e}")
|
| 50 |
|
| 51 |
def log_response_chunk(self, chunk: str):
|
| 52 |
"""Logs a raw chunk from the iFlow response stream."""
|
| 53 |
-
if not self.enabled:
|
|
|
|
| 54 |
try:
|
| 55 |
with open(self.log_dir / "response_stream.log", "a", encoding="utf-8") as f:
|
| 56 |
f.write(chunk + "\n")
|
|
@@ -59,7 +73,8 @@ class _IFlowFileLogger:
|
|
| 59 |
|
| 60 |
def log_error(self, error_message: str):
|
| 61 |
"""Logs an error message."""
|
| 62 |
-
if not self.enabled:
|
|
|
|
| 63 |
try:
|
| 64 |
with open(self.log_dir / "error.log", "a", encoding="utf-8") as f:
|
| 65 |
f.write(f"[{datetime.utcnow().isoformat()}] {error_message}\n")
|
|
@@ -68,13 +83,15 @@ class _IFlowFileLogger:
|
|
| 68 |
|
| 69 |
def log_final_response(self, response_data: Dict[str, Any]):
|
| 70 |
"""Logs the final, reassembled response."""
|
| 71 |
-
if not self.enabled:
|
|
|
|
| 72 |
try:
|
| 73 |
with open(self.log_dir / "final_response.json", "w", encoding="utf-8") as f:
|
| 74 |
json.dump(response_data, f, indent=2, ensure_ascii=False)
|
| 75 |
except Exception as e:
|
| 76 |
lib_logger.error(f"_IFlowFileLogger: Failed to write final response: {e}")
|
| 77 |
|
|
|
|
| 78 |
# Model list can be expanded as iFlow supports more models
|
| 79 |
HARDCODED_MODELS = [
|
| 80 |
"glm-4.6",
|
|
@@ -90,14 +107,25 @@ HARDCODED_MODELS = [
|
|
| 90 |
"deepseek-v3",
|
| 91 |
"qwen3-vl-plus",
|
| 92 |
"qwen3-235b-a22b-instruct",
|
| 93 |
-
"qwen3-235b"
|
| 94 |
]
|
| 95 |
|
| 96 |
# OpenAI-compatible parameters supported by iFlow API
|
| 97 |
SUPPORTED_PARAMS = {
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
}
|
| 102 |
|
| 103 |
|
|
@@ -106,6 +134,7 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
|
|
| 106 |
iFlow provider using OAuth authentication with local callback server.
|
| 107 |
API requests use the derived API key (NOT OAuth access_token).
|
| 108 |
"""
|
|
|
|
| 109 |
skip_cost_calculation = True
|
| 110 |
|
| 111 |
def __init__(self):
|
|
@@ -128,7 +157,9 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
|
|
| 128 |
Validates OAuth credentials if applicable.
|
| 129 |
"""
|
| 130 |
models = []
|
| 131 |
-
env_var_ids =
|
|
|
|
|
|
|
| 132 |
|
| 133 |
def extract_model_id(item) -> str:
|
| 134 |
"""Extract model ID from various formats (dict, string with/without provider prefix)."""
|
|
@@ -154,7 +185,9 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
|
|
| 154 |
# Track the ID to prevent hardcoded/dynamic duplicates
|
| 155 |
if model_id:
|
| 156 |
env_var_ids.add(model_id)
|
| 157 |
-
lib_logger.info(
|
|
|
|
|
|
|
| 158 |
|
| 159 |
# Source 2: Add hardcoded models (only if ID not already in env vars)
|
| 160 |
for model_id in HARDCODED_MODELS:
|
|
@@ -172,14 +205,17 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
|
|
| 172 |
models_url = f"{api_base.rstrip('/')}/models"
|
| 173 |
|
| 174 |
response = await client.get(
|
| 175 |
-
models_url,
|
| 176 |
-
headers={"Authorization": f"Bearer {api_key}"}
|
| 177 |
)
|
| 178 |
response.raise_for_status()
|
| 179 |
|
| 180 |
dynamic_data = response.json()
|
| 181 |
# Handle both {data: [...]} and direct [...] formats
|
| 182 |
-
model_list =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
dynamic_count = 0
|
| 185 |
for model in model_list:
|
|
@@ -190,7 +226,9 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
|
|
| 190 |
dynamic_count += 1
|
| 191 |
|
| 192 |
if dynamic_count > 0:
|
| 193 |
-
lib_logger.debug(
|
|
|
|
|
|
|
| 194 |
|
| 195 |
except Exception as e:
|
| 196 |
# Silently ignore dynamic discovery errors
|
|
@@ -255,7 +293,7 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
|
|
| 255 |
payload = {k: v for k, v in kwargs.items() if k in SUPPORTED_PARAMS}
|
| 256 |
|
| 257 |
# Always force streaming for internal processing
|
| 258 |
-
payload[
|
| 259 |
|
| 260 |
# NOTE: iFlow API does not support stream_options parameter
|
| 261 |
# Unlike other providers, we don't include it to avoid HTTP 406 errors
|
|
@@ -264,16 +302,22 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
|
|
| 264 |
if "tools" in payload and payload["tools"]:
|
| 265 |
payload["tools"] = self._clean_tool_schemas(payload["tools"])
|
| 266 |
lib_logger.debug(f"Cleaned {len(payload['tools'])} tool schemas")
|
| 267 |
-
elif
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
# Inject dummy tool for empty arrays to prevent streaming issues (similar to Qwen's behavior)
|
| 269 |
-
payload["tools"] = [
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
"
|
| 273 |
-
|
| 274 |
-
|
|
|
|
|
|
|
| 275 |
}
|
| 276 |
-
|
| 277 |
lib_logger.debug("Injected placeholder tool for empty tools array")
|
| 278 |
|
| 279 |
return payload
|
|
@@ -282,7 +326,7 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
|
|
| 282 |
"""
|
| 283 |
Converts a raw iFlow SSE chunk to an OpenAI-compatible chunk.
|
| 284 |
Since iFlow is OpenAI-compatible, minimal conversion is needed.
|
| 285 |
-
|
| 286 |
CRITICAL FIX: Handle chunks with BOTH usage and choices (final chunk)
|
| 287 |
without early return to ensure finish_reason is properly processed.
|
| 288 |
"""
|
|
@@ -302,32 +346,36 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
|
|
| 302 |
"model": model_id,
|
| 303 |
"object": "chat.completion.chunk",
|
| 304 |
"id": chunk.get("id", f"chatcmpl-iflow-{time.time()}"),
|
| 305 |
-
"created": chunk.get("created", int(time.time()))
|
| 306 |
}
|
| 307 |
# Then yield the usage chunk
|
| 308 |
yield {
|
| 309 |
-
"choices": [],
|
|
|
|
|
|
|
| 310 |
"id": chunk.get("id", f"chatcmpl-iflow-{time.time()}"),
|
| 311 |
"created": chunk.get("created", int(time.time())),
|
| 312 |
"usage": {
|
| 313 |
"prompt_tokens": usage_data.get("prompt_tokens", 0),
|
| 314 |
"completion_tokens": usage_data.get("completion_tokens", 0),
|
| 315 |
"total_tokens": usage_data.get("total_tokens", 0),
|
| 316 |
-
}
|
| 317 |
}
|
| 318 |
return
|
| 319 |
|
| 320 |
# Handle usage-only chunks
|
| 321 |
if usage_data:
|
| 322 |
yield {
|
| 323 |
-
"choices": [],
|
|
|
|
|
|
|
| 324 |
"id": chunk.get("id", f"chatcmpl-iflow-{time.time()}"),
|
| 325 |
"created": chunk.get("created", int(time.time())),
|
| 326 |
"usage": {
|
| 327 |
"prompt_tokens": usage_data.get("prompt_tokens", 0),
|
| 328 |
"completion_tokens": usage_data.get("completion_tokens", 0),
|
| 329 |
"total_tokens": usage_data.get("total_tokens", 0),
|
| 330 |
-
}
|
| 331 |
}
|
| 332 |
return
|
| 333 |
|
|
@@ -339,13 +387,15 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
|
|
| 339 |
"model": model_id,
|
| 340 |
"object": "chat.completion.chunk",
|
| 341 |
"id": chunk.get("id", f"chatcmpl-iflow-{time.time()}"),
|
| 342 |
-
"created": chunk.get("created", int(time.time()))
|
| 343 |
}
|
| 344 |
|
| 345 |
-
def _stream_to_completion_response(
|
|
|
|
|
|
|
| 346 |
"""
|
| 347 |
Manually reassembles streaming chunks into a complete response.
|
| 348 |
-
|
| 349 |
Key improvements:
|
| 350 |
- Determines finish_reason based on accumulated state (tool_calls vs stop)
|
| 351 |
- Properly initializes tool_calls with type field
|
|
@@ -358,14 +408,16 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
|
|
| 358 |
final_message = {"role": "assistant"}
|
| 359 |
aggregated_tool_calls = {}
|
| 360 |
usage_data = None
|
| 361 |
-
chunk_finish_reason =
|
|
|
|
|
|
|
| 362 |
|
| 363 |
# Get the first chunk for basic response metadata
|
| 364 |
first_chunk = chunks[0]
|
| 365 |
|
| 366 |
# Process each chunk to aggregate content
|
| 367 |
for chunk in chunks:
|
| 368 |
-
if not hasattr(chunk,
|
| 369 |
continue
|
| 370 |
|
| 371 |
choice = chunk.choices[0]
|
|
@@ -389,25 +441,48 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
|
|
| 389 |
index = tc_chunk.get("index", 0)
|
| 390 |
if index not in aggregated_tool_calls:
|
| 391 |
# Initialize with type field for OpenAI compatibility
|
| 392 |
-
aggregated_tool_calls[index] = {
|
|
|
|
|
|
|
|
|
|
| 393 |
if "id" in tc_chunk:
|
| 394 |
aggregated_tool_calls[index]["id"] = tc_chunk["id"]
|
| 395 |
if "type" in tc_chunk:
|
| 396 |
aggregated_tool_calls[index]["type"] = tc_chunk["type"]
|
| 397 |
if "function" in tc_chunk:
|
| 398 |
-
if
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
|
| 403 |
# Aggregate function calls (legacy format)
|
| 404 |
if "function_call" in delta and delta["function_call"] is not None:
|
| 405 |
if "function_call" not in final_message:
|
| 406 |
final_message["function_call"] = {"name": "", "arguments": ""}
|
| 407 |
-
if
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
|
| 412 |
# Track finish_reason from chunks (for reference only)
|
| 413 |
if choice.get("finish_reason"):
|
|
@@ -415,7 +490,7 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
|
|
| 415 |
|
| 416 |
# Handle usage data from the last chunk that has it
|
| 417 |
for chunk in reversed(chunks):
|
| 418 |
-
if hasattr(chunk,
|
| 419 |
usage_data = chunk.usage
|
| 420 |
break
|
| 421 |
|
|
@@ -441,7 +516,7 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
|
|
| 441 |
final_choice = {
|
| 442 |
"index": 0,
|
| 443 |
"message": final_message,
|
| 444 |
-
"finish_reason": finish_reason
|
| 445 |
}
|
| 446 |
|
| 447 |
# Create the final ModelResponse
|
|
@@ -451,21 +526,20 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
|
|
| 451 |
"created": first_chunk.created,
|
| 452 |
"model": first_chunk.model,
|
| 453 |
"choices": [final_choice],
|
| 454 |
-
"usage": usage_data
|
| 455 |
}
|
| 456 |
|
| 457 |
return litellm.ModelResponse(**final_response_data)
|
| 458 |
|
| 459 |
-
async def acompletion(
|
|
|
|
|
|
|
| 460 |
credential_path = kwargs.pop("credential_identifier")
|
| 461 |
enable_request_logging = kwargs.pop("enable_request_logging", False)
|
| 462 |
model = kwargs["model"]
|
| 463 |
|
| 464 |
# Create dedicated file logger for this request
|
| 465 |
-
file_logger = _IFlowFileLogger(
|
| 466 |
-
model_name=model,
|
| 467 |
-
enabled=enable_request_logging
|
| 468 |
-
)
|
| 469 |
|
| 470 |
async def make_request():
|
| 471 |
"""Prepares and makes the actual API call."""
|
|
@@ -473,8 +547,8 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
|
|
| 473 |
api_base, api_key = await self.get_api_details(credential_path)
|
| 474 |
|
| 475 |
# Strip provider prefix from model name (e.g., "iflow/Qwen3-Coder-Plus" -> "Qwen3-Coder-Plus")
|
| 476 |
-
model_name = model.split(
|
| 477 |
-
kwargs_with_stripped_model = {**kwargs,
|
| 478 |
|
| 479 |
# Build clean payload with only supported parameters
|
| 480 |
payload = self._build_request_payload(**kwargs_with_stripped_model)
|
|
@@ -483,7 +557,7 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
|
|
| 483 |
"Authorization": f"Bearer {api_key}", # Uses api_key from user info
|
| 484 |
"Content-Type": "application/json",
|
| 485 |
"Accept": "text/event-stream",
|
| 486 |
-
"User-Agent": "iFlow-Cli"
|
| 487 |
}
|
| 488 |
|
| 489 |
url = f"{api_base.rstrip('/')}/chat/completions"
|
|
@@ -492,7 +566,13 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
|
|
| 492 |
file_logger.log_request(payload)
|
| 493 |
lib_logger.debug(f"iFlow Request URL: {url}")
|
| 494 |
|
| 495 |
-
return client.stream(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
|
| 497 |
async def stream_handler(response_stream, attempt=1):
|
| 498 |
"""Handles the streaming response and converts chunks."""
|
|
@@ -501,11 +581,17 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
|
|
| 501 |
# Check for HTTP errors before processing stream
|
| 502 |
if response.status_code >= 400:
|
| 503 |
error_text = await response.aread()
|
| 504 |
-
error_text =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 505 |
|
| 506 |
# Handle 401: Force token refresh and retry once
|
| 507 |
if response.status_code == 401 and attempt == 1:
|
| 508 |
-
lib_logger.warning(
|
|
|
|
|
|
|
| 509 |
await self._refresh_token(credential_path, force=True)
|
| 510 |
retry_stream = await make_request()
|
| 511 |
async for chunk in stream_handler(retry_stream, attempt=2):
|
|
@@ -513,50 +599,61 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
|
|
| 513 |
return
|
| 514 |
|
| 515 |
# Handle 429: Rate limit
|
| 516 |
-
elif
|
|
|
|
|
|
|
|
|
|
| 517 |
raise RateLimitError(
|
| 518 |
f"iFlow rate limit exceeded: {error_text}",
|
| 519 |
llm_provider="iflow",
|
| 520 |
model=model,
|
| 521 |
-
response=response
|
| 522 |
)
|
| 523 |
|
| 524 |
# Handle other errors
|
| 525 |
else:
|
| 526 |
-
error_msg =
|
|
|
|
|
|
|
| 527 |
file_logger.log_error(error_msg)
|
| 528 |
raise httpx.HTTPStatusError(
|
| 529 |
f"HTTP {response.status_code}: {error_text}",
|
| 530 |
request=response.request,
|
| 531 |
-
response=response
|
| 532 |
)
|
| 533 |
|
| 534 |
# Process successful streaming response
|
| 535 |
async for line in response.aiter_lines():
|
| 536 |
file_logger.log_response_chunk(line)
|
| 537 |
-
|
| 538 |
# CRITICAL FIX: Handle both "data:" (no space) and "data: " (with space)
|
| 539 |
-
if line.startswith(
|
| 540 |
# Extract data after "data:" prefix, handling both formats
|
| 541 |
-
if line.startswith(
|
| 542 |
data_str = line[6:] # Skip "data: "
|
| 543 |
else:
|
| 544 |
data_str = line[5:] # Skip "data:"
|
| 545 |
-
|
| 546 |
if data_str.strip() == "[DONE]":
|
| 547 |
break
|
| 548 |
try:
|
| 549 |
chunk = json.loads(data_str)
|
| 550 |
-
for openai_chunk in self._convert_chunk_to_openai(
|
|
|
|
|
|
|
| 551 |
yield litellm.ModelResponse(**openai_chunk)
|
| 552 |
except json.JSONDecodeError:
|
| 553 |
-
lib_logger.warning(
|
|
|
|
|
|
|
| 554 |
|
| 555 |
except httpx.HTTPStatusError:
|
| 556 |
raise # Re-raise HTTP errors we already handled
|
| 557 |
except Exception as e:
|
| 558 |
file_logger.log_error(f"Error during iFlow stream processing: {e}")
|
| 559 |
-
lib_logger.error(
|
|
|
|
|
|
|
| 560 |
raise
|
| 561 |
|
| 562 |
async def logging_stream_wrapper():
|
|
@@ -574,7 +671,9 @@ class IFlowProvider(IFlowAuthBase, ProviderInterface):
|
|
| 574 |
if kwargs.get("stream"):
|
| 575 |
return logging_stream_wrapper()
|
| 576 |
else:
|
|
|
|
| 577 |
async def non_stream_wrapper():
|
| 578 |
chunks = [chunk async for chunk in logging_stream_wrapper()]
|
| 579 |
return self._stream_to_completion_response(chunks)
|
|
|
|
| 580 |
return await non_stream_wrapper()
|
|
|
|
| 10 |
from .provider_interface import ProviderInterface
|
| 11 |
from .iflow_auth_base import IFlowAuthBase
|
| 12 |
from ..model_definitions import ModelDefinitions
|
| 13 |
+
from ..timeout_config import TimeoutConfig
|
| 14 |
+
from ..utils.paths import get_logs_dir
|
| 15 |
import litellm
|
| 16 |
from litellm.exceptions import RateLimitError, AuthenticationError
|
| 17 |
from pathlib import Path
|
| 18 |
import uuid
|
| 19 |
from datetime import datetime
|
| 20 |
|
| 21 |
+
lib_logger = logging.getLogger("rotator_library")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _get_iflow_logs_dir() -> Path:
|
| 25 |
+
"""Get the iFlow logs directory."""
|
| 26 |
+
logs_dir = get_logs_dir() / "iflow_logs"
|
| 27 |
+
logs_dir.mkdir(parents=True, exist_ok=True)
|
| 28 |
+
return logs_dir
|
| 29 |
|
|
|
|
|
|
|
| 30 |
|
| 31 |
class _IFlowFileLogger:
|
| 32 |
"""A simple file logger for a single iFlow transaction."""
|
| 33 |
+
|
| 34 |
def __init__(self, model_name: str, enabled: bool = True):
|
| 35 |
self.enabled = enabled
|
| 36 |
if not self.enabled:
|
|
|
|
| 39 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 40 |
request_id = str(uuid.uuid4())
|
| 41 |
# Sanitize model name for directory
|
| 42 |
+
safe_model_name = model_name.replace("/", "_").replace(":", "_")
|
| 43 |
+
self.log_dir = (
|
| 44 |
+
_get_iflow_logs_dir() / f"{timestamp}_{safe_model_name}_{request_id}"
|
| 45 |
+
)
|
| 46 |
try:
|
| 47 |
self.log_dir.mkdir(parents=True, exist_ok=True)
|
| 48 |
except Exception as e:
|
|
|
|
| 51 |
|
| 52 |
def log_request(self, payload: Dict[str, Any]):
|
| 53 |
"""Logs the request payload sent to iFlow."""
|
| 54 |
+
if not self.enabled:
|
| 55 |
+
return
|
| 56 |
try:
|
| 57 |
+
with open(
|
| 58 |
+
self.log_dir / "request_payload.json", "w", encoding="utf-8"
|
| 59 |
+
) as f:
|
| 60 |
json.dump(payload, f, indent=2, ensure_ascii=False)
|
| 61 |
except Exception as e:
|
| 62 |
lib_logger.error(f"_IFlowFileLogger: Failed to write request: {e}")
|
| 63 |
|
| 64 |
def log_response_chunk(self, chunk: str):
|
| 65 |
"""Logs a raw chunk from the iFlow response stream."""
|
| 66 |
+
if not self.enabled:
|
| 67 |
+
return
|
| 68 |
try:
|
| 69 |
with open(self.log_dir / "response_stream.log", "a", encoding="utf-8") as f:
|
| 70 |
f.write(chunk + "\n")
|
|
|
|
| 73 |
|
| 74 |
def log_error(self, error_message: str):
|
| 75 |
"""Logs an error message."""
|
| 76 |
+
if not self.enabled:
|
| 77 |
+
return
|
| 78 |
try:
|
| 79 |
with open(self.log_dir / "error.log", "a", encoding="utf-8") as f:
|
| 80 |
f.write(f"[{datetime.utcnow().isoformat()}] {error_message}\n")
|
|
|
|
| 83 |
|
| 84 |
def log_final_response(self, response_data: Dict[str, Any]):
|
| 85 |
"""Logs the final, reassembled response."""
|
| 86 |
+
if not self.enabled:
|
| 87 |
+
return
|
| 88 |
try:
|
| 89 |
with open(self.log_dir / "final_response.json", "w", encoding="utf-8") as f:
|
| 90 |
json.dump(response_data, f, indent=2, ensure_ascii=False)
|
| 91 |
except Exception as e:
|
| 92 |
lib_logger.error(f"_IFlowFileLogger: Failed to write final response: {e}")
|
| 93 |
|
| 94 |
+
|
| 95 |
# Model list can be expanded as iFlow supports more models
|
| 96 |
HARDCODED_MODELS = [
|
| 97 |
"glm-4.6",
|
|
|
|
| 107 |
"deepseek-v3",
|
| 108 |
"qwen3-vl-plus",
|
| 109 |
"qwen3-235b-a22b-instruct",
|
| 110 |
+
"qwen3-235b",
|
| 111 |
]
|
| 112 |
|
| 113 |
# OpenAI-compatible parameters supported by iFlow API
|
| 114 |
SUPPORTED_PARAMS = {
|
| 115 |
+
"model",
|
| 116 |
+
"messages",
|
| 117 |
+
"temperature",
|
| 118 |
+
"top_p",
|
| 119 |
+
"max_tokens",
|
| 120 |
+
"stream",
|
| 121 |
+
"tools",
|
| 122 |
+
"tool_choice",
|
| 123 |
+
"presence_penalty",
|
| 124 |
+
"frequency_penalty",
|
| 125 |
+
"n",
|
| 126 |
+
"stop",
|
| 127 |
+
"seed",
|
| 128 |
+
"response_format",
|
| 129 |
}
|
| 130 |
|
| 131 |
|
|
|
|
| 134 |
iFlow provider using OAuth authentication with local callback server.
|
| 135 |
API requests use the derived API key (NOT OAuth access_token).
|
| 136 |
"""
|
| 137 |
+
|
| 138 |
skip_cost_calculation = True
|
| 139 |
|
| 140 |
def __init__(self):
|
|
|
|
| 157 |
Validates OAuth credentials if applicable.
|
| 158 |
"""
|
| 159 |
models = []
|
| 160 |
+
env_var_ids = (
|
| 161 |
+
set()
|
| 162 |
+
) # Track IDs from env vars to prevent hardcoded/dynamic duplicates
|
| 163 |
|
| 164 |
def extract_model_id(item) -> str:
|
| 165 |
"""Extract model ID from various formats (dict, string with/without provider prefix)."""
|
|
|
|
| 185 |
# Track the ID to prevent hardcoded/dynamic duplicates
|
| 186 |
if model_id:
|
| 187 |
env_var_ids.add(model_id)
|
| 188 |
+
lib_logger.info(
|
| 189 |
+
f"Loaded {len(static_models)} static models for iflow from environment variables"
|
| 190 |
+
)
|
| 191 |
|
| 192 |
# Source 2: Add hardcoded models (only if ID not already in env vars)
|
| 193 |
for model_id in HARDCODED_MODELS:
|
|
|
|
| 205 |
models_url = f"{api_base.rstrip('/')}/models"
|
| 206 |
|
| 207 |
response = await client.get(
|
| 208 |
+
models_url, headers={"Authorization": f"Bearer {api_key}"}
|
|
|
|
| 209 |
)
|
| 210 |
response.raise_for_status()
|
| 211 |
|
| 212 |
dynamic_data = response.json()
|
| 213 |
# Handle both {data: [...]} and direct [...] formats
|
| 214 |
+
model_list = (
|
| 215 |
+
dynamic_data.get("data", dynamic_data)
|
| 216 |
+
if isinstance(dynamic_data, dict)
|
| 217 |
+
else dynamic_data
|
| 218 |
+
)
|
| 219 |
|
| 220 |
dynamic_count = 0
|
| 221 |
for model in model_list:
|
|
|
|
| 226 |
dynamic_count += 1
|
| 227 |
|
| 228 |
if dynamic_count > 0:
|
| 229 |
+
lib_logger.debug(
|
| 230 |
+
f"Discovered {dynamic_count} additional models for iflow from API"
|
| 231 |
+
)
|
| 232 |
|
| 233 |
except Exception as e:
|
| 234 |
# Silently ignore dynamic discovery errors
|
|
|
|
| 293 |
payload = {k: v for k, v in kwargs.items() if k in SUPPORTED_PARAMS}
|
| 294 |
|
| 295 |
# Always force streaming for internal processing
|
| 296 |
+
payload["stream"] = True
|
| 297 |
|
| 298 |
# NOTE: iFlow API does not support stream_options parameter
|
| 299 |
# Unlike other providers, we don't include it to avoid HTTP 406 errors
|
|
|
|
| 302 |
if "tools" in payload and payload["tools"]:
|
| 303 |
payload["tools"] = self._clean_tool_schemas(payload["tools"])
|
| 304 |
lib_logger.debug(f"Cleaned {len(payload['tools'])} tool schemas")
|
| 305 |
+
elif (
|
| 306 |
+
"tools" in payload
|
| 307 |
+
and isinstance(payload["tools"], list)
|
| 308 |
+
and len(payload["tools"]) == 0
|
| 309 |
+
):
|
| 310 |
# Inject dummy tool for empty arrays to prevent streaming issues (similar to Qwen's behavior)
|
| 311 |
+
payload["tools"] = [
|
| 312 |
+
{
|
| 313 |
+
"type": "function",
|
| 314 |
+
"function": {
|
| 315 |
+
"name": "noop",
|
| 316 |
+
"description": "Placeholder tool to stabilise streaming",
|
| 317 |
+
"parameters": {"type": "object"},
|
| 318 |
+
},
|
| 319 |
}
|
| 320 |
+
]
|
| 321 |
lib_logger.debug("Injected placeholder tool for empty tools array")
|
| 322 |
|
| 323 |
return payload
|
|
|
|
| 326 |
"""
|
| 327 |
Converts a raw iFlow SSE chunk to an OpenAI-compatible chunk.
|
| 328 |
Since iFlow is OpenAI-compatible, minimal conversion is needed.
|
| 329 |
+
|
| 330 |
CRITICAL FIX: Handle chunks with BOTH usage and choices (final chunk)
|
| 331 |
without early return to ensure finish_reason is properly processed.
|
| 332 |
"""
|
|
|
|
| 346 |
"model": model_id,
|
| 347 |
"object": "chat.completion.chunk",
|
| 348 |
"id": chunk.get("id", f"chatcmpl-iflow-{time.time()}"),
|
| 349 |
+
"created": chunk.get("created", int(time.time())),
|
| 350 |
}
|
| 351 |
# Then yield the usage chunk
|
| 352 |
yield {
|
| 353 |
+
"choices": [],
|
| 354 |
+
"model": model_id,
|
| 355 |
+
"object": "chat.completion.chunk",
|
| 356 |
"id": chunk.get("id", f"chatcmpl-iflow-{time.time()}"),
|
| 357 |
"created": chunk.get("created", int(time.time())),
|
| 358 |
"usage": {
|
| 359 |
"prompt_tokens": usage_data.get("prompt_tokens", 0),
|
| 360 |
"completion_tokens": usage_data.get("completion_tokens", 0),
|
| 361 |
"total_tokens": usage_data.get("total_tokens", 0),
|
| 362 |
+
},
|
| 363 |
}
|
| 364 |
return
|
| 365 |
|
| 366 |
# Handle usage-only chunks
|
| 367 |
if usage_data:
|
| 368 |
yield {
|
| 369 |
+
"choices": [],
|
| 370 |
+
"model": model_id,
|
| 371 |
+
"object": "chat.completion.chunk",
|
| 372 |
"id": chunk.get("id", f"chatcmpl-iflow-{time.time()}"),
|
| 373 |
"created": chunk.get("created", int(time.time())),
|
| 374 |
"usage": {
|
| 375 |
"prompt_tokens": usage_data.get("prompt_tokens", 0),
|
| 376 |
"completion_tokens": usage_data.get("completion_tokens", 0),
|
| 377 |
"total_tokens": usage_data.get("total_tokens", 0),
|
| 378 |
+
},
|
| 379 |
}
|
| 380 |
return
|
| 381 |
|
|
|
|
| 387 |
"model": model_id,
|
| 388 |
"object": "chat.completion.chunk",
|
| 389 |
"id": chunk.get("id", f"chatcmpl-iflow-{time.time()}"),
|
| 390 |
+
"created": chunk.get("created", int(time.time())),
|
| 391 |
}
|
| 392 |
|
| 393 |
+
def _stream_to_completion_response(
|
| 394 |
+
self, chunks: List[litellm.ModelResponse]
|
| 395 |
+
) -> litellm.ModelResponse:
|
| 396 |
"""
|
| 397 |
Manually reassembles streaming chunks into a complete response.
|
| 398 |
+
|
| 399 |
Key improvements:
|
| 400 |
- Determines finish_reason based on accumulated state (tool_calls vs stop)
|
| 401 |
- Properly initializes tool_calls with type field
|
|
|
|
| 408 |
final_message = {"role": "assistant"}
|
| 409 |
aggregated_tool_calls = {}
|
| 410 |
usage_data = None
|
| 411 |
+
chunk_finish_reason = (
|
| 412 |
+
None # Track finish_reason from chunks (but we'll override)
|
| 413 |
+
)
|
| 414 |
|
| 415 |
# Get the first chunk for basic response metadata
|
| 416 |
first_chunk = chunks[0]
|
| 417 |
|
| 418 |
# Process each chunk to aggregate content
|
| 419 |
for chunk in chunks:
|
| 420 |
+
if not hasattr(chunk, "choices") or not chunk.choices:
|
| 421 |
continue
|
| 422 |
|
| 423 |
choice = chunk.choices[0]
|
|
|
|
| 441 |
index = tc_chunk.get("index", 0)
|
| 442 |
if index not in aggregated_tool_calls:
|
| 443 |
# Initialize with type field for OpenAI compatibility
|
| 444 |
+
aggregated_tool_calls[index] = {
|
| 445 |
+
"type": "function",
|
| 446 |
+
"function": {"name": "", "arguments": ""},
|
| 447 |
+
}
|
| 448 |
if "id" in tc_chunk:
|
| 449 |
aggregated_tool_calls[index]["id"] = tc_chunk["id"]
|
| 450 |
if "type" in tc_chunk:
|
| 451 |
aggregated_tool_calls[index]["type"] = tc_chunk["type"]
|
| 452 |
if "function" in tc_chunk:
|
| 453 |
+
if (
|
| 454 |
+
"name" in tc_chunk["function"]
|
| 455 |
+
and tc_chunk["function"]["name"] is not None
|
| 456 |
+
):
|
| 457 |
+
aggregated_tool_calls[index]["function"]["name"] += (
|
| 458 |
+
tc_chunk["function"]["name"]
|
| 459 |
+
)
|
| 460 |
+
if (
|
| 461 |
+
"arguments" in tc_chunk["function"]
|
| 462 |
+
and tc_chunk["function"]["arguments"] is not None
|
| 463 |
+
):
|
| 464 |
+
aggregated_tool_calls[index]["function"]["arguments"] += (
|
| 465 |
+
tc_chunk["function"]["arguments"]
|
| 466 |
+
)
|
| 467 |
|
| 468 |
# Aggregate function calls (legacy format)
|
| 469 |
if "function_call" in delta and delta["function_call"] is not None:
|
| 470 |
if "function_call" not in final_message:
|
| 471 |
final_message["function_call"] = {"name": "", "arguments": ""}
|
| 472 |
+
if (
|
| 473 |
+
"name" in delta["function_call"]
|
| 474 |
+
and delta["function_call"]["name"] is not None
|
| 475 |
+
):
|
| 476 |
+
final_message["function_call"]["name"] += delta["function_call"][
|
| 477 |
+
"name"
|
| 478 |
+
]
|
| 479 |
+
if (
|
| 480 |
+
"arguments" in delta["function_call"]
|
| 481 |
+
and delta["function_call"]["arguments"] is not None
|
| 482 |
+
):
|
| 483 |
+
final_message["function_call"]["arguments"] += delta[
|
| 484 |
+
"function_call"
|
| 485 |
+
]["arguments"]
|
| 486 |
|
| 487 |
# Track finish_reason from chunks (for reference only)
|
| 488 |
if choice.get("finish_reason"):
|
|
|
|
| 490 |
|
| 491 |
# Handle usage data from the last chunk that has it
|
| 492 |
for chunk in reversed(chunks):
|
| 493 |
+
if hasattr(chunk, "usage") and chunk.usage:
|
| 494 |
usage_data = chunk.usage
|
| 495 |
break
|
| 496 |
|
|
|
|
| 516 |
final_choice = {
|
| 517 |
"index": 0,
|
| 518 |
"message": final_message,
|
| 519 |
+
"finish_reason": finish_reason,
|
| 520 |
}
|
| 521 |
|
| 522 |
# Create the final ModelResponse
|
|
|
|
| 526 |
"created": first_chunk.created,
|
| 527 |
"model": first_chunk.model,
|
| 528 |
"choices": [final_choice],
|
| 529 |
+
"usage": usage_data,
|
| 530 |
}
|
| 531 |
|
| 532 |
return litellm.ModelResponse(**final_response_data)
|
| 533 |
|
| 534 |
+
async def acompletion(
|
| 535 |
+
self, client: httpx.AsyncClient, **kwargs
|
| 536 |
+
) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]:
|
| 537 |
credential_path = kwargs.pop("credential_identifier")
|
| 538 |
enable_request_logging = kwargs.pop("enable_request_logging", False)
|
| 539 |
model = kwargs["model"]
|
| 540 |
|
| 541 |
# Create dedicated file logger for this request
|
| 542 |
+
file_logger = _IFlowFileLogger(model_name=model, enabled=enable_request_logging)
|
|
|
|
|
|
|
|
|
|
| 543 |
|
| 544 |
async def make_request():
|
| 545 |
"""Prepares and makes the actual API call."""
|
|
|
|
| 547 |
api_base, api_key = await self.get_api_details(credential_path)
|
| 548 |
|
| 549 |
# Strip provider prefix from model name (e.g., "iflow/Qwen3-Coder-Plus" -> "Qwen3-Coder-Plus")
|
| 550 |
+
model_name = model.split("/")[-1]
|
| 551 |
+
kwargs_with_stripped_model = {**kwargs, "model": model_name}
|
| 552 |
|
| 553 |
# Build clean payload with only supported parameters
|
| 554 |
payload = self._build_request_payload(**kwargs_with_stripped_model)
|
|
|
|
| 557 |
"Authorization": f"Bearer {api_key}", # Uses api_key from user info
|
| 558 |
"Content-Type": "application/json",
|
| 559 |
"Accept": "text/event-stream",
|
| 560 |
+
"User-Agent": "iFlow-Cli",
|
| 561 |
}
|
| 562 |
|
| 563 |
url = f"{api_base.rstrip('/')}/chat/completions"
|
|
|
|
| 566 |
file_logger.log_request(payload)
|
| 567 |
lib_logger.debug(f"iFlow Request URL: {url}")
|
| 568 |
|
| 569 |
+
return client.stream(
|
| 570 |
+
"POST",
|
| 571 |
+
url,
|
| 572 |
+
headers=headers,
|
| 573 |
+
json=payload,
|
| 574 |
+
timeout=TimeoutConfig.streaming(),
|
| 575 |
+
)
|
| 576 |
|
| 577 |
async def stream_handler(response_stream, attempt=1):
|
| 578 |
"""Handles the streaming response and converts chunks."""
|
|
|
|
| 581 |
# Check for HTTP errors before processing stream
|
| 582 |
if response.status_code >= 400:
|
| 583 |
error_text = await response.aread()
|
| 584 |
+
error_text = (
|
| 585 |
+
error_text.decode("utf-8")
|
| 586 |
+
if isinstance(error_text, bytes)
|
| 587 |
+
else error_text
|
| 588 |
+
)
|
| 589 |
|
| 590 |
# Handle 401: Force token refresh and retry once
|
| 591 |
if response.status_code == 401 and attempt == 1:
|
| 592 |
+
lib_logger.warning(
|
| 593 |
+
"iFlow returned 401. Forcing token refresh and retrying once."
|
| 594 |
+
)
|
| 595 |
await self._refresh_token(credential_path, force=True)
|
| 596 |
retry_stream = await make_request()
|
| 597 |
async for chunk in stream_handler(retry_stream, attempt=2):
|
|
|
|
| 599 |
return
|
| 600 |
|
| 601 |
# Handle 429: Rate limit
|
| 602 |
+
elif (
|
| 603 |
+
response.status_code == 429
|
| 604 |
+
or "slow_down" in error_text.lower()
|
| 605 |
+
):
|
| 606 |
raise RateLimitError(
|
| 607 |
f"iFlow rate limit exceeded: {error_text}",
|
| 608 |
llm_provider="iflow",
|
| 609 |
model=model,
|
| 610 |
+
response=response,
|
| 611 |
)
|
| 612 |
|
| 613 |
# Handle other errors
|
| 614 |
else:
|
| 615 |
+
error_msg = (
|
| 616 |
+
f"iFlow HTTP {response.status_code} error: {error_text}"
|
| 617 |
+
)
|
| 618 |
file_logger.log_error(error_msg)
|
| 619 |
raise httpx.HTTPStatusError(
|
| 620 |
f"HTTP {response.status_code}: {error_text}",
|
| 621 |
request=response.request,
|
| 622 |
+
response=response,
|
| 623 |
)
|
| 624 |
|
| 625 |
# Process successful streaming response
|
| 626 |
async for line in response.aiter_lines():
|
| 627 |
file_logger.log_response_chunk(line)
|
| 628 |
+
|
| 629 |
# CRITICAL FIX: Handle both "data:" (no space) and "data: " (with space)
|
| 630 |
+
if line.startswith("data:"):
|
| 631 |
# Extract data after "data:" prefix, handling both formats
|
| 632 |
+
if line.startswith("data: "):
|
| 633 |
data_str = line[6:] # Skip "data: "
|
| 634 |
else:
|
| 635 |
data_str = line[5:] # Skip "data:"
|
| 636 |
+
|
| 637 |
if data_str.strip() == "[DONE]":
|
| 638 |
break
|
| 639 |
try:
|
| 640 |
chunk = json.loads(data_str)
|
| 641 |
+
for openai_chunk in self._convert_chunk_to_openai(
|
| 642 |
+
chunk, model
|
| 643 |
+
):
|
| 644 |
yield litellm.ModelResponse(**openai_chunk)
|
| 645 |
except json.JSONDecodeError:
|
| 646 |
+
lib_logger.warning(
|
| 647 |
+
f"Could not decode JSON from iFlow: {line}"
|
| 648 |
+
)
|
| 649 |
|
| 650 |
except httpx.HTTPStatusError:
|
| 651 |
raise # Re-raise HTTP errors we already handled
|
| 652 |
except Exception as e:
|
| 653 |
file_logger.log_error(f"Error during iFlow stream processing: {e}")
|
| 654 |
+
lib_logger.error(
|
| 655 |
+
f"Error during iFlow stream processing: {e}", exc_info=True
|
| 656 |
+
)
|
| 657 |
raise
|
| 658 |
|
| 659 |
async def logging_stream_wrapper():
|
|
|
|
| 671 |
if kwargs.get("stream"):
|
| 672 |
return logging_stream_wrapper()
|
| 673 |
else:
|
| 674 |
+
|
| 675 |
async def non_stream_wrapper():
|
| 676 |
chunks = [chunk async for chunk in logging_stream_wrapper()]
|
| 677 |
return self._stream_to_completion_response(chunks)
|
| 678 |
+
|
| 679 |
return await non_stream_wrapper()
|
src/rotator_library/providers/provider_cache.py
CHANGED
|
@@ -20,19 +20,20 @@ import asyncio
|
|
| 20 |
import json
|
| 21 |
import logging
|
| 22 |
import os
|
| 23 |
-
import shutil
|
| 24 |
-
import tempfile
|
| 25 |
import time
|
| 26 |
from pathlib import Path
|
| 27 |
from typing import Any, Dict, Optional, Tuple
|
| 28 |
|
| 29 |
-
|
|
|
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
# =============================================================================
|
| 33 |
# UTILITY FUNCTIONS
|
| 34 |
# =============================================================================
|
| 35 |
|
|
|
|
| 36 |
def _env_bool(key: str, default: bool = False) -> bool:
|
| 37 |
"""Get boolean from environment variable."""
|
| 38 |
return os.getenv(key, str(default).lower()).lower() in ("true", "1", "yes")
|
|
@@ -47,18 +48,19 @@ def _env_int(key: str, default: int) -> int:
|
|
| 47 |
# PROVIDER CACHE CLASS
|
| 48 |
# =============================================================================
|
| 49 |
|
|
|
|
| 50 |
class ProviderCache:
|
| 51 |
"""
|
| 52 |
Server-side cache for provider conversation state preservation.
|
| 53 |
-
|
| 54 |
A generic, modular cache supporting any key-value data that providers need
|
| 55 |
to persist across requests. Features:
|
| 56 |
-
|
| 57 |
- Dual-TTL system: configurable memory TTL, longer disk TTL
|
| 58 |
- Async disk persistence with batched writes
|
| 59 |
- Background cleanup task for expired entries
|
| 60 |
- Statistics tracking (hits, misses, writes)
|
| 61 |
-
|
| 62 |
Args:
|
| 63 |
cache_file: Path to disk cache file
|
| 64 |
memory_ttl_seconds: In-memory entry lifetime (default: 1 hour)
|
|
@@ -67,13 +69,13 @@ class ProviderCache:
|
|
| 67 |
write_interval: Seconds between background disk writes (default: 60)
|
| 68 |
cleanup_interval: Seconds between expired entry cleanup (default: 30 min)
|
| 69 |
env_prefix: Environment variable prefix for configuration overrides
|
| 70 |
-
|
| 71 |
Environment Variables (with default prefix "PROVIDER_CACHE"):
|
| 72 |
{PREFIX}_ENABLE: Enable/disable disk persistence
|
| 73 |
{PREFIX}_WRITE_INTERVAL: Background write interval in seconds
|
| 74 |
{PREFIX}_CLEANUP_INTERVAL: Cleanup interval in seconds
|
| 75 |
"""
|
| 76 |
-
|
| 77 |
def __init__(
|
| 78 |
self,
|
| 79 |
cache_file: Path,
|
|
@@ -82,7 +84,7 @@ class ProviderCache:
|
|
| 82 |
enable_disk: Optional[bool] = None,
|
| 83 |
write_interval: Optional[int] = None,
|
| 84 |
cleanup_interval: Optional[int] = None,
|
| 85 |
-
env_prefix: str = "PROVIDER_CACHE"
|
| 86 |
):
|
| 87 |
# In-memory cache: {cache_key: (data, timestamp)}
|
| 88 |
self._cache: Dict[str, Tuple[str, float]] = {}
|
|
@@ -90,25 +92,42 @@ class ProviderCache:
|
|
| 90 |
self._disk_ttl = disk_ttl_seconds
|
| 91 |
self._lock = asyncio.Lock()
|
| 92 |
self._disk_lock = asyncio.Lock()
|
| 93 |
-
|
| 94 |
# Disk persistence configuration
|
| 95 |
self._cache_file = cache_file
|
| 96 |
-
self._enable_disk =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
self._dirty = False
|
| 98 |
-
self._write_interval = write_interval or _env_int(
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
# Background tasks
|
| 102 |
self._writer_task: Optional[asyncio.Task] = None
|
| 103 |
self._cleanup_task: Optional[asyncio.Task] = None
|
| 104 |
self._running = False
|
| 105 |
-
|
| 106 |
# Statistics
|
| 107 |
-
self._stats = {
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
# Metadata about this cache instance
|
| 110 |
self._cache_name = cache_file.stem if cache_file else "unnamed"
|
| 111 |
-
|
| 112 |
if self._enable_disk:
|
| 113 |
lib_logger.debug(
|
| 114 |
f"ProviderCache[{self._cache_name}]: Disk enabled "
|
|
@@ -117,123 +136,120 @@ class ProviderCache:
|
|
| 117 |
asyncio.create_task(self._async_init())
|
| 118 |
else:
|
| 119 |
lib_logger.debug(f"ProviderCache[{self._cache_name}]: Memory-only mode")
|
| 120 |
-
|
| 121 |
# =========================================================================
|
| 122 |
# INITIALIZATION
|
| 123 |
# =========================================================================
|
| 124 |
-
|
| 125 |
async def _async_init(self) -> None:
|
| 126 |
"""Async initialization: load from disk and start background tasks."""
|
| 127 |
try:
|
| 128 |
await self._load_from_disk()
|
| 129 |
await self._start_background_tasks()
|
| 130 |
except Exception as e:
|
| 131 |
-
lib_logger.error(
|
| 132 |
-
|
|
|
|
|
|
|
| 133 |
async def _load_from_disk(self) -> None:
|
| 134 |
"""Load cache from disk file with TTL validation."""
|
| 135 |
if not self._enable_disk or not self._cache_file.exists():
|
| 136 |
return
|
| 137 |
-
|
| 138 |
try:
|
| 139 |
async with self._disk_lock:
|
| 140 |
-
with open(self._cache_file,
|
| 141 |
data = json.load(f)
|
| 142 |
-
|
| 143 |
if data.get("version") != "1.0":
|
| 144 |
-
lib_logger.warning(
|
|
|
|
|
|
|
| 145 |
return
|
| 146 |
-
|
| 147 |
now = time.time()
|
| 148 |
entries = data.get("entries", {})
|
| 149 |
loaded = expired = 0
|
| 150 |
-
|
| 151 |
for cache_key, entry in entries.items():
|
| 152 |
age = now - entry.get("timestamp", 0)
|
| 153 |
if age <= self._disk_ttl:
|
| 154 |
-
value = entry.get(
|
|
|
|
|
|
|
| 155 |
if value:
|
| 156 |
self._cache[cache_key] = (value, entry["timestamp"])
|
| 157 |
loaded += 1
|
| 158 |
else:
|
| 159 |
expired += 1
|
| 160 |
-
|
| 161 |
lib_logger.debug(
|
| 162 |
f"ProviderCache[{self._cache_name}]: Loaded {loaded} entries ({expired} expired)"
|
| 163 |
)
|
| 164 |
except json.JSONDecodeError as e:
|
| 165 |
-
lib_logger.warning(
|
|
|
|
|
|
|
| 166 |
except Exception as e:
|
| 167 |
lib_logger.error(f"ProviderCache[{self._cache_name}]: Load failed: {e}")
|
| 168 |
-
|
| 169 |
# =========================================================================
|
| 170 |
# DISK PERSISTENCE
|
| 171 |
# =========================================================================
|
| 172 |
-
|
| 173 |
-
async def _save_to_disk(self) ->
|
| 174 |
-
"""Persist cache to disk using atomic write.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
if not self._enable_disk:
|
| 176 |
-
return
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
"
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
except (OSError, AttributeError):
|
| 209 |
-
pass
|
| 210 |
-
|
| 211 |
-
shutil.move(tmp_path, self._cache_file)
|
| 212 |
-
self._stats["writes"] += 1
|
| 213 |
-
lib_logger.debug(
|
| 214 |
-
f"ProviderCache[{self._cache_name}]: Saved {len(self._cache)} entries"
|
| 215 |
-
)
|
| 216 |
-
except Exception:
|
| 217 |
-
if tmp_path and os.path.exists(tmp_path):
|
| 218 |
-
os.unlink(tmp_path)
|
| 219 |
-
raise
|
| 220 |
-
except Exception as e:
|
| 221 |
-
lib_logger.error(f"ProviderCache[{self._cache_name}]: Disk save failed: {e}")
|
| 222 |
-
|
| 223 |
# =========================================================================
|
| 224 |
# BACKGROUND TASKS
|
| 225 |
# =========================================================================
|
| 226 |
-
|
| 227 |
async def _start_background_tasks(self) -> None:
|
| 228 |
"""Start background writer and cleanup tasks."""
|
| 229 |
if not self._enable_disk or self._running:
|
| 230 |
return
|
| 231 |
-
|
| 232 |
self._running = True
|
| 233 |
self._writer_task = asyncio.create_task(self._writer_loop())
|
| 234 |
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
| 235 |
lib_logger.debug(f"ProviderCache[{self._cache_name}]: Started background tasks")
|
| 236 |
-
|
| 237 |
async def _writer_loop(self) -> None:
|
| 238 |
"""Background task: periodically flush dirty cache to disk."""
|
| 239 |
try:
|
|
@@ -241,13 +257,17 @@ class ProviderCache:
|
|
| 241 |
await asyncio.sleep(self._write_interval)
|
| 242 |
if self._dirty:
|
| 243 |
try:
|
| 244 |
-
await self._save_to_disk()
|
| 245 |
-
|
|
|
|
|
|
|
| 246 |
except Exception as e:
|
| 247 |
-
lib_logger.error(
|
|
|
|
|
|
|
| 248 |
except asyncio.CancelledError:
|
| 249 |
pass
|
| 250 |
-
|
| 251 |
async def _cleanup_loop(self) -> None:
|
| 252 |
"""Background task: periodically clean up expired entries."""
|
| 253 |
try:
|
|
@@ -256,12 +276,14 @@ class ProviderCache:
|
|
| 256 |
await self._cleanup_expired()
|
| 257 |
except asyncio.CancelledError:
|
| 258 |
pass
|
| 259 |
-
|
| 260 |
async def _cleanup_expired(self) -> None:
|
| 261 |
"""Remove expired entries from memory cache."""
|
| 262 |
async with self._lock:
|
| 263 |
now = time.time()
|
| 264 |
-
expired = [
|
|
|
|
|
|
|
| 265 |
for k in expired:
|
| 266 |
del self._cache[k]
|
| 267 |
if expired:
|
|
@@ -269,42 +291,42 @@ class ProviderCache:
|
|
| 269 |
lib_logger.debug(
|
| 270 |
f"ProviderCache[{self._cache_name}]: Cleaned {len(expired)} expired entries"
|
| 271 |
)
|
| 272 |
-
|
| 273 |
# =========================================================================
|
| 274 |
# CORE OPERATIONS
|
| 275 |
# =========================================================================
|
| 276 |
-
|
| 277 |
def store(self, key: str, value: str) -> None:
|
| 278 |
"""
|
| 279 |
Store a value synchronously (schedules async storage).
|
| 280 |
-
|
| 281 |
Args:
|
| 282 |
key: Cache key
|
| 283 |
value: Value to store (typically JSON-serialized data)
|
| 284 |
"""
|
| 285 |
asyncio.create_task(self._async_store(key, value))
|
| 286 |
-
|
| 287 |
async def _async_store(self, key: str, value: str) -> None:
|
| 288 |
"""Async implementation of store."""
|
| 289 |
async with self._lock:
|
| 290 |
self._cache[key] = (value, time.time())
|
| 291 |
self._dirty = True
|
| 292 |
-
|
| 293 |
async def store_async(self, key: str, value: str) -> None:
|
| 294 |
"""
|
| 295 |
Store a value asynchronously (awaitable).
|
| 296 |
-
|
| 297 |
Use this when you need to ensure the value is stored before continuing.
|
| 298 |
"""
|
| 299 |
await self._async_store(key, value)
|
| 300 |
-
|
| 301 |
def retrieve(self, key: str) -> Optional[str]:
|
| 302 |
"""
|
| 303 |
Retrieve a value by key (synchronous, with optional async disk fallback).
|
| 304 |
-
|
| 305 |
Args:
|
| 306 |
key: Cache key
|
| 307 |
-
|
| 308 |
Returns:
|
| 309 |
Cached value if found and not expired, None otherwise
|
| 310 |
"""
|
|
@@ -316,17 +338,17 @@ class ProviderCache:
|
|
| 316 |
else:
|
| 317 |
del self._cache[key]
|
| 318 |
self._dirty = True
|
| 319 |
-
|
| 320 |
self._stats["misses"] += 1
|
| 321 |
if self._enable_disk:
|
| 322 |
# Schedule async disk lookup for next time
|
| 323 |
asyncio.create_task(self._check_disk_fallback(key))
|
| 324 |
return None
|
| 325 |
-
|
| 326 |
async def retrieve_async(self, key: str) -> Optional[str]:
|
| 327 |
"""
|
| 328 |
Retrieve a value asynchronously (checks disk if not in memory).
|
| 329 |
-
|
| 330 |
Use this when you can await and need guaranteed disk fallback.
|
| 331 |
"""
|
| 332 |
# Check memory first
|
|
@@ -340,24 +362,24 @@ class ProviderCache:
|
|
| 340 |
if key in self._cache:
|
| 341 |
del self._cache[key]
|
| 342 |
self._dirty = True
|
| 343 |
-
|
| 344 |
# Check disk
|
| 345 |
if self._enable_disk:
|
| 346 |
return await self._disk_retrieve(key)
|
| 347 |
-
|
| 348 |
self._stats["misses"] += 1
|
| 349 |
return None
|
| 350 |
-
|
| 351 |
async def _check_disk_fallback(self, key: str) -> None:
|
| 352 |
"""Check disk for key and load into memory if found (background)."""
|
| 353 |
try:
|
| 354 |
if not self._cache_file.exists():
|
| 355 |
return
|
| 356 |
-
|
| 357 |
async with self._disk_lock:
|
| 358 |
-
with open(self._cache_file,
|
| 359 |
data = json.load(f)
|
| 360 |
-
|
| 361 |
entries = data.get("entries", {})
|
| 362 |
if key in entries:
|
| 363 |
entry = entries[key]
|
|
@@ -372,19 +394,21 @@ class ProviderCache:
|
|
| 372 |
f"ProviderCache[{self._cache_name}]: Loaded {key} from disk"
|
| 373 |
)
|
| 374 |
except Exception as e:
|
| 375 |
-
lib_logger.debug(
|
| 376 |
-
|
|
|
|
|
|
|
| 377 |
async def _disk_retrieve(self, key: str) -> Optional[str]:
|
| 378 |
"""Direct disk retrieval with loading into memory."""
|
| 379 |
try:
|
| 380 |
if not self._cache_file.exists():
|
| 381 |
self._stats["misses"] += 1
|
| 382 |
return None
|
| 383 |
-
|
| 384 |
async with self._disk_lock:
|
| 385 |
-
with open(self._cache_file,
|
| 386 |
data = json.load(f)
|
| 387 |
-
|
| 388 |
entries = data.get("entries", {})
|
| 389 |
if key in entries:
|
| 390 |
entry = entries[key]
|
|
@@ -396,34 +420,37 @@ class ProviderCache:
|
|
| 396 |
self._cache[key] = (value, ts)
|
| 397 |
self._stats["disk_hits"] += 1
|
| 398 |
return value
|
| 399 |
-
|
| 400 |
self._stats["misses"] += 1
|
| 401 |
return None
|
| 402 |
except Exception as e:
|
| 403 |
-
lib_logger.debug(
|
|
|
|
|
|
|
| 404 |
self._stats["misses"] += 1
|
| 405 |
return None
|
| 406 |
-
|
| 407 |
# =========================================================================
|
| 408 |
# UTILITY METHODS
|
| 409 |
# =========================================================================
|
| 410 |
-
|
| 411 |
def contains(self, key: str) -> bool:
|
| 412 |
"""Check if key exists in memory cache (without updating stats)."""
|
| 413 |
if key in self._cache:
|
| 414 |
_, timestamp = self._cache[key]
|
| 415 |
return time.time() - timestamp <= self._memory_ttl
|
| 416 |
return False
|
| 417 |
-
|
| 418 |
def get_stats(self) -> Dict[str, Any]:
|
| 419 |
-
"""Get cache statistics."""
|
| 420 |
return {
|
| 421 |
**self._stats,
|
| 422 |
"memory_entries": len(self._cache),
|
| 423 |
"dirty": self._dirty,
|
| 424 |
-
"disk_enabled": self._enable_disk
|
|
|
|
| 425 |
}
|
| 426 |
-
|
| 427 |
async def clear(self) -> None:
|
| 428 |
"""Clear all cached data."""
|
| 429 |
async with self._lock:
|
|
@@ -431,12 +458,12 @@ class ProviderCache:
|
|
| 431 |
self._dirty = True
|
| 432 |
if self._enable_disk:
|
| 433 |
await self._save_to_disk()
|
| 434 |
-
|
| 435 |
async def shutdown(self) -> None:
|
| 436 |
"""Graceful shutdown: flush pending writes and stop background tasks."""
|
| 437 |
lib_logger.info(f"ProviderCache[{self._cache_name}]: Shutting down...")
|
| 438 |
self._running = False
|
| 439 |
-
|
| 440 |
# Cancel background tasks
|
| 441 |
for task in (self._writer_task, self._cleanup_task):
|
| 442 |
if task:
|
|
@@ -445,11 +472,11 @@ class ProviderCache:
|
|
| 445 |
await task
|
| 446 |
except asyncio.CancelledError:
|
| 447 |
pass
|
| 448 |
-
|
| 449 |
# Final save
|
| 450 |
if self._dirty and self._enable_disk:
|
| 451 |
await self._save_to_disk()
|
| 452 |
-
|
| 453 |
lib_logger.info(
|
| 454 |
f"ProviderCache[{self._cache_name}]: Shutdown complete "
|
| 455 |
f"(stats: mem_hits={self._stats['memory_hits']}, "
|
|
@@ -461,38 +488,39 @@ class ProviderCache:
|
|
| 461 |
# CONVENIENCE FACTORY
|
| 462 |
# =============================================================================
|
| 463 |
|
|
|
|
| 464 |
def create_provider_cache(
|
| 465 |
name: str,
|
| 466 |
cache_dir: Optional[Path] = None,
|
| 467 |
memory_ttl_seconds: int = 3600,
|
| 468 |
disk_ttl_seconds: int = 86400,
|
| 469 |
-
env_prefix: Optional[str] = None
|
| 470 |
) -> ProviderCache:
|
| 471 |
"""
|
| 472 |
Factory function to create a provider cache with sensible defaults.
|
| 473 |
-
|
| 474 |
Args:
|
| 475 |
name: Cache name (used as filename and for logging)
|
| 476 |
cache_dir: Directory for cache file (default: project_root/cache/provider_name)
|
| 477 |
memory_ttl_seconds: In-memory TTL
|
| 478 |
disk_ttl_seconds: Disk TTL
|
| 479 |
env_prefix: Environment variable prefix (default: derived from name)
|
| 480 |
-
|
| 481 |
Returns:
|
| 482 |
Configured ProviderCache instance
|
| 483 |
"""
|
| 484 |
if cache_dir is None:
|
| 485 |
cache_dir = Path(__file__).resolve().parent.parent.parent.parent / "cache"
|
| 486 |
-
|
| 487 |
cache_file = cache_dir / f"{name}.json"
|
| 488 |
-
|
| 489 |
if env_prefix is None:
|
| 490 |
# Convert name to env prefix: "gemini3_signatures" -> "GEMINI3_SIGNATURES_CACHE"
|
| 491 |
env_prefix = f"{name.upper().replace('-', '_')}_CACHE"
|
| 492 |
-
|
| 493 |
return ProviderCache(
|
| 494 |
cache_file=cache_file,
|
| 495 |
memory_ttl_seconds=memory_ttl_seconds,
|
| 496 |
disk_ttl_seconds=disk_ttl_seconds,
|
| 497 |
-
env_prefix=env_prefix
|
| 498 |
)
|
|
|
|
| 20 |
import json
|
| 21 |
import logging
|
| 22 |
import os
|
|
|
|
|
|
|
| 23 |
import time
|
| 24 |
from pathlib import Path
|
| 25 |
from typing import Any, Dict, Optional, Tuple
|
| 26 |
|
| 27 |
+
from ..utils.resilient_io import safe_write_json
|
| 28 |
+
|
| 29 |
+
lib_logger = logging.getLogger("rotator_library")
|
| 30 |
|
| 31 |
|
| 32 |
# =============================================================================
|
| 33 |
# UTILITY FUNCTIONS
|
| 34 |
# =============================================================================
|
| 35 |
|
| 36 |
+
|
| 37 |
def _env_bool(key: str, default: bool = False) -> bool:
|
| 38 |
"""Get boolean from environment variable."""
|
| 39 |
return os.getenv(key, str(default).lower()).lower() in ("true", "1", "yes")
|
|
|
|
| 48 |
# PROVIDER CACHE CLASS
|
| 49 |
# =============================================================================
|
| 50 |
|
| 51 |
+
|
| 52 |
class ProviderCache:
|
| 53 |
"""
|
| 54 |
Server-side cache for provider conversation state preservation.
|
| 55 |
+
|
| 56 |
A generic, modular cache supporting any key-value data that providers need
|
| 57 |
to persist across requests. Features:
|
| 58 |
+
|
| 59 |
- Dual-TTL system: configurable memory TTL, longer disk TTL
|
| 60 |
- Async disk persistence with batched writes
|
| 61 |
- Background cleanup task for expired entries
|
| 62 |
- Statistics tracking (hits, misses, writes)
|
| 63 |
+
|
| 64 |
Args:
|
| 65 |
cache_file: Path to disk cache file
|
| 66 |
memory_ttl_seconds: In-memory entry lifetime (default: 1 hour)
|
|
|
|
| 69 |
write_interval: Seconds between background disk writes (default: 60)
|
| 70 |
cleanup_interval: Seconds between expired entry cleanup (default: 30 min)
|
| 71 |
env_prefix: Environment variable prefix for configuration overrides
|
| 72 |
+
|
| 73 |
Environment Variables (with default prefix "PROVIDER_CACHE"):
|
| 74 |
{PREFIX}_ENABLE: Enable/disable disk persistence
|
| 75 |
{PREFIX}_WRITE_INTERVAL: Background write interval in seconds
|
| 76 |
{PREFIX}_CLEANUP_INTERVAL: Cleanup interval in seconds
|
| 77 |
"""
|
| 78 |
+
|
| 79 |
def __init__(
|
| 80 |
self,
|
| 81 |
cache_file: Path,
|
|
|
|
| 84 |
enable_disk: Optional[bool] = None,
|
| 85 |
write_interval: Optional[int] = None,
|
| 86 |
cleanup_interval: Optional[int] = None,
|
| 87 |
+
env_prefix: str = "PROVIDER_CACHE",
|
| 88 |
):
|
| 89 |
# In-memory cache: {cache_key: (data, timestamp)}
|
| 90 |
self._cache: Dict[str, Tuple[str, float]] = {}
|
|
|
|
| 92 |
self._disk_ttl = disk_ttl_seconds
|
| 93 |
self._lock = asyncio.Lock()
|
| 94 |
self._disk_lock = asyncio.Lock()
|
| 95 |
+
|
| 96 |
# Disk persistence configuration
|
| 97 |
self._cache_file = cache_file
|
| 98 |
+
self._enable_disk = (
|
| 99 |
+
enable_disk
|
| 100 |
+
if enable_disk is not None
|
| 101 |
+
else _env_bool(f"{env_prefix}_ENABLE", True)
|
| 102 |
+
)
|
| 103 |
self._dirty = False
|
| 104 |
+
self._write_interval = write_interval or _env_int(
|
| 105 |
+
f"{env_prefix}_WRITE_INTERVAL", 60
|
| 106 |
+
)
|
| 107 |
+
self._cleanup_interval = cleanup_interval or _env_int(
|
| 108 |
+
f"{env_prefix}_CLEANUP_INTERVAL", 1800
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
# Background tasks
|
| 112 |
self._writer_task: Optional[asyncio.Task] = None
|
| 113 |
self._cleanup_task: Optional[asyncio.Task] = None
|
| 114 |
self._running = False
|
| 115 |
+
|
| 116 |
# Statistics
|
| 117 |
+
self._stats = {
|
| 118 |
+
"memory_hits": 0,
|
| 119 |
+
"disk_hits": 0,
|
| 120 |
+
"misses": 0,
|
| 121 |
+
"writes": 0,
|
| 122 |
+
"disk_errors": 0,
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
# Track disk health for monitoring
|
| 126 |
+
self._disk_available = True
|
| 127 |
+
|
| 128 |
# Metadata about this cache instance
|
| 129 |
self._cache_name = cache_file.stem if cache_file else "unnamed"
|
| 130 |
+
|
| 131 |
if self._enable_disk:
|
| 132 |
lib_logger.debug(
|
| 133 |
f"ProviderCache[{self._cache_name}]: Disk enabled "
|
|
|
|
| 136 |
asyncio.create_task(self._async_init())
|
| 137 |
else:
|
| 138 |
lib_logger.debug(f"ProviderCache[{self._cache_name}]: Memory-only mode")
|
| 139 |
+
|
| 140 |
# =========================================================================
|
| 141 |
# INITIALIZATION
|
| 142 |
# =========================================================================
|
| 143 |
+
|
| 144 |
async def _async_init(self) -> None:
|
| 145 |
"""Async initialization: load from disk and start background tasks."""
|
| 146 |
try:
|
| 147 |
await self._load_from_disk()
|
| 148 |
await self._start_background_tasks()
|
| 149 |
except Exception as e:
|
| 150 |
+
lib_logger.error(
|
| 151 |
+
f"ProviderCache[{self._cache_name}] async init failed: {e}"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
async def _load_from_disk(self) -> None:
|
| 155 |
"""Load cache from disk file with TTL validation."""
|
| 156 |
if not self._enable_disk or not self._cache_file.exists():
|
| 157 |
return
|
| 158 |
+
|
| 159 |
try:
|
| 160 |
async with self._disk_lock:
|
| 161 |
+
with open(self._cache_file, "r", encoding="utf-8") as f:
|
| 162 |
data = json.load(f)
|
| 163 |
+
|
| 164 |
if data.get("version") != "1.0":
|
| 165 |
+
lib_logger.warning(
|
| 166 |
+
f"ProviderCache[{self._cache_name}]: Version mismatch, starting fresh"
|
| 167 |
+
)
|
| 168 |
return
|
| 169 |
+
|
| 170 |
now = time.time()
|
| 171 |
entries = data.get("entries", {})
|
| 172 |
loaded = expired = 0
|
| 173 |
+
|
| 174 |
for cache_key, entry in entries.items():
|
| 175 |
age = now - entry.get("timestamp", 0)
|
| 176 |
if age <= self._disk_ttl:
|
| 177 |
+
value = entry.get(
|
| 178 |
+
"value", entry.get("signature", "")
|
| 179 |
+
) # Support both formats
|
| 180 |
if value:
|
| 181 |
self._cache[cache_key] = (value, entry["timestamp"])
|
| 182 |
loaded += 1
|
| 183 |
else:
|
| 184 |
expired += 1
|
| 185 |
+
|
| 186 |
lib_logger.debug(
|
| 187 |
f"ProviderCache[{self._cache_name}]: Loaded {loaded} entries ({expired} expired)"
|
| 188 |
)
|
| 189 |
except json.JSONDecodeError as e:
|
| 190 |
+
lib_logger.warning(
|
| 191 |
+
f"ProviderCache[{self._cache_name}]: File corrupted: {e}"
|
| 192 |
+
)
|
| 193 |
except Exception as e:
|
| 194 |
lib_logger.error(f"ProviderCache[{self._cache_name}]: Load failed: {e}")
|
| 195 |
+
|
| 196 |
# =========================================================================
|
| 197 |
# DISK PERSISTENCE
|
| 198 |
# =========================================================================
|
| 199 |
+
|
| 200 |
+
async def _save_to_disk(self) -> bool:
|
| 201 |
+
"""Persist cache to disk using atomic write with health tracking.
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
True if write succeeded, False otherwise.
|
| 205 |
+
"""
|
| 206 |
if not self._enable_disk:
|
| 207 |
+
return True # Not an error if disk is disabled
|
| 208 |
+
|
| 209 |
+
async with self._disk_lock:
|
| 210 |
+
cache_data = {
|
| 211 |
+
"version": "1.0",
|
| 212 |
+
"memory_ttl_seconds": self._memory_ttl,
|
| 213 |
+
"disk_ttl_seconds": self._disk_ttl,
|
| 214 |
+
"entries": {
|
| 215 |
+
key: {"value": val, "timestamp": ts}
|
| 216 |
+
for key, (val, ts) in self._cache.items()
|
| 217 |
+
},
|
| 218 |
+
"statistics": {
|
| 219 |
+
"total_entries": len(self._cache),
|
| 220 |
+
"last_write": time.time(),
|
| 221 |
+
**self._stats,
|
| 222 |
+
},
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
if safe_write_json(
|
| 226 |
+
self._cache_file, cache_data, lib_logger, secure_permissions=True
|
| 227 |
+
):
|
| 228 |
+
self._stats["writes"] += 1
|
| 229 |
+
self._disk_available = True
|
| 230 |
+
lib_logger.debug(
|
| 231 |
+
f"ProviderCache[{self._cache_name}]: Saved {len(self._cache)} entries"
|
| 232 |
+
)
|
| 233 |
+
return True
|
| 234 |
+
else:
|
| 235 |
+
self._stats["disk_errors"] += 1
|
| 236 |
+
self._disk_available = False
|
| 237 |
+
return False
|
| 238 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
# =========================================================================
|
| 240 |
# BACKGROUND TASKS
|
| 241 |
# =========================================================================
|
| 242 |
+
|
| 243 |
async def _start_background_tasks(self) -> None:
|
| 244 |
"""Start background writer and cleanup tasks."""
|
| 245 |
if not self._enable_disk or self._running:
|
| 246 |
return
|
| 247 |
+
|
| 248 |
self._running = True
|
| 249 |
self._writer_task = asyncio.create_task(self._writer_loop())
|
| 250 |
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
| 251 |
lib_logger.debug(f"ProviderCache[{self._cache_name}]: Started background tasks")
|
| 252 |
+
|
| 253 |
async def _writer_loop(self) -> None:
|
| 254 |
"""Background task: periodically flush dirty cache to disk."""
|
| 255 |
try:
|
|
|
|
| 257 |
await asyncio.sleep(self._write_interval)
|
| 258 |
if self._dirty:
|
| 259 |
try:
|
| 260 |
+
success = await self._save_to_disk()
|
| 261 |
+
if success:
|
| 262 |
+
self._dirty = False
|
| 263 |
+
# If save failed, _dirty remains True so we retry next interval
|
| 264 |
except Exception as e:
|
| 265 |
+
lib_logger.error(
|
| 266 |
+
f"ProviderCache[{self._cache_name}]: Writer error: {e}"
|
| 267 |
+
)
|
| 268 |
except asyncio.CancelledError:
|
| 269 |
pass
|
| 270 |
+
|
| 271 |
async def _cleanup_loop(self) -> None:
|
| 272 |
"""Background task: periodically clean up expired entries."""
|
| 273 |
try:
|
|
|
|
| 276 |
await self._cleanup_expired()
|
| 277 |
except asyncio.CancelledError:
|
| 278 |
pass
|
| 279 |
+
|
| 280 |
async def _cleanup_expired(self) -> None:
|
| 281 |
"""Remove expired entries from memory cache."""
|
| 282 |
async with self._lock:
|
| 283 |
now = time.time()
|
| 284 |
+
expired = [
|
| 285 |
+
k for k, (_, ts) in self._cache.items() if now - ts > self._memory_ttl
|
| 286 |
+
]
|
| 287 |
for k in expired:
|
| 288 |
del self._cache[k]
|
| 289 |
if expired:
|
|
|
|
| 291 |
lib_logger.debug(
|
| 292 |
f"ProviderCache[{self._cache_name}]: Cleaned {len(expired)} expired entries"
|
| 293 |
)
|
| 294 |
+
|
| 295 |
# =========================================================================
|
| 296 |
# CORE OPERATIONS
|
| 297 |
# =========================================================================
|
| 298 |
+
|
| 299 |
def store(self, key: str, value: str) -> None:
|
| 300 |
"""
|
| 301 |
Store a value synchronously (schedules async storage).
|
| 302 |
+
|
| 303 |
Args:
|
| 304 |
key: Cache key
|
| 305 |
value: Value to store (typically JSON-serialized data)
|
| 306 |
"""
|
| 307 |
asyncio.create_task(self._async_store(key, value))
|
| 308 |
+
|
| 309 |
async def _async_store(self, key: str, value: str) -> None:
|
| 310 |
"""Async implementation of store."""
|
| 311 |
async with self._lock:
|
| 312 |
self._cache[key] = (value, time.time())
|
| 313 |
self._dirty = True
|
| 314 |
+
|
| 315 |
async def store_async(self, key: str, value: str) -> None:
|
| 316 |
"""
|
| 317 |
Store a value asynchronously (awaitable).
|
| 318 |
+
|
| 319 |
Use this when you need to ensure the value is stored before continuing.
|
| 320 |
"""
|
| 321 |
await self._async_store(key, value)
|
| 322 |
+
|
| 323 |
def retrieve(self, key: str) -> Optional[str]:
|
| 324 |
"""
|
| 325 |
Retrieve a value by key (synchronous, with optional async disk fallback).
|
| 326 |
+
|
| 327 |
Args:
|
| 328 |
key: Cache key
|
| 329 |
+
|
| 330 |
Returns:
|
| 331 |
Cached value if found and not expired, None otherwise
|
| 332 |
"""
|
|
|
|
| 338 |
else:
|
| 339 |
del self._cache[key]
|
| 340 |
self._dirty = True
|
| 341 |
+
|
| 342 |
self._stats["misses"] += 1
|
| 343 |
if self._enable_disk:
|
| 344 |
# Schedule async disk lookup for next time
|
| 345 |
asyncio.create_task(self._check_disk_fallback(key))
|
| 346 |
return None
|
| 347 |
+
|
| 348 |
async def retrieve_async(self, key: str) -> Optional[str]:
|
| 349 |
"""
|
| 350 |
Retrieve a value asynchronously (checks disk if not in memory).
|
| 351 |
+
|
| 352 |
Use this when you can await and need guaranteed disk fallback.
|
| 353 |
"""
|
| 354 |
# Check memory first
|
|
|
|
| 362 |
if key in self._cache:
|
| 363 |
del self._cache[key]
|
| 364 |
self._dirty = True
|
| 365 |
+
|
| 366 |
# Check disk
|
| 367 |
if self._enable_disk:
|
| 368 |
return await self._disk_retrieve(key)
|
| 369 |
+
|
| 370 |
self._stats["misses"] += 1
|
| 371 |
return None
|
| 372 |
+
|
| 373 |
async def _check_disk_fallback(self, key: str) -> None:
|
| 374 |
"""Check disk for key and load into memory if found (background)."""
|
| 375 |
try:
|
| 376 |
if not self._cache_file.exists():
|
| 377 |
return
|
| 378 |
+
|
| 379 |
async with self._disk_lock:
|
| 380 |
+
with open(self._cache_file, "r", encoding="utf-8") as f:
|
| 381 |
data = json.load(f)
|
| 382 |
+
|
| 383 |
entries = data.get("entries", {})
|
| 384 |
if key in entries:
|
| 385 |
entry = entries[key]
|
|
|
|
| 394 |
f"ProviderCache[{self._cache_name}]: Loaded {key} from disk"
|
| 395 |
)
|
| 396 |
except Exception as e:
|
| 397 |
+
lib_logger.debug(
|
| 398 |
+
f"ProviderCache[{self._cache_name}]: Disk fallback failed: {e}"
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
async def _disk_retrieve(self, key: str) -> Optional[str]:
|
| 402 |
"""Direct disk retrieval with loading into memory."""
|
| 403 |
try:
|
| 404 |
if not self._cache_file.exists():
|
| 405 |
self._stats["misses"] += 1
|
| 406 |
return None
|
| 407 |
+
|
| 408 |
async with self._disk_lock:
|
| 409 |
+
with open(self._cache_file, "r", encoding="utf-8") as f:
|
| 410 |
data = json.load(f)
|
| 411 |
+
|
| 412 |
entries = data.get("entries", {})
|
| 413 |
if key in entries:
|
| 414 |
entry = entries[key]
|
|
|
|
| 420 |
self._cache[key] = (value, ts)
|
| 421 |
self._stats["disk_hits"] += 1
|
| 422 |
return value
|
| 423 |
+
|
| 424 |
self._stats["misses"] += 1
|
| 425 |
return None
|
| 426 |
except Exception as e:
|
| 427 |
+
lib_logger.debug(
|
| 428 |
+
f"ProviderCache[{self._cache_name}]: Disk retrieve failed: {e}"
|
| 429 |
+
)
|
| 430 |
self._stats["misses"] += 1
|
| 431 |
return None
|
| 432 |
+
|
| 433 |
# =========================================================================
|
| 434 |
# UTILITY METHODS
|
| 435 |
# =========================================================================
|
| 436 |
+
|
| 437 |
def contains(self, key: str) -> bool:
|
| 438 |
"""Check if key exists in memory cache (without updating stats)."""
|
| 439 |
if key in self._cache:
|
| 440 |
_, timestamp = self._cache[key]
|
| 441 |
return time.time() - timestamp <= self._memory_ttl
|
| 442 |
return False
|
| 443 |
+
|
| 444 |
def get_stats(self) -> Dict[str, Any]:
|
| 445 |
+
"""Get cache statistics including disk health."""
|
| 446 |
return {
|
| 447 |
**self._stats,
|
| 448 |
"memory_entries": len(self._cache),
|
| 449 |
"dirty": self._dirty,
|
| 450 |
+
"disk_enabled": self._enable_disk,
|
| 451 |
+
"disk_available": self._disk_available,
|
| 452 |
}
|
| 453 |
+
|
| 454 |
async def clear(self) -> None:
|
| 455 |
"""Clear all cached data."""
|
| 456 |
async with self._lock:
|
|
|
|
| 458 |
self._dirty = True
|
| 459 |
if self._enable_disk:
|
| 460 |
await self._save_to_disk()
|
| 461 |
+
|
| 462 |
async def shutdown(self) -> None:
|
| 463 |
"""Graceful shutdown: flush pending writes and stop background tasks."""
|
| 464 |
lib_logger.info(f"ProviderCache[{self._cache_name}]: Shutting down...")
|
| 465 |
self._running = False
|
| 466 |
+
|
| 467 |
# Cancel background tasks
|
| 468 |
for task in (self._writer_task, self._cleanup_task):
|
| 469 |
if task:
|
|
|
|
| 472 |
await task
|
| 473 |
except asyncio.CancelledError:
|
| 474 |
pass
|
| 475 |
+
|
| 476 |
# Final save
|
| 477 |
if self._dirty and self._enable_disk:
|
| 478 |
await self._save_to_disk()
|
| 479 |
+
|
| 480 |
lib_logger.info(
|
| 481 |
f"ProviderCache[{self._cache_name}]: Shutdown complete "
|
| 482 |
f"(stats: mem_hits={self._stats['memory_hits']}, "
|
|
|
|
| 488 |
# CONVENIENCE FACTORY
|
| 489 |
# =============================================================================
|
| 490 |
|
| 491 |
+
|
| 492 |
def create_provider_cache(
|
| 493 |
name: str,
|
| 494 |
cache_dir: Optional[Path] = None,
|
| 495 |
memory_ttl_seconds: int = 3600,
|
| 496 |
disk_ttl_seconds: int = 86400,
|
| 497 |
+
env_prefix: Optional[str] = None,
|
| 498 |
) -> ProviderCache:
|
| 499 |
"""
|
| 500 |
Factory function to create a provider cache with sensible defaults.
|
| 501 |
+
|
| 502 |
Args:
|
| 503 |
name: Cache name (used as filename and for logging)
|
| 504 |
cache_dir: Directory for cache file (default: project_root/cache/provider_name)
|
| 505 |
memory_ttl_seconds: In-memory TTL
|
| 506 |
disk_ttl_seconds: Disk TTL
|
| 507 |
env_prefix: Environment variable prefix (default: derived from name)
|
| 508 |
+
|
| 509 |
Returns:
|
| 510 |
Configured ProviderCache instance
|
| 511 |
"""
|
| 512 |
if cache_dir is None:
|
| 513 |
cache_dir = Path(__file__).resolve().parent.parent.parent.parent / "cache"
|
| 514 |
+
|
| 515 |
cache_file = cache_dir / f"{name}.json"
|
| 516 |
+
|
| 517 |
if env_prefix is None:
|
| 518 |
# Convert name to env prefix: "gemini3_signatures" -> "GEMINI3_SIGNATURES_CACHE"
|
| 519 |
env_prefix = f"{name.upper().replace('-', '_')}_CACHE"
|
| 520 |
+
|
| 521 |
return ProviderCache(
|
| 522 |
cache_file=cache_file,
|
| 523 |
memory_ttl_seconds=memory_ttl_seconds,
|
| 524 |
disk_ttl_seconds=disk_ttl_seconds,
|
| 525 |
+
env_prefix=env_prefix,
|
| 526 |
)
|
src/rotator_library/providers/qwen_auth_base.py
CHANGED
|
@@ -9,10 +9,11 @@ import asyncio
|
|
| 9 |
import logging
|
| 10 |
import webbrowser
|
| 11 |
import os
|
|
|
|
|
|
|
| 12 |
from pathlib import Path
|
| 13 |
-
from
|
| 14 |
-
import
|
| 15 |
-
import shutil
|
| 16 |
|
| 17 |
import httpx
|
| 18 |
from rich.console import Console
|
|
@@ -23,6 +24,7 @@ from rich.markup import escape as rich_escape
|
|
| 23 |
|
| 24 |
from ..utils.headless_detection import is_headless_environment
|
| 25 |
from ..utils.reauth_coordinator import get_reauth_coordinator
|
|
|
|
| 26 |
|
| 27 |
lib_logger = logging.getLogger("rotator_library")
|
| 28 |
|
|
@@ -36,6 +38,20 @@ REFRESH_EXPIRY_BUFFER_SECONDS = 3 * 60 * 60 # 3 hours buffer before expiry
|
|
| 36 |
console = Console()
|
| 37 |
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
class QwenAuthBase:
|
| 40 |
def __init__(self):
|
| 41 |
self._credentials_cache: Dict[str, Dict[str, Any]] = {}
|
|
@@ -51,19 +67,36 @@ class QwenAuthBase:
|
|
| 51 |
str, float
|
| 52 |
] = {} # Track backoff timers (Unix timestamp)
|
| 53 |
|
| 54 |
-
# [QUEUE SYSTEM] Sequential refresh processing
|
|
|
|
| 55 |
self._refresh_queue: asyncio.Queue = asyncio.Queue()
|
| 56 |
-
self.
|
| 57 |
-
|
| 58 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
self._unavailable_credentials: Dict[
|
| 60 |
str, float
|
| 61 |
] = {} # Maps credential path -> timestamp when marked unavailable
|
| 62 |
-
|
|
|
|
| 63 |
self._queue_tracking_lock = asyncio.Lock() # Protects queue sets
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
def _parse_env_credential_path(self, path: str) -> Optional[str]:
|
| 69 |
"""
|
|
@@ -188,81 +221,54 @@ class QwenAuthBase:
|
|
| 188 |
f"Environment variables for Qwen Code credential index {credential_index} not found"
|
| 189 |
)
|
| 190 |
|
| 191 |
-
#
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
)
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
|
|
|
|
|
|
| 202 |
|
| 203 |
async def _save_credentials(self, path: str, creds: Dict[str, Any]):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
# Don't save to file if credentials were loaded from environment
|
| 205 |
if creds.get("_proxy_metadata", {}).get("loaded_from_env"):
|
| 206 |
lib_logger.debug("Credentials loaded from env, skipping file save")
|
| 207 |
-
# Still update cache for in-memory consistency
|
| 208 |
-
self._credentials_cache[path] = creds
|
| 209 |
return
|
| 210 |
|
| 211 |
-
#
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
dir=parent_dir, prefix=".tmp_", suffix=".json", text=True
|
| 221 |
-
)
|
| 222 |
-
|
| 223 |
-
# Write JSON to temp file
|
| 224 |
-
with os.fdopen(tmp_fd, "w") as f:
|
| 225 |
-
json.dump(creds, f, indent=2)
|
| 226 |
-
tmp_fd = None # fdopen closes the fd
|
| 227 |
-
|
| 228 |
-
# Set secure permissions (0600 = owner read/write only)
|
| 229 |
-
try:
|
| 230 |
-
os.chmod(tmp_path, 0o600)
|
| 231 |
-
except (OSError, AttributeError):
|
| 232 |
-
# Windows may not support chmod, ignore
|
| 233 |
-
pass
|
| 234 |
-
|
| 235 |
-
# Atomic move (overwrites target if it exists)
|
| 236 |
-
shutil.move(tmp_path, path)
|
| 237 |
-
tmp_path = None # Successfully moved
|
| 238 |
-
|
| 239 |
-
# Update cache AFTER successful file write
|
| 240 |
-
self._credentials_cache[path] = creds
|
| 241 |
-
lib_logger.debug(
|
| 242 |
-
f"Saved updated Qwen OAuth credentials to '{path}' (atomic write)."
|
| 243 |
-
)
|
| 244 |
-
|
| 245 |
-
except Exception as e:
|
| 246 |
-
lib_logger.error(
|
| 247 |
-
f"Failed to save updated Qwen OAuth credentials to '{path}': {e}"
|
| 248 |
)
|
| 249 |
-
# Clean up temp file if it still exists
|
| 250 |
-
if tmp_fd is not None:
|
| 251 |
-
try:
|
| 252 |
-
os.close(tmp_fd)
|
| 253 |
-
except:
|
| 254 |
-
pass
|
| 255 |
-
if tmp_path and os.path.exists(tmp_path):
|
| 256 |
-
try:
|
| 257 |
-
os.unlink(tmp_path)
|
| 258 |
-
except:
|
| 259 |
-
pass
|
| 260 |
-
raise
|
| 261 |
|
| 262 |
def _is_token_expired(self, creds: Dict[str, Any]) -> bool:
|
| 263 |
expiry_timestamp = creds.get("expiry_date", 0) / 1000
|
| 264 |
return expiry_timestamp < time.time() + REFRESH_EXPIRY_BUFFER_SECONDS
|
| 265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
async def _refresh_token(self, path: str, force: bool = False) -> Dict[str, Any]:
|
| 267 |
async with await self._get_lock(path):
|
| 268 |
cached_creds = self._credentials_cache.get(path)
|
|
@@ -476,7 +482,7 @@ class QwenAuthBase:
|
|
| 476 |
Proactively refreshes tokens if they're close to expiry.
|
| 477 |
Only applies to OAuth credentials (file paths or env:// paths). Direct API keys are skipped.
|
| 478 |
"""
|
| 479 |
-
lib_logger.debug(f"proactively_refresh called for: {credential_identifier}")
|
| 480 |
|
| 481 |
# Try to load credentials - this will fail for direct API keys
|
| 482 |
# and succeed for OAuth credentials (file paths or env:// paths)
|
|
@@ -484,21 +490,21 @@ class QwenAuthBase:
|
|
| 484 |
creds = await self._load_credentials(credential_identifier)
|
| 485 |
except IOError as e:
|
| 486 |
# Not a valid credential path (likely a direct API key string)
|
| 487 |
-
lib_logger.debug(
|
| 488 |
-
|
| 489 |
-
)
|
| 490 |
return
|
| 491 |
|
| 492 |
is_expired = self._is_token_expired(creds)
|
| 493 |
-
lib_logger.debug(
|
| 494 |
-
|
| 495 |
-
)
|
| 496 |
|
| 497 |
if is_expired:
|
| 498 |
-
lib_logger.debug(
|
| 499 |
-
|
| 500 |
-
)
|
| 501 |
-
#
|
| 502 |
await self._queue_refresh(
|
| 503 |
credential_identifier, force=False, needs_reauth=False
|
| 504 |
)
|
|
@@ -511,30 +517,55 @@ class QwenAuthBase:
|
|
| 511 |
return self._refresh_locks[path]
|
| 512 |
|
| 513 |
def is_credential_available(self, path: str) -> bool:
|
| 514 |
-
"""Check if a credential is available for rotation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
|
|
|
| 519 |
"""
|
| 520 |
-
|
| 521 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
|
| 523 |
-
#
|
| 524 |
-
|
| 525 |
-
if
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
lib_logger.
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
|
|
|
| 533 |
)
|
| 534 |
-
|
| 535 |
-
return True
|
| 536 |
|
| 537 |
-
return
|
| 538 |
|
| 539 |
async def _ensure_queue_processor_running(self):
|
| 540 |
"""Lazily starts the queue processor if not already running."""
|
|
@@ -543,15 +574,27 @@ class QwenAuthBase:
|
|
| 543 |
self._process_refresh_queue()
|
| 544 |
)
|
| 545 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 546 |
async def _queue_refresh(
|
| 547 |
self, path: str, force: bool = False, needs_reauth: bool = False
|
| 548 |
):
|
| 549 |
-
"""Add a credential to the refresh queue if not already queued.
|
| 550 |
|
| 551 |
Args:
|
| 552 |
path: Credential file path
|
| 553 |
force: Force refresh even if not expired
|
| 554 |
-
needs_reauth: True if full re-authentication needed (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 555 |
"""
|
| 556 |
# IMPORTANT: Only check backoff for simple automated refreshes
|
| 557 |
# Re-authentication (interactive OAuth) should BYPASS backoff since it needs user input
|
|
@@ -561,114 +604,223 @@ class QwenAuthBase:
|
|
| 561 |
backoff_until = self._next_refresh_after[path]
|
| 562 |
if now < backoff_until:
|
| 563 |
# Credential is in backoff for automated refresh, do not queue
|
| 564 |
-
remaining = int(backoff_until - now)
|
| 565 |
-
lib_logger.debug(
|
| 566 |
-
|
| 567 |
-
)
|
| 568 |
return
|
| 569 |
|
| 570 |
async with self._queue_tracking_lock:
|
| 571 |
if path not in self._queued_credentials:
|
| 572 |
self._queued_credentials.add(path)
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 581 |
|
| 582 |
async def _process_refresh_queue(self):
|
| 583 |
-
"""Background worker that processes refresh requests sequentially.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 584 |
while True:
|
| 585 |
path = None
|
| 586 |
try:
|
| 587 |
# Wait for an item with timeout to allow graceful shutdown
|
| 588 |
try:
|
| 589 |
-
path, force
|
| 590 |
self._refresh_queue.get(), timeout=60.0
|
| 591 |
)
|
| 592 |
except asyncio.TimeoutError:
|
| 593 |
-
#
|
| 594 |
-
# If we're idle for 60s, no refreshes are in progress
|
| 595 |
async with self._queue_tracking_lock:
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
lib_logger.warning(
|
| 599 |
-
f"Queue processor idle timeout. Cleaning {stale_count} "
|
| 600 |
-
f"stale unavailable credentials: {list(self._unavailable_credentials.keys())}"
|
| 601 |
-
)
|
| 602 |
-
self._unavailable_credentials.clear()
|
| 603 |
-
# [FIX BUG#6] Also clear queued credentials to prevent stuck state
|
| 604 |
-
if self._queued_credentials:
|
| 605 |
-
lib_logger.debug(
|
| 606 |
-
f"Clearing {len(self._queued_credentials)} queued credentials on timeout"
|
| 607 |
-
)
|
| 608 |
-
self._queued_credentials.clear()
|
| 609 |
self._queue_processor_task = None
|
|
|
|
| 610 |
return
|
| 611 |
|
| 612 |
try:
|
| 613 |
-
#
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
|
|
|
|
|
|
| 626 |
|
| 627 |
-
#
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
await self._refresh_token(path, force=force)
|
| 631 |
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
self.
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 638 |
)
|
| 639 |
|
|
|
|
|
|
|
|
|
|
| 640 |
finally:
|
| 641 |
-
#
|
| 642 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 643 |
async with self._queue_tracking_lock:
|
| 644 |
self._queued_credentials.discard(path)
|
| 645 |
-
# [FIX PR#34] Always clean up unavailable credentials in finally block
|
| 646 |
self._unavailable_credentials.pop(path, None)
|
| 647 |
-
lib_logger.debug(
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
)
|
| 651 |
-
self.
|
|
|
|
| 652 |
except asyncio.CancelledError:
|
| 653 |
-
#
|
| 654 |
if path:
|
| 655 |
async with self._queue_tracking_lock:
|
|
|
|
| 656 |
self._unavailable_credentials.pop(path, None)
|
| 657 |
-
|
| 658 |
-
f"CancelledError cleanup for '{Path(path).name}'. "
|
| 659 |
-
f"Remaining unavailable: {len(self._unavailable_credentials)}"
|
| 660 |
-
)
|
| 661 |
break
|
| 662 |
except Exception as e:
|
| 663 |
-
lib_logger.error(f"Error in queue processor: {e}")
|
| 664 |
-
# Even on error, mark as available (backoff will prevent immediate retry)
|
| 665 |
if path:
|
| 666 |
async with self._queue_tracking_lock:
|
|
|
|
| 667 |
self._unavailable_credentials.pop(path, None)
|
| 668 |
-
lib_logger.debug(
|
| 669 |
-
f"Error cleanup for '{Path(path).name}': {e}. "
|
| 670 |
-
f"Remaining unavailable: {len(self._unavailable_credentials)}"
|
| 671 |
-
)
|
| 672 |
|
| 673 |
async def _perform_interactive_oauth(
|
| 674 |
self, path: str, creds: Dict[str, Any], display_name: str
|
|
@@ -965,3 +1117,251 @@ class QwenAuthBase:
|
|
| 965 |
except Exception as e:
|
| 966 |
lib_logger.error(f"Failed to get Qwen user info from credentials: {e}")
|
| 967 |
return {"email": None}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
import logging
|
| 10 |
import webbrowser
|
| 11 |
import os
|
| 12 |
+
import re
|
| 13 |
+
from dataclasses import dataclass, field
|
| 14 |
from pathlib import Path
|
| 15 |
+
from glob import glob
|
| 16 |
+
from typing import Dict, Any, Tuple, Union, Optional, List
|
|
|
|
| 17 |
|
| 18 |
import httpx
|
| 19 |
from rich.console import Console
|
|
|
|
| 24 |
|
| 25 |
from ..utils.headless_detection import is_headless_environment
|
| 26 |
from ..utils.reauth_coordinator import get_reauth_coordinator
|
| 27 |
+
from ..utils.resilient_io import safe_write_json
|
| 28 |
|
| 29 |
lib_logger = logging.getLogger("rotator_library")
|
| 30 |
|
|
|
|
| 38 |
console = Console()
|
| 39 |
|
| 40 |
|
| 41 |
+
@dataclass
|
| 42 |
+
class QwenCredentialSetupResult:
|
| 43 |
+
"""
|
| 44 |
+
Standardized result structure for Qwen credential setup operations.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
success: bool
|
| 48 |
+
file_path: Optional[str] = None
|
| 49 |
+
email: Optional[str] = None
|
| 50 |
+
is_update: bool = False
|
| 51 |
+
error: Optional[str] = None
|
| 52 |
+
credentials: Optional[Dict[str, Any]] = field(default=None, repr=False)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
class QwenAuthBase:
|
| 56 |
def __init__(self):
|
| 57 |
self._credentials_cache: Dict[str, Dict[str, Any]] = {}
|
|
|
|
| 67 |
str, float
|
| 68 |
] = {} # Track backoff timers (Unix timestamp)
|
| 69 |
|
| 70 |
+
# [QUEUE SYSTEM] Sequential refresh processing with two separate queues
|
| 71 |
+
# Normal refresh queue: for proactive token refresh (old token still valid)
|
| 72 |
self._refresh_queue: asyncio.Queue = asyncio.Queue()
|
| 73 |
+
self._queue_processor_task: Optional[asyncio.Task] = None
|
| 74 |
+
|
| 75 |
+
# Re-auth queue: for invalid refresh tokens (requires user interaction)
|
| 76 |
+
self._reauth_queue: asyncio.Queue = asyncio.Queue()
|
| 77 |
+
self._reauth_processor_task: Optional[asyncio.Task] = None
|
| 78 |
+
|
| 79 |
+
# Tracking sets/dicts
|
| 80 |
+
self._queued_credentials: set = set() # Track credentials in either queue
|
| 81 |
+
# Only credentials in re-auth queue are marked unavailable (not normal refresh)
|
| 82 |
+
# TTL cleanup is defense-in-depth for edge cases where re-auth processor crashes
|
| 83 |
self._unavailable_credentials: Dict[
|
| 84 |
str, float
|
| 85 |
] = {} # Maps credential path -> timestamp when marked unavailable
|
| 86 |
+
# TTL should exceed reauth timeout (300s) to avoid premature cleanup
|
| 87 |
+
self._unavailable_ttl_seconds: int = 360 # 6 minutes TTL for stale entries
|
| 88 |
self._queue_tracking_lock = asyncio.Lock() # Protects queue sets
|
| 89 |
+
|
| 90 |
+
# Retry tracking for normal refresh queue
|
| 91 |
+
self._queue_retry_count: Dict[
|
| 92 |
+
str, int
|
| 93 |
+
] = {} # Track retry attempts per credential
|
| 94 |
+
|
| 95 |
+
# Configuration constants
|
| 96 |
+
self._refresh_timeout_seconds: int = 15 # Max time for single refresh
|
| 97 |
+
self._refresh_interval_seconds: int = 30 # Delay between queue items
|
| 98 |
+
self._refresh_max_retries: int = 3 # Attempts before kicked out
|
| 99 |
+
self._reauth_timeout_seconds: int = 300 # Time for user to complete OAuth
|
| 100 |
|
| 101 |
def _parse_env_credential_path(self, path: str) -> Optional[str]:
|
| 102 |
"""
|
|
|
|
| 221 |
f"Environment variables for Qwen Code credential index {credential_index} not found"
|
| 222 |
)
|
| 223 |
|
| 224 |
+
# Try file-based loading first (preferred for explicit file paths)
|
| 225 |
+
try:
|
| 226 |
+
return await self._read_creds_from_file(path)
|
| 227 |
+
except IOError:
|
| 228 |
+
# File not found - fall back to legacy env vars for backwards compatibility
|
| 229 |
+
env_creds = self._load_from_env()
|
| 230 |
+
if env_creds:
|
| 231 |
+
lib_logger.info(
|
| 232 |
+
f"File '{path}' not found, using Qwen Code credentials from environment variables"
|
| 233 |
+
)
|
| 234 |
+
self._credentials_cache[path] = env_creds
|
| 235 |
+
return env_creds
|
| 236 |
+
raise # Re-raise the original file not found error
|
| 237 |
|
| 238 |
async def _save_credentials(self, path: str, creds: Dict[str, Any]):
|
| 239 |
+
"""Save credentials with in-memory fallback if disk unavailable."""
|
| 240 |
+
# Always update cache first (memory is reliable)
|
| 241 |
+
self._credentials_cache[path] = creds
|
| 242 |
+
|
| 243 |
# Don't save to file if credentials were loaded from environment
|
| 244 |
if creds.get("_proxy_metadata", {}).get("loaded_from_env"):
|
| 245 |
lib_logger.debug("Credentials loaded from env, skipping file save")
|
|
|
|
|
|
|
| 246 |
return
|
| 247 |
|
| 248 |
+
# Attempt disk write - if it fails, we still have the cache
|
| 249 |
+
# buffer_on_failure ensures data is retried periodically and saved on shutdown
|
| 250 |
+
if safe_write_json(
|
| 251 |
+
path, creds, lib_logger, secure_permissions=True, buffer_on_failure=True
|
| 252 |
+
):
|
| 253 |
+
lib_logger.debug(f"Saved updated Qwen OAuth credentials to '{path}'.")
|
| 254 |
+
else:
|
| 255 |
+
lib_logger.warning(
|
| 256 |
+
"Qwen credentials cached in memory only (buffered for retry)."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
|
| 259 |
def _is_token_expired(self, creds: Dict[str, Any]) -> bool:
|
| 260 |
expiry_timestamp = creds.get("expiry_date", 0) / 1000
|
| 261 |
return expiry_timestamp < time.time() + REFRESH_EXPIRY_BUFFER_SECONDS
|
| 262 |
|
| 263 |
+
def _is_token_truly_expired(self, creds: Dict[str, Any]) -> bool:
|
| 264 |
+
"""Check if token is TRULY expired (past actual expiry, not just threshold).
|
| 265 |
+
|
| 266 |
+
This is different from _is_token_expired() which uses a buffer for proactive refresh.
|
| 267 |
+
This method checks if the token is actually unusable.
|
| 268 |
+
"""
|
| 269 |
+
expiry_timestamp = creds.get("expiry_date", 0) / 1000
|
| 270 |
+
return expiry_timestamp < time.time()
|
| 271 |
+
|
| 272 |
async def _refresh_token(self, path: str, force: bool = False) -> Dict[str, Any]:
|
| 273 |
async with await self._get_lock(path):
|
| 274 |
cached_creds = self._credentials_cache.get(path)
|
|
|
|
| 482 |
Proactively refreshes tokens if they're close to expiry.
|
| 483 |
Only applies to OAuth credentials (file paths or env:// paths). Direct API keys are skipped.
|
| 484 |
"""
|
| 485 |
+
# lib_logger.debug(f"proactively_refresh called for: {credential_identifier}")
|
| 486 |
|
| 487 |
# Try to load credentials - this will fail for direct API keys
|
| 488 |
# and succeed for OAuth credentials (file paths or env:// paths)
|
|
|
|
| 490 |
creds = await self._load_credentials(credential_identifier)
|
| 491 |
except IOError as e:
|
| 492 |
# Not a valid credential path (likely a direct API key string)
|
| 493 |
+
# lib_logger.debug(
|
| 494 |
+
# f"Skipping refresh for '{credential_identifier}' - not an OAuth credential: {e}"
|
| 495 |
+
# )
|
| 496 |
return
|
| 497 |
|
| 498 |
is_expired = self._is_token_expired(creds)
|
| 499 |
+
# lib_logger.debug(
|
| 500 |
+
# f"Token expired check for '{Path(credential_identifier).name}': {is_expired}"
|
| 501 |
+
# )
|
| 502 |
|
| 503 |
if is_expired:
|
| 504 |
+
# lib_logger.debug(
|
| 505 |
+
# f"Queueing refresh for '{Path(credential_identifier).name}'"
|
| 506 |
+
# )
|
| 507 |
+
# lib_logger.info(f"Proactive refresh triggered for '{Path(credential_identifier).name}'")
|
| 508 |
await self._queue_refresh(
|
| 509 |
credential_identifier, force=False, needs_reauth=False
|
| 510 |
)
|
|
|
|
| 517 |
return self._refresh_locks[path]
|
| 518 |
|
| 519 |
def is_credential_available(self, path: str) -> bool:
|
| 520 |
+
"""Check if a credential is available for rotation.
|
| 521 |
+
|
| 522 |
+
Credentials are unavailable if:
|
| 523 |
+
1. In re-auth queue (token is truly broken, requires user interaction)
|
| 524 |
+
2. Token is TRULY expired (past actual expiry, not just threshold)
|
| 525 |
+
|
| 526 |
+
Note: Credentials in normal refresh queue are still available because
|
| 527 |
+
the old token is valid until actual expiry.
|
| 528 |
|
| 529 |
+
TTL cleanup (defense-in-depth): If a credential has been in the re-auth
|
| 530 |
+
queue longer than _unavailable_ttl_seconds without being processed, it's
|
| 531 |
+
cleaned up. This should only happen if the re-auth processor crashes or
|
| 532 |
+
is cancelled without proper cleanup.
|
| 533 |
"""
|
| 534 |
+
# Check if in re-auth queue (truly unavailable)
|
| 535 |
+
if path in self._unavailable_credentials:
|
| 536 |
+
marked_time = self._unavailable_credentials.get(path)
|
| 537 |
+
if marked_time is not None:
|
| 538 |
+
now = time.time()
|
| 539 |
+
if now - marked_time > self._unavailable_ttl_seconds:
|
| 540 |
+
# Entry is stale - clean it up and return available
|
| 541 |
+
# This is a defense-in-depth for edge cases where re-auth
|
| 542 |
+
# processor crashed or was cancelled without cleanup
|
| 543 |
+
lib_logger.warning(
|
| 544 |
+
f"Credential '{Path(path).name}' stuck in re-auth queue for "
|
| 545 |
+
f"{int(now - marked_time)}s (TTL: {self._unavailable_ttl_seconds}s). "
|
| 546 |
+
f"Re-auth processor may have crashed. Auto-cleaning stale entry."
|
| 547 |
+
)
|
| 548 |
+
# Clean up both tracking structures for consistency
|
| 549 |
+
self._unavailable_credentials.pop(path, None)
|
| 550 |
+
self._queued_credentials.discard(path)
|
| 551 |
+
else:
|
| 552 |
+
return False # Still in re-auth, not available
|
| 553 |
|
| 554 |
+
# Check if token is TRULY expired (not just threshold-expired)
|
| 555 |
+
creds = self._credentials_cache.get(path)
|
| 556 |
+
if creds and self._is_token_truly_expired(creds):
|
| 557 |
+
# Token is actually expired - should not be used
|
| 558 |
+
# Queue for refresh if not already queued
|
| 559 |
+
if path not in self._queued_credentials:
|
| 560 |
+
# lib_logger.debug(
|
| 561 |
+
# f"Credential '{Path(path).name}' is truly expired, queueing for refresh"
|
| 562 |
+
# )
|
| 563 |
+
asyncio.create_task(
|
| 564 |
+
self._queue_refresh(path, force=True, needs_reauth=False)
|
| 565 |
)
|
| 566 |
+
return False
|
|
|
|
| 567 |
|
| 568 |
+
return True
|
| 569 |
|
| 570 |
async def _ensure_queue_processor_running(self):
|
| 571 |
"""Lazily starts the queue processor if not already running."""
|
|
|
|
| 574 |
self._process_refresh_queue()
|
| 575 |
)
|
| 576 |
|
| 577 |
+
async def _ensure_reauth_processor_running(self):
|
| 578 |
+
"""Lazily starts the re-auth queue processor if not already running."""
|
| 579 |
+
if self._reauth_processor_task is None or self._reauth_processor_task.done():
|
| 580 |
+
self._reauth_processor_task = asyncio.create_task(
|
| 581 |
+
self._process_reauth_queue()
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
async def _queue_refresh(
|
| 585 |
self, path: str, force: bool = False, needs_reauth: bool = False
|
| 586 |
):
|
| 587 |
+
"""Add a credential to the appropriate refresh queue if not already queued.
|
| 588 |
|
| 589 |
Args:
|
| 590 |
path: Credential file path
|
| 591 |
force: Force refresh even if not expired
|
| 592 |
+
needs_reauth: True if full re-authentication needed (routes to re-auth queue)
|
| 593 |
+
|
| 594 |
+
Queue routing:
|
| 595 |
+
- needs_reauth=True: Goes to re-auth queue, marks as unavailable
|
| 596 |
+
- needs_reauth=False: Goes to normal refresh queue, does NOT mark unavailable
|
| 597 |
+
(old token is still valid until actual expiry)
|
| 598 |
"""
|
| 599 |
# IMPORTANT: Only check backoff for simple automated refreshes
|
| 600 |
# Re-authentication (interactive OAuth) should BYPASS backoff since it needs user input
|
|
|
|
| 604 |
backoff_until = self._next_refresh_after[path]
|
| 605 |
if now < backoff_until:
|
| 606 |
# Credential is in backoff for automated refresh, do not queue
|
| 607 |
+
# remaining = int(backoff_until - now)
|
| 608 |
+
# lib_logger.debug(
|
| 609 |
+
# f"Skipping automated refresh for '{Path(path).name}' (in backoff for {remaining}s)"
|
| 610 |
+
# )
|
| 611 |
return
|
| 612 |
|
| 613 |
async with self._queue_tracking_lock:
|
| 614 |
if path not in self._queued_credentials:
|
| 615 |
self._queued_credentials.add(path)
|
| 616 |
+
|
| 617 |
+
if needs_reauth:
|
| 618 |
+
# Re-auth queue: mark as unavailable (token is truly broken)
|
| 619 |
+
self._unavailable_credentials[path] = time.time()
|
| 620 |
+
# lib_logger.debug(
|
| 621 |
+
# f"Queued '{Path(path).name}' for RE-AUTH (marked unavailable). "
|
| 622 |
+
# f"Total unavailable: {len(self._unavailable_credentials)}"
|
| 623 |
+
# )
|
| 624 |
+
await self._reauth_queue.put(path)
|
| 625 |
+
await self._ensure_reauth_processor_running()
|
| 626 |
+
else:
|
| 627 |
+
# Normal refresh queue: do NOT mark unavailable (old token still valid)
|
| 628 |
+
# lib_logger.debug(
|
| 629 |
+
# f"Queued '{Path(path).name}' for refresh (still available). "
|
| 630 |
+
# f"Queue size: {self._refresh_queue.qsize() + 1}"
|
| 631 |
+
# )
|
| 632 |
+
await self._refresh_queue.put((path, force))
|
| 633 |
+
await self._ensure_queue_processor_running()
|
| 634 |
|
| 635 |
async def _process_refresh_queue(self):
|
| 636 |
+
"""Background worker that processes normal refresh requests sequentially.
|
| 637 |
+
|
| 638 |
+
Key behaviors:
|
| 639 |
+
- 15s timeout per refresh operation
|
| 640 |
+
- 30s delay between processing credentials (prevents thundering herd)
|
| 641 |
+
- On failure: back of queue, max 3 retries before kicked
|
| 642 |
+
- If 401/403 detected: routes to re-auth queue
|
| 643 |
+
- Does NOT mark credentials unavailable (old token still valid)
|
| 644 |
+
"""
|
| 645 |
+
# lib_logger.info("Refresh queue processor started")
|
| 646 |
while True:
|
| 647 |
path = None
|
| 648 |
try:
|
| 649 |
# Wait for an item with timeout to allow graceful shutdown
|
| 650 |
try:
|
| 651 |
+
path, force = await asyncio.wait_for(
|
| 652 |
self._refresh_queue.get(), timeout=60.0
|
| 653 |
)
|
| 654 |
except asyncio.TimeoutError:
|
| 655 |
+
# Queue is empty and idle for 60s - clean up and exit
|
|
|
|
| 656 |
async with self._queue_tracking_lock:
|
| 657 |
+
# Clear any stale retry counts
|
| 658 |
+
self._queue_retry_count.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 659 |
self._queue_processor_task = None
|
| 660 |
+
# lib_logger.debug("Refresh queue processor idle, shutting down")
|
| 661 |
return
|
| 662 |
|
| 663 |
try:
|
| 664 |
+
# Quick check if still expired (optimization to avoid unnecessary refresh)
|
| 665 |
+
creds = self._credentials_cache.get(path)
|
| 666 |
+
if creds and not self._is_token_expired(creds):
|
| 667 |
+
# No longer expired, skip refresh
|
| 668 |
+
# lib_logger.debug(
|
| 669 |
+
# f"Credential '{Path(path).name}' no longer expired, skipping refresh"
|
| 670 |
+
# )
|
| 671 |
+
# Clear retry count on skip (not a failure)
|
| 672 |
+
self._queue_retry_count.pop(path, None)
|
| 673 |
+
continue
|
| 674 |
+
|
| 675 |
+
# Perform refresh with timeout
|
| 676 |
+
try:
|
| 677 |
+
async with asyncio.timeout(self._refresh_timeout_seconds):
|
| 678 |
+
await self._refresh_token(path, force=force)
|
| 679 |
|
| 680 |
+
# SUCCESS: Clear retry count
|
| 681 |
+
self._queue_retry_count.pop(path, None)
|
| 682 |
+
# lib_logger.info(f"Refresh SUCCESS for '{Path(path).name}'")
|
|
|
|
| 683 |
|
| 684 |
+
except asyncio.TimeoutError:
|
| 685 |
+
lib_logger.warning(
|
| 686 |
+
f"Refresh timeout ({self._refresh_timeout_seconds}s) for '{Path(path).name}'"
|
| 687 |
+
)
|
| 688 |
+
await self._handle_refresh_failure(path, force, "timeout")
|
| 689 |
+
|
| 690 |
+
except httpx.HTTPStatusError as e:
|
| 691 |
+
status_code = e.response.status_code
|
| 692 |
+
if status_code in (401, 403):
|
| 693 |
+
# Invalid refresh token - route to re-auth queue
|
| 694 |
+
lib_logger.warning(
|
| 695 |
+
f"Refresh token invalid for '{Path(path).name}' (HTTP {status_code}). "
|
| 696 |
+
f"Routing to re-auth queue."
|
| 697 |
+
)
|
| 698 |
+
self._queue_retry_count.pop(path, None) # Clear retry count
|
| 699 |
+
async with self._queue_tracking_lock:
|
| 700 |
+
self._queued_credentials.discard(
|
| 701 |
+
path
|
| 702 |
+
) # Remove from queued
|
| 703 |
+
await self._queue_refresh(
|
| 704 |
+
path, force=True, needs_reauth=True
|
| 705 |
+
)
|
| 706 |
+
else:
|
| 707 |
+
await self._handle_refresh_failure(
|
| 708 |
+
path, force, f"HTTP {status_code}"
|
| 709 |
)
|
| 710 |
|
| 711 |
+
except Exception as e:
|
| 712 |
+
await self._handle_refresh_failure(path, force, str(e))
|
| 713 |
+
|
| 714 |
finally:
|
| 715 |
+
# Remove from queued set (unless re-queued by failure handler)
|
| 716 |
+
async with self._queue_tracking_lock:
|
| 717 |
+
# Only discard if not re-queued (check if still in queue set from retry)
|
| 718 |
+
if (
|
| 719 |
+
path in self._queued_credentials
|
| 720 |
+
and self._queue_retry_count.get(path, 0) == 0
|
| 721 |
+
):
|
| 722 |
+
self._queued_credentials.discard(path)
|
| 723 |
+
self._refresh_queue.task_done()
|
| 724 |
+
|
| 725 |
+
# Wait between credentials to spread load
|
| 726 |
+
await asyncio.sleep(self._refresh_interval_seconds)
|
| 727 |
+
|
| 728 |
+
except asyncio.CancelledError:
|
| 729 |
+
# lib_logger.debug("Refresh queue processor cancelled")
|
| 730 |
+
break
|
| 731 |
+
except Exception as e:
|
| 732 |
+
lib_logger.error(f"Error in refresh queue processor: {e}")
|
| 733 |
+
if path:
|
| 734 |
+
async with self._queue_tracking_lock:
|
| 735 |
+
self._queued_credentials.discard(path)
|
| 736 |
+
|
| 737 |
+
async def _handle_refresh_failure(self, path: str, force: bool, error: str):
|
| 738 |
+
"""Handle a refresh failure with back-of-line retry logic.
|
| 739 |
+
|
| 740 |
+
- Increments retry count
|
| 741 |
+
- If under max retries: re-adds to END of queue
|
| 742 |
+
- If at max retries: kicks credential out (retried next BackgroundRefresher cycle)
|
| 743 |
+
"""
|
| 744 |
+
retry_count = self._queue_retry_count.get(path, 0) + 1
|
| 745 |
+
self._queue_retry_count[path] = retry_count
|
| 746 |
+
|
| 747 |
+
if retry_count >= self._refresh_max_retries:
|
| 748 |
+
# Kicked out until next BackgroundRefresher cycle
|
| 749 |
+
lib_logger.error(
|
| 750 |
+
f"Max retries ({self._refresh_max_retries}) reached for '{Path(path).name}' "
|
| 751 |
+
f"(last error: {error}). Will retry next refresh cycle."
|
| 752 |
+
)
|
| 753 |
+
self._queue_retry_count.pop(path, None)
|
| 754 |
+
async with self._queue_tracking_lock:
|
| 755 |
+
self._queued_credentials.discard(path)
|
| 756 |
+
return
|
| 757 |
+
|
| 758 |
+
# Re-add to END of queue for retry
|
| 759 |
+
lib_logger.warning(
|
| 760 |
+
f"Refresh failed for '{Path(path).name}' ({error}). "
|
| 761 |
+
f"Retry {retry_count}/{self._refresh_max_retries}, back of queue."
|
| 762 |
+
)
|
| 763 |
+
# Keep in queued_credentials set, add back to queue
|
| 764 |
+
await self._refresh_queue.put((path, force))
|
| 765 |
+
|
| 766 |
+
async def _process_reauth_queue(self):
|
| 767 |
+
"""Background worker that processes re-auth requests.
|
| 768 |
+
|
| 769 |
+
Key behaviors:
|
| 770 |
+
- Credentials ARE marked unavailable (token is truly broken)
|
| 771 |
+
- Uses ReauthCoordinator for interactive OAuth
|
| 772 |
+
- No automatic retry (requires user action)
|
| 773 |
+
- Cleans up unavailable status when done
|
| 774 |
+
"""
|
| 775 |
+
# lib_logger.info("Re-auth queue processor started")
|
| 776 |
+
while True:
|
| 777 |
+
path = None
|
| 778 |
+
try:
|
| 779 |
+
# Wait for an item with timeout to allow graceful shutdown
|
| 780 |
+
try:
|
| 781 |
+
path = await asyncio.wait_for(
|
| 782 |
+
self._reauth_queue.get(), timeout=60.0
|
| 783 |
+
)
|
| 784 |
+
except asyncio.TimeoutError:
|
| 785 |
+
# Queue is empty and idle for 60s - exit
|
| 786 |
+
self._reauth_processor_task = None
|
| 787 |
+
# lib_logger.debug("Re-auth queue processor idle, shutting down")
|
| 788 |
+
return
|
| 789 |
+
|
| 790 |
+
try:
|
| 791 |
+
lib_logger.info(f"Starting re-auth for '{Path(path).name}'...")
|
| 792 |
+
await self.initialize_token(path)
|
| 793 |
+
lib_logger.info(f"Re-auth SUCCESS for '{Path(path).name}'")
|
| 794 |
+
|
| 795 |
+
except Exception as e:
|
| 796 |
+
lib_logger.error(f"Re-auth FAILED for '{Path(path).name}': {e}")
|
| 797 |
+
# No automatic retry for re-auth (requires user action)
|
| 798 |
+
|
| 799 |
+
finally:
|
| 800 |
+
# Always clean up
|
| 801 |
async with self._queue_tracking_lock:
|
| 802 |
self._queued_credentials.discard(path)
|
|
|
|
| 803 |
self._unavailable_credentials.pop(path, None)
|
| 804 |
+
# lib_logger.debug(
|
| 805 |
+
# f"Re-auth cleanup for '{Path(path).name}'. "
|
| 806 |
+
# f"Remaining unavailable: {len(self._unavailable_credentials)}"
|
| 807 |
+
# )
|
| 808 |
+
self._reauth_queue.task_done()
|
| 809 |
+
|
| 810 |
except asyncio.CancelledError:
|
| 811 |
+
# Clean up current credential before breaking
|
| 812 |
if path:
|
| 813 |
async with self._queue_tracking_lock:
|
| 814 |
+
self._queued_credentials.discard(path)
|
| 815 |
self._unavailable_credentials.pop(path, None)
|
| 816 |
+
# lib_logger.debug("Re-auth queue processor cancelled")
|
|
|
|
|
|
|
|
|
|
| 817 |
break
|
| 818 |
except Exception as e:
|
| 819 |
+
lib_logger.error(f"Error in re-auth queue processor: {e}")
|
|
|
|
| 820 |
if path:
|
| 821 |
async with self._queue_tracking_lock:
|
| 822 |
+
self._queued_credentials.discard(path)
|
| 823 |
self._unavailable_credentials.pop(path, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 824 |
|
| 825 |
async def _perform_interactive_oauth(
|
| 826 |
self, path: str, creds: Dict[str, Any], display_name: str
|
|
|
|
| 1117 |
except Exception as e:
|
| 1118 |
lib_logger.error(f"Failed to get Qwen user info from credentials: {e}")
|
| 1119 |
return {"email": None}
|
| 1120 |
+
|
| 1121 |
+
# =========================================================================
|
| 1122 |
+
# CREDENTIAL MANAGEMENT METHODS
|
| 1123 |
+
# =========================================================================
|
| 1124 |
+
|
| 1125 |
+
def _get_provider_file_prefix(self) -> str:
|
| 1126 |
+
"""Return the file prefix for Qwen credentials."""
|
| 1127 |
+
return "qwen_code"
|
| 1128 |
+
|
| 1129 |
+
def _get_oauth_base_dir(self) -> Path:
|
| 1130 |
+
"""Get the base directory for OAuth credential files."""
|
| 1131 |
+
return Path.cwd() / "oauth_creds"
|
| 1132 |
+
|
| 1133 |
+
def _find_existing_credential_by_email(
|
| 1134 |
+
self, email: str, base_dir: Optional[Path] = None
|
| 1135 |
+
) -> Optional[Path]:
|
| 1136 |
+
"""Find an existing credential file for the given email."""
|
| 1137 |
+
if base_dir is None:
|
| 1138 |
+
base_dir = self._get_oauth_base_dir()
|
| 1139 |
+
|
| 1140 |
+
prefix = self._get_provider_file_prefix()
|
| 1141 |
+
pattern = str(base_dir / f"{prefix}_oauth_*.json")
|
| 1142 |
+
|
| 1143 |
+
for cred_file in glob(pattern):
|
| 1144 |
+
try:
|
| 1145 |
+
with open(cred_file, "r") as f:
|
| 1146 |
+
creds = json.load(f)
|
| 1147 |
+
existing_email = creds.get("_proxy_metadata", {}).get("email")
|
| 1148 |
+
if existing_email == email:
|
| 1149 |
+
return Path(cred_file)
|
| 1150 |
+
except (json.JSONDecodeError, IOError) as e:
|
| 1151 |
+
lib_logger.debug(f"Could not read credential file {cred_file}: {e}")
|
| 1152 |
+
continue
|
| 1153 |
+
|
| 1154 |
+
return None
|
| 1155 |
+
|
| 1156 |
+
def _get_next_credential_number(self, base_dir: Optional[Path] = None) -> int:
|
| 1157 |
+
"""Get the next available credential number."""
|
| 1158 |
+
if base_dir is None:
|
| 1159 |
+
base_dir = self._get_oauth_base_dir()
|
| 1160 |
+
|
| 1161 |
+
prefix = self._get_provider_file_prefix()
|
| 1162 |
+
pattern = str(base_dir / f"{prefix}_oauth_*.json")
|
| 1163 |
+
|
| 1164 |
+
existing_numbers = []
|
| 1165 |
+
for cred_file in glob(pattern):
|
| 1166 |
+
match = re.search(r"_oauth_(\d+)\.json$", cred_file)
|
| 1167 |
+
if match:
|
| 1168 |
+
existing_numbers.append(int(match.group(1)))
|
| 1169 |
+
|
| 1170 |
+
if not existing_numbers:
|
| 1171 |
+
return 1
|
| 1172 |
+
return max(existing_numbers) + 1
|
| 1173 |
+
|
| 1174 |
+
def _build_credential_path(
|
| 1175 |
+
self, base_dir: Optional[Path] = None, number: Optional[int] = None
|
| 1176 |
+
) -> Path:
|
| 1177 |
+
"""Build a path for a new credential file."""
|
| 1178 |
+
if base_dir is None:
|
| 1179 |
+
base_dir = self._get_oauth_base_dir()
|
| 1180 |
+
|
| 1181 |
+
if number is None:
|
| 1182 |
+
number = self._get_next_credential_number(base_dir)
|
| 1183 |
+
|
| 1184 |
+
prefix = self._get_provider_file_prefix()
|
| 1185 |
+
filename = f"{prefix}_oauth_{number}.json"
|
| 1186 |
+
return base_dir / filename
|
| 1187 |
+
|
| 1188 |
+
async def setup_credential(
|
| 1189 |
+
self, base_dir: Optional[Path] = None
|
| 1190 |
+
) -> QwenCredentialSetupResult:
|
| 1191 |
+
"""
|
| 1192 |
+
Complete credential setup flow: OAuth -> save.
|
| 1193 |
+
|
| 1194 |
+
This is the main entry point for setting up new credentials.
|
| 1195 |
+
"""
|
| 1196 |
+
if base_dir is None:
|
| 1197 |
+
base_dir = self._get_oauth_base_dir()
|
| 1198 |
+
|
| 1199 |
+
# Ensure directory exists
|
| 1200 |
+
base_dir.mkdir(exist_ok=True)
|
| 1201 |
+
|
| 1202 |
+
try:
|
| 1203 |
+
# Step 1: Perform OAuth authentication
|
| 1204 |
+
temp_creds = {
|
| 1205 |
+
"_proxy_metadata": {"display_name": "new Qwen Code credential"}
|
| 1206 |
+
}
|
| 1207 |
+
new_creds = await self.initialize_token(temp_creds)
|
| 1208 |
+
|
| 1209 |
+
# Step 2: Get user info for deduplication
|
| 1210 |
+
email = new_creds.get("_proxy_metadata", {}).get("email")
|
| 1211 |
+
|
| 1212 |
+
if not email:
|
| 1213 |
+
return QwenCredentialSetupResult(
|
| 1214 |
+
success=False, error="Could not retrieve email from OAuth response"
|
| 1215 |
+
)
|
| 1216 |
+
|
| 1217 |
+
# Step 3: Check for existing credential with same email
|
| 1218 |
+
existing_path = self._find_existing_credential_by_email(email, base_dir)
|
| 1219 |
+
is_update = existing_path is not None
|
| 1220 |
+
|
| 1221 |
+
if is_update:
|
| 1222 |
+
file_path = existing_path
|
| 1223 |
+
lib_logger.info(
|
| 1224 |
+
f"Found existing credential for {email}, updating {file_path.name}"
|
| 1225 |
+
)
|
| 1226 |
+
else:
|
| 1227 |
+
file_path = self._build_credential_path(base_dir)
|
| 1228 |
+
lib_logger.info(
|
| 1229 |
+
f"Creating new credential for {email} at {file_path.name}"
|
| 1230 |
+
)
|
| 1231 |
+
|
| 1232 |
+
# Step 4: Save credentials to file
|
| 1233 |
+
await self._save_credentials(str(file_path), new_creds)
|
| 1234 |
+
|
| 1235 |
+
return QwenCredentialSetupResult(
|
| 1236 |
+
success=True,
|
| 1237 |
+
file_path=str(file_path),
|
| 1238 |
+
email=email,
|
| 1239 |
+
is_update=is_update,
|
| 1240 |
+
credentials=new_creds,
|
| 1241 |
+
)
|
| 1242 |
+
|
| 1243 |
+
except Exception as e:
|
| 1244 |
+
lib_logger.error(f"Credential setup failed: {e}")
|
| 1245 |
+
return QwenCredentialSetupResult(success=False, error=str(e))
|
| 1246 |
+
|
| 1247 |
+
def build_env_lines(self, creds: Dict[str, Any], cred_number: int) -> List[str]:
|
| 1248 |
+
"""Generate .env file lines for a Qwen credential."""
|
| 1249 |
+
email = creds.get("_proxy_metadata", {}).get("email", "unknown")
|
| 1250 |
+
prefix = f"QWEN_CODE_{cred_number}"
|
| 1251 |
+
|
| 1252 |
+
lines = [
|
| 1253 |
+
f"# QWEN_CODE Credential #{cred_number} for: {email}",
|
| 1254 |
+
f"# Exported from: qwen_code_oauth_{cred_number}.json",
|
| 1255 |
+
f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
|
| 1256 |
+
"#",
|
| 1257 |
+
"# To combine multiple credentials into one .env file, copy these lines",
|
| 1258 |
+
"# and ensure each credential has a unique number (1, 2, 3, etc.)",
|
| 1259 |
+
"",
|
| 1260 |
+
f"{prefix}_ACCESS_TOKEN={creds.get('access_token', '')}",
|
| 1261 |
+
f"{prefix}_REFRESH_TOKEN={creds.get('refresh_token', '')}",
|
| 1262 |
+
f"{prefix}_EXPIRY_DATE={creds.get('expiry_date', 0)}",
|
| 1263 |
+
f"{prefix}_RESOURCE_URL={creds.get('resource_url', 'https://portal.qwen.ai/v1')}",
|
| 1264 |
+
f"{prefix}_EMAIL={email}",
|
| 1265 |
+
]
|
| 1266 |
+
|
| 1267 |
+
return lines
|
| 1268 |
+
|
| 1269 |
+
def export_credential_to_env(
|
| 1270 |
+
self, credential_path: str, output_dir: Optional[Path] = None
|
| 1271 |
+
) -> Optional[str]:
|
| 1272 |
+
"""Export a credential file to .env format."""
|
| 1273 |
+
try:
|
| 1274 |
+
cred_path = Path(credential_path)
|
| 1275 |
+
|
| 1276 |
+
# Load credential
|
| 1277 |
+
with open(cred_path, "r") as f:
|
| 1278 |
+
creds = json.load(f)
|
| 1279 |
+
|
| 1280 |
+
# Extract metadata
|
| 1281 |
+
email = creds.get("_proxy_metadata", {}).get("email", "unknown")
|
| 1282 |
+
|
| 1283 |
+
# Get credential number from filename
|
| 1284 |
+
match = re.search(r"_oauth_(\d+)\.json$", cred_path.name)
|
| 1285 |
+
cred_number = int(match.group(1)) if match else 1
|
| 1286 |
+
|
| 1287 |
+
# Build output path
|
| 1288 |
+
if output_dir is None:
|
| 1289 |
+
output_dir = cred_path.parent
|
| 1290 |
+
|
| 1291 |
+
safe_email = email.replace("@", "_at_").replace(".", "_")
|
| 1292 |
+
env_filename = f"qwen_code_{cred_number}_{safe_email}.env"
|
| 1293 |
+
env_path = output_dir / env_filename
|
| 1294 |
+
|
| 1295 |
+
# Build and write content
|
| 1296 |
+
env_lines = self.build_env_lines(creds, cred_number)
|
| 1297 |
+
with open(env_path, "w") as f:
|
| 1298 |
+
f.write("\n".join(env_lines))
|
| 1299 |
+
|
| 1300 |
+
lib_logger.info(f"Exported credential to {env_path}")
|
| 1301 |
+
return str(env_path)
|
| 1302 |
+
|
| 1303 |
+
except Exception as e:
|
| 1304 |
+
lib_logger.error(f"Failed to export credential: {e}")
|
| 1305 |
+
return None
|
| 1306 |
+
|
| 1307 |
+
def list_credentials(self, base_dir: Optional[Path] = None) -> List[Dict[str, Any]]:
|
| 1308 |
+
"""List all Qwen credential files."""
|
| 1309 |
+
if base_dir is None:
|
| 1310 |
+
base_dir = self._get_oauth_base_dir()
|
| 1311 |
+
|
| 1312 |
+
prefix = self._get_provider_file_prefix()
|
| 1313 |
+
pattern = str(base_dir / f"{prefix}_oauth_*.json")
|
| 1314 |
+
|
| 1315 |
+
credentials = []
|
| 1316 |
+
for cred_file in sorted(glob(pattern)):
|
| 1317 |
+
try:
|
| 1318 |
+
with open(cred_file, "r") as f:
|
| 1319 |
+
creds = json.load(f)
|
| 1320 |
+
|
| 1321 |
+
metadata = creds.get("_proxy_metadata", {})
|
| 1322 |
+
|
| 1323 |
+
# Extract number from filename
|
| 1324 |
+
match = re.search(r"_oauth_(\d+)\.json$", cred_file)
|
| 1325 |
+
number = int(match.group(1)) if match else 0
|
| 1326 |
+
|
| 1327 |
+
credentials.append(
|
| 1328 |
+
{
|
| 1329 |
+
"file_path": cred_file,
|
| 1330 |
+
"email": metadata.get("email", "unknown"),
|
| 1331 |
+
"number": number,
|
| 1332 |
+
}
|
| 1333 |
+
)
|
| 1334 |
+
except Exception as e:
|
| 1335 |
+
lib_logger.debug(f"Could not read credential file {cred_file}: {e}")
|
| 1336 |
+
continue
|
| 1337 |
+
|
| 1338 |
+
return credentials
|
| 1339 |
+
|
| 1340 |
+
def delete_credential(self, credential_path: str) -> bool:
|
| 1341 |
+
"""Delete a credential file."""
|
| 1342 |
+
try:
|
| 1343 |
+
cred_path = Path(credential_path)
|
| 1344 |
+
|
| 1345 |
+
# Validate that it's one of our credential files
|
| 1346 |
+
prefix = self._get_provider_file_prefix()
|
| 1347 |
+
if not cred_path.name.startswith(f"{prefix}_oauth_"):
|
| 1348 |
+
lib_logger.error(
|
| 1349 |
+
f"File {cred_path.name} does not appear to be a Qwen Code credential"
|
| 1350 |
+
)
|
| 1351 |
+
return False
|
| 1352 |
+
|
| 1353 |
+
if not cred_path.exists():
|
| 1354 |
+
lib_logger.warning(f"Credential file does not exist: {credential_path}")
|
| 1355 |
+
return False
|
| 1356 |
+
|
| 1357 |
+
# Remove from cache if present
|
| 1358 |
+
self._credentials_cache.pop(credential_path, None)
|
| 1359 |
+
|
| 1360 |
+
# Delete the file
|
| 1361 |
+
cred_path.unlink()
|
| 1362 |
+
lib_logger.info(f"Deleted credential file: {credential_path}")
|
| 1363 |
+
return True
|
| 1364 |
+
|
| 1365 |
+
except Exception as e:
|
| 1366 |
+
lib_logger.error(f"Failed to delete credential: {e}")
|
| 1367 |
+
return False
|
src/rotator_library/providers/qwen_code_provider.py
CHANGED
|
@@ -10,19 +10,27 @@ from typing import Union, AsyncGenerator, List, Dict, Any
|
|
| 10 |
from .provider_interface import ProviderInterface
|
| 11 |
from .qwen_auth_base import QwenAuthBase
|
| 12 |
from ..model_definitions import ModelDefinitions
|
|
|
|
|
|
|
| 13 |
import litellm
|
| 14 |
from litellm.exceptions import RateLimitError, AuthenticationError
|
| 15 |
from pathlib import Path
|
| 16 |
import uuid
|
| 17 |
from datetime import datetime
|
| 18 |
|
| 19 |
-
lib_logger = logging.getLogger(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
LOGS_DIR = Path(__file__).resolve().parent.parent.parent.parent / "logs"
|
| 22 |
-
QWEN_CODE_LOGS_DIR = LOGS_DIR / "qwen_code_logs"
|
| 23 |
|
| 24 |
class _QwenCodeFileLogger:
|
| 25 |
"""A simple file logger for a single Qwen Code transaction."""
|
|
|
|
| 26 |
def __init__(self, model_name: str, enabled: bool = True):
|
| 27 |
self.enabled = enabled
|
| 28 |
if not self.enabled:
|
|
@@ -31,8 +39,10 @@ class _QwenCodeFileLogger:
|
|
| 31 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 32 |
request_id = str(uuid.uuid4())
|
| 33 |
# Sanitize model name for directory
|
| 34 |
-
safe_model_name = model_name.replace(
|
| 35 |
-
self.log_dir =
|
|
|
|
|
|
|
| 36 |
try:
|
| 37 |
self.log_dir.mkdir(parents=True, exist_ok=True)
|
| 38 |
except Exception as e:
|
|
@@ -41,25 +51,32 @@ class _QwenCodeFileLogger:
|
|
| 41 |
|
| 42 |
def log_request(self, payload: Dict[str, Any]):
|
| 43 |
"""Logs the request payload sent to Qwen Code."""
|
| 44 |
-
if not self.enabled:
|
|
|
|
| 45 |
try:
|
| 46 |
-
with open(
|
|
|
|
|
|
|
| 47 |
json.dump(payload, f, indent=2, ensure_ascii=False)
|
| 48 |
except Exception as e:
|
| 49 |
lib_logger.error(f"_QwenCodeFileLogger: Failed to write request: {e}")
|
| 50 |
|
| 51 |
def log_response_chunk(self, chunk: str):
|
| 52 |
"""Logs a raw chunk from the Qwen Code response stream."""
|
| 53 |
-
if not self.enabled:
|
|
|
|
| 54 |
try:
|
| 55 |
with open(self.log_dir / "response_stream.log", "a", encoding="utf-8") as f:
|
| 56 |
f.write(chunk + "\n")
|
| 57 |
except Exception as e:
|
| 58 |
-
lib_logger.error(
|
|
|
|
|
|
|
| 59 |
|
| 60 |
def log_error(self, error_message: str):
|
| 61 |
"""Logs an error message."""
|
| 62 |
-
if not self.enabled:
|
|
|
|
| 63 |
try:
|
| 64 |
with open(self.log_dir / "error.log", "a", encoding="utf-8") as f:
|
| 65 |
f.write(f"[{datetime.utcnow().isoformat()}] {error_message}\n")
|
|
@@ -68,28 +85,41 @@ class _QwenCodeFileLogger:
|
|
| 68 |
|
| 69 |
def log_final_response(self, response_data: Dict[str, Any]):
|
| 70 |
"""Logs the final, reassembled response."""
|
| 71 |
-
if not self.enabled:
|
|
|
|
| 72 |
try:
|
| 73 |
with open(self.log_dir / "final_response.json", "w", encoding="utf-8") as f:
|
| 74 |
json.dump(response_data, f, indent=2, ensure_ascii=False)
|
| 75 |
except Exception as e:
|
| 76 |
-
lib_logger.error(
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
HARDCODED_MODELS = [
|
| 79 |
-
"qwen3-coder-plus",
|
| 80 |
-
"qwen3-coder-flash"
|
| 81 |
-
]
|
| 82 |
|
| 83 |
# OpenAI-compatible parameters supported by Qwen Code API
|
| 84 |
SUPPORTED_PARAMS = {
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
}
|
| 89 |
|
|
|
|
| 90 |
class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
| 91 |
skip_cost_calculation = True
|
| 92 |
-
REASONING_START_MARKER =
|
| 93 |
|
| 94 |
def __init__(self):
|
| 95 |
super().__init__()
|
|
@@ -111,7 +141,9 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
|
| 111 |
Validates OAuth credentials if applicable.
|
| 112 |
"""
|
| 113 |
models = []
|
| 114 |
-
env_var_ids =
|
|
|
|
|
|
|
| 115 |
|
| 116 |
def extract_model_id(item) -> str:
|
| 117 |
"""Extract model ID from various formats (dict, string with/without provider prefix)."""
|
|
@@ -137,7 +169,9 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
|
| 137 |
# Track the ID to prevent hardcoded/dynamic duplicates
|
| 138 |
if model_id:
|
| 139 |
env_var_ids.add(model_id)
|
| 140 |
-
lib_logger.info(
|
|
|
|
|
|
|
| 141 |
|
| 142 |
# Source 2: Add hardcoded models (only if ID not already in env vars)
|
| 143 |
for model_id in HARDCODED_MODELS:
|
|
@@ -155,14 +189,17 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
|
| 155 |
models_url = f"{api_base.rstrip('/')}/v1/models"
|
| 156 |
|
| 157 |
response = await client.get(
|
| 158 |
-
models_url,
|
| 159 |
-
headers={"Authorization": f"Bearer {access_token}"}
|
| 160 |
)
|
| 161 |
response.raise_for_status()
|
| 162 |
|
| 163 |
dynamic_data = response.json()
|
| 164 |
# Handle both {data: [...]} and direct [...] formats
|
| 165 |
-
model_list =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
dynamic_count = 0
|
| 168 |
for model in model_list:
|
|
@@ -173,7 +210,9 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
|
| 173 |
dynamic_count += 1
|
| 174 |
|
| 175 |
if dynamic_count > 0:
|
| 176 |
-
lib_logger.debug(
|
|
|
|
|
|
|
| 177 |
|
| 178 |
except Exception as e:
|
| 179 |
# Silently ignore dynamic discovery errors
|
|
@@ -238,10 +277,10 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
|
| 238 |
payload = {k: v for k, v in kwargs.items() if k in SUPPORTED_PARAMS}
|
| 239 |
|
| 240 |
# Always force streaming for internal processing
|
| 241 |
-
payload[
|
| 242 |
|
| 243 |
# Always include usage data in stream
|
| 244 |
-
payload[
|
| 245 |
|
| 246 |
# Handle tool schema cleaning
|
| 247 |
if "tools" in payload and payload["tools"]:
|
|
@@ -250,22 +289,26 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
|
| 250 |
elif not payload.get("tools"):
|
| 251 |
# Per Qwen Code API bug (see: https://github.com/qianwen-team/flash-dance/issues/2),
|
| 252 |
# injecting a dummy tool prevents stream corruption when no tools are provided
|
| 253 |
-
payload["tools"] = [
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
"
|
| 257 |
-
|
| 258 |
-
|
|
|
|
|
|
|
| 259 |
}
|
| 260 |
-
|
| 261 |
-
lib_logger.debug(
|
|
|
|
|
|
|
| 262 |
|
| 263 |
return payload
|
| 264 |
|
| 265 |
def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str):
|
| 266 |
"""
|
| 267 |
Converts a raw Qwen SSE chunk to an OpenAI-compatible chunk.
|
| 268 |
-
|
| 269 |
CRITICAL FIX: Handle chunks with BOTH usage and choices (final chunk)
|
| 270 |
without early return to ensure finish_reason is properly processed.
|
| 271 |
"""
|
|
@@ -287,32 +330,42 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
|
| 287 |
|
| 288 |
# Yield the choice chunk first (contains finish_reason)
|
| 289 |
yield {
|
| 290 |
-
"choices": [
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
}
|
| 294 |
# Then yield the usage chunk
|
| 295 |
yield {
|
| 296 |
-
"choices": [],
|
| 297 |
-
"
|
|
|
|
|
|
|
|
|
|
| 298 |
"usage": {
|
| 299 |
"prompt_tokens": usage_data.get("prompt_tokens", 0),
|
| 300 |
"completion_tokens": usage_data.get("completion_tokens", 0),
|
| 301 |
"total_tokens": usage_data.get("total_tokens", 0),
|
| 302 |
-
}
|
| 303 |
}
|
| 304 |
return
|
| 305 |
|
| 306 |
# Handle usage-only chunks
|
| 307 |
if usage_data:
|
| 308 |
yield {
|
| 309 |
-
"choices": [],
|
| 310 |
-
"
|
|
|
|
|
|
|
|
|
|
| 311 |
"usage": {
|
| 312 |
"prompt_tokens": usage_data.get("prompt_tokens", 0),
|
| 313 |
"completion_tokens": usage_data.get("completion_tokens", 0),
|
| 314 |
"total_tokens": usage_data.get("total_tokens", 0),
|
| 315 |
-
}
|
| 316 |
}
|
| 317 |
return
|
| 318 |
|
|
@@ -327,35 +380,52 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
|
| 327 |
# Handle <think> tags for reasoning content
|
| 328 |
content = delta.get("content")
|
| 329 |
if content and ("<think>" in content or "</think>" in content):
|
| 330 |
-
parts =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
for part in parts:
|
| 332 |
-
if not part:
|
| 333 |
-
|
|
|
|
| 334 |
new_delta = {}
|
| 335 |
if part.startswith(self.REASONING_START_MARKER):
|
| 336 |
-
new_delta[
|
|
|
|
|
|
|
| 337 |
elif part.startswith(f"/{self.REASONING_START_MARKER}"):
|
| 338 |
continue
|
| 339 |
else:
|
| 340 |
-
new_delta[
|
| 341 |
-
|
| 342 |
yield {
|
| 343 |
-
"choices": [
|
| 344 |
-
|
| 345 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
}
|
| 347 |
else:
|
| 348 |
# Standard content chunk
|
| 349 |
yield {
|
| 350 |
-
"choices": [
|
| 351 |
-
|
| 352 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
}
|
| 354 |
|
| 355 |
-
def _stream_to_completion_response(
|
|
|
|
|
|
|
| 356 |
"""
|
| 357 |
Manually reassembles streaming chunks into a complete response.
|
| 358 |
-
|
| 359 |
Key improvements:
|
| 360 |
- Determines finish_reason based on accumulated state (tool_calls vs stop)
|
| 361 |
- Properly initializes tool_calls with type field
|
|
@@ -368,14 +438,16 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
|
| 368 |
final_message = {"role": "assistant"}
|
| 369 |
aggregated_tool_calls = {}
|
| 370 |
usage_data = None
|
| 371 |
-
chunk_finish_reason =
|
|
|
|
|
|
|
| 372 |
|
| 373 |
# Get the first chunk for basic response metadata
|
| 374 |
first_chunk = chunks[0]
|
| 375 |
|
| 376 |
# Process each chunk to aggregate content
|
| 377 |
for chunk in chunks:
|
| 378 |
-
if not hasattr(chunk,
|
| 379 |
continue
|
| 380 |
|
| 381 |
choice = chunk.choices[0]
|
|
@@ -399,25 +471,48 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
|
| 399 |
index = tc_chunk.get("index", 0)
|
| 400 |
if index not in aggregated_tool_calls:
|
| 401 |
# Initialize with type field for OpenAI compatibility
|
| 402 |
-
aggregated_tool_calls[index] = {
|
|
|
|
|
|
|
|
|
|
| 403 |
if "id" in tc_chunk:
|
| 404 |
aggregated_tool_calls[index]["id"] = tc_chunk["id"]
|
| 405 |
if "type" in tc_chunk:
|
| 406 |
aggregated_tool_calls[index]["type"] = tc_chunk["type"]
|
| 407 |
if "function" in tc_chunk:
|
| 408 |
-
if
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
|
| 413 |
# Aggregate function calls (legacy format)
|
| 414 |
if "function_call" in delta and delta["function_call"] is not None:
|
| 415 |
if "function_call" not in final_message:
|
| 416 |
final_message["function_call"] = {"name": "", "arguments": ""}
|
| 417 |
-
if
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
|
| 422 |
# Track finish_reason from chunks (for reference only)
|
| 423 |
if choice.get("finish_reason"):
|
|
@@ -425,7 +520,7 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
|
| 425 |
|
| 426 |
# Handle usage data from the last chunk that has it
|
| 427 |
for chunk in reversed(chunks):
|
| 428 |
-
if hasattr(chunk,
|
| 429 |
usage_data = chunk.usage
|
| 430 |
break
|
| 431 |
|
|
@@ -451,7 +546,7 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
|
| 451 |
final_choice = {
|
| 452 |
"index": 0,
|
| 453 |
"message": final_message,
|
| 454 |
-
"finish_reason": finish_reason
|
| 455 |
}
|
| 456 |
|
| 457 |
# Create the final ModelResponse
|
|
@@ -461,20 +556,21 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
|
| 461 |
"created": first_chunk.created,
|
| 462 |
"model": first_chunk.model,
|
| 463 |
"choices": [final_choice],
|
| 464 |
-
"usage": usage_data
|
| 465 |
}
|
| 466 |
|
| 467 |
return litellm.ModelResponse(**final_response_data)
|
| 468 |
|
| 469 |
-
async def acompletion(
|
|
|
|
|
|
|
| 470 |
credential_path = kwargs.pop("credential_identifier")
|
| 471 |
enable_request_logging = kwargs.pop("enable_request_logging", False)
|
| 472 |
model = kwargs["model"]
|
| 473 |
|
| 474 |
# Create dedicated file logger for this request
|
| 475 |
file_logger = _QwenCodeFileLogger(
|
| 476 |
-
model_name=model,
|
| 477 |
-
enabled=enable_request_logging
|
| 478 |
)
|
| 479 |
|
| 480 |
async def make_request():
|
|
@@ -482,8 +578,8 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
|
| 482 |
api_base, access_token = await self.get_api_details(credential_path)
|
| 483 |
|
| 484 |
# Strip provider prefix from model name (e.g., "qwen_code/qwen3-coder-plus" -> "qwen3-coder-plus")
|
| 485 |
-
model_name = model.split(
|
| 486 |
-
kwargs_with_stripped_model = {**kwargs,
|
| 487 |
|
| 488 |
# Build clean payload with only supported parameters
|
| 489 |
payload = self._build_request_payload(**kwargs_with_stripped_model)
|
|
@@ -503,7 +599,13 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
|
| 503 |
file_logger.log_request(payload)
|
| 504 |
lib_logger.debug(f"Qwen Code Request URL: {url}")
|
| 505 |
|
| 506 |
-
return client.stream(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
|
| 508 |
async def stream_handler(response_stream, attempt=1):
|
| 509 |
"""Handles the streaming response and converts chunks."""
|
|
@@ -512,11 +614,17 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
|
| 512 |
# Check for HTTP errors before processing stream
|
| 513 |
if response.status_code >= 400:
|
| 514 |
error_text = await response.aread()
|
| 515 |
-
error_text =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 516 |
|
| 517 |
# Handle 401: Force token refresh and retry once
|
| 518 |
if response.status_code == 401 and attempt == 1:
|
| 519 |
-
lib_logger.warning(
|
|
|
|
|
|
|
| 520 |
await self._refresh_token(credential_path, force=True)
|
| 521 |
retry_stream = await make_request()
|
| 522 |
async for chunk in stream_handler(retry_stream, attempt=2):
|
|
@@ -524,12 +632,15 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
|
| 524 |
return
|
| 525 |
|
| 526 |
# Handle 429: Rate limit
|
| 527 |
-
elif
|
|
|
|
|
|
|
|
|
|
| 528 |
raise RateLimitError(
|
| 529 |
f"Qwen Code rate limit exceeded: {error_text}",
|
| 530 |
llm_provider="qwen_code",
|
| 531 |
model=model,
|
| 532 |
-
response=response
|
| 533 |
)
|
| 534 |
|
| 535 |
# Handle other errors
|
|
@@ -539,28 +650,34 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
|
| 539 |
raise httpx.HTTPStatusError(
|
| 540 |
f"HTTP {response.status_code}: {error_text}",
|
| 541 |
request=response.request,
|
| 542 |
-
response=response
|
| 543 |
)
|
| 544 |
|
| 545 |
# Process successful streaming response
|
| 546 |
async for line in response.aiter_lines():
|
| 547 |
file_logger.log_response_chunk(line)
|
| 548 |
-
if line.startswith(
|
| 549 |
data_str = line[6:]
|
| 550 |
if data_str == "[DONE]":
|
| 551 |
break
|
| 552 |
try:
|
| 553 |
chunk = json.loads(data_str)
|
| 554 |
-
for openai_chunk in self._convert_chunk_to_openai(
|
|
|
|
|
|
|
| 555 |
yield litellm.ModelResponse(**openai_chunk)
|
| 556 |
except json.JSONDecodeError:
|
| 557 |
-
lib_logger.warning(
|
|
|
|
|
|
|
| 558 |
|
| 559 |
except httpx.HTTPStatusError:
|
| 560 |
raise # Re-raise HTTP errors we already handled
|
| 561 |
except Exception as e:
|
| 562 |
file_logger.log_error(f"Error during Qwen Code stream processing: {e}")
|
| 563 |
-
lib_logger.error(
|
|
|
|
|
|
|
| 564 |
raise
|
| 565 |
|
| 566 |
async def logging_stream_wrapper():
|
|
@@ -578,7 +695,9 @@ class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
|
| 578 |
if kwargs.get("stream"):
|
| 579 |
return logging_stream_wrapper()
|
| 580 |
else:
|
|
|
|
| 581 |
async def non_stream_wrapper():
|
| 582 |
chunks = [chunk async for chunk in logging_stream_wrapper()]
|
| 583 |
return self._stream_to_completion_response(chunks)
|
| 584 |
-
|
|
|
|
|
|
| 10 |
from .provider_interface import ProviderInterface
|
| 11 |
from .qwen_auth_base import QwenAuthBase
|
| 12 |
from ..model_definitions import ModelDefinitions
|
| 13 |
+
from ..timeout_config import TimeoutConfig
|
| 14 |
+
from ..utils.paths import get_logs_dir
|
| 15 |
import litellm
|
| 16 |
from litellm.exceptions import RateLimitError, AuthenticationError
|
| 17 |
from pathlib import Path
|
| 18 |
import uuid
|
| 19 |
from datetime import datetime
|
| 20 |
|
| 21 |
+
lib_logger = logging.getLogger("rotator_library")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _get_qwen_code_logs_dir() -> Path:
|
| 25 |
+
"""Get the Qwen Code logs directory."""
|
| 26 |
+
logs_dir = get_logs_dir() / "qwen_code_logs"
|
| 27 |
+
logs_dir.mkdir(parents=True, exist_ok=True)
|
| 28 |
+
return logs_dir
|
| 29 |
|
|
|
|
|
|
|
| 30 |
|
| 31 |
class _QwenCodeFileLogger:
|
| 32 |
"""A simple file logger for a single Qwen Code transaction."""
|
| 33 |
+
|
| 34 |
def __init__(self, model_name: str, enabled: bool = True):
|
| 35 |
self.enabled = enabled
|
| 36 |
if not self.enabled:
|
|
|
|
| 39 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 40 |
request_id = str(uuid.uuid4())
|
| 41 |
# Sanitize model name for directory
|
| 42 |
+
safe_model_name = model_name.replace("/", "_").replace(":", "_")
|
| 43 |
+
self.log_dir = (
|
| 44 |
+
_get_qwen_code_logs_dir() / f"{timestamp}_{safe_model_name}_{request_id}"
|
| 45 |
+
)
|
| 46 |
try:
|
| 47 |
self.log_dir.mkdir(parents=True, exist_ok=True)
|
| 48 |
except Exception as e:
|
|
|
|
| 51 |
|
| 52 |
def log_request(self, payload: Dict[str, Any]):
|
| 53 |
"""Logs the request payload sent to Qwen Code."""
|
| 54 |
+
if not self.enabled:
|
| 55 |
+
return
|
| 56 |
try:
|
| 57 |
+
with open(
|
| 58 |
+
self.log_dir / "request_payload.json", "w", encoding="utf-8"
|
| 59 |
+
) as f:
|
| 60 |
json.dump(payload, f, indent=2, ensure_ascii=False)
|
| 61 |
except Exception as e:
|
| 62 |
lib_logger.error(f"_QwenCodeFileLogger: Failed to write request: {e}")
|
| 63 |
|
| 64 |
def log_response_chunk(self, chunk: str):
|
| 65 |
"""Logs a raw chunk from the Qwen Code response stream."""
|
| 66 |
+
if not self.enabled:
|
| 67 |
+
return
|
| 68 |
try:
|
| 69 |
with open(self.log_dir / "response_stream.log", "a", encoding="utf-8") as f:
|
| 70 |
f.write(chunk + "\n")
|
| 71 |
except Exception as e:
|
| 72 |
+
lib_logger.error(
|
| 73 |
+
f"_QwenCodeFileLogger: Failed to write response chunk: {e}"
|
| 74 |
+
)
|
| 75 |
|
| 76 |
def log_error(self, error_message: str):
|
| 77 |
"""Logs an error message."""
|
| 78 |
+
if not self.enabled:
|
| 79 |
+
return
|
| 80 |
try:
|
| 81 |
with open(self.log_dir / "error.log", "a", encoding="utf-8") as f:
|
| 82 |
f.write(f"[{datetime.utcnow().isoformat()}] {error_message}\n")
|
|
|
|
| 85 |
|
| 86 |
def log_final_response(self, response_data: Dict[str, Any]):
|
| 87 |
"""Logs the final, reassembled response."""
|
| 88 |
+
if not self.enabled:
|
| 89 |
+
return
|
| 90 |
try:
|
| 91 |
with open(self.log_dir / "final_response.json", "w", encoding="utf-8") as f:
|
| 92 |
json.dump(response_data, f, indent=2, ensure_ascii=False)
|
| 93 |
except Exception as e:
|
| 94 |
+
lib_logger.error(
|
| 95 |
+
f"_QwenCodeFileLogger: Failed to write final response: {e}"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
|
| 99 |
+
HARDCODED_MODELS = ["qwen3-coder-plus", "qwen3-coder-flash"]
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
# OpenAI-compatible parameters supported by Qwen Code API
|
| 102 |
SUPPORTED_PARAMS = {
|
| 103 |
+
"model",
|
| 104 |
+
"messages",
|
| 105 |
+
"temperature",
|
| 106 |
+
"top_p",
|
| 107 |
+
"max_tokens",
|
| 108 |
+
"stream",
|
| 109 |
+
"tools",
|
| 110 |
+
"tool_choice",
|
| 111 |
+
"presence_penalty",
|
| 112 |
+
"frequency_penalty",
|
| 113 |
+
"n",
|
| 114 |
+
"stop",
|
| 115 |
+
"seed",
|
| 116 |
+
"response_format",
|
| 117 |
}
|
| 118 |
|
| 119 |
+
|
| 120 |
class QwenCodeProvider(QwenAuthBase, ProviderInterface):
|
| 121 |
skip_cost_calculation = True
|
| 122 |
+
REASONING_START_MARKER = "THINK||"
|
| 123 |
|
| 124 |
def __init__(self):
|
| 125 |
super().__init__()
|
|
|
|
| 141 |
Validates OAuth credentials if applicable.
|
| 142 |
"""
|
| 143 |
models = []
|
| 144 |
+
env_var_ids = (
|
| 145 |
+
set()
|
| 146 |
+
) # Track IDs from env vars to prevent hardcoded/dynamic duplicates
|
| 147 |
|
| 148 |
def extract_model_id(item) -> str:
|
| 149 |
"""Extract model ID from various formats (dict, string with/without provider prefix)."""
|
|
|
|
| 169 |
# Track the ID to prevent hardcoded/dynamic duplicates
|
| 170 |
if model_id:
|
| 171 |
env_var_ids.add(model_id)
|
| 172 |
+
lib_logger.info(
|
| 173 |
+
f"Loaded {len(static_models)} static models for qwen_code from environment variables"
|
| 174 |
+
)
|
| 175 |
|
| 176 |
# Source 2: Add hardcoded models (only if ID not already in env vars)
|
| 177 |
for model_id in HARDCODED_MODELS:
|
|
|
|
| 189 |
models_url = f"{api_base.rstrip('/')}/v1/models"
|
| 190 |
|
| 191 |
response = await client.get(
|
| 192 |
+
models_url, headers={"Authorization": f"Bearer {access_token}"}
|
|
|
|
| 193 |
)
|
| 194 |
response.raise_for_status()
|
| 195 |
|
| 196 |
dynamic_data = response.json()
|
| 197 |
# Handle both {data: [...]} and direct [...] formats
|
| 198 |
+
model_list = (
|
| 199 |
+
dynamic_data.get("data", dynamic_data)
|
| 200 |
+
if isinstance(dynamic_data, dict)
|
| 201 |
+
else dynamic_data
|
| 202 |
+
)
|
| 203 |
|
| 204 |
dynamic_count = 0
|
| 205 |
for model in model_list:
|
|
|
|
| 210 |
dynamic_count += 1
|
| 211 |
|
| 212 |
if dynamic_count > 0:
|
| 213 |
+
lib_logger.debug(
|
| 214 |
+
f"Discovered {dynamic_count} additional models for qwen_code from API"
|
| 215 |
+
)
|
| 216 |
|
| 217 |
except Exception as e:
|
| 218 |
# Silently ignore dynamic discovery errors
|
|
|
|
| 277 |
payload = {k: v for k, v in kwargs.items() if k in SUPPORTED_PARAMS}
|
| 278 |
|
| 279 |
# Always force streaming for internal processing
|
| 280 |
+
payload["stream"] = True
|
| 281 |
|
| 282 |
# Always include usage data in stream
|
| 283 |
+
payload["stream_options"] = {"include_usage": True}
|
| 284 |
|
| 285 |
# Handle tool schema cleaning
|
| 286 |
if "tools" in payload and payload["tools"]:
|
|
|
|
| 289 |
elif not payload.get("tools"):
|
| 290 |
# Per Qwen Code API bug (see: https://github.com/qianwen-team/flash-dance/issues/2),
|
| 291 |
# injecting a dummy tool prevents stream corruption when no tools are provided
|
| 292 |
+
payload["tools"] = [
|
| 293 |
+
{
|
| 294 |
+
"type": "function",
|
| 295 |
+
"function": {
|
| 296 |
+
"name": "do_not_call_me",
|
| 297 |
+
"description": "Do not call this tool.",
|
| 298 |
+
"parameters": {"type": "object", "properties": {}},
|
| 299 |
+
},
|
| 300 |
}
|
| 301 |
+
]
|
| 302 |
+
lib_logger.debug(
|
| 303 |
+
"Injected dummy tool to prevent Qwen API stream corruption"
|
| 304 |
+
)
|
| 305 |
|
| 306 |
return payload
|
| 307 |
|
| 308 |
def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str):
|
| 309 |
"""
|
| 310 |
Converts a raw Qwen SSE chunk to an OpenAI-compatible chunk.
|
| 311 |
+
|
| 312 |
CRITICAL FIX: Handle chunks with BOTH usage and choices (final chunk)
|
| 313 |
without early return to ensure finish_reason is properly processed.
|
| 314 |
"""
|
|
|
|
| 330 |
|
| 331 |
# Yield the choice chunk first (contains finish_reason)
|
| 332 |
yield {
|
| 333 |
+
"choices": [
|
| 334 |
+
{"index": 0, "delta": delta, "finish_reason": finish_reason}
|
| 335 |
+
],
|
| 336 |
+
"model": model_id,
|
| 337 |
+
"object": "chat.completion.chunk",
|
| 338 |
+
"id": chunk_id,
|
| 339 |
+
"created": chunk_created,
|
| 340 |
}
|
| 341 |
# Then yield the usage chunk
|
| 342 |
yield {
|
| 343 |
+
"choices": [],
|
| 344 |
+
"model": model_id,
|
| 345 |
+
"object": "chat.completion.chunk",
|
| 346 |
+
"id": chunk_id,
|
| 347 |
+
"created": chunk_created,
|
| 348 |
"usage": {
|
| 349 |
"prompt_tokens": usage_data.get("prompt_tokens", 0),
|
| 350 |
"completion_tokens": usage_data.get("completion_tokens", 0),
|
| 351 |
"total_tokens": usage_data.get("total_tokens", 0),
|
| 352 |
+
},
|
| 353 |
}
|
| 354 |
return
|
| 355 |
|
| 356 |
# Handle usage-only chunks
|
| 357 |
if usage_data:
|
| 358 |
yield {
|
| 359 |
+
"choices": [],
|
| 360 |
+
"model": model_id,
|
| 361 |
+
"object": "chat.completion.chunk",
|
| 362 |
+
"id": chunk_id,
|
| 363 |
+
"created": chunk_created,
|
| 364 |
"usage": {
|
| 365 |
"prompt_tokens": usage_data.get("prompt_tokens", 0),
|
| 366 |
"completion_tokens": usage_data.get("completion_tokens", 0),
|
| 367 |
"total_tokens": usage_data.get("total_tokens", 0),
|
| 368 |
+
},
|
| 369 |
}
|
| 370 |
return
|
| 371 |
|
|
|
|
| 380 |
# Handle <think> tags for reasoning content
|
| 381 |
content = delta.get("content")
|
| 382 |
if content and ("<think>" in content or "</think>" in content):
|
| 383 |
+
parts = (
|
| 384 |
+
content.replace("<think>", f"||{self.REASONING_START_MARKER}")
|
| 385 |
+
.replace("</think>", f"||/{self.REASONING_START_MARKER}")
|
| 386 |
+
.split("||")
|
| 387 |
+
)
|
| 388 |
for part in parts:
|
| 389 |
+
if not part:
|
| 390 |
+
continue
|
| 391 |
+
|
| 392 |
new_delta = {}
|
| 393 |
if part.startswith(self.REASONING_START_MARKER):
|
| 394 |
+
new_delta["reasoning_content"] = part.replace(
|
| 395 |
+
self.REASONING_START_MARKER, ""
|
| 396 |
+
)
|
| 397 |
elif part.startswith(f"/{self.REASONING_START_MARKER}"):
|
| 398 |
continue
|
| 399 |
else:
|
| 400 |
+
new_delta["content"] = part
|
| 401 |
+
|
| 402 |
yield {
|
| 403 |
+
"choices": [
|
| 404 |
+
{"index": 0, "delta": new_delta, "finish_reason": None}
|
| 405 |
+
],
|
| 406 |
+
"model": model_id,
|
| 407 |
+
"object": "chat.completion.chunk",
|
| 408 |
+
"id": chunk_id,
|
| 409 |
+
"created": chunk_created,
|
| 410 |
}
|
| 411 |
else:
|
| 412 |
# Standard content chunk
|
| 413 |
yield {
|
| 414 |
+
"choices": [
|
| 415 |
+
{"index": 0, "delta": delta, "finish_reason": finish_reason}
|
| 416 |
+
],
|
| 417 |
+
"model": model_id,
|
| 418 |
+
"object": "chat.completion.chunk",
|
| 419 |
+
"id": chunk_id,
|
| 420 |
+
"created": chunk_created,
|
| 421 |
}
|
| 422 |
|
| 423 |
+
def _stream_to_completion_response(
|
| 424 |
+
self, chunks: List[litellm.ModelResponse]
|
| 425 |
+
) -> litellm.ModelResponse:
|
| 426 |
"""
|
| 427 |
Manually reassembles streaming chunks into a complete response.
|
| 428 |
+
|
| 429 |
Key improvements:
|
| 430 |
- Determines finish_reason based on accumulated state (tool_calls vs stop)
|
| 431 |
- Properly initializes tool_calls with type field
|
|
|
|
| 438 |
final_message = {"role": "assistant"}
|
| 439 |
aggregated_tool_calls = {}
|
| 440 |
usage_data = None
|
| 441 |
+
chunk_finish_reason = (
|
| 442 |
+
None # Track finish_reason from chunks (but we'll override)
|
| 443 |
+
)
|
| 444 |
|
| 445 |
# Get the first chunk for basic response metadata
|
| 446 |
first_chunk = chunks[0]
|
| 447 |
|
| 448 |
# Process each chunk to aggregate content
|
| 449 |
for chunk in chunks:
|
| 450 |
+
if not hasattr(chunk, "choices") or not chunk.choices:
|
| 451 |
continue
|
| 452 |
|
| 453 |
choice = chunk.choices[0]
|
|
|
|
| 471 |
index = tc_chunk.get("index", 0)
|
| 472 |
if index not in aggregated_tool_calls:
|
| 473 |
# Initialize with type field for OpenAI compatibility
|
| 474 |
+
aggregated_tool_calls[index] = {
|
| 475 |
+
"type": "function",
|
| 476 |
+
"function": {"name": "", "arguments": ""},
|
| 477 |
+
}
|
| 478 |
if "id" in tc_chunk:
|
| 479 |
aggregated_tool_calls[index]["id"] = tc_chunk["id"]
|
| 480 |
if "type" in tc_chunk:
|
| 481 |
aggregated_tool_calls[index]["type"] = tc_chunk["type"]
|
| 482 |
if "function" in tc_chunk:
|
| 483 |
+
if (
|
| 484 |
+
"name" in tc_chunk["function"]
|
| 485 |
+
and tc_chunk["function"]["name"] is not None
|
| 486 |
+
):
|
| 487 |
+
aggregated_tool_calls[index]["function"]["name"] += (
|
| 488 |
+
tc_chunk["function"]["name"]
|
| 489 |
+
)
|
| 490 |
+
if (
|
| 491 |
+
"arguments" in tc_chunk["function"]
|
| 492 |
+
and tc_chunk["function"]["arguments"] is not None
|
| 493 |
+
):
|
| 494 |
+
aggregated_tool_calls[index]["function"]["arguments"] += (
|
| 495 |
+
tc_chunk["function"]["arguments"]
|
| 496 |
+
)
|
| 497 |
|
| 498 |
# Aggregate function calls (legacy format)
|
| 499 |
if "function_call" in delta and delta["function_call"] is not None:
|
| 500 |
if "function_call" not in final_message:
|
| 501 |
final_message["function_call"] = {"name": "", "arguments": ""}
|
| 502 |
+
if (
|
| 503 |
+
"name" in delta["function_call"]
|
| 504 |
+
and delta["function_call"]["name"] is not None
|
| 505 |
+
):
|
| 506 |
+
final_message["function_call"]["name"] += delta["function_call"][
|
| 507 |
+
"name"
|
| 508 |
+
]
|
| 509 |
+
if (
|
| 510 |
+
"arguments" in delta["function_call"]
|
| 511 |
+
and delta["function_call"]["arguments"] is not None
|
| 512 |
+
):
|
| 513 |
+
final_message["function_call"]["arguments"] += delta[
|
| 514 |
+
"function_call"
|
| 515 |
+
]["arguments"]
|
| 516 |
|
| 517 |
# Track finish_reason from chunks (for reference only)
|
| 518 |
if choice.get("finish_reason"):
|
|
|
|
| 520 |
|
| 521 |
# Handle usage data from the last chunk that has it
|
| 522 |
for chunk in reversed(chunks):
|
| 523 |
+
if hasattr(chunk, "usage") and chunk.usage:
|
| 524 |
usage_data = chunk.usage
|
| 525 |
break
|
| 526 |
|
|
|
|
| 546 |
final_choice = {
|
| 547 |
"index": 0,
|
| 548 |
"message": final_message,
|
| 549 |
+
"finish_reason": finish_reason,
|
| 550 |
}
|
| 551 |
|
| 552 |
# Create the final ModelResponse
|
|
|
|
| 556 |
"created": first_chunk.created,
|
| 557 |
"model": first_chunk.model,
|
| 558 |
"choices": [final_choice],
|
| 559 |
+
"usage": usage_data,
|
| 560 |
}
|
| 561 |
|
| 562 |
return litellm.ModelResponse(**final_response_data)
|
| 563 |
|
| 564 |
+
async def acompletion(
|
| 565 |
+
self, client: httpx.AsyncClient, **kwargs
|
| 566 |
+
) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]:
|
| 567 |
credential_path = kwargs.pop("credential_identifier")
|
| 568 |
enable_request_logging = kwargs.pop("enable_request_logging", False)
|
| 569 |
model = kwargs["model"]
|
| 570 |
|
| 571 |
# Create dedicated file logger for this request
|
| 572 |
file_logger = _QwenCodeFileLogger(
|
| 573 |
+
model_name=model, enabled=enable_request_logging
|
|
|
|
| 574 |
)
|
| 575 |
|
| 576 |
async def make_request():
|
|
|
|
| 578 |
api_base, access_token = await self.get_api_details(credential_path)
|
| 579 |
|
| 580 |
# Strip provider prefix from model name (e.g., "qwen_code/qwen3-coder-plus" -> "qwen3-coder-plus")
|
| 581 |
+
model_name = model.split("/")[-1]
|
| 582 |
+
kwargs_with_stripped_model = {**kwargs, "model": model_name}
|
| 583 |
|
| 584 |
# Build clean payload with only supported parameters
|
| 585 |
payload = self._build_request_payload(**kwargs_with_stripped_model)
|
|
|
|
| 599 |
file_logger.log_request(payload)
|
| 600 |
lib_logger.debug(f"Qwen Code Request URL: {url}")
|
| 601 |
|
| 602 |
+
return client.stream(
|
| 603 |
+
"POST",
|
| 604 |
+
url,
|
| 605 |
+
headers=headers,
|
| 606 |
+
json=payload,
|
| 607 |
+
timeout=TimeoutConfig.streaming(),
|
| 608 |
+
)
|
| 609 |
|
| 610 |
async def stream_handler(response_stream, attempt=1):
|
| 611 |
"""Handles the streaming response and converts chunks."""
|
|
|
|
| 614 |
# Check for HTTP errors before processing stream
|
| 615 |
if response.status_code >= 400:
|
| 616 |
error_text = await response.aread()
|
| 617 |
+
error_text = (
|
| 618 |
+
error_text.decode("utf-8")
|
| 619 |
+
if isinstance(error_text, bytes)
|
| 620 |
+
else error_text
|
| 621 |
+
)
|
| 622 |
|
| 623 |
# Handle 401: Force token refresh and retry once
|
| 624 |
if response.status_code == 401 and attempt == 1:
|
| 625 |
+
lib_logger.warning(
|
| 626 |
+
"Qwen Code returned 401. Forcing token refresh and retrying once."
|
| 627 |
+
)
|
| 628 |
await self._refresh_token(credential_path, force=True)
|
| 629 |
retry_stream = await make_request()
|
| 630 |
async for chunk in stream_handler(retry_stream, attempt=2):
|
|
|
|
| 632 |
return
|
| 633 |
|
| 634 |
# Handle 429: Rate limit
|
| 635 |
+
elif (
|
| 636 |
+
response.status_code == 429
|
| 637 |
+
or "slow_down" in error_text.lower()
|
| 638 |
+
):
|
| 639 |
raise RateLimitError(
|
| 640 |
f"Qwen Code rate limit exceeded: {error_text}",
|
| 641 |
llm_provider="qwen_code",
|
| 642 |
model=model,
|
| 643 |
+
response=response,
|
| 644 |
)
|
| 645 |
|
| 646 |
# Handle other errors
|
|
|
|
| 650 |
raise httpx.HTTPStatusError(
|
| 651 |
f"HTTP {response.status_code}: {error_text}",
|
| 652 |
request=response.request,
|
| 653 |
+
response=response,
|
| 654 |
)
|
| 655 |
|
| 656 |
# Process successful streaming response
|
| 657 |
async for line in response.aiter_lines():
|
| 658 |
file_logger.log_response_chunk(line)
|
| 659 |
+
if line.startswith("data: "):
|
| 660 |
data_str = line[6:]
|
| 661 |
if data_str == "[DONE]":
|
| 662 |
break
|
| 663 |
try:
|
| 664 |
chunk = json.loads(data_str)
|
| 665 |
+
for openai_chunk in self._convert_chunk_to_openai(
|
| 666 |
+
chunk, model
|
| 667 |
+
):
|
| 668 |
yield litellm.ModelResponse(**openai_chunk)
|
| 669 |
except json.JSONDecodeError:
|
| 670 |
+
lib_logger.warning(
|
| 671 |
+
f"Could not decode JSON from Qwen Code: {line}"
|
| 672 |
+
)
|
| 673 |
|
| 674 |
except httpx.HTTPStatusError:
|
| 675 |
raise # Re-raise HTTP errors we already handled
|
| 676 |
except Exception as e:
|
| 677 |
file_logger.log_error(f"Error during Qwen Code stream processing: {e}")
|
| 678 |
+
lib_logger.error(
|
| 679 |
+
f"Error during Qwen Code stream processing: {e}", exc_info=True
|
| 680 |
+
)
|
| 681 |
raise
|
| 682 |
|
| 683 |
async def logging_stream_wrapper():
|
|
|
|
| 695 |
if kwargs.get("stream"):
|
| 696 |
return logging_stream_wrapper()
|
| 697 |
else:
|
| 698 |
+
|
| 699 |
async def non_stream_wrapper():
|
| 700 |
chunks = [chunk async for chunk in logging_stream_wrapper()]
|
| 701 |
return self._stream_to_completion_response(chunks)
|
| 702 |
+
|
| 703 |
+
return await non_stream_wrapper()
|
src/rotator_library/timeout_config.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/rotator_library/timeout_config.py
|
| 2 |
+
"""
|
| 3 |
+
Centralized timeout configuration for HTTP requests.
|
| 4 |
+
|
| 5 |
+
All values can be overridden via environment variables:
|
| 6 |
+
TIMEOUT_CONNECT - Connection establishment timeout (default: 30s)
|
| 7 |
+
TIMEOUT_WRITE - Request body send timeout (default: 30s)
|
| 8 |
+
TIMEOUT_POOL - Connection pool acquisition timeout (default: 60s)
|
| 9 |
+
TIMEOUT_READ_STREAMING - Read timeout between chunks for streaming (default: 180s / 3 min)
|
| 10 |
+
TIMEOUT_READ_NON_STREAMING - Read timeout for non-streaming responses (default: 600s / 10 min)
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import logging
|
| 15 |
+
import httpx
|
| 16 |
+
|
| 17 |
+
lib_logger = logging.getLogger("rotator_library")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TimeoutConfig:
|
| 21 |
+
"""
|
| 22 |
+
Centralized timeout configuration for HTTP requests.
|
| 23 |
+
|
| 24 |
+
All values can be overridden via environment variables.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
# Default values (in seconds)
|
| 28 |
+
_CONNECT = 30.0
|
| 29 |
+
_WRITE = 30.0
|
| 30 |
+
_POOL = 60.0
|
| 31 |
+
_READ_STREAMING = 180.0 # 3 minutes between chunks
|
| 32 |
+
_READ_NON_STREAMING = 600.0 # 10 minutes for full response
|
| 33 |
+
|
| 34 |
+
@classmethod
|
| 35 |
+
def _get_env_float(cls, key: str, default: float) -> float:
|
| 36 |
+
"""Get a float value from environment variable, or return default."""
|
| 37 |
+
value = os.environ.get(key)
|
| 38 |
+
if value is not None:
|
| 39 |
+
try:
|
| 40 |
+
return float(value)
|
| 41 |
+
except ValueError:
|
| 42 |
+
lib_logger.warning(
|
| 43 |
+
f"Invalid value for {key}: {value}. Using default: {default}"
|
| 44 |
+
)
|
| 45 |
+
return default
|
| 46 |
+
|
| 47 |
+
@classmethod
|
| 48 |
+
def connect(cls) -> float:
|
| 49 |
+
"""Connection establishment timeout."""
|
| 50 |
+
return cls._get_env_float("TIMEOUT_CONNECT", cls._CONNECT)
|
| 51 |
+
|
| 52 |
+
@classmethod
|
| 53 |
+
def write(cls) -> float:
|
| 54 |
+
"""Request body send timeout."""
|
| 55 |
+
return cls._get_env_float("TIMEOUT_WRITE", cls._WRITE)
|
| 56 |
+
|
| 57 |
+
@classmethod
|
| 58 |
+
def pool(cls) -> float:
|
| 59 |
+
"""Connection pool acquisition timeout."""
|
| 60 |
+
return cls._get_env_float("TIMEOUT_POOL", cls._POOL)
|
| 61 |
+
|
| 62 |
+
@classmethod
|
| 63 |
+
def read_streaming(cls) -> float:
|
| 64 |
+
"""Read timeout between chunks for streaming requests."""
|
| 65 |
+
return cls._get_env_float("TIMEOUT_READ_STREAMING", cls._READ_STREAMING)
|
| 66 |
+
|
| 67 |
+
@classmethod
|
| 68 |
+
def read_non_streaming(cls) -> float:
|
| 69 |
+
"""Read timeout for non-streaming responses."""
|
| 70 |
+
return cls._get_env_float("TIMEOUT_READ_NON_STREAMING", cls._READ_NON_STREAMING)
|
| 71 |
+
|
| 72 |
+
@classmethod
|
| 73 |
+
def streaming(cls) -> httpx.Timeout:
|
| 74 |
+
"""
|
| 75 |
+
Timeout configuration for streaming LLM requests.
|
| 76 |
+
|
| 77 |
+
Uses a shorter read timeout (default 3 min) since we expect
|
| 78 |
+
periodic chunks. If no data arrives for this duration, the
|
| 79 |
+
connection is considered stalled.
|
| 80 |
+
"""
|
| 81 |
+
return httpx.Timeout(
|
| 82 |
+
connect=cls.connect(),
|
| 83 |
+
read=cls.read_streaming(),
|
| 84 |
+
write=cls.write(),
|
| 85 |
+
pool=cls.pool(),
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
@classmethod
|
| 89 |
+
def non_streaming(cls) -> httpx.Timeout:
|
| 90 |
+
"""
|
| 91 |
+
Timeout configuration for non-streaming LLM requests.
|
| 92 |
+
|
| 93 |
+
Uses a longer read timeout (default 10 min) since the server
|
| 94 |
+
may take significant time to generate the complete response
|
| 95 |
+
before sending anything back.
|
| 96 |
+
"""
|
| 97 |
+
return httpx.Timeout(
|
| 98 |
+
connect=cls.connect(),
|
| 99 |
+
read=cls.read_non_streaming(),
|
| 100 |
+
write=cls.write(),
|
| 101 |
+
pool=cls.pool(),
|
| 102 |
+
)
|
src/rotator_library/usage_manager.py
CHANGED
|
@@ -5,12 +5,15 @@ import logging
|
|
| 5 |
import asyncio
|
| 6 |
import random
|
| 7 |
from datetime import date, datetime, timezone, time as dt_time
|
| 8 |
-
from
|
|
|
|
| 9 |
import aiofiles
|
| 10 |
import litellm
|
| 11 |
|
| 12 |
from .error_handler import ClassifiedError, NoAvailableKeysError, mask_credential
|
| 13 |
from .providers import PROVIDER_PLUGINS
|
|
|
|
|
|
|
| 14 |
|
| 15 |
lib_logger = logging.getLogger("rotator_library")
|
| 16 |
lib_logger.propagate = False
|
|
@@ -50,7 +53,7 @@ class UsageManager:
|
|
| 50 |
|
| 51 |
def __init__(
|
| 52 |
self,
|
| 53 |
-
file_path: str =
|
| 54 |
daily_reset_time_utc: Optional[str] = "03:00",
|
| 55 |
rotation_tolerance: float = 0.0,
|
| 56 |
provider_rotation_modes: Optional[Dict[str, str]] = None,
|
|
@@ -65,7 +68,8 @@ class UsageManager:
|
|
| 65 |
Initialize the UsageManager.
|
| 66 |
|
| 67 |
Args:
|
| 68 |
-
file_path: Path to the usage data JSON file
|
|
|
|
| 69 |
daily_reset_time_utc: Time in UTC when daily stats should reset (HH:MM format)
|
| 70 |
rotation_tolerance: Tolerance for weighted random credential rotation.
|
| 71 |
- 0.0: Deterministic, least-used credential always selected
|
|
@@ -85,7 +89,14 @@ class UsageManager:
|
|
| 85 |
Used in sequential mode when priority not in priority_multipliers.
|
| 86 |
Example: {"antigravity": 2}
|
| 87 |
"""
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
self.rotation_tolerance = rotation_tolerance
|
| 90 |
self.provider_rotation_modes = provider_rotation_modes or {}
|
| 91 |
self.provider_plugins = provider_plugins or PROVIDER_PLUGINS
|
|
@@ -103,6 +114,9 @@ class UsageManager:
|
|
| 103 |
self._timeout_lock = asyncio.Lock()
|
| 104 |
self._claimed_on_timeout: Set[str] = set()
|
| 105 |
|
|
|
|
|
|
|
|
|
|
| 106 |
if daily_reset_time_utc:
|
| 107 |
hour, minute = map(int, daily_reset_time_utc.split(":"))
|
| 108 |
self.daily_reset_time_utc = dt_time(
|
|
@@ -540,27 +554,40 @@ class UsageManager:
|
|
| 540 |
self._initialized.set()
|
| 541 |
|
| 542 |
async def _load_usage(self):
|
| 543 |
-
"""Loads usage data from the JSON file asynchronously."""
|
| 544 |
async with self._data_lock:
|
| 545 |
if not os.path.exists(self.file_path):
|
| 546 |
self._usage_data = {}
|
| 547 |
return
|
|
|
|
| 548 |
try:
|
| 549 |
async with aiofiles.open(self.file_path, "r") as f:
|
| 550 |
content = await f.read()
|
| 551 |
-
self._usage_data = json.loads(content)
|
| 552 |
-
except
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
self._usage_data = {}
|
| 554 |
|
| 555 |
async def _save_usage(self):
|
| 556 |
-
"""Saves the current usage data
|
| 557 |
if self._usage_data is None:
|
| 558 |
return
|
|
|
|
| 559 |
async with self._data_lock:
|
| 560 |
# Add human-readable timestamp fields before saving
|
| 561 |
self._add_readable_timestamps(self._usage_data)
|
| 562 |
-
|
| 563 |
-
|
| 564 |
|
| 565 |
async def _reset_daily_stats_if_needed(self):
|
| 566 |
"""
|
|
|
|
| 5 |
import asyncio
|
| 6 |
import random
|
| 7 |
from datetime import date, datetime, timezone, time as dt_time
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
| 10 |
import aiofiles
|
| 11 |
import litellm
|
| 12 |
|
| 13 |
from .error_handler import ClassifiedError, NoAvailableKeysError, mask_credential
|
| 14 |
from .providers import PROVIDER_PLUGINS
|
| 15 |
+
from .utils.resilient_io import ResilientStateWriter
|
| 16 |
+
from .utils.paths import get_data_file
|
| 17 |
|
| 18 |
lib_logger = logging.getLogger("rotator_library")
|
| 19 |
lib_logger.propagate = False
|
|
|
|
| 53 |
|
| 54 |
def __init__(
|
| 55 |
self,
|
| 56 |
+
file_path: Optional[Union[str, Path]] = None,
|
| 57 |
daily_reset_time_utc: Optional[str] = "03:00",
|
| 58 |
rotation_tolerance: float = 0.0,
|
| 59 |
provider_rotation_modes: Optional[Dict[str, str]] = None,
|
|
|
|
| 68 |
Initialize the UsageManager.
|
| 69 |
|
| 70 |
Args:
|
| 71 |
+
file_path: Path to the usage data JSON file. If None, uses get_data_file("key_usage.json").
|
| 72 |
+
Can be absolute Path, relative Path, or string.
|
| 73 |
daily_reset_time_utc: Time in UTC when daily stats should reset (HH:MM format)
|
| 74 |
rotation_tolerance: Tolerance for weighted random credential rotation.
|
| 75 |
- 0.0: Deterministic, least-used credential always selected
|
|
|
|
| 89 |
Used in sequential mode when priority not in priority_multipliers.
|
| 90 |
Example: {"antigravity": 2}
|
| 91 |
"""
|
| 92 |
+
# Resolve file_path - use default if not provided
|
| 93 |
+
if file_path is None:
|
| 94 |
+
self.file_path = str(get_data_file("key_usage.json"))
|
| 95 |
+
elif isinstance(file_path, Path):
|
| 96 |
+
self.file_path = str(file_path)
|
| 97 |
+
else:
|
| 98 |
+
# String path - could be relative or absolute
|
| 99 |
+
self.file_path = file_path
|
| 100 |
self.rotation_tolerance = rotation_tolerance
|
| 101 |
self.provider_rotation_modes = provider_rotation_modes or {}
|
| 102 |
self.provider_plugins = provider_plugins or PROVIDER_PLUGINS
|
|
|
|
| 114 |
self._timeout_lock = asyncio.Lock()
|
| 115 |
self._claimed_on_timeout: Set[str] = set()
|
| 116 |
|
| 117 |
+
# Resilient writer for usage data persistence
|
| 118 |
+
self._state_writer = ResilientStateWriter(file_path, lib_logger)
|
| 119 |
+
|
| 120 |
if daily_reset_time_utc:
|
| 121 |
hour, minute = map(int, daily_reset_time_utc.split(":"))
|
| 122 |
self.daily_reset_time_utc = dt_time(
|
|
|
|
| 554 |
self._initialized.set()
|
| 555 |
|
| 556 |
async def _load_usage(self):
|
| 557 |
+
"""Loads usage data from the JSON file asynchronously with resilience."""
|
| 558 |
async with self._data_lock:
|
| 559 |
if not os.path.exists(self.file_path):
|
| 560 |
self._usage_data = {}
|
| 561 |
return
|
| 562 |
+
|
| 563 |
try:
|
| 564 |
async with aiofiles.open(self.file_path, "r") as f:
|
| 565 |
content = await f.read()
|
| 566 |
+
self._usage_data = json.loads(content) if content.strip() else {}
|
| 567 |
+
except FileNotFoundError:
|
| 568 |
+
# File deleted between exists check and open
|
| 569 |
+
self._usage_data = {}
|
| 570 |
+
except json.JSONDecodeError as e:
|
| 571 |
+
lib_logger.warning(
|
| 572 |
+
f"Corrupted usage file {self.file_path}: {e}. Starting fresh."
|
| 573 |
+
)
|
| 574 |
+
self._usage_data = {}
|
| 575 |
+
except (OSError, PermissionError, IOError) as e:
|
| 576 |
+
lib_logger.warning(
|
| 577 |
+
f"Cannot read usage file {self.file_path}: {e}. Using empty state."
|
| 578 |
+
)
|
| 579 |
self._usage_data = {}
|
| 580 |
|
| 581 |
async def _save_usage(self):
|
| 582 |
+
"""Saves the current usage data using the resilient state writer."""
|
| 583 |
if self._usage_data is None:
|
| 584 |
return
|
| 585 |
+
|
| 586 |
async with self._data_lock:
|
| 587 |
# Add human-readable timestamp fields before saving
|
| 588 |
self._add_readable_timestamps(self._usage_data)
|
| 589 |
+
# Hand off to resilient writer - handles retries and disk failures
|
| 590 |
+
self._state_writer.write(self._usage_data)
|
| 591 |
|
| 592 |
async def _reset_daily_stats_if_needed(self):
|
| 593 |
"""
|
src/rotator_library/utils/__init__.py
CHANGED
|
@@ -1,6 +1,34 @@
|
|
| 1 |
# src/rotator_library/utils/__init__.py
|
| 2 |
|
| 3 |
from .headless_detection import is_headless_environment
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from .reauth_coordinator import get_reauth_coordinator, ReauthCoordinator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
__all__ = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# src/rotator_library/utils/__init__.py
|
| 2 |
|
| 3 |
from .headless_detection import is_headless_environment
|
| 4 |
+
from .paths import (
|
| 5 |
+
get_default_root,
|
| 6 |
+
get_logs_dir,
|
| 7 |
+
get_cache_dir,
|
| 8 |
+
get_oauth_dir,
|
| 9 |
+
get_data_file,
|
| 10 |
+
)
|
| 11 |
from .reauth_coordinator import get_reauth_coordinator, ReauthCoordinator
|
| 12 |
+
from .resilient_io import (
|
| 13 |
+
BufferedWriteRegistry,
|
| 14 |
+
ResilientStateWriter,
|
| 15 |
+
safe_write_json,
|
| 16 |
+
safe_log_write,
|
| 17 |
+
safe_mkdir,
|
| 18 |
+
)
|
| 19 |
|
| 20 |
+
__all__ = [
|
| 21 |
+
"is_headless_environment",
|
| 22 |
+
"get_default_root",
|
| 23 |
+
"get_logs_dir",
|
| 24 |
+
"get_cache_dir",
|
| 25 |
+
"get_oauth_dir",
|
| 26 |
+
"get_data_file",
|
| 27 |
+
"get_reauth_coordinator",
|
| 28 |
+
"ReauthCoordinator",
|
| 29 |
+
"BufferedWriteRegistry",
|
| 30 |
+
"ResilientStateWriter",
|
| 31 |
+
"safe_write_json",
|
| 32 |
+
"safe_log_write",
|
| 33 |
+
"safe_mkdir",
|
| 34 |
+
]
|
src/rotator_library/utils/paths.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/rotator_library/utils/paths.py
|
| 2 |
+
"""
|
| 3 |
+
Centralized path management for the rotator library.
|
| 4 |
+
|
| 5 |
+
Supports two runtime modes:
|
| 6 |
+
1. PyInstaller EXE -> files in the directory containing the executable
|
| 7 |
+
2. Script/Library -> files in the current working directory (overridable)
|
| 8 |
+
|
| 9 |
+
Library users can override by passing `data_dir` to RotatingClient.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Optional, Union
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_default_root() -> Path:
|
| 18 |
+
"""
|
| 19 |
+
Get the default root directory for data files.
|
| 20 |
+
|
| 21 |
+
- EXE mode (PyInstaller): directory containing the executable
|
| 22 |
+
- Otherwise: current working directory
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
Path to the root directory
|
| 26 |
+
"""
|
| 27 |
+
if getattr(sys, "frozen", False):
|
| 28 |
+
# Running as PyInstaller bundle - use executable's directory
|
| 29 |
+
return Path(sys.executable).parent
|
| 30 |
+
# Running as script or library - use current working directory
|
| 31 |
+
return Path.cwd()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_logs_dir(root: Optional[Union[Path, str]] = None) -> Path:
|
| 35 |
+
"""
|
| 36 |
+
Get the logs directory, creating it if needed.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
root: Optional root directory. If None, uses get_default_root().
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Path to the logs directory
|
| 43 |
+
"""
|
| 44 |
+
base = Path(root) if root else get_default_root()
|
| 45 |
+
logs_dir = base / "logs"
|
| 46 |
+
logs_dir.mkdir(exist_ok=True)
|
| 47 |
+
return logs_dir
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_cache_dir(
|
| 51 |
+
root: Optional[Union[Path, str]] = None, subdir: Optional[str] = None
|
| 52 |
+
) -> Path:
|
| 53 |
+
"""
|
| 54 |
+
Get the cache directory, optionally with a subdirectory.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
root: Optional root directory. If None, uses get_default_root().
|
| 58 |
+
subdir: Optional subdirectory name (e.g., "gemini_cli", "antigravity")
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
Path to the cache directory (or subdirectory)
|
| 62 |
+
"""
|
| 63 |
+
base = Path(root) if root else get_default_root()
|
| 64 |
+
cache_dir = base / "cache"
|
| 65 |
+
if subdir:
|
| 66 |
+
cache_dir = cache_dir / subdir
|
| 67 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 68 |
+
return cache_dir
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def get_oauth_dir(root: Optional[Union[Path, str]] = None) -> Path:
|
| 72 |
+
"""
|
| 73 |
+
Get the OAuth credentials directory, creating it if needed.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
root: Optional root directory. If None, uses get_default_root().
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
Path to the oauth_creds directory
|
| 80 |
+
"""
|
| 81 |
+
base = Path(root) if root else get_default_root()
|
| 82 |
+
oauth_dir = base / "oauth_creds"
|
| 83 |
+
oauth_dir.mkdir(exist_ok=True)
|
| 84 |
+
return oauth_dir
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_data_file(filename: str, root: Optional[Union[Path, str]] = None) -> Path:
|
| 88 |
+
"""
|
| 89 |
+
Get the path to a data file in the root directory.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
filename: Name of the file (e.g., "key_usage.json", ".env")
|
| 93 |
+
root: Optional root directory. If None, uses get_default_root().
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
Path to the file (does not create the file)
|
| 97 |
+
"""
|
| 98 |
+
base = Path(root) if root else get_default_root()
|
| 99 |
+
return base / filename
|
src/rotator_library/utils/resilient_io.py
ADDED
|
@@ -0,0 +1,665 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/rotator_library/utils/resilient_io.py
|
| 2 |
+
"""
|
| 3 |
+
Resilient I/O utilities for handling file operations gracefully.
|
| 4 |
+
|
| 5 |
+
Provides three main patterns:
|
| 6 |
+
1. BufferedWriteRegistry - Global singleton for buffered writes with periodic
|
| 7 |
+
retry and shutdown flush. Ensures data is saved on app exit (Ctrl+C).
|
| 8 |
+
2. ResilientStateWriter - For stateful files (usage.json) that should be
|
| 9 |
+
buffered in memory and retried on disk failure.
|
| 10 |
+
3. safe_write_json (with buffer_on_failure) - For critical files (auth tokens)
|
| 11 |
+
that should be buffered and retried if write fails.
|
| 12 |
+
4. safe_log_write - For logs that can be dropped on failure.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import atexit
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
import shutil
|
| 19 |
+
import tempfile
|
| 20 |
+
import threading
|
| 21 |
+
import time
|
| 22 |
+
import logging
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# =============================================================================
|
| 28 |
+
# BUFFERED WRITE REGISTRY (SINGLETON)
|
| 29 |
+
# =============================================================================
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class BufferedWriteRegistry:
|
| 33 |
+
"""
|
| 34 |
+
Global singleton registry for buffered writes with periodic retry and shutdown flush.
|
| 35 |
+
|
| 36 |
+
This ensures that critical data (auth tokens, usage stats) is saved even if
|
| 37 |
+
disk writes fail temporarily. On app exit (including Ctrl+C), all pending
|
| 38 |
+
writes are flushed.
|
| 39 |
+
|
| 40 |
+
Features:
|
| 41 |
+
- Per-file buffering: each file path has its own pending write
|
| 42 |
+
- Periodic retries: background thread retries failed writes every N seconds
|
| 43 |
+
- Shutdown flush: atexit hook ensures final write attempt on app exit
|
| 44 |
+
- Thread-safe: safe for concurrent access from multiple threads
|
| 45 |
+
|
| 46 |
+
Usage:
|
| 47 |
+
# Get the singleton instance
|
| 48 |
+
registry = BufferedWriteRegistry.get_instance()
|
| 49 |
+
|
| 50 |
+
# Register a pending write (usually called by safe_write_json on failure)
|
| 51 |
+
registry.register_pending(path, data, serializer_fn, options)
|
| 52 |
+
|
| 53 |
+
# Manual flush (optional - atexit handles this automatically)
|
| 54 |
+
results = registry.flush_all()
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
_instance: Optional["BufferedWriteRegistry"] = None
|
| 58 |
+
_instance_lock = threading.Lock()
|
| 59 |
+
|
| 60 |
+
def __init__(self, retry_interval: float = 30.0):
|
| 61 |
+
"""
|
| 62 |
+
Initialize the registry. Use get_instance() instead of direct construction.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
retry_interval: Seconds between retry attempts (default: 30)
|
| 66 |
+
"""
|
| 67 |
+
self._pending: Dict[str, Tuple[Any, Callable[[Any], str], Dict[str, Any]]] = {}
|
| 68 |
+
self._retry_interval = retry_interval
|
| 69 |
+
self._lock = threading.Lock()
|
| 70 |
+
self._running = False
|
| 71 |
+
self._retry_thread: Optional[threading.Thread] = None
|
| 72 |
+
self._logger = logging.getLogger("rotator_library.resilient_io")
|
| 73 |
+
|
| 74 |
+
# Start background retry thread
|
| 75 |
+
self._start_retry_thread()
|
| 76 |
+
|
| 77 |
+
# Register atexit handler for shutdown flush
|
| 78 |
+
atexit.register(self._atexit_handler)
|
| 79 |
+
|
| 80 |
+
@classmethod
|
| 81 |
+
def get_instance(cls, retry_interval: float = 30.0) -> "BufferedWriteRegistry":
|
| 82 |
+
"""
|
| 83 |
+
Get or create the singleton instance.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
retry_interval: Seconds between retry attempts (only used on first call)
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
The singleton BufferedWriteRegistry instance
|
| 90 |
+
"""
|
| 91 |
+
if cls._instance is None:
|
| 92 |
+
with cls._instance_lock:
|
| 93 |
+
if cls._instance is None:
|
| 94 |
+
cls._instance = cls(retry_interval)
|
| 95 |
+
return cls._instance
|
| 96 |
+
|
| 97 |
+
def _start_retry_thread(self) -> None:
|
| 98 |
+
"""Start the background retry thread."""
|
| 99 |
+
if self._running:
|
| 100 |
+
return
|
| 101 |
+
|
| 102 |
+
self._running = True
|
| 103 |
+
self._retry_thread = threading.Thread(
|
| 104 |
+
target=self._retry_loop,
|
| 105 |
+
name="BufferedWriteRegistry-Retry",
|
| 106 |
+
daemon=True, # Daemon so it doesn't block app exit
|
| 107 |
+
)
|
| 108 |
+
self._retry_thread.start()
|
| 109 |
+
|
| 110 |
+
def _retry_loop(self) -> None:
|
| 111 |
+
"""Background thread: periodically retry pending writes."""
|
| 112 |
+
while self._running:
|
| 113 |
+
time.sleep(self._retry_interval)
|
| 114 |
+
if not self._running:
|
| 115 |
+
break
|
| 116 |
+
self._retry_pending()
|
| 117 |
+
|
| 118 |
+
def _retry_pending(self) -> None:
|
| 119 |
+
"""Attempt to write all pending files."""
|
| 120 |
+
with self._lock:
|
| 121 |
+
if not self._pending:
|
| 122 |
+
return
|
| 123 |
+
|
| 124 |
+
# Copy paths to avoid modifying dict during iteration
|
| 125 |
+
paths = list(self._pending.keys())
|
| 126 |
+
|
| 127 |
+
for path_str in paths:
|
| 128 |
+
self._try_write(path_str, remove_on_success=True)
|
| 129 |
+
|
| 130 |
+
def register_pending(
|
| 131 |
+
self,
|
| 132 |
+
path: Union[str, Path],
|
| 133 |
+
data: Any,
|
| 134 |
+
serializer: Callable[[Any], str],
|
| 135 |
+
options: Optional[Dict[str, Any]] = None,
|
| 136 |
+
) -> None:
|
| 137 |
+
"""
|
| 138 |
+
Register a pending write for later retry.
|
| 139 |
+
|
| 140 |
+
If a write is already pending for this path, it is replaced with the new data
|
| 141 |
+
(we always want to write the latest state).
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
path: File path to write to
|
| 145 |
+
data: Data to serialize and write
|
| 146 |
+
serializer: Function to serialize data to string
|
| 147 |
+
options: Additional options (e.g., secure_permissions)
|
| 148 |
+
"""
|
| 149 |
+
path_str = str(Path(path).resolve())
|
| 150 |
+
with self._lock:
|
| 151 |
+
self._pending[path_str] = (data, serializer, options or {})
|
| 152 |
+
self._logger.debug(f"Registered pending write for {Path(path).name}")
|
| 153 |
+
|
| 154 |
+
def unregister(self, path: Union[str, Path]) -> None:
|
| 155 |
+
"""
|
| 156 |
+
Remove a pending write (called when write succeeds elsewhere).
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
path: File path to remove from pending
|
| 160 |
+
"""
|
| 161 |
+
path_str = str(Path(path).resolve())
|
| 162 |
+
with self._lock:
|
| 163 |
+
self._pending.pop(path_str, None)
|
| 164 |
+
|
| 165 |
+
def _try_write(self, path_str: str, remove_on_success: bool = True) -> bool:
|
| 166 |
+
"""
|
| 167 |
+
Attempt to write a pending file.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
path_str: Resolved path string
|
| 171 |
+
remove_on_success: Remove from pending if successful
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
True if write succeeded, False otherwise
|
| 175 |
+
"""
|
| 176 |
+
with self._lock:
|
| 177 |
+
if path_str not in self._pending:
|
| 178 |
+
return True
|
| 179 |
+
data, serializer, options = self._pending[path_str]
|
| 180 |
+
|
| 181 |
+
path = Path(path_str)
|
| 182 |
+
try:
|
| 183 |
+
# Ensure directory exists
|
| 184 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 185 |
+
|
| 186 |
+
# Serialize data
|
| 187 |
+
content = serializer(data)
|
| 188 |
+
|
| 189 |
+
# Atomic write
|
| 190 |
+
tmp_fd = None
|
| 191 |
+
tmp_path = None
|
| 192 |
+
try:
|
| 193 |
+
tmp_fd, tmp_path = tempfile.mkstemp(
|
| 194 |
+
dir=path.parent, prefix=".tmp_", suffix=".json", text=True
|
| 195 |
+
)
|
| 196 |
+
with os.fdopen(tmp_fd, "w", encoding="utf-8") as f:
|
| 197 |
+
f.write(content)
|
| 198 |
+
tmp_fd = None
|
| 199 |
+
|
| 200 |
+
# Set secure permissions if requested
|
| 201 |
+
if options.get("secure_permissions"):
|
| 202 |
+
try:
|
| 203 |
+
os.chmod(tmp_path, 0o600)
|
| 204 |
+
except (OSError, AttributeError):
|
| 205 |
+
pass
|
| 206 |
+
|
| 207 |
+
shutil.move(tmp_path, path)
|
| 208 |
+
tmp_path = None
|
| 209 |
+
|
| 210 |
+
finally:
|
| 211 |
+
if tmp_fd is not None:
|
| 212 |
+
try:
|
| 213 |
+
os.close(tmp_fd)
|
| 214 |
+
except OSError:
|
| 215 |
+
pass
|
| 216 |
+
if tmp_path and os.path.exists(tmp_path):
|
| 217 |
+
try:
|
| 218 |
+
os.unlink(tmp_path)
|
| 219 |
+
except OSError:
|
| 220 |
+
pass
|
| 221 |
+
|
| 222 |
+
# Success - remove from pending
|
| 223 |
+
if remove_on_success:
|
| 224 |
+
with self._lock:
|
| 225 |
+
self._pending.pop(path_str, None)
|
| 226 |
+
|
| 227 |
+
self._logger.debug(f"Retry succeeded for {path.name}")
|
| 228 |
+
return True
|
| 229 |
+
|
| 230 |
+
except (OSError, PermissionError, IOError) as e:
|
| 231 |
+
self._logger.debug(f"Retry failed for {path.name}: {e}")
|
| 232 |
+
return False
|
| 233 |
+
|
| 234 |
+
def flush_all(self) -> Dict[str, bool]:
|
| 235 |
+
"""
|
| 236 |
+
Attempt to write all pending files immediately.
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
Dict mapping file paths to success status
|
| 240 |
+
"""
|
| 241 |
+
with self._lock:
|
| 242 |
+
paths = list(self._pending.keys())
|
| 243 |
+
|
| 244 |
+
results = {}
|
| 245 |
+
for path_str in paths:
|
| 246 |
+
results[path_str] = self._try_write(path_str, remove_on_success=True)
|
| 247 |
+
|
| 248 |
+
return results
|
| 249 |
+
|
| 250 |
+
def _atexit_handler(self) -> None:
|
| 251 |
+
"""Called on app exit to flush pending writes."""
|
| 252 |
+
self._running = False
|
| 253 |
+
|
| 254 |
+
with self._lock:
|
| 255 |
+
pending_count = len(self._pending)
|
| 256 |
+
|
| 257 |
+
if pending_count == 0:
|
| 258 |
+
return
|
| 259 |
+
|
| 260 |
+
self._logger.info(f"Flushing {pending_count} pending write(s) on shutdown...")
|
| 261 |
+
results = self.flush_all()
|
| 262 |
+
|
| 263 |
+
succeeded = sum(1 for v in results.values() if v)
|
| 264 |
+
failed = pending_count - succeeded
|
| 265 |
+
|
| 266 |
+
if failed > 0:
|
| 267 |
+
self._logger.warning(
|
| 268 |
+
f"Shutdown flush: {succeeded} succeeded, {failed} failed"
|
| 269 |
+
)
|
| 270 |
+
for path_str, success in results.items():
|
| 271 |
+
if not success:
|
| 272 |
+
self._logger.warning(f" Failed to save: {Path(path_str).name}")
|
| 273 |
+
else:
|
| 274 |
+
self._logger.info(f"Shutdown flush: all {succeeded} write(s) succeeded")
|
| 275 |
+
|
| 276 |
+
def get_pending_count(self) -> int:
|
| 277 |
+
"""Get the number of pending writes."""
|
| 278 |
+
with self._lock:
|
| 279 |
+
return len(self._pending)
|
| 280 |
+
|
| 281 |
+
def get_pending_paths(self) -> list:
|
| 282 |
+
"""Get list of paths with pending writes (for monitoring)."""
|
| 283 |
+
with self._lock:
|
| 284 |
+
return [Path(p).name for p in self._pending.keys()]
|
| 285 |
+
|
| 286 |
+
def shutdown(self) -> Dict[str, bool]:
|
| 287 |
+
"""
|
| 288 |
+
Manually trigger shutdown: stop retry thread and flush all pending writes.
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
Dict mapping file paths to success status
|
| 292 |
+
"""
|
| 293 |
+
self._running = False
|
| 294 |
+
if self._retry_thread and self._retry_thread.is_alive():
|
| 295 |
+
self._retry_thread.join(timeout=1.0)
|
| 296 |
+
return self.flush_all()
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
# =============================================================================
|
| 300 |
+
# RESILIENT STATE WRITER
|
| 301 |
+
# =============================================================================
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class ResilientStateWriter:
|
| 305 |
+
"""
|
| 306 |
+
Manages resilient writes for stateful files (usage stats, credentials, cache).
|
| 307 |
+
|
| 308 |
+
Design:
|
| 309 |
+
- Caller hands off data via write() - always succeeds (memory update)
|
| 310 |
+
- Attempts disk write immediately
|
| 311 |
+
- If disk fails, retries periodically in background
|
| 312 |
+
- On recovery, writes full current state (not just new data)
|
| 313 |
+
|
| 314 |
+
Thread-safe for use in async contexts with sync file I/O.
|
| 315 |
+
|
| 316 |
+
Usage:
|
| 317 |
+
writer = ResilientStateWriter("data.json", logger)
|
| 318 |
+
writer.write({"key": "value"}) # Always succeeds
|
| 319 |
+
# ... later ...
|
| 320 |
+
if not writer.is_healthy:
|
| 321 |
+
logger.warning("Disk writes failing, data in memory only")
|
| 322 |
+
"""
|
| 323 |
+
|
| 324 |
+
def __init__(
|
| 325 |
+
self,
|
| 326 |
+
path: Union[str, Path],
|
| 327 |
+
logger: logging.Logger,
|
| 328 |
+
retry_interval: float = 30.0,
|
| 329 |
+
serializer: Optional[Callable[[Any], str]] = None,
|
| 330 |
+
):
|
| 331 |
+
"""
|
| 332 |
+
Initialize the resilient writer.
|
| 333 |
+
|
| 334 |
+
Args:
|
| 335 |
+
path: File path to write to
|
| 336 |
+
logger: Logger for warnings/errors
|
| 337 |
+
retry_interval: Seconds between retry attempts when disk is unhealthy
|
| 338 |
+
serializer: Custom serializer function (defaults to JSON with indent=2)
|
| 339 |
+
"""
|
| 340 |
+
self.path = Path(path)
|
| 341 |
+
self.logger = logger
|
| 342 |
+
self.retry_interval = retry_interval
|
| 343 |
+
self._serializer = serializer or (lambda d: json.dumps(d, indent=2))
|
| 344 |
+
|
| 345 |
+
self._current_state: Optional[Any] = None
|
| 346 |
+
self._disk_healthy = True
|
| 347 |
+
self._last_attempt: float = 0
|
| 348 |
+
self._last_success: Optional[float] = None
|
| 349 |
+
self._failure_count = 0
|
| 350 |
+
self._lock = threading.Lock()
|
| 351 |
+
|
| 352 |
+
def write(self, data: Any) -> bool:
|
| 353 |
+
"""
|
| 354 |
+
Update state and attempt disk write.
|
| 355 |
+
|
| 356 |
+
Always updates in-memory state (guaranteed to succeed).
|
| 357 |
+
Attempts disk write - if disk is unhealthy, respects retry_interval
|
| 358 |
+
before attempting again to avoid flooding with failed writes.
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
data: Data to persist (must be serializable)
|
| 362 |
+
|
| 363 |
+
Returns:
|
| 364 |
+
True if disk write succeeded, False if failed (data still in memory)
|
| 365 |
+
"""
|
| 366 |
+
with self._lock:
|
| 367 |
+
self._current_state = data
|
| 368 |
+
|
| 369 |
+
# If disk is unhealthy, only retry after retry_interval has passed
|
| 370 |
+
if not self._disk_healthy:
|
| 371 |
+
now = time.time()
|
| 372 |
+
if now - self._last_attempt < self.retry_interval:
|
| 373 |
+
# Too soon to retry, data is safe in memory
|
| 374 |
+
return False
|
| 375 |
+
|
| 376 |
+
return self._try_disk_write()
|
| 377 |
+
|
| 378 |
+
def retry_if_needed(self) -> bool:
|
| 379 |
+
"""
|
| 380 |
+
Retry disk write if unhealthy and retry interval has passed.
|
| 381 |
+
|
| 382 |
+
Call this periodically (e.g., on each save attempt) to recover
|
| 383 |
+
from transient disk failures.
|
| 384 |
+
|
| 385 |
+
Returns:
|
| 386 |
+
True if healthy (no retry needed or retry succeeded)
|
| 387 |
+
"""
|
| 388 |
+
with self._lock:
|
| 389 |
+
if self._disk_healthy:
|
| 390 |
+
return True
|
| 391 |
+
|
| 392 |
+
if self._current_state is None:
|
| 393 |
+
return True
|
| 394 |
+
|
| 395 |
+
now = time.time()
|
| 396 |
+
if now - self._last_attempt < self.retry_interval:
|
| 397 |
+
return False
|
| 398 |
+
|
| 399 |
+
return self._try_disk_write()
|
| 400 |
+
|
| 401 |
+
def _try_disk_write(self) -> bool:
|
| 402 |
+
"""
|
| 403 |
+
Attempt atomic write to disk. Updates health status.
|
| 404 |
+
|
| 405 |
+
Uses tempfile + move pattern for atomic writes on POSIX systems.
|
| 406 |
+
On Windows, uses direct write (still safe for our use case).
|
| 407 |
+
|
| 408 |
+
Also registers/unregisters with BufferedWriteRegistry for shutdown flush.
|
| 409 |
+
"""
|
| 410 |
+
if self._current_state is None:
|
| 411 |
+
return True
|
| 412 |
+
|
| 413 |
+
self._last_attempt = time.time()
|
| 414 |
+
|
| 415 |
+
try:
|
| 416 |
+
# Ensure directory exists
|
| 417 |
+
self.path.parent.mkdir(parents=True, exist_ok=True)
|
| 418 |
+
|
| 419 |
+
# Serialize data
|
| 420 |
+
content = self._serializer(self._current_state)
|
| 421 |
+
|
| 422 |
+
# Atomic write: write to temp file, then move
|
| 423 |
+
tmp_fd = None
|
| 424 |
+
tmp_path = None
|
| 425 |
+
try:
|
| 426 |
+
tmp_fd, tmp_path = tempfile.mkstemp(
|
| 427 |
+
dir=self.path.parent, prefix=".tmp_", suffix=".json", text=True
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
with os.fdopen(tmp_fd, "w", encoding="utf-8") as f:
|
| 431 |
+
f.write(content)
|
| 432 |
+
tmp_fd = None # fdopen closes the fd
|
| 433 |
+
|
| 434 |
+
# Atomic move
|
| 435 |
+
shutil.move(tmp_path, self.path)
|
| 436 |
+
tmp_path = None
|
| 437 |
+
|
| 438 |
+
finally:
|
| 439 |
+
# Cleanup on failure
|
| 440 |
+
if tmp_fd is not None:
|
| 441 |
+
try:
|
| 442 |
+
os.close(tmp_fd)
|
| 443 |
+
except OSError:
|
| 444 |
+
pass
|
| 445 |
+
if tmp_path and os.path.exists(tmp_path):
|
| 446 |
+
try:
|
| 447 |
+
os.unlink(tmp_path)
|
| 448 |
+
except OSError:
|
| 449 |
+
pass
|
| 450 |
+
|
| 451 |
+
# Success - update health and unregister from shutdown flush
|
| 452 |
+
self._disk_healthy = True
|
| 453 |
+
self._last_success = time.time()
|
| 454 |
+
self._failure_count = 0
|
| 455 |
+
BufferedWriteRegistry.get_instance().unregister(self.path)
|
| 456 |
+
return True
|
| 457 |
+
|
| 458 |
+
except (OSError, PermissionError, IOError) as e:
|
| 459 |
+
self._disk_healthy = False
|
| 460 |
+
self._failure_count += 1
|
| 461 |
+
|
| 462 |
+
# Register with BufferedWriteRegistry for shutdown flush
|
| 463 |
+
registry = BufferedWriteRegistry.get_instance()
|
| 464 |
+
registry.register_pending(
|
| 465 |
+
self.path,
|
| 466 |
+
self._current_state,
|
| 467 |
+
self._serializer,
|
| 468 |
+
{}, # No special options for ResilientStateWriter
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
# Log warning (rate-limited to avoid flooding)
|
| 472 |
+
if self._failure_count == 1 or self._failure_count % 10 == 0:
|
| 473 |
+
self.logger.warning(
|
| 474 |
+
f"Failed to write {self.path.name}: {e}. "
|
| 475 |
+
f"Data retained in memory (failure #{self._failure_count})."
|
| 476 |
+
)
|
| 477 |
+
return False
|
| 478 |
+
|
| 479 |
+
@property
|
| 480 |
+
def is_healthy(self) -> bool:
|
| 481 |
+
"""Check if disk writes are currently working."""
|
| 482 |
+
return self._disk_healthy
|
| 483 |
+
|
| 484 |
+
@property
|
| 485 |
+
def current_state(self) -> Optional[Any]:
|
| 486 |
+
"""Get the current in-memory state (for inspection/debugging)."""
|
| 487 |
+
return self._current_state
|
| 488 |
+
|
| 489 |
+
def get_health_info(self) -> Dict[str, Any]:
|
| 490 |
+
"""
|
| 491 |
+
Get detailed health information for monitoring.
|
| 492 |
+
|
| 493 |
+
Returns dict with:
|
| 494 |
+
- healthy: bool
|
| 495 |
+
- failure_count: int
|
| 496 |
+
- last_success: Optional[float] (timestamp)
|
| 497 |
+
- last_attempt: float (timestamp)
|
| 498 |
+
- path: str
|
| 499 |
+
"""
|
| 500 |
+
return {
|
| 501 |
+
"healthy": self._disk_healthy,
|
| 502 |
+
"failure_count": self._failure_count,
|
| 503 |
+
"last_success": self._last_success,
|
| 504 |
+
"last_attempt": self._last_attempt,
|
| 505 |
+
"path": str(self.path),
|
| 506 |
+
}
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def safe_write_json(
|
| 510 |
+
path: Union[str, Path],
|
| 511 |
+
data: Dict[str, Any],
|
| 512 |
+
logger: logging.Logger,
|
| 513 |
+
atomic: bool = True,
|
| 514 |
+
indent: int = 2,
|
| 515 |
+
ensure_ascii: bool = True,
|
| 516 |
+
secure_permissions: bool = False,
|
| 517 |
+
buffer_on_failure: bool = False,
|
| 518 |
+
) -> bool:
|
| 519 |
+
"""
|
| 520 |
+
Write JSON data to file with error handling and optional buffering.
|
| 521 |
+
|
| 522 |
+
When buffer_on_failure is True, failed writes are registered with the
|
| 523 |
+
BufferedWriteRegistry for periodic retry and shutdown flush. This ensures
|
| 524 |
+
critical data (like auth tokens) is eventually saved.
|
| 525 |
+
|
| 526 |
+
Args:
|
| 527 |
+
path: File path to write to
|
| 528 |
+
data: JSON-serializable data
|
| 529 |
+
logger: Logger for warnings
|
| 530 |
+
atomic: Use atomic write pattern (tempfile + move)
|
| 531 |
+
indent: JSON indentation level (default: 2)
|
| 532 |
+
ensure_ascii: Escape non-ASCII characters (default: True)
|
| 533 |
+
secure_permissions: Set file permissions to 0o600 (default: False)
|
| 534 |
+
buffer_on_failure: Register with BufferedWriteRegistry on failure (default: False)
|
| 535 |
+
|
| 536 |
+
Returns:
|
| 537 |
+
True on success, False on failure (never raises)
|
| 538 |
+
"""
|
| 539 |
+
path = Path(path)
|
| 540 |
+
|
| 541 |
+
# Create serializer function that matches the requested formatting
|
| 542 |
+
def serializer(d: Any) -> str:
|
| 543 |
+
return json.dumps(d, indent=indent, ensure_ascii=ensure_ascii)
|
| 544 |
+
|
| 545 |
+
try:
|
| 546 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 547 |
+
content = serializer(data)
|
| 548 |
+
|
| 549 |
+
if atomic:
|
| 550 |
+
tmp_fd = None
|
| 551 |
+
tmp_path = None
|
| 552 |
+
try:
|
| 553 |
+
tmp_fd, tmp_path = tempfile.mkstemp(
|
| 554 |
+
dir=path.parent, prefix=".tmp_", suffix=".json", text=True
|
| 555 |
+
)
|
| 556 |
+
with os.fdopen(tmp_fd, "w", encoding="utf-8") as f:
|
| 557 |
+
f.write(content)
|
| 558 |
+
tmp_fd = None
|
| 559 |
+
|
| 560 |
+
# Set secure permissions if requested (before move for security)
|
| 561 |
+
if secure_permissions:
|
| 562 |
+
try:
|
| 563 |
+
os.chmod(tmp_path, 0o600)
|
| 564 |
+
except (OSError, AttributeError):
|
| 565 |
+
# Windows may not support chmod, ignore
|
| 566 |
+
pass
|
| 567 |
+
|
| 568 |
+
shutil.move(tmp_path, path)
|
| 569 |
+
tmp_path = None
|
| 570 |
+
finally:
|
| 571 |
+
if tmp_fd is not None:
|
| 572 |
+
try:
|
| 573 |
+
os.close(tmp_fd)
|
| 574 |
+
except OSError:
|
| 575 |
+
pass
|
| 576 |
+
if tmp_path and os.path.exists(tmp_path):
|
| 577 |
+
try:
|
| 578 |
+
os.unlink(tmp_path)
|
| 579 |
+
except OSError:
|
| 580 |
+
pass
|
| 581 |
+
else:
|
| 582 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 583 |
+
f.write(content)
|
| 584 |
+
|
| 585 |
+
# Set secure permissions if requested
|
| 586 |
+
if secure_permissions:
|
| 587 |
+
try:
|
| 588 |
+
os.chmod(path, 0o600)
|
| 589 |
+
except (OSError, AttributeError):
|
| 590 |
+
pass
|
| 591 |
+
|
| 592 |
+
# Success - remove from pending if it was there
|
| 593 |
+
if buffer_on_failure:
|
| 594 |
+
BufferedWriteRegistry.get_instance().unregister(path)
|
| 595 |
+
|
| 596 |
+
return True
|
| 597 |
+
|
| 598 |
+
except (OSError, PermissionError, IOError, TypeError, ValueError) as e:
|
| 599 |
+
logger.warning(f"Failed to write JSON to {path}: {e}")
|
| 600 |
+
|
| 601 |
+
# Register for retry if buffering is enabled
|
| 602 |
+
if buffer_on_failure:
|
| 603 |
+
registry = BufferedWriteRegistry.get_instance()
|
| 604 |
+
registry.register_pending(
|
| 605 |
+
path,
|
| 606 |
+
data,
|
| 607 |
+
serializer,
|
| 608 |
+
{"secure_permissions": secure_permissions},
|
| 609 |
+
)
|
| 610 |
+
logger.debug(f"Buffered {path.name} for retry on next interval or shutdown")
|
| 611 |
+
|
| 612 |
+
return False
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
def safe_log_write(
|
| 616 |
+
path: Union[str, Path],
|
| 617 |
+
content: str,
|
| 618 |
+
logger: logging.Logger,
|
| 619 |
+
mode: str = "a",
|
| 620 |
+
) -> bool:
|
| 621 |
+
"""
|
| 622 |
+
Write content to log file with error handling. No buffering or retry.
|
| 623 |
+
|
| 624 |
+
Suitable for log files where occasional loss is acceptable.
|
| 625 |
+
Creates parent directories if needed.
|
| 626 |
+
|
| 627 |
+
Args:
|
| 628 |
+
path: File path to write to
|
| 629 |
+
content: String content to write
|
| 630 |
+
logger: Logger for warnings
|
| 631 |
+
mode: File mode ('a' for append, 'w' for overwrite)
|
| 632 |
+
|
| 633 |
+
Returns:
|
| 634 |
+
True on success, False on failure (never raises)
|
| 635 |
+
"""
|
| 636 |
+
path = Path(path)
|
| 637 |
+
|
| 638 |
+
try:
|
| 639 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 640 |
+
with open(path, mode, encoding="utf-8") as f:
|
| 641 |
+
f.write(content)
|
| 642 |
+
return True
|
| 643 |
+
|
| 644 |
+
except (OSError, PermissionError, IOError) as e:
|
| 645 |
+
logger.warning(f"Failed to write log to {path}: {e}")
|
| 646 |
+
return False
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
def safe_mkdir(path: Union[str, Path], logger: logging.Logger) -> bool:
|
| 650 |
+
"""
|
| 651 |
+
Create directory with error handling.
|
| 652 |
+
|
| 653 |
+
Args:
|
| 654 |
+
path: Directory path to create
|
| 655 |
+
logger: Logger for warnings
|
| 656 |
+
|
| 657 |
+
Returns:
|
| 658 |
+
True on success (or already exists), False on failure
|
| 659 |
+
"""
|
| 660 |
+
try:
|
| 661 |
+
Path(path).mkdir(parents=True, exist_ok=True)
|
| 662 |
+
return True
|
| 663 |
+
except (OSError, PermissionError) as e:
|
| 664 |
+
logger.warning(f"Failed to create directory {path}: {e}")
|
| 665 |
+
return False
|