Spaces:
Running
Running
Upload 86 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- app/__init__.py +1 -0
- app/api/v1/admin.py +1299 -0
- app/api/v1/chat.py +251 -0
- app/api/v1/files.py +72 -0
- app/api/v1/image.py +1065 -0
- app/api/v1/models.py +51 -0
- app/api/v1/uploads.py +64 -0
- app/api/v1/video.py +3 -0
- app/core/auth.py +159 -0
- app/core/config.py +329 -0
- app/core/exceptions.py +221 -0
- app/core/legacy_migration.py +285 -0
- app/core/logger.py +117 -0
- app/core/response_middleware.py +71 -0
- app/core/storage.py +720 -0
- app/services/api_keys.py +432 -0
- app/services/base.py +2 -0
- app/services/grok/assets.py +875 -0
- app/services/grok/chat.py +571 -0
- app/services/grok/imagine_experimental.py +416 -0
- app/services/grok/imagine_generation.py +137 -0
- app/services/grok/media.py +512 -0
- app/services/grok/model.py +226 -0
- app/services/grok/processor.py +596 -0
- app/services/grok/retry.py +178 -0
- app/services/grok/statsig.py +46 -0
- app/services/grok/usage.py +162 -0
- app/services/quota.py +70 -0
- app/services/register/__init__.py +5 -0
- app/services/register/account_settings_refresh.py +267 -0
- app/services/register/manager.py +332 -0
- app/services/register/runner.py +415 -0
- app/services/register/services/__init__.py +15 -0
- app/services/register/services/birth_date_service.py +97 -0
- app/services/register/services/email_service.py +90 -0
- app/services/register/services/nsfw_service.py +118 -0
- app/services/register/services/turnstile_service.py +161 -0
- app/services/register/services/user_agreement_service.py +115 -0
- app/services/register/solver.py +296 -0
- app/services/request_logger.py +143 -0
- app/services/request_stats.py +205 -0
- app/services/token/__init__.py +36 -0
- app/services/token/manager.py +654 -0
- app/services/token/models.py +221 -0
- app/services/token/pool.py +112 -0
- app/services/token/scheduler.py +104 -0
- app/services/token/service.py +156 -0
- app/static/.assetsignore +2 -0
- app/static/_worker.js +4 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
app/template/favicon.png filter=lfs diff=lfs merge=lfs -text
|
app/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""App Package"""
|
app/api/v1/admin.py
ADDED
|
@@ -0,0 +1,1299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, HTTPException, Request, Query, Body, WebSocket
|
| 2 |
+
from fastapi.responses import HTMLResponse, RedirectResponse
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
from typing import Any, Optional
|
| 5 |
+
|
| 6 |
+
from app.core.auth import verify_api_key
|
| 7 |
+
from app.core.config import config, get_config
|
| 8 |
+
from app.core.storage import get_storage, LocalStorage, RedisStorage, SQLStorage
|
| 9 |
+
import os
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import aiofiles
|
| 12 |
+
import asyncio
|
| 13 |
+
import json
|
| 14 |
+
import time
|
| 15 |
+
import uuid
|
| 16 |
+
import orjson
|
| 17 |
+
from starlette.websockets import WebSocketDisconnect, WebSocketState
|
| 18 |
+
from app.core.logger import logger
|
| 19 |
+
from app.services.register import get_auto_register_manager
|
| 20 |
+
from app.services.register.account_settings_refresh import (
|
| 21 |
+
refresh_account_settings_for_tokens,
|
| 22 |
+
normalize_sso_token as normalize_refresh_token,
|
| 23 |
+
)
|
| 24 |
+
from app.services.api_keys import api_key_manager
|
| 25 |
+
from app.services.grok.model import ModelService
|
| 26 |
+
from app.services.grok.imagine_generation import (
|
| 27 |
+
collect_experimental_generation_images,
|
| 28 |
+
is_valid_image_value as is_valid_imagine_image_value,
|
| 29 |
+
resolve_aspect_ratio as resolve_imagine_aspect_ratio,
|
| 30 |
+
)
|
| 31 |
+
from app.services.token import get_token_manager
|
| 32 |
+
from app.core.auth import _load_legacy_api_keys
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
router = APIRouter()
|
| 36 |
+
|
| 37 |
+
TEMPLATE_DIR = Path(__file__).parent.parent.parent / "static"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class AdminLoginBody(BaseModel):
|
| 41 |
+
username: str | None = None
|
| 42 |
+
password: str | None = None
|
| 43 |
+
|
| 44 |
+
async def render_template(filename: str):
|
| 45 |
+
"""渲染指定模板"""
|
| 46 |
+
template_path = TEMPLATE_DIR / filename
|
| 47 |
+
if not template_path.exists():
|
| 48 |
+
return HTMLResponse(f"Template {filename} not found.", status_code=404)
|
| 49 |
+
|
| 50 |
+
async with aiofiles.open(template_path, "r", encoding="utf-8") as f:
|
| 51 |
+
content = await f.read()
|
| 52 |
+
return HTMLResponse(content)
|
| 53 |
+
|
| 54 |
+
@router.get("/", include_in_schema=False)
|
| 55 |
+
async def root_redirect():
|
| 56 |
+
"""Default entry -> /login (consistent with Workers/Pages)."""
|
| 57 |
+
return RedirectResponse(url="/login", status_code=302)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@router.get("/login", response_class=HTMLResponse, include_in_schema=False)
|
| 61 |
+
async def login_page():
|
| 62 |
+
"""Login page (default)."""
|
| 63 |
+
return await render_template("login/login.html")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@router.get("/admin", response_class=HTMLResponse, include_in_schema=False)
|
| 67 |
+
async def admin_login_page():
|
| 68 |
+
"""Legacy login entry (redirect to /login)."""
|
| 69 |
+
return RedirectResponse(url="/login", status_code=302)
|
| 70 |
+
|
| 71 |
+
@router.get("/admin/config", response_class=HTMLResponse, include_in_schema=False)
|
| 72 |
+
async def admin_config_page():
|
| 73 |
+
"""配置管理页"""
|
| 74 |
+
return await render_template("config/config.html")
|
| 75 |
+
|
| 76 |
+
@router.get("/admin/token", response_class=HTMLResponse, include_in_schema=False)
|
| 77 |
+
async def admin_token_page():
|
| 78 |
+
"""Token 管理页"""
|
| 79 |
+
return await render_template("token/token.html")
|
| 80 |
+
|
| 81 |
+
@router.get("/admin/datacenter", response_class=HTMLResponse, include_in_schema=False)
|
| 82 |
+
async def admin_datacenter_page():
|
| 83 |
+
"""数据中心页"""
|
| 84 |
+
return await render_template("datacenter/datacenter.html")
|
| 85 |
+
|
| 86 |
+
@router.get("/admin/keys", response_class=HTMLResponse, include_in_schema=False)
|
| 87 |
+
async def admin_keys_page():
|
| 88 |
+
"""API Key 管理页"""
|
| 89 |
+
return await render_template("keys/keys.html")
|
| 90 |
+
|
| 91 |
+
@router.get("/chat", response_class=HTMLResponse, include_in_schema=False)
|
| 92 |
+
async def chat_page():
|
| 93 |
+
"""在线聊天页(公开入口)"""
|
| 94 |
+
return await render_template("chat/chat.html")
|
| 95 |
+
|
| 96 |
+
@router.get("/admin/chat", response_class=HTMLResponse, include_in_schema=False)
|
| 97 |
+
async def admin_chat_page():
|
| 98 |
+
"""在线聊天页(后台入口)"""
|
| 99 |
+
return await render_template("chat/chat_admin.html")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
async def _verify_ws_api_key(websocket: WebSocket) -> bool:
|
| 103 |
+
api_key = str(get_config("app.api_key", "") or "").strip()
|
| 104 |
+
legacy_keys = await _load_legacy_api_keys()
|
| 105 |
+
if not api_key and not legacy_keys:
|
| 106 |
+
return True
|
| 107 |
+
token = str(websocket.query_params.get("api_key") or "").strip()
|
| 108 |
+
if not token:
|
| 109 |
+
return False
|
| 110 |
+
if (api_key and token == api_key) or token in legacy_keys:
|
| 111 |
+
return True
|
| 112 |
+
try:
|
| 113 |
+
await api_key_manager.init()
|
| 114 |
+
if api_key_manager.validate_key(token):
|
| 115 |
+
return True
|
| 116 |
+
except Exception as e:
|
| 117 |
+
logger.warning(f"Imagine ws api_key validation fallback failed: {e}")
|
| 118 |
+
return False
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
async def _collect_imagine_batch(token: str, prompt: str, aspect_ratio: str) -> list[str]:
|
| 122 |
+
return await collect_experimental_generation_images(
|
| 123 |
+
token=token,
|
| 124 |
+
prompt=prompt,
|
| 125 |
+
n=6,
|
| 126 |
+
response_format="b64_json",
|
| 127 |
+
aspect_ratio=aspect_ratio,
|
| 128 |
+
concurrency=1,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@router.websocket("/api/v1/admin/imagine/ws")
|
| 133 |
+
async def admin_imagine_ws(websocket: WebSocket):
|
| 134 |
+
if not await _verify_ws_api_key(websocket):
|
| 135 |
+
await websocket.close(code=1008)
|
| 136 |
+
return
|
| 137 |
+
|
| 138 |
+
await websocket.accept()
|
| 139 |
+
stop_event = asyncio.Event()
|
| 140 |
+
run_task: Optional[asyncio.Task] = None
|
| 141 |
+
|
| 142 |
+
async def _send(payload: dict) -> bool:
|
| 143 |
+
try:
|
| 144 |
+
await websocket.send_text(orjson.dumps(payload).decode())
|
| 145 |
+
return True
|
| 146 |
+
except Exception:
|
| 147 |
+
return False
|
| 148 |
+
|
| 149 |
+
async def _stop_run():
|
| 150 |
+
nonlocal run_task
|
| 151 |
+
stop_event.set()
|
| 152 |
+
if run_task and not run_task.done():
|
| 153 |
+
run_task.cancel()
|
| 154 |
+
try:
|
| 155 |
+
await run_task
|
| 156 |
+
except asyncio.CancelledError:
|
| 157 |
+
pass
|
| 158 |
+
except Exception:
|
| 159 |
+
pass
|
| 160 |
+
run_task = None
|
| 161 |
+
stop_event.clear()
|
| 162 |
+
|
| 163 |
+
async def _run(prompt: str, aspect_ratio: str):
|
| 164 |
+
model_id = "grok-imagine-1.0"
|
| 165 |
+
model_info = ModelService.get(model_id)
|
| 166 |
+
if not model_info or not model_info.is_image:
|
| 167 |
+
await _send(
|
| 168 |
+
{
|
| 169 |
+
"type": "error",
|
| 170 |
+
"message": "Image model is not available.",
|
| 171 |
+
"code": "model_not_supported",
|
| 172 |
+
}
|
| 173 |
+
)
|
| 174 |
+
return
|
| 175 |
+
|
| 176 |
+
token_mgr = await get_token_manager()
|
| 177 |
+
sequence = 0
|
| 178 |
+
run_id = uuid.uuid4().hex
|
| 179 |
+
await _send(
|
| 180 |
+
{
|
| 181 |
+
"type": "status",
|
| 182 |
+
"status": "running",
|
| 183 |
+
"prompt": prompt,
|
| 184 |
+
"aspect_ratio": aspect_ratio,
|
| 185 |
+
"run_id": run_id,
|
| 186 |
+
}
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
while not stop_event.is_set():
|
| 190 |
+
try:
|
| 191 |
+
await token_mgr.reload_if_stale()
|
| 192 |
+
token = token_mgr.get_token_for_model(model_info.model_id)
|
| 193 |
+
if not token:
|
| 194 |
+
await _send(
|
| 195 |
+
{
|
| 196 |
+
"type": "error",
|
| 197 |
+
"message": "No available tokens. Please try again later.",
|
| 198 |
+
"code": "rate_limit_exceeded",
|
| 199 |
+
}
|
| 200 |
+
)
|
| 201 |
+
await asyncio.sleep(2)
|
| 202 |
+
continue
|
| 203 |
+
|
| 204 |
+
start_at = time.time()
|
| 205 |
+
images = await _collect_imagine_batch(token, prompt, aspect_ratio)
|
| 206 |
+
elapsed_ms = int((time.time() - start_at) * 1000)
|
| 207 |
+
|
| 208 |
+
sent_any = False
|
| 209 |
+
for image_b64 in images:
|
| 210 |
+
if not is_valid_imagine_image_value(image_b64):
|
| 211 |
+
continue
|
| 212 |
+
sent_any = True
|
| 213 |
+
sequence += 1
|
| 214 |
+
ok = await _send(
|
| 215 |
+
{
|
| 216 |
+
"type": "image",
|
| 217 |
+
"b64_json": image_b64,
|
| 218 |
+
"sequence": sequence,
|
| 219 |
+
"created_at": int(time.time() * 1000),
|
| 220 |
+
"elapsed_ms": elapsed_ms,
|
| 221 |
+
"aspect_ratio": aspect_ratio,
|
| 222 |
+
"run_id": run_id,
|
| 223 |
+
}
|
| 224 |
+
)
|
| 225 |
+
if not ok:
|
| 226 |
+
stop_event.set()
|
| 227 |
+
break
|
| 228 |
+
|
| 229 |
+
if sent_any:
|
| 230 |
+
try:
|
| 231 |
+
await token_mgr.sync_usage(
|
| 232 |
+
token,
|
| 233 |
+
model_info.model_id,
|
| 234 |
+
consume_on_fail=True,
|
| 235 |
+
is_usage=True,
|
| 236 |
+
)
|
| 237 |
+
except Exception as e:
|
| 238 |
+
logger.warning(f"Imagine ws token sync failed: {e}")
|
| 239 |
+
else:
|
| 240 |
+
await _send(
|
| 241 |
+
{
|
| 242 |
+
"type": "error",
|
| 243 |
+
"message": "Image generation returned empty data.",
|
| 244 |
+
"code": "empty_image",
|
| 245 |
+
}
|
| 246 |
+
)
|
| 247 |
+
except asyncio.CancelledError:
|
| 248 |
+
break
|
| 249 |
+
except Exception as e:
|
| 250 |
+
logger.warning(f"Imagine stream error: {e}")
|
| 251 |
+
await _send(
|
| 252 |
+
{
|
| 253 |
+
"type": "error",
|
| 254 |
+
"message": str(e),
|
| 255 |
+
"code": "internal_error",
|
| 256 |
+
}
|
| 257 |
+
)
|
| 258 |
+
await asyncio.sleep(1.5)
|
| 259 |
+
|
| 260 |
+
await _send({"type": "status", "status": "stopped", "run_id": run_id})
|
| 261 |
+
|
| 262 |
+
try:
|
| 263 |
+
while True:
|
| 264 |
+
try:
|
| 265 |
+
raw = await websocket.receive_text()
|
| 266 |
+
except (RuntimeError, WebSocketDisconnect):
|
| 267 |
+
break
|
| 268 |
+
|
| 269 |
+
try:
|
| 270 |
+
payload = orjson.loads(raw)
|
| 271 |
+
except Exception:
|
| 272 |
+
await _send(
|
| 273 |
+
{
|
| 274 |
+
"type": "error",
|
| 275 |
+
"message": "Invalid message format.",
|
| 276 |
+
"code": "invalid_payload",
|
| 277 |
+
}
|
| 278 |
+
)
|
| 279 |
+
continue
|
| 280 |
+
|
| 281 |
+
msg_type = payload.get("type")
|
| 282 |
+
if msg_type == "start":
|
| 283 |
+
prompt = str(payload.get("prompt") or "").strip()
|
| 284 |
+
if not prompt:
|
| 285 |
+
await _send(
|
| 286 |
+
{
|
| 287 |
+
"type": "error",
|
| 288 |
+
"message": "Prompt cannot be empty.",
|
| 289 |
+
"code": "empty_prompt",
|
| 290 |
+
}
|
| 291 |
+
)
|
| 292 |
+
continue
|
| 293 |
+
ratio = resolve_imagine_aspect_ratio(str(payload.get("aspect_ratio") or "2:3").strip())
|
| 294 |
+
await _stop_run()
|
| 295 |
+
run_task = asyncio.create_task(_run(prompt, ratio))
|
| 296 |
+
elif msg_type == "stop":
|
| 297 |
+
await _stop_run()
|
| 298 |
+
elif msg_type == "ping":
|
| 299 |
+
await _send({"type": "pong"})
|
| 300 |
+
else:
|
| 301 |
+
await _send(
|
| 302 |
+
{
|
| 303 |
+
"type": "error",
|
| 304 |
+
"message": "Unknown command.",
|
| 305 |
+
"code": "unknown_command",
|
| 306 |
+
}
|
| 307 |
+
)
|
| 308 |
+
except WebSocketDisconnect:
|
| 309 |
+
logger.debug("WebSocket disconnected by client")
|
| 310 |
+
except asyncio.CancelledError:
|
| 311 |
+
logger.debug("WebSocket handler cancelled")
|
| 312 |
+
except Exception as e:
|
| 313 |
+
logger.warning(f"WebSocket error: {e}")
|
| 314 |
+
finally:
|
| 315 |
+
await _stop_run()
|
| 316 |
+
try:
|
| 317 |
+
if websocket.client_state == WebSocketState.CONNECTED:
|
| 318 |
+
await websocket.close(code=1000, reason="Server closing connection")
|
| 319 |
+
except Exception as e:
|
| 320 |
+
logger.debug(f"WebSocket close ignored: {e}")
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
@router.post("/api/v1/admin/login")
|
| 324 |
+
async def admin_login_api(request: Request, body: AdminLoginBody | None = Body(default=None)):
|
| 325 |
+
"""管理后台登录验证(用户名+密码)
|
| 326 |
+
|
| 327 |
+
- 默认账号/密码:admin/admin(可在配置管理的「应用设置」里修改)
|
| 328 |
+
- 兼容旧版本:允许 Authorization: Bearer <password> 仅密码登录(用户名默认为 admin)
|
| 329 |
+
"""
|
| 330 |
+
|
| 331 |
+
admin_username = str(get_config("app.admin_username", "admin") or "admin").strip() or "admin"
|
| 332 |
+
admin_password = str(get_config("app.app_key", "admin") or "admin").strip()
|
| 333 |
+
|
| 334 |
+
username = (body.username.strip() if body and isinstance(body.username, str) else "").strip()
|
| 335 |
+
password = (body.password.strip() if body and isinstance(body.password, str) else "").strip()
|
| 336 |
+
|
| 337 |
+
# Legacy: password-only via Bearer token.
|
| 338 |
+
if not password:
|
| 339 |
+
auth = request.headers.get("Authorization") or ""
|
| 340 |
+
if auth.lower().startswith("bearer "):
|
| 341 |
+
password = auth[7:].strip()
|
| 342 |
+
if not username:
|
| 343 |
+
username = "admin"
|
| 344 |
+
|
| 345 |
+
if not username or not password:
|
| 346 |
+
raise HTTPException(status_code=400, detail="Missing username or password")
|
| 347 |
+
|
| 348 |
+
if username != admin_username or password != admin_password:
|
| 349 |
+
raise HTTPException(status_code=401, detail="Invalid username or password")
|
| 350 |
+
|
| 351 |
+
return {"status": "success", "api_key": get_config("app.api_key", "")}
|
| 352 |
+
|
| 353 |
+
@router.get("/api/v1/admin/config", dependencies=[Depends(verify_api_key)])
|
| 354 |
+
async def get_config_api():
|
| 355 |
+
"""获取当前配置"""
|
| 356 |
+
# 暴露原始配置字典
|
| 357 |
+
return config._config
|
| 358 |
+
|
| 359 |
+
@router.post("/api/v1/admin/config", dependencies=[Depends(verify_api_key)])
|
| 360 |
+
async def update_config_api(data: dict):
|
| 361 |
+
"""更新配置"""
|
| 362 |
+
try:
|
| 363 |
+
await config.update(data)
|
| 364 |
+
return {"status": "success", "message": "配置已更新"}
|
| 365 |
+
except Exception as e:
|
| 366 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def _display_key(key: str) -> str:
|
| 370 |
+
k = str(key or "")
|
| 371 |
+
if len(k) <= 12:
|
| 372 |
+
return k
|
| 373 |
+
return f"{k[:6]}...{k[-4:]}"
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def _normalize_limit(v: Any) -> int:
|
| 377 |
+
if v is None or v == "":
|
| 378 |
+
return -1
|
| 379 |
+
try:
|
| 380 |
+
return max(-1, int(v))
|
| 381 |
+
except Exception:
|
| 382 |
+
return -1
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def _pool_to_token_type(pool_name: str) -> str:
|
| 386 |
+
return "ssoSuper" if str(pool_name or "").strip() == "ssoSuper" else "sso"
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def _parse_quota_value(v: Any) -> tuple[int, bool]:
|
| 390 |
+
if v is None or v == "":
|
| 391 |
+
return -1, False
|
| 392 |
+
try:
|
| 393 |
+
n = int(v)
|
| 394 |
+
except Exception:
|
| 395 |
+
return -1, False
|
| 396 |
+
if n < 0:
|
| 397 |
+
return -1, False
|
| 398 |
+
return n, True
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def _safe_int(v: Any, default: int = 0) -> int:
|
| 402 |
+
try:
|
| 403 |
+
return int(v)
|
| 404 |
+
except Exception:
|
| 405 |
+
return default
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def _normalize_token_status(raw_status: Any) -> str:
|
| 409 |
+
s = str(raw_status or "active").strip().lower()
|
| 410 |
+
if s == "expired":
|
| 411 |
+
return "invalid"
|
| 412 |
+
if s in ("active", "cooling", "invalid", "disabled"):
|
| 413 |
+
return s
|
| 414 |
+
return "active"
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def _normalize_admin_token_item(pool_name: str, item: Any) -> dict | None:
|
| 418 |
+
token_type = _pool_to_token_type(pool_name)
|
| 419 |
+
|
| 420 |
+
if isinstance(item, str):
|
| 421 |
+
token = item.strip()
|
| 422 |
+
if not token:
|
| 423 |
+
return None
|
| 424 |
+
if token.startswith("sso="):
|
| 425 |
+
token = token[4:]
|
| 426 |
+
return {
|
| 427 |
+
"token": token,
|
| 428 |
+
"status": "active",
|
| 429 |
+
"quota": 0,
|
| 430 |
+
"quota_known": False,
|
| 431 |
+
"heavy_quota": -1,
|
| 432 |
+
"heavy_quota_known": False,
|
| 433 |
+
"token_type": token_type,
|
| 434 |
+
"note": "",
|
| 435 |
+
"fail_count": 0,
|
| 436 |
+
"use_count": 0,
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
if not isinstance(item, dict):
|
| 440 |
+
return None
|
| 441 |
+
|
| 442 |
+
token = str(item.get("token") or "").strip()
|
| 443 |
+
if not token:
|
| 444 |
+
return None
|
| 445 |
+
if token.startswith("sso="):
|
| 446 |
+
token = token[4:]
|
| 447 |
+
|
| 448 |
+
quota, quota_known = _parse_quota_value(item.get("quota"))
|
| 449 |
+
heavy_quota, heavy_quota_known = _parse_quota_value(item.get("heavy_quota"))
|
| 450 |
+
|
| 451 |
+
return {
|
| 452 |
+
"token": token,
|
| 453 |
+
"status": _normalize_token_status(item.get("status")),
|
| 454 |
+
"quota": quota if quota_known else 0,
|
| 455 |
+
"quota_known": quota_known,
|
| 456 |
+
"heavy_quota": heavy_quota,
|
| 457 |
+
"heavy_quota_known": heavy_quota_known,
|
| 458 |
+
"token_type": token_type,
|
| 459 |
+
"note": str(item.get("note") or ""),
|
| 460 |
+
"fail_count": _safe_int(item.get("fail_count") or 0, 0),
|
| 461 |
+
"use_count": _safe_int(item.get("use_count") or 0, 0),
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
def _collect_tokens_from_pool_payload(payload: Any) -> list[str]:
|
| 466 |
+
if not isinstance(payload, dict):
|
| 467 |
+
return []
|
| 468 |
+
|
| 469 |
+
collected: list[str] = []
|
| 470 |
+
seen: set[str] = set()
|
| 471 |
+
for raw_items in payload.values():
|
| 472 |
+
if not isinstance(raw_items, list):
|
| 473 |
+
continue
|
| 474 |
+
for item in raw_items:
|
| 475 |
+
token_raw = item if isinstance(item, str) else (item.get("token") if isinstance(item, dict) else "")
|
| 476 |
+
token = normalize_refresh_token(str(token_raw or "").strip())
|
| 477 |
+
if not token or token in seen:
|
| 478 |
+
continue
|
| 479 |
+
seen.add(token)
|
| 480 |
+
collected.append(token)
|
| 481 |
+
return collected
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def _resolve_nsfw_refresh_concurrency(override: Any = None) -> int:
|
| 485 |
+
source = override if override is not None else get_config("token.nsfw_refresh_concurrency", 10)
|
| 486 |
+
try:
|
| 487 |
+
value = int(source)
|
| 488 |
+
except Exception:
|
| 489 |
+
value = 10
|
| 490 |
+
return max(1, value)
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
def _resolve_nsfw_refresh_retries(override: Any = None) -> int:
|
| 494 |
+
source = override if override is not None else get_config("token.nsfw_refresh_retries", 3)
|
| 495 |
+
try:
|
| 496 |
+
value = int(source)
|
| 497 |
+
except Exception:
|
| 498 |
+
value = 3
|
| 499 |
+
return max(0, value)
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def _trigger_account_settings_refresh_background(
|
| 503 |
+
tokens: list[str],
|
| 504 |
+
concurrency: int,
|
| 505 |
+
retries: int,
|
| 506 |
+
) -> None:
|
| 507 |
+
if not tokens:
|
| 508 |
+
return
|
| 509 |
+
|
| 510 |
+
async def _run() -> None:
|
| 511 |
+
try:
|
| 512 |
+
result = await refresh_account_settings_for_tokens(
|
| 513 |
+
tokens=tokens,
|
| 514 |
+
concurrency=concurrency,
|
| 515 |
+
retries=retries,
|
| 516 |
+
)
|
| 517 |
+
summary = result.get("summary") or {}
|
| 518 |
+
logger.info(
|
| 519 |
+
"Background account-settings refresh finished: total={} success={} failed={} invalidated={}",
|
| 520 |
+
summary.get("total", 0),
|
| 521 |
+
summary.get("success", 0),
|
| 522 |
+
summary.get("failed", 0),
|
| 523 |
+
summary.get("invalidated", 0),
|
| 524 |
+
)
|
| 525 |
+
except Exception as exc:
|
| 526 |
+
logger.warning("Background account-settings refresh failed: {}", exc)
|
| 527 |
+
|
| 528 |
+
asyncio.create_task(_run())
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
@router.get("/api/v1/admin/keys", dependencies=[Depends(verify_api_key)])
|
| 532 |
+
async def list_api_keys():
|
| 533 |
+
"""List API keys + daily usage/remaining (for admin UI)."""
|
| 534 |
+
await api_key_manager.init()
|
| 535 |
+
day, usage_map = await api_key_manager.usage_today()
|
| 536 |
+
|
| 537 |
+
out = []
|
| 538 |
+
for row in api_key_manager.get_all_keys():
|
| 539 |
+
key = str(row.get("key") or "")
|
| 540 |
+
used = usage_map.get(key) or {}
|
| 541 |
+
chat_used = int(used.get("chat_used", 0) or 0)
|
| 542 |
+
heavy_used = int(used.get("heavy_used", 0) or 0)
|
| 543 |
+
image_used = int(used.get("image_used", 0) or 0)
|
| 544 |
+
video_used = int(used.get("video_used", 0) or 0)
|
| 545 |
+
|
| 546 |
+
chat_limit = _normalize_limit(row.get("chat_limit", -1))
|
| 547 |
+
heavy_limit = _normalize_limit(row.get("heavy_limit", -1))
|
| 548 |
+
image_limit = _normalize_limit(row.get("image_limit", -1))
|
| 549 |
+
video_limit = _normalize_limit(row.get("video_limit", -1))
|
| 550 |
+
|
| 551 |
+
remaining = {
|
| 552 |
+
"chat": None if chat_limit < 0 else max(0, chat_limit - chat_used),
|
| 553 |
+
"heavy": None if heavy_limit < 0 else max(0, heavy_limit - heavy_used),
|
| 554 |
+
"image": None if image_limit < 0 else max(0, image_limit - image_used),
|
| 555 |
+
"video": None if video_limit < 0 else max(0, video_limit - video_used),
|
| 556 |
+
}
|
| 557 |
+
|
| 558 |
+
out.append({
|
| 559 |
+
**row,
|
| 560 |
+
"is_active": bool(row.get("is_active", True)),
|
| 561 |
+
"display_key": _display_key(key),
|
| 562 |
+
"usage_today": {
|
| 563 |
+
"chat_used": chat_used,
|
| 564 |
+
"heavy_used": heavy_used,
|
| 565 |
+
"image_used": image_used,
|
| 566 |
+
"video_used": video_used,
|
| 567 |
+
},
|
| 568 |
+
"remaining_today": remaining,
|
| 569 |
+
"day": day,
|
| 570 |
+
})
|
| 571 |
+
|
| 572 |
+
# New UI expects { success: true, data: [...] }
|
| 573 |
+
return {"success": True, "data": out}
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
@router.post("/api/v1/admin/keys", dependencies=[Depends(verify_api_key)])
|
| 577 |
+
async def create_api_key(data: dict):
|
| 578 |
+
"""Create a new API key (optional name/key/limits)."""
|
| 579 |
+
await api_key_manager.init()
|
| 580 |
+
data = data or {}
|
| 581 |
+
|
| 582 |
+
name = str(data.get("name") or "").strip() or api_key_manager.generate_name()
|
| 583 |
+
key_val = str(data.get("key") or "").strip() or None
|
| 584 |
+
is_active = bool(data.get("is_active", True))
|
| 585 |
+
|
| 586 |
+
limits = data.get("limits") if isinstance(data.get("limits"), dict) else {}
|
| 587 |
+
try:
|
| 588 |
+
row = await api_key_manager.add_key(
|
| 589 |
+
name=name,
|
| 590 |
+
key=key_val,
|
| 591 |
+
is_active=is_active,
|
| 592 |
+
limits={
|
| 593 |
+
"chat_per_day": limits.get("chat_per_day"),
|
| 594 |
+
"heavy_per_day": limits.get("heavy_per_day"),
|
| 595 |
+
"image_per_day": limits.get("image_per_day"),
|
| 596 |
+
"video_per_day": limits.get("video_per_day"),
|
| 597 |
+
},
|
| 598 |
+
)
|
| 599 |
+
except ValueError as e:
|
| 600 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 601 |
+
|
| 602 |
+
return {"success": True, "data": {**row, "display_key": _display_key(row.get("key", ""))}}
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
@router.post("/api/v1/admin/keys/update", dependencies=[Depends(verify_api_key)])
|
| 606 |
+
async def update_api_key(data: dict):
|
| 607 |
+
"""Update name/status/limits for an API key."""
|
| 608 |
+
await api_key_manager.init()
|
| 609 |
+
data = data or {}
|
| 610 |
+
key = str(data.get("key") or "").strip()
|
| 611 |
+
if not key:
|
| 612 |
+
raise HTTPException(status_code=400, detail="Missing key")
|
| 613 |
+
|
| 614 |
+
existing = api_key_manager.get_key_row(key)
|
| 615 |
+
if not existing:
|
| 616 |
+
raise HTTPException(status_code=404, detail="Key not found")
|
| 617 |
+
|
| 618 |
+
if "name" in data and data.get("name") is not None:
|
| 619 |
+
name = str(data.get("name") or "").strip()
|
| 620 |
+
if name:
|
| 621 |
+
await api_key_manager.update_key_name(key, name)
|
| 622 |
+
|
| 623 |
+
if "is_active" in data:
|
| 624 |
+
await api_key_manager.update_key_status(key, bool(data.get("is_active")))
|
| 625 |
+
|
| 626 |
+
limits = data.get("limits") if isinstance(data.get("limits"), dict) else None
|
| 627 |
+
if limits is not None:
|
| 628 |
+
await api_key_manager.update_key_limits(
|
| 629 |
+
key,
|
| 630 |
+
{
|
| 631 |
+
"chat_per_day": limits.get("chat_per_day"),
|
| 632 |
+
"heavy_per_day": limits.get("heavy_per_day"),
|
| 633 |
+
"image_per_day": limits.get("image_per_day"),
|
| 634 |
+
"video_per_day": limits.get("video_per_day"),
|
| 635 |
+
},
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
return {"success": True}
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
@router.post("/api/v1/admin/keys/delete", dependencies=[Depends(verify_api_key)])
|
| 642 |
+
async def delete_api_key(data: dict):
|
| 643 |
+
"""Delete an API key."""
|
| 644 |
+
await api_key_manager.init()
|
| 645 |
+
data = data or {}
|
| 646 |
+
key = str(data.get("key") or "").strip()
|
| 647 |
+
if not key:
|
| 648 |
+
raise HTTPException(status_code=400, detail="Missing key")
|
| 649 |
+
|
| 650 |
+
ok = await api_key_manager.delete_key(key)
|
| 651 |
+
if not ok:
|
| 652 |
+
raise HTTPException(status_code=404, detail="Key not found")
|
| 653 |
+
return {"success": True}
|
| 654 |
+
|
| 655 |
+
@router.get("/api/v1/admin/storage", dependencies=[Depends(verify_api_key)])
|
| 656 |
+
async def get_storage_info():
|
| 657 |
+
"""获取当前存储模式"""
|
| 658 |
+
storage_type = os.getenv("SERVER_STORAGE_TYPE", "local").lower()
|
| 659 |
+
logger.info(f"Storage type: {storage_type}")
|
| 660 |
+
if not storage_type:
|
| 661 |
+
storage_type = str(get_config("storage.type", "")).lower()
|
| 662 |
+
if not storage_type:
|
| 663 |
+
storage = get_storage()
|
| 664 |
+
if isinstance(storage, LocalStorage):
|
| 665 |
+
storage_type = "local"
|
| 666 |
+
elif isinstance(storage, RedisStorage):
|
| 667 |
+
storage_type = "redis"
|
| 668 |
+
elif isinstance(storage, SQLStorage):
|
| 669 |
+
if storage.dialect in ("mysql", "mariadb"):
|
| 670 |
+
storage_type = "mysql"
|
| 671 |
+
elif storage.dialect in ("postgres", "postgresql", "pgsql"):
|
| 672 |
+
storage_type = "pgsql"
|
| 673 |
+
else:
|
| 674 |
+
storage_type = storage.dialect
|
| 675 |
+
return {"type": storage_type or "local"}
|
| 676 |
+
|
| 677 |
+
@router.get("/api/v1/admin/tokens", dependencies=[Depends(verify_api_key)])
|
| 678 |
+
async def get_tokens_api():
|
| 679 |
+
"""获取所有 Token"""
|
| 680 |
+
storage = get_storage()
|
| 681 |
+
tokens = await storage.load_tokens()
|
| 682 |
+
data = tokens if isinstance(tokens, dict) else {}
|
| 683 |
+
out: dict[str, list[dict]] = {}
|
| 684 |
+
for pool_name, raw_items in data.items():
|
| 685 |
+
arr = raw_items if isinstance(raw_items, list) else []
|
| 686 |
+
normalized: list[dict] = []
|
| 687 |
+
for item in arr:
|
| 688 |
+
obj = _normalize_admin_token_item(pool_name, item)
|
| 689 |
+
if obj:
|
| 690 |
+
normalized.append(obj)
|
| 691 |
+
out[str(pool_name)] = normalized
|
| 692 |
+
return out
|
| 693 |
+
|
| 694 |
+
@router.post("/api/v1/admin/tokens", dependencies=[Depends(verify_api_key)])
|
| 695 |
+
async def update_tokens_api(data: dict):
|
| 696 |
+
"""Update token payload and trigger background account-settings refresh for new tokens."""
|
| 697 |
+
storage = get_storage()
|
| 698 |
+
try:
|
| 699 |
+
from app.services.token.manager import get_token_manager
|
| 700 |
+
|
| 701 |
+
posted_data = data if isinstance(data, dict) else {}
|
| 702 |
+
existing_tokens: list[str] = []
|
| 703 |
+
added_tokens: list[str] = []
|
| 704 |
+
|
| 705 |
+
async with storage.acquire_lock("tokens_save", timeout=10):
|
| 706 |
+
old_data = await storage.load_tokens()
|
| 707 |
+
existing_tokens = _collect_tokens_from_pool_payload(
|
| 708 |
+
old_data if isinstance(old_data, dict) else {}
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
await storage.save_tokens(posted_data)
|
| 712 |
+
mgr = await get_token_manager()
|
| 713 |
+
await mgr.reload()
|
| 714 |
+
|
| 715 |
+
new_tokens = _collect_tokens_from_pool_payload(posted_data)
|
| 716 |
+
existing_set = set(existing_tokens)
|
| 717 |
+
added_tokens = [token for token in new_tokens if token not in existing_set]
|
| 718 |
+
|
| 719 |
+
concurrency = _resolve_nsfw_refresh_concurrency()
|
| 720 |
+
retries = _resolve_nsfw_refresh_retries()
|
| 721 |
+
_trigger_account_settings_refresh_background(
|
| 722 |
+
tokens=added_tokens,
|
| 723 |
+
concurrency=concurrency,
|
| 724 |
+
retries=retries,
|
| 725 |
+
)
|
| 726 |
+
|
| 727 |
+
return {
|
| 728 |
+
"status": "success",
|
| 729 |
+
"message": "Token updated",
|
| 730 |
+
"nsfw_refresh": {
|
| 731 |
+
"mode": "background",
|
| 732 |
+
"triggered": len(added_tokens),
|
| 733 |
+
"concurrency": concurrency,
|
| 734 |
+
"retries": retries,
|
| 735 |
+
},
|
| 736 |
+
}
|
| 737 |
+
except Exception as e:
|
| 738 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 739 |
+
|
| 740 |
+
@router.post("/api/v1/admin/tokens/refresh", dependencies=[Depends(verify_api_key)])
|
| 741 |
+
async def refresh_tokens_api(data: dict):
|
| 742 |
+
"""刷新 Token 状态"""
|
| 743 |
+
from app.services.token.manager import get_token_manager
|
| 744 |
+
|
| 745 |
+
try:
|
| 746 |
+
mgr = await get_token_manager()
|
| 747 |
+
tokens = []
|
| 748 |
+
if "token" in data:
|
| 749 |
+
tokens.append(data["token"])
|
| 750 |
+
if "tokens" in data and isinstance(data["tokens"], list):
|
| 751 |
+
tokens.extend(data["tokens"])
|
| 752 |
+
|
| 753 |
+
if not tokens:
|
| 754 |
+
raise HTTPException(status_code=400, detail="No tokens provided")
|
| 755 |
+
|
| 756 |
+
unique_tokens = list(set(tokens))
|
| 757 |
+
|
| 758 |
+
sem = asyncio.Semaphore(10)
|
| 759 |
+
|
| 760 |
+
async def _refresh_one(t):
|
| 761 |
+
async with sem:
|
| 762 |
+
return t, await mgr.sync_usage(t, "grok-3", consume_on_fail=False, is_usage=False)
|
| 763 |
+
|
| 764 |
+
results_list = await asyncio.gather(*[_refresh_one(t) for t in unique_tokens])
|
| 765 |
+
results = dict(results_list)
|
| 766 |
+
|
| 767 |
+
return {"status": "success", "results": results}
|
| 768 |
+
except Exception as e:
|
| 769 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
@router.post("/api/v1/admin/tokens/nsfw/refresh", dependencies=[Depends(verify_api_key)])
|
| 773 |
+
async def refresh_tokens_nsfw_api(data: dict):
|
| 774 |
+
"""Refresh account settings (TOS + birth date + NSFW) for selected/all tokens."""
|
| 775 |
+
payload = data if isinstance(data, dict) else {}
|
| 776 |
+
mgr = await get_token_manager()
|
| 777 |
+
|
| 778 |
+
tokens: list[str] = []
|
| 779 |
+
seen: set[str] = set()
|
| 780 |
+
|
| 781 |
+
if bool(payload.get("all")):
|
| 782 |
+
for pool in mgr.pools.values():
|
| 783 |
+
for info in pool.list():
|
| 784 |
+
token = normalize_refresh_token(str(info.token or "").strip())
|
| 785 |
+
if not token or token in seen:
|
| 786 |
+
continue
|
| 787 |
+
seen.add(token)
|
| 788 |
+
tokens.append(token)
|
| 789 |
+
else:
|
| 790 |
+
candidates: list[str] = []
|
| 791 |
+
single = payload.get("token")
|
| 792 |
+
if isinstance(single, str):
|
| 793 |
+
candidates.append(single)
|
| 794 |
+
batch = payload.get("tokens")
|
| 795 |
+
if isinstance(batch, list):
|
| 796 |
+
candidates.extend([item for item in batch if isinstance(item, str)])
|
| 797 |
+
|
| 798 |
+
for raw in candidates:
|
| 799 |
+
token = normalize_refresh_token(str(raw or "").strip())
|
| 800 |
+
if not token or token in seen:
|
| 801 |
+
continue
|
| 802 |
+
seen.add(token)
|
| 803 |
+
tokens.append(token)
|
| 804 |
+
|
| 805 |
+
if not tokens:
|
| 806 |
+
raise HTTPException(status_code=400, detail="No tokens provided")
|
| 807 |
+
|
| 808 |
+
concurrency = _resolve_nsfw_refresh_concurrency(payload.get("concurrency"))
|
| 809 |
+
retries = _resolve_nsfw_refresh_retries(payload.get("retries"))
|
| 810 |
+
result = await refresh_account_settings_for_tokens(
|
| 811 |
+
tokens=tokens,
|
| 812 |
+
concurrency=concurrency,
|
| 813 |
+
retries=retries,
|
| 814 |
+
)
|
| 815 |
+
return {
|
| 816 |
+
"status": "success",
|
| 817 |
+
"summary": result.get("summary") or {},
|
| 818 |
+
"failed": result.get("failed") or [],
|
| 819 |
+
}
|
| 820 |
+
|
| 821 |
+
|
| 822 |
+
@router.post("/api/v1/admin/tokens/auto-register", dependencies=[Depends(verify_api_key)])
|
| 823 |
+
async def auto_register_tokens_api(data: dict):
|
| 824 |
+
"""Start auto registration."""
|
| 825 |
+
try:
|
| 826 |
+
data = data or {}
|
| 827 |
+
count = data.get("count")
|
| 828 |
+
concurrency = data.get("concurrency")
|
| 829 |
+
pool = (data.get("pool") or "ssoBasic").strip() or "ssoBasic"
|
| 830 |
+
|
| 831 |
+
try:
|
| 832 |
+
count_val = int(count)
|
| 833 |
+
except Exception:
|
| 834 |
+
count_val = int(get_config("register.default_count", 100) or 100)
|
| 835 |
+
|
| 836 |
+
if count_val <= 0:
|
| 837 |
+
count_val = int(get_config("register.default_count", 100) or 100)
|
| 838 |
+
|
| 839 |
+
try:
|
| 840 |
+
concurrency_val = int(concurrency)
|
| 841 |
+
except Exception:
|
| 842 |
+
concurrency_val = None
|
| 843 |
+
if concurrency_val is not None and concurrency_val <= 0:
|
| 844 |
+
concurrency_val = None
|
| 845 |
+
|
| 846 |
+
manager = get_auto_register_manager()
|
| 847 |
+
job = await manager.start_job(count=count_val, pool=pool, concurrency=concurrency_val)
|
| 848 |
+
return {"status": "started", "job": job.to_dict()}
|
| 849 |
+
except RuntimeError as e:
|
| 850 |
+
raise HTTPException(status_code=409, detail=str(e))
|
| 851 |
+
except Exception as e:
|
| 852 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
@router.get("/api/v1/admin/tokens/auto-register/status", dependencies=[Depends(verify_api_key)])
|
| 856 |
+
async def auto_register_status_api(job_id: str | None = None):
|
| 857 |
+
"""Get auto registration status."""
|
| 858 |
+
manager = get_auto_register_manager()
|
| 859 |
+
status = manager.get_status(job_id)
|
| 860 |
+
if status.get("status") == "not_found":
|
| 861 |
+
raise HTTPException(status_code=404, detail="Job not found")
|
| 862 |
+
return status
|
| 863 |
+
|
| 864 |
+
|
| 865 |
+
@router.post("/api/v1/admin/tokens/auto-register/stop", dependencies=[Depends(verify_api_key)])
|
| 866 |
+
async def auto_register_stop_api(job_id: str | None = None):
|
| 867 |
+
"""Stop auto registration (best-effort)."""
|
| 868 |
+
manager = get_auto_register_manager()
|
| 869 |
+
status = manager.get_status(job_id)
|
| 870 |
+
if status.get("status") == "not_found":
|
| 871 |
+
raise HTTPException(status_code=404, detail="Job not found")
|
| 872 |
+
await manager.stop_job()
|
| 873 |
+
return {"status": "stopping"}
|
| 874 |
+
|
| 875 |
+
@router.get("/admin/cache", response_class=HTMLResponse, include_in_schema=False)
|
| 876 |
+
async def admin_cache_page():
|
| 877 |
+
"""缓存管理页"""
|
| 878 |
+
return await render_template("cache/cache.html")
|
| 879 |
+
|
| 880 |
+
@router.get("/api/v1/admin/cache", dependencies=[Depends(verify_api_key)])
|
| 881 |
+
async def get_cache_stats_api(request: Request):
|
| 882 |
+
"""获取缓存统计"""
|
| 883 |
+
from app.services.grok.assets import DownloadService, ListService
|
| 884 |
+
from app.services.token.manager import get_token_manager
|
| 885 |
+
|
| 886 |
+
try:
|
| 887 |
+
dl_service = DownloadService()
|
| 888 |
+
image_stats = dl_service.get_stats("image")
|
| 889 |
+
video_stats = dl_service.get_stats("video")
|
| 890 |
+
|
| 891 |
+
mgr = await get_token_manager()
|
| 892 |
+
pools = mgr.pools
|
| 893 |
+
accounts = []
|
| 894 |
+
for pool_name, pool in pools.items():
|
| 895 |
+
for info in pool.list():
|
| 896 |
+
raw_token = info.token[4:] if info.token.startswith("sso=") else info.token
|
| 897 |
+
masked = f"{raw_token[:8]}...{raw_token[-16:]}" if len(raw_token) > 24 else raw_token
|
| 898 |
+
accounts.append({
|
| 899 |
+
"token": raw_token,
|
| 900 |
+
"token_masked": masked,
|
| 901 |
+
"pool": pool_name,
|
| 902 |
+
"status": info.status,
|
| 903 |
+
"last_asset_clear_at": info.last_asset_clear_at
|
| 904 |
+
})
|
| 905 |
+
|
| 906 |
+
scope = request.query_params.get("scope")
|
| 907 |
+
selected_token = request.query_params.get("token")
|
| 908 |
+
tokens_param = request.query_params.get("tokens")
|
| 909 |
+
selected_tokens = []
|
| 910 |
+
if tokens_param:
|
| 911 |
+
selected_tokens = [t.strip() for t in tokens_param.split(",") if t.strip()]
|
| 912 |
+
|
| 913 |
+
online_stats = {"count": 0, "status": "unknown", "token": None, "last_asset_clear_at": None}
|
| 914 |
+
online_details = []
|
| 915 |
+
account_map = {a["token"]: a for a in accounts}
|
| 916 |
+
batch_size = get_config("performance.admin_assets_batch_size", 10)
|
| 917 |
+
try:
|
| 918 |
+
batch_size = int(batch_size)
|
| 919 |
+
except Exception:
|
| 920 |
+
batch_size = 10
|
| 921 |
+
batch_size = max(1, batch_size)
|
| 922 |
+
|
| 923 |
+
async def _fetch_assets(token: str):
|
| 924 |
+
list_service = ListService()
|
| 925 |
+
try:
|
| 926 |
+
return await list_service.count(token)
|
| 927 |
+
finally:
|
| 928 |
+
await list_service.close()
|
| 929 |
+
|
| 930 |
+
async def _fetch_detail(token: str):
|
| 931 |
+
account = account_map.get(token)
|
| 932 |
+
try:
|
| 933 |
+
count = await _fetch_assets(token)
|
| 934 |
+
return ({
|
| 935 |
+
"token": token,
|
| 936 |
+
"token_masked": account["token_masked"] if account else token,
|
| 937 |
+
"count": count,
|
| 938 |
+
"status": "ok",
|
| 939 |
+
"last_asset_clear_at": account["last_asset_clear_at"] if account else None
|
| 940 |
+
}, count)
|
| 941 |
+
except Exception as e:
|
| 942 |
+
return ({
|
| 943 |
+
"token": token,
|
| 944 |
+
"token_masked": account["token_masked"] if account else token,
|
| 945 |
+
"count": 0,
|
| 946 |
+
"status": f"error: {str(e)}",
|
| 947 |
+
"last_asset_clear_at": account["last_asset_clear_at"] if account else None
|
| 948 |
+
}, 0)
|
| 949 |
+
|
| 950 |
+
if selected_tokens:
|
| 951 |
+
total = 0
|
| 952 |
+
for i in range(0, len(selected_tokens), batch_size):
|
| 953 |
+
chunk = selected_tokens[i:i + batch_size]
|
| 954 |
+
results = await asyncio.gather(*[_fetch_detail(token) for token in chunk])
|
| 955 |
+
for detail, count in results:
|
| 956 |
+
online_details.append(detail)
|
| 957 |
+
total += count
|
| 958 |
+
online_stats = {"count": total, "status": "ok" if selected_tokens else "no_token", "token": None, "last_asset_clear_at": None}
|
| 959 |
+
scope = "selected"
|
| 960 |
+
elif scope == "all":
|
| 961 |
+
total = 0
|
| 962 |
+
tokens = [account["token"] for account in accounts]
|
| 963 |
+
for i in range(0, len(tokens), batch_size):
|
| 964 |
+
chunk = tokens[i:i + batch_size]
|
| 965 |
+
results = await asyncio.gather(*[_fetch_detail(token) for token in chunk])
|
| 966 |
+
for detail, count in results:
|
| 967 |
+
online_details.append(detail)
|
| 968 |
+
total += count
|
| 969 |
+
online_stats = {"count": total, "status": "ok" if accounts else "no_token", "token": None, "last_asset_clear_at": None}
|
| 970 |
+
else:
|
| 971 |
+
token = selected_token
|
| 972 |
+
if token:
|
| 973 |
+
try:
|
| 974 |
+
count = await _fetch_assets(token)
|
| 975 |
+
match = next((a for a in accounts if a["token"] == token), None)
|
| 976 |
+
online_stats = {
|
| 977 |
+
"count": count,
|
| 978 |
+
"status": "ok",
|
| 979 |
+
"token": token,
|
| 980 |
+
"token_masked": match["token_masked"] if match else token,
|
| 981 |
+
"last_asset_clear_at": match["last_asset_clear_at"] if match else None
|
| 982 |
+
}
|
| 983 |
+
except Exception as e:
|
| 984 |
+
match = next((a for a in accounts if a["token"] == token), None)
|
| 985 |
+
online_stats = {
|
| 986 |
+
"count": 0,
|
| 987 |
+
"status": f"error: {str(e)}",
|
| 988 |
+
"token": token,
|
| 989 |
+
"token_masked": match["token_masked"] if match else token,
|
| 990 |
+
"last_asset_clear_at": match["last_asset_clear_at"] if match else None
|
| 991 |
+
}
|
| 992 |
+
else:
|
| 993 |
+
online_stats = {"count": 0, "status": "not_loaded", "token": None, "last_asset_clear_at": None}
|
| 994 |
+
|
| 995 |
+
return {
|
| 996 |
+
"local_image": image_stats,
|
| 997 |
+
"local_video": video_stats,
|
| 998 |
+
"online": online_stats,
|
| 999 |
+
"online_accounts": accounts,
|
| 1000 |
+
"online_scope": scope or "none",
|
| 1001 |
+
"online_details": online_details
|
| 1002 |
+
}
|
| 1003 |
+
except Exception as e:
|
| 1004 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1005 |
+
|
| 1006 |
+
@router.post("/api/v1/admin/cache/clear", dependencies=[Depends(verify_api_key)])
|
| 1007 |
+
async def clear_local_cache_api(data: dict):
|
| 1008 |
+
"""清理本地缓存"""
|
| 1009 |
+
from app.services.grok.assets import DownloadService
|
| 1010 |
+
cache_type = data.get("type", "image")
|
| 1011 |
+
|
| 1012 |
+
try:
|
| 1013 |
+
dl_service = DownloadService()
|
| 1014 |
+
result = dl_service.clear(cache_type)
|
| 1015 |
+
return {"status": "success", "result": result}
|
| 1016 |
+
except Exception as e:
|
| 1017 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1018 |
+
|
| 1019 |
+
@router.get("/api/v1/admin/cache/list", dependencies=[Depends(verify_api_key)])
|
| 1020 |
+
async def list_local_cache_api(
|
| 1021 |
+
cache_type: str = "image",
|
| 1022 |
+
type_: str = Query(default=None, alias="type"),
|
| 1023 |
+
page: int = 1,
|
| 1024 |
+
page_size: int = 1000
|
| 1025 |
+
):
|
| 1026 |
+
"""列出本地缓存文件"""
|
| 1027 |
+
from app.services.grok.assets import DownloadService
|
| 1028 |
+
try:
|
| 1029 |
+
if type_:
|
| 1030 |
+
cache_type = type_
|
| 1031 |
+
dl_service = DownloadService()
|
| 1032 |
+
result = dl_service.list_files(cache_type, page, page_size)
|
| 1033 |
+
return {"status": "success", **result}
|
| 1034 |
+
except Exception as e:
|
| 1035 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1036 |
+
|
| 1037 |
+
@router.post("/api/v1/admin/cache/item/delete", dependencies=[Depends(verify_api_key)])
|
| 1038 |
+
async def delete_local_cache_item_api(data: dict):
|
| 1039 |
+
"""删除单个本地缓存文件"""
|
| 1040 |
+
from app.services.grok.assets import DownloadService
|
| 1041 |
+
cache_type = data.get("type", "image")
|
| 1042 |
+
name = data.get("name")
|
| 1043 |
+
if not name:
|
| 1044 |
+
raise HTTPException(status_code=400, detail="Missing file name")
|
| 1045 |
+
try:
|
| 1046 |
+
dl_service = DownloadService()
|
| 1047 |
+
result = dl_service.delete_file(cache_type, name)
|
| 1048 |
+
return {"status": "success", "result": result}
|
| 1049 |
+
except Exception as e:
|
| 1050 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1051 |
+
|
| 1052 |
+
@router.post("/api/v1/admin/cache/online/clear", dependencies=[Depends(verify_api_key)])
|
| 1053 |
+
async def clear_online_cache_api(data: dict):
|
| 1054 |
+
"""清理在线缓存"""
|
| 1055 |
+
from app.services.grok.assets import DeleteService
|
| 1056 |
+
from app.services.token.manager import get_token_manager
|
| 1057 |
+
|
| 1058 |
+
delete_service = None
|
| 1059 |
+
try:
|
| 1060 |
+
mgr = await get_token_manager()
|
| 1061 |
+
tokens = data.get("tokens")
|
| 1062 |
+
delete_service = DeleteService()
|
| 1063 |
+
|
| 1064 |
+
if isinstance(tokens, list):
|
| 1065 |
+
token_list = [t.strip() for t in tokens if isinstance(t, str) and t.strip()]
|
| 1066 |
+
if not token_list:
|
| 1067 |
+
raise HTTPException(status_code=400, detail="No tokens provided")
|
| 1068 |
+
|
| 1069 |
+
results = {}
|
| 1070 |
+
batch_size = get_config("performance.admin_assets_batch_size", 10)
|
| 1071 |
+
try:
|
| 1072 |
+
batch_size = int(batch_size)
|
| 1073 |
+
except Exception:
|
| 1074 |
+
batch_size = 10
|
| 1075 |
+
batch_size = max(1, batch_size)
|
| 1076 |
+
|
| 1077 |
+
async def _clear_one(t: str):
|
| 1078 |
+
try:
|
| 1079 |
+
result = await delete_service.delete_all(t)
|
| 1080 |
+
await mgr.mark_asset_clear(t)
|
| 1081 |
+
return t, {"status": "success", "result": result}
|
| 1082 |
+
except Exception as e:
|
| 1083 |
+
return t, {"status": "error", "error": str(e)}
|
| 1084 |
+
|
| 1085 |
+
for i in range(0, len(token_list), batch_size):
|
| 1086 |
+
chunk = token_list[i:i + batch_size]
|
| 1087 |
+
res_list = await asyncio.gather(*[_clear_one(t) for t in chunk])
|
| 1088 |
+
for t, res in res_list:
|
| 1089 |
+
results[t] = res
|
| 1090 |
+
|
| 1091 |
+
return {"status": "success", "results": results}
|
| 1092 |
+
|
| 1093 |
+
token = data.get("token") or mgr.get_token()
|
| 1094 |
+
if not token:
|
| 1095 |
+
raise HTTPException(status_code=400, detail="No available token to perform cleanup")
|
| 1096 |
+
|
| 1097 |
+
result = await delete_service.delete_all(token)
|
| 1098 |
+
await mgr.mark_asset_clear(token)
|
| 1099 |
+
return {"status": "success", "result": result}
|
| 1100 |
+
except Exception as e:
|
| 1101 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1102 |
+
finally:
|
| 1103 |
+
if delete_service:
|
| 1104 |
+
await delete_service.close()
|
| 1105 |
+
|
| 1106 |
+
|
| 1107 |
+
@router.get("/api/v1/admin/metrics", dependencies=[Depends(verify_api_key)])
|
| 1108 |
+
async def get_metrics_api():
|
| 1109 |
+
"""数据中心:聚合常用指标(token/cache/request_stats)。"""
|
| 1110 |
+
try:
|
| 1111 |
+
from app.services.request_stats import request_stats
|
| 1112 |
+
from app.services.token.manager import get_token_manager
|
| 1113 |
+
from app.services.token.models import TokenStatus
|
| 1114 |
+
from app.services.grok.assets import DownloadService
|
| 1115 |
+
|
| 1116 |
+
mgr = await get_token_manager()
|
| 1117 |
+
await mgr.reload_if_stale()
|
| 1118 |
+
|
| 1119 |
+
total = 0
|
| 1120 |
+
active = 0
|
| 1121 |
+
cooling = 0
|
| 1122 |
+
expired = 0
|
| 1123 |
+
disabled = 0
|
| 1124 |
+
chat_quota = 0
|
| 1125 |
+
total_calls = 0
|
| 1126 |
+
|
| 1127 |
+
for pool in mgr.pools.values():
|
| 1128 |
+
for info in pool.list():
|
| 1129 |
+
total += 1
|
| 1130 |
+
total_calls += int(getattr(info, "use_count", 0) or 0)
|
| 1131 |
+
if info.status == TokenStatus.ACTIVE:
|
| 1132 |
+
active += 1
|
| 1133 |
+
chat_quota += int(getattr(info, "quota", 0) or 0)
|
| 1134 |
+
elif info.status == TokenStatus.COOLING:
|
| 1135 |
+
cooling += 1
|
| 1136 |
+
elif info.status == TokenStatus.EXPIRED:
|
| 1137 |
+
expired += 1
|
| 1138 |
+
elif info.status == TokenStatus.DISABLED:
|
| 1139 |
+
disabled += 1
|
| 1140 |
+
|
| 1141 |
+
dl = DownloadService()
|
| 1142 |
+
local_image = dl.get_stats("image")
|
| 1143 |
+
local_video = dl.get_stats("video")
|
| 1144 |
+
|
| 1145 |
+
await request_stats.init()
|
| 1146 |
+
stats = request_stats.get_stats(hours=24, days=7)
|
| 1147 |
+
|
| 1148 |
+
return {
|
| 1149 |
+
"tokens": {
|
| 1150 |
+
"total": total,
|
| 1151 |
+
"active": active,
|
| 1152 |
+
"cooling": cooling,
|
| 1153 |
+
"expired": expired,
|
| 1154 |
+
"disabled": disabled,
|
| 1155 |
+
"chat_quota": chat_quota,
|
| 1156 |
+
"image_quota": int(chat_quota // 2),
|
| 1157 |
+
"total_calls": total_calls,
|
| 1158 |
+
},
|
| 1159 |
+
"cache": {
|
| 1160 |
+
"local_image": local_image,
|
| 1161 |
+
"local_video": local_video,
|
| 1162 |
+
},
|
| 1163 |
+
"request_stats": stats,
|
| 1164 |
+
}
|
| 1165 |
+
except Exception as e:
|
| 1166 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1167 |
+
|
| 1168 |
+
|
| 1169 |
+
@router.get("/api/v1/admin/cache/local", dependencies=[Depends(verify_api_key)])
|
| 1170 |
+
async def get_cache_local_stats_api():
|
| 1171 |
+
"""仅获取本地缓存统计(用于前端实时刷新)。"""
|
| 1172 |
+
from app.services.grok.assets import DownloadService
|
| 1173 |
+
|
| 1174 |
+
try:
|
| 1175 |
+
dl_service = DownloadService()
|
| 1176 |
+
image_stats = dl_service.get_stats("image")
|
| 1177 |
+
video_stats = dl_service.get_stats("video")
|
| 1178 |
+
return {"local_image": image_stats, "local_video": video_stats}
|
| 1179 |
+
except Exception as e:
|
| 1180 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1181 |
+
|
| 1182 |
+
|
| 1183 |
+
def _safe_log_file_path(name: str) -> Path:
|
| 1184 |
+
"""Resolve a log file name under ./logs safely."""
|
| 1185 |
+
from app.core.logger import LOG_DIR
|
| 1186 |
+
|
| 1187 |
+
name = (name or "").strip()
|
| 1188 |
+
if not name:
|
| 1189 |
+
raise ValueError("Missing log file")
|
| 1190 |
+
# Disallow path traversal.
|
| 1191 |
+
if "/" in name or "\\" in name or ".." in name:
|
| 1192 |
+
raise ValueError("Invalid log file name")
|
| 1193 |
+
|
| 1194 |
+
p = (LOG_DIR / name).resolve()
|
| 1195 |
+
if LOG_DIR.resolve() not in p.parents:
|
| 1196 |
+
raise ValueError("Invalid log file path")
|
| 1197 |
+
if not p.exists() or not p.is_file():
|
| 1198 |
+
raise FileNotFoundError(name)
|
| 1199 |
+
return p
|
| 1200 |
+
|
| 1201 |
+
|
| 1202 |
+
def _format_log_line(raw: str) -> str:
|
| 1203 |
+
raw = (raw or "").rstrip("\r\n")
|
| 1204 |
+
if not raw:
|
| 1205 |
+
return ""
|
| 1206 |
+
|
| 1207 |
+
# Try JSON log line (our file sink uses json lines).
|
| 1208 |
+
try:
|
| 1209 |
+
obj = json.loads(raw)
|
| 1210 |
+
if not isinstance(obj, dict):
|
| 1211 |
+
return raw
|
| 1212 |
+
ts = str(obj.get("time", "") or "")
|
| 1213 |
+
ts = ts.replace("T", " ")
|
| 1214 |
+
if len(ts) >= 19:
|
| 1215 |
+
ts = ts[:19]
|
| 1216 |
+
level = str(obj.get("level", "") or "").upper()
|
| 1217 |
+
caller = str(obj.get("caller", "") or "")
|
| 1218 |
+
msg = str(obj.get("msg", "") or "")
|
| 1219 |
+
if not (ts and level and msg):
|
| 1220 |
+
return raw
|
| 1221 |
+
return f"{ts} | {level:<8} | {caller} - {msg}".rstrip()
|
| 1222 |
+
except Exception:
|
| 1223 |
+
return raw
|
| 1224 |
+
|
| 1225 |
+
|
| 1226 |
+
def _tail_lines(path: Path, max_lines: int = 2000, max_bytes: int = 1024 * 1024) -> list[str]:
|
| 1227 |
+
"""Best-effort tail for a text file."""
|
| 1228 |
+
try:
|
| 1229 |
+
max_lines = int(max_lines)
|
| 1230 |
+
except Exception:
|
| 1231 |
+
max_lines = 2000
|
| 1232 |
+
max_lines = max(1, min(5000, max_lines))
|
| 1233 |
+
max_bytes = max(16 * 1024, min(5 * 1024 * 1024, int(max_bytes)))
|
| 1234 |
+
|
| 1235 |
+
with open(path, "rb") as f:
|
| 1236 |
+
f.seek(0, os.SEEK_END)
|
| 1237 |
+
end = f.tell()
|
| 1238 |
+
start = max(0, end - max_bytes)
|
| 1239 |
+
f.seek(start, os.SEEK_SET)
|
| 1240 |
+
data = f.read()
|
| 1241 |
+
|
| 1242 |
+
text = data.decode("utf-8", errors="replace")
|
| 1243 |
+
lines = text.splitlines()
|
| 1244 |
+
# If we read from the middle of a line, drop the first partial line.
|
| 1245 |
+
if start > 0 and lines:
|
| 1246 |
+
lines = lines[1:]
|
| 1247 |
+
lines = lines[-max_lines:]
|
| 1248 |
+
return [_format_log_line(ln) for ln in lines if ln is not None]
|
| 1249 |
+
|
| 1250 |
+
|
| 1251 |
+
@router.get("/api/v1/admin/logs/files", dependencies=[Depends(verify_api_key)])
|
| 1252 |
+
async def list_log_files_api():
|
| 1253 |
+
"""列出可查看的日志文件(logs/*.log)。"""
|
| 1254 |
+
from app.core.logger import LOG_DIR
|
| 1255 |
+
|
| 1256 |
+
try:
|
| 1257 |
+
items = []
|
| 1258 |
+
for p in LOG_DIR.glob("*.log"):
|
| 1259 |
+
try:
|
| 1260 |
+
stat = p.stat()
|
| 1261 |
+
items.append(
|
| 1262 |
+
{
|
| 1263 |
+
"name": p.name,
|
| 1264 |
+
"size_bytes": stat.st_size,
|
| 1265 |
+
"mtime_ms": int(stat.st_mtime * 1000),
|
| 1266 |
+
}
|
| 1267 |
+
)
|
| 1268 |
+
except Exception:
|
| 1269 |
+
continue
|
| 1270 |
+
items.sort(key=lambda x: x["mtime_ms"], reverse=True)
|
| 1271 |
+
return {"files": items}
|
| 1272 |
+
except Exception as e:
|
| 1273 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1274 |
+
|
| 1275 |
+
|
| 1276 |
+
@router.get("/api/v1/admin/logs/tail", dependencies=[Depends(verify_api_key)])
|
| 1277 |
+
async def tail_log_api(file: str | None = None, lines: int = 500):
|
| 1278 |
+
"""读取后台日志(尾部)。"""
|
| 1279 |
+
from app.core.logger import LOG_DIR
|
| 1280 |
+
|
| 1281 |
+
try:
|
| 1282 |
+
# Default to latest log.
|
| 1283 |
+
if not file:
|
| 1284 |
+
candidates = sorted(LOG_DIR.glob("*.log"), key=lambda p: p.stat().st_mtime if p.exists() else 0, reverse=True)
|
| 1285 |
+
if not candidates:
|
| 1286 |
+
return {"file": None, "lines": []}
|
| 1287 |
+
path = candidates[0]
|
| 1288 |
+
file = path.name
|
| 1289 |
+
else:
|
| 1290 |
+
path = _safe_log_file_path(file)
|
| 1291 |
+
|
| 1292 |
+
data = await asyncio.to_thread(_tail_lines, path, lines)
|
| 1293 |
+
return {"file": str(file), "lines": data}
|
| 1294 |
+
except FileNotFoundError:
|
| 1295 |
+
raise HTTPException(status_code=404, detail="Log file not found")
|
| 1296 |
+
except ValueError as ve:
|
| 1297 |
+
raise HTTPException(status_code=400, detail=str(ve))
|
| 1298 |
+
except Exception as e:
|
| 1299 |
+
raise HTTPException(status_code=500, detail=str(e))
|
app/api/v1/chat.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chat Completions API 路由
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, List, Optional, Union
|
| 6 |
+
|
| 7 |
+
from fastapi import APIRouter, Depends
|
| 8 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
| 9 |
+
from pydantic import BaseModel, Field, field_validator
|
| 10 |
+
|
| 11 |
+
from app.core.auth import verify_api_key
|
| 12 |
+
from app.services.grok.chat import ChatService
|
| 13 |
+
from app.services.grok.model import ModelService
|
| 14 |
+
from app.core.exceptions import ValidationException
|
| 15 |
+
from app.services.quota import enforce_daily_quota
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
router = APIRouter(tags=["Chat"])
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
VALID_ROLES = ["developer", "system", "user", "assistant"]
|
| 22 |
+
USER_CONTENT_TYPES = ["text", "image_url", "input_audio", "file"]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class MessageItem(BaseModel):
|
| 26 |
+
"""消息项"""
|
| 27 |
+
role: str
|
| 28 |
+
content: Union[str, List[Dict[str, Any]]]
|
| 29 |
+
|
| 30 |
+
@field_validator("role")
|
| 31 |
+
@classmethod
|
| 32 |
+
def validate_role(cls, v):
|
| 33 |
+
if v not in VALID_ROLES:
|
| 34 |
+
raise ValueError(f"role must be one of {VALID_ROLES}")
|
| 35 |
+
return v
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class VideoConfig(BaseModel):
|
| 39 |
+
"""视频生成配置"""
|
| 40 |
+
aspect_ratio: Optional[str] = Field("3:2", description="视频比例: 3:2, 16:9, 1:1 等")
|
| 41 |
+
video_length: Optional[int] = Field(6, description="视频时长(秒): 5-15")
|
| 42 |
+
resolution: Optional[str] = Field("SD", description="视频分辨率: SD, HD")
|
| 43 |
+
preset: Optional[str] = Field("custom", description="风格预设: fun, normal, spicy")
|
| 44 |
+
|
| 45 |
+
@field_validator("aspect_ratio")
|
| 46 |
+
@classmethod
|
| 47 |
+
def validate_aspect_ratio(cls, v):
|
| 48 |
+
allowed = ["2:3", "3:2", "1:1", "9:16", "16:9"]
|
| 49 |
+
if v and v not in allowed:
|
| 50 |
+
raise ValidationException(
|
| 51 |
+
message=f"aspect_ratio must be one of {allowed}",
|
| 52 |
+
param="video_config.aspect_ratio",
|
| 53 |
+
code="invalid_aspect_ratio"
|
| 54 |
+
)
|
| 55 |
+
return v
|
| 56 |
+
|
| 57 |
+
@field_validator("video_length")
|
| 58 |
+
@classmethod
|
| 59 |
+
def validate_video_length(cls, v):
|
| 60 |
+
if v is not None:
|
| 61 |
+
if v < 5 or v > 15:
|
| 62 |
+
raise ValidationException(
|
| 63 |
+
message="video_length must be between 5 and 15 seconds",
|
| 64 |
+
param="video_config.video_length",
|
| 65 |
+
code="invalid_video_length"
|
| 66 |
+
)
|
| 67 |
+
return v
|
| 68 |
+
|
| 69 |
+
@field_validator("resolution")
|
| 70 |
+
@classmethod
|
| 71 |
+
def validate_resolution(cls, v):
|
| 72 |
+
allowed = ["SD", "HD"]
|
| 73 |
+
if v and v not in allowed:
|
| 74 |
+
raise ValidationException(
|
| 75 |
+
message=f"resolution must be one of {allowed}",
|
| 76 |
+
param="video_config.resolution",
|
| 77 |
+
code="invalid_resolution"
|
| 78 |
+
)
|
| 79 |
+
return v
|
| 80 |
+
|
| 81 |
+
@field_validator("preset")
|
| 82 |
+
@classmethod
|
| 83 |
+
def validate_preset(cls, v):
|
| 84 |
+
# 允许为空,默认 custom
|
| 85 |
+
if not v:
|
| 86 |
+
return "custom"
|
| 87 |
+
allowed = ["fun", "normal", "spicy", "custom"]
|
| 88 |
+
if v not in allowed:
|
| 89 |
+
raise ValidationException(
|
| 90 |
+
message=f"preset must be one of {allowed}",
|
| 91 |
+
param="video_config.preset",
|
| 92 |
+
code="invalid_preset"
|
| 93 |
+
)
|
| 94 |
+
return v
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class ChatCompletionRequest(BaseModel):
|
| 98 |
+
"""Chat Completions 请求"""
|
| 99 |
+
model: str = Field(..., description="模型名称")
|
| 100 |
+
messages: List[MessageItem] = Field(..., description="消息数组")
|
| 101 |
+
stream: Optional[bool] = Field(None, description="是否流式输出")
|
| 102 |
+
thinking: Optional[str] = Field(None, description="思考模式: enabled/disabled/None")
|
| 103 |
+
|
| 104 |
+
# 视频生成配置
|
| 105 |
+
video_config: Optional[VideoConfig] = Field(None, description="视频生成参数")
|
| 106 |
+
|
| 107 |
+
model_config = {
|
| 108 |
+
"extra": "ignore"
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def validate_request(request: ChatCompletionRequest):
|
| 113 |
+
"""验证请求参数"""
|
| 114 |
+
# 验证模型
|
| 115 |
+
if not ModelService.valid(request.model):
|
| 116 |
+
raise ValidationException(
|
| 117 |
+
message=f"The model `{request.model}` does not exist or you do not have access to it.",
|
| 118 |
+
param="model",
|
| 119 |
+
code="model_not_found"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# 验证消息
|
| 123 |
+
for idx, msg in enumerate(request.messages):
|
| 124 |
+
content = msg.content
|
| 125 |
+
|
| 126 |
+
# 字符串内容
|
| 127 |
+
if isinstance(content, str):
|
| 128 |
+
if not content.strip():
|
| 129 |
+
raise ValidationException(
|
| 130 |
+
message="Message content cannot be empty",
|
| 131 |
+
param=f"messages.{idx}.content",
|
| 132 |
+
code="empty_content"
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# 列表内容
|
| 136 |
+
elif isinstance(content, list):
|
| 137 |
+
if not content:
|
| 138 |
+
raise ValidationException(
|
| 139 |
+
message="Message content cannot be an empty array",
|
| 140 |
+
param=f"messages.{idx}.content",
|
| 141 |
+
code="empty_content"
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
for block_idx, block in enumerate(content):
|
| 145 |
+
# 检查空对象
|
| 146 |
+
if not block:
|
| 147 |
+
raise ValidationException(
|
| 148 |
+
message="Content block cannot be empty",
|
| 149 |
+
param=f"messages.{idx}.content.{block_idx}",
|
| 150 |
+
code="empty_block"
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# 检查 type 字段
|
| 154 |
+
if "type" not in block:
|
| 155 |
+
raise ValidationException(
|
| 156 |
+
message="Content block must have a 'type' field",
|
| 157 |
+
param=f"messages.{idx}.content.{block_idx}",
|
| 158 |
+
code="missing_type"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
block_type = block.get("type")
|
| 162 |
+
|
| 163 |
+
# 检查 type 空值
|
| 164 |
+
if not block_type or not isinstance(block_type, str) or not block_type.strip():
|
| 165 |
+
raise ValidationException(
|
| 166 |
+
message="Content block 'type' cannot be empty",
|
| 167 |
+
param=f"messages.{idx}.content.{block_idx}.type",
|
| 168 |
+
code="empty_type"
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
# 验证 type 有效性
|
| 172 |
+
if msg.role == "user":
|
| 173 |
+
if block_type not in USER_CONTENT_TYPES:
|
| 174 |
+
raise ValidationException(
|
| 175 |
+
message=f"Invalid content block type: '{block_type}'",
|
| 176 |
+
param=f"messages.{idx}.content.{block_idx}.type",
|
| 177 |
+
code="invalid_type"
|
| 178 |
+
)
|
| 179 |
+
elif block_type != "text":
|
| 180 |
+
raise ValidationException(
|
| 181 |
+
message=f"The `{msg.role}` role only supports 'text' type, got '{block_type}'",
|
| 182 |
+
param=f"messages.{idx}.content.{block_idx}.type",
|
| 183 |
+
code="invalid_type"
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# 验证字段是否存在 & 非空
|
| 187 |
+
if block_type == "text":
|
| 188 |
+
text = block.get("text", "")
|
| 189 |
+
if not isinstance(text, str) or not text.strip():
|
| 190 |
+
raise ValidationException(
|
| 191 |
+
message="Text content cannot be empty",
|
| 192 |
+
param=f"messages.{idx}.content.{block_idx}.text",
|
| 193 |
+
code="empty_text"
|
| 194 |
+
)
|
| 195 |
+
elif block_type == "image_url":
|
| 196 |
+
image_url = block.get("image_url")
|
| 197 |
+
if not image_url or not (isinstance(image_url, dict) and image_url.get("url")):
|
| 198 |
+
raise ValidationException(
|
| 199 |
+
message="image_url must have a 'url' field",
|
| 200 |
+
param=f"messages.{idx}.content.{block_idx}.image_url",
|
| 201 |
+
code="missing_url"
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
@router.post("/chat/completions")
|
| 206 |
+
async def chat_completions(request: ChatCompletionRequest, api_key: Optional[str] = Depends(verify_api_key)):
|
| 207 |
+
"""Chat Completions API - 兼容 OpenAI"""
|
| 208 |
+
|
| 209 |
+
# 参数验证
|
| 210 |
+
validate_request(request)
|
| 211 |
+
|
| 212 |
+
# Daily quota (best-effort)
|
| 213 |
+
await enforce_daily_quota(api_key, request.model)
|
| 214 |
+
|
| 215 |
+
# 检测视频模型
|
| 216 |
+
model_info = ModelService.get(request.model)
|
| 217 |
+
if model_info and model_info.is_video:
|
| 218 |
+
from app.services.grok.media import VideoService
|
| 219 |
+
|
| 220 |
+
# 提取视频配置 (默认值在 Pydantic 模型中处理)
|
| 221 |
+
v_conf = request.video_config or VideoConfig()
|
| 222 |
+
|
| 223 |
+
result = await VideoService.completions(
|
| 224 |
+
model=request.model,
|
| 225 |
+
messages=[msg.model_dump() for msg in request.messages],
|
| 226 |
+
stream=request.stream,
|
| 227 |
+
thinking=request.thinking,
|
| 228 |
+
aspect_ratio=v_conf.aspect_ratio,
|
| 229 |
+
video_length=v_conf.video_length,
|
| 230 |
+
resolution=v_conf.resolution,
|
| 231 |
+
preset=v_conf.preset
|
| 232 |
+
)
|
| 233 |
+
else:
|
| 234 |
+
result = await ChatService.completions(
|
| 235 |
+
model=request.model,
|
| 236 |
+
messages=[msg.model_dump() for msg in request.messages],
|
| 237 |
+
stream=request.stream,
|
| 238 |
+
thinking=request.thinking
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
if isinstance(result, dict):
|
| 242 |
+
return JSONResponse(content=result)
|
| 243 |
+
else:
|
| 244 |
+
return StreamingResponse(
|
| 245 |
+
result,
|
| 246 |
+
media_type="text/event-stream",
|
| 247 |
+
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
__all__ = ["router"]
|
app/api/v1/files.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
文件服务 API 路由
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import aiofiles.os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from fastapi import APIRouter, HTTPException
|
| 8 |
+
from fastapi.responses import FileResponse
|
| 9 |
+
|
| 10 |
+
from app.core.logger import logger
|
| 11 |
+
|
| 12 |
+
router = APIRouter(tags=["Files"])
|
| 13 |
+
|
| 14 |
+
# 缓存根目录
|
| 15 |
+
BASE_DIR = Path(__file__).parent.parent.parent.parent / "data" / "tmp"
|
| 16 |
+
IMAGE_DIR = BASE_DIR / "image"
|
| 17 |
+
VIDEO_DIR = BASE_DIR / "video"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@router.get("/image/{filename:path}")
|
| 21 |
+
async def get_image(filename: str):
|
| 22 |
+
"""
|
| 23 |
+
获取图片文件
|
| 24 |
+
"""
|
| 25 |
+
if "/" in filename:
|
| 26 |
+
filename = filename.replace("/", "-")
|
| 27 |
+
|
| 28 |
+
file_path = IMAGE_DIR / filename
|
| 29 |
+
|
| 30 |
+
if await aiofiles.os.path.exists(file_path):
|
| 31 |
+
if await aiofiles.os.path.isfile(file_path):
|
| 32 |
+
content_type = "image/jpeg"
|
| 33 |
+
if file_path.suffix.lower() == ".png":
|
| 34 |
+
content_type = "image/png"
|
| 35 |
+
elif file_path.suffix.lower() == ".webp":
|
| 36 |
+
content_type = "image/webp"
|
| 37 |
+
|
| 38 |
+
# 增加缓存头,支持高并发场景下的浏览器/CDN缓存
|
| 39 |
+
return FileResponse(
|
| 40 |
+
file_path,
|
| 41 |
+
media_type=content_type,
|
| 42 |
+
headers={
|
| 43 |
+
"Cache-Control": "public, max-age=31536000, immutable"
|
| 44 |
+
}
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
logger.warning(f"Image not found: {filename}")
|
| 48 |
+
raise HTTPException(status_code=404, detail="Image not found")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@router.get("/video/{filename:path}")
|
| 52 |
+
async def get_video(filename: str):
|
| 53 |
+
"""
|
| 54 |
+
获取视频文件
|
| 55 |
+
"""
|
| 56 |
+
if "/" in filename:
|
| 57 |
+
filename = filename.replace("/", "-")
|
| 58 |
+
|
| 59 |
+
file_path = VIDEO_DIR / filename
|
| 60 |
+
|
| 61 |
+
if await aiofiles.os.path.exists(file_path):
|
| 62 |
+
if await aiofiles.os.path.isfile(file_path):
|
| 63 |
+
return FileResponse(
|
| 64 |
+
file_path,
|
| 65 |
+
media_type="video/mp4",
|
| 66 |
+
headers={
|
| 67 |
+
"Cache-Control": "public, max-age=31536000, immutable"
|
| 68 |
+
}
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
logger.warning(f"Video not found: {filename}")
|
| 72 |
+
raise HTTPException(status_code=404, detail="Video not found")
|
app/api/v1/image.py
ADDED
|
@@ -0,0 +1,1065 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Image Generation API 路由
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import base64
|
| 7 |
+
import random
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Awaitable, Callable, Dict, List, Optional, Union
|
| 10 |
+
|
| 11 |
+
import orjson
|
| 12 |
+
from fastapi import APIRouter, Depends, File, Form, UploadFile
|
| 13 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
| 14 |
+
from pydantic import BaseModel, Field, ValidationError
|
| 15 |
+
|
| 16 |
+
from app.core.auth import verify_api_key
|
| 17 |
+
from app.core.config import get_config
|
| 18 |
+
from app.core.exceptions import AppException, ErrorType, UpstreamException, ValidationException
|
| 19 |
+
from app.core.logger import logger
|
| 20 |
+
from app.services.grok.assets import UploadService
|
| 21 |
+
from app.services.grok.chat import GrokChatService
|
| 22 |
+
from app.services.grok.imagine_experimental import (
|
| 23 |
+
IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL,
|
| 24 |
+
IMAGE_METHOD_LEGACY,
|
| 25 |
+
ImagineExperimentalService,
|
| 26 |
+
resolve_image_generation_method,
|
| 27 |
+
)
|
| 28 |
+
from app.services.grok.imagine_generation import (
|
| 29 |
+
call_experimental_generation_once,
|
| 30 |
+
collect_experimental_generation_images,
|
| 31 |
+
dedupe_images as dedupe_imagine_images,
|
| 32 |
+
is_valid_image_value as is_valid_imagine_image_value,
|
| 33 |
+
resolve_aspect_ratio as resolve_imagine_aspect_ratio,
|
| 34 |
+
)
|
| 35 |
+
from app.services.grok.model import ModelService
|
| 36 |
+
from app.services.grok.processor import ImageCollectProcessor, ImageStreamProcessor
|
| 37 |
+
from app.services.quota import enforce_daily_quota
|
| 38 |
+
from app.services.request_stats import request_stats
|
| 39 |
+
from app.services.token import get_token_manager
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
router = APIRouter(tags=["Images"])
|
| 43 |
+
ALLOWED_RESPONSE_FORMATS = {"b64_json", "base64", "url"}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ImageGenerationRequest(BaseModel):
|
| 47 |
+
"""Image generation request - OpenAI compatible."""
|
| 48 |
+
|
| 49 |
+
prompt: str = Field(..., description="Image prompt")
|
| 50 |
+
model: Optional[str] = Field("grok-imagine-1.0", description="Model name")
|
| 51 |
+
n: Optional[int] = Field(1, ge=1, le=10, description="Image count (1-10)")
|
| 52 |
+
size: Optional[str] = Field("1024x1024", description="Image size / ratio")
|
| 53 |
+
quality: Optional[str] = Field("standard", description="Reserved")
|
| 54 |
+
response_format: Optional[str] = Field(None, description="Response format")
|
| 55 |
+
style: Optional[str] = Field(None, description="Reserved")
|
| 56 |
+
stream: Optional[bool] = Field(False, description="Enable streaming")
|
| 57 |
+
concurrency: Optional[int] = Field(1, ge=1, le=3, description="Experimental concurrency")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class ImageEditRequest(BaseModel):
|
| 61 |
+
"""Image edit request - OpenAI compatible."""
|
| 62 |
+
|
| 63 |
+
prompt: str = Field(..., description="Edit prompt")
|
| 64 |
+
model: Optional[str] = Field("grok-imagine-1.0-edit", description="Model name")
|
| 65 |
+
image: Optional[Union[str, List[str]]] = Field(None, description="Input image(s)")
|
| 66 |
+
n: Optional[int] = Field(1, ge=1, le=10, description="Image count (1-10)")
|
| 67 |
+
size: Optional[str] = Field("1024x1024", description="Reserved")
|
| 68 |
+
quality: Optional[str] = Field("standard", description="Reserved")
|
| 69 |
+
response_format: Optional[str] = Field(None, description="Response format")
|
| 70 |
+
style: Optional[str] = Field(None, description="Reserved")
|
| 71 |
+
stream: Optional[bool] = Field(False, description="Enable streaming")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def validate_generation_request(request: ImageGenerationRequest):
|
| 75 |
+
"""Validate image generation request parameters."""
|
| 76 |
+
model_id = request.model or "grok-imagine-1.0"
|
| 77 |
+
if model_id != "grok-imagine-1.0":
|
| 78 |
+
raise ValidationException(
|
| 79 |
+
message="The model `grok-imagine-1.0` is required for image generation.",
|
| 80 |
+
param="model",
|
| 81 |
+
code="model_not_supported",
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
model_info = ModelService.get(model_id)
|
| 85 |
+
if not model_info or not model_info.is_image:
|
| 86 |
+
raise ValidationException(
|
| 87 |
+
message=f"The model `{model_id}` is not supported for image generation.",
|
| 88 |
+
param="model",
|
| 89 |
+
code="model_not_supported",
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
if not request.prompt or not request.prompt.strip():
|
| 93 |
+
raise ValidationException(
|
| 94 |
+
message="Prompt cannot be empty",
|
| 95 |
+
param="prompt",
|
| 96 |
+
code="empty_prompt",
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
if request.n is None:
|
| 100 |
+
request.n = 1
|
| 101 |
+
if request.n < 1 or request.n > 10:
|
| 102 |
+
raise ValidationException(
|
| 103 |
+
message="n must be between 1 and 10",
|
| 104 |
+
param="n",
|
| 105 |
+
code="invalid_n",
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
if request.stream and request.n not in [1, 2]:
|
| 109 |
+
raise ValidationException(
|
| 110 |
+
message="Streaming is only supported when n=1 or n=2",
|
| 111 |
+
param="stream",
|
| 112 |
+
code="invalid_stream_n",
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
if request.concurrency is None:
|
| 116 |
+
request.concurrency = 1
|
| 117 |
+
if request.concurrency < 1 or request.concurrency > 3:
|
| 118 |
+
raise ValidationException(
|
| 119 |
+
message="concurrency must be between 1 and 3",
|
| 120 |
+
param="concurrency",
|
| 121 |
+
code="invalid_concurrency",
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
if request.response_format:
|
| 125 |
+
candidate = request.response_format.lower()
|
| 126 |
+
if candidate not in ALLOWED_RESPONSE_FORMATS:
|
| 127 |
+
raise ValidationException(
|
| 128 |
+
message=f"response_format must be one of {sorted(ALLOWED_RESPONSE_FORMATS)}",
|
| 129 |
+
param="response_format",
|
| 130 |
+
code="invalid_response_format",
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def validate_edit_request(request: ImageEditRequest, images: List[UploadFile]):
|
| 135 |
+
"""Validate image edit request parameters."""
|
| 136 |
+
model_id = request.model or "grok-imagine-1.0-edit"
|
| 137 |
+
if model_id != "grok-imagine-1.0-edit":
|
| 138 |
+
raise ValidationException(
|
| 139 |
+
message="The model `grok-imagine-1.0-edit` is required for image edits.",
|
| 140 |
+
param="model",
|
| 141 |
+
code="model_not_supported",
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
model_info = ModelService.get(model_id)
|
| 145 |
+
if not model_info or not model_info.is_image:
|
| 146 |
+
raise ValidationException(
|
| 147 |
+
message=f"The model `{model_id}` is not supported for image edits.",
|
| 148 |
+
param="model",
|
| 149 |
+
code="model_not_supported",
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
if not request.prompt or not request.prompt.strip():
|
| 153 |
+
raise ValidationException(
|
| 154 |
+
message="Prompt cannot be empty",
|
| 155 |
+
param="prompt",
|
| 156 |
+
code="empty_prompt",
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
if request.n is None:
|
| 160 |
+
request.n = 1
|
| 161 |
+
if request.n < 1 or request.n > 10:
|
| 162 |
+
raise ValidationException(
|
| 163 |
+
message="n must be between 1 and 10",
|
| 164 |
+
param="n",
|
| 165 |
+
code="invalid_n",
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
if request.stream and request.n not in [1, 2]:
|
| 169 |
+
raise ValidationException(
|
| 170 |
+
message="Streaming is only supported when n=1 or n=2",
|
| 171 |
+
param="stream",
|
| 172 |
+
code="invalid_stream_n",
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
if request.response_format:
|
| 176 |
+
candidate = request.response_format.lower()
|
| 177 |
+
if candidate not in ALLOWED_RESPONSE_FORMATS:
|
| 178 |
+
raise ValidationException(
|
| 179 |
+
message=f"response_format must be one of {sorted(ALLOWED_RESPONSE_FORMATS)}",
|
| 180 |
+
param="response_format",
|
| 181 |
+
code="invalid_response_format",
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
if not images:
|
| 185 |
+
raise ValidationException(
|
| 186 |
+
message="Image is required",
|
| 187 |
+
param="image",
|
| 188 |
+
code="missing_image",
|
| 189 |
+
)
|
| 190 |
+
if len(images) > 16:
|
| 191 |
+
raise ValidationException(
|
| 192 |
+
message="Too many images. Maximum is 16.",
|
| 193 |
+
param="image",
|
| 194 |
+
code="invalid_image_count",
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def resolve_response_format(response_format: Optional[str]) -> str:
|
| 199 |
+
candidate = response_format
|
| 200 |
+
if not candidate:
|
| 201 |
+
candidate = get_config("app.image_format", "url")
|
| 202 |
+
if isinstance(candidate, str):
|
| 203 |
+
candidate = candidate.lower()
|
| 204 |
+
if candidate in ALLOWED_RESPONSE_FORMATS:
|
| 205 |
+
return candidate
|
| 206 |
+
raise ValidationException(
|
| 207 |
+
message=f"response_format must be one of {sorted(ALLOWED_RESPONSE_FORMATS)}",
|
| 208 |
+
param="response_format",
|
| 209 |
+
code="invalid_response_format",
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def resolve_image_response_format(
|
| 214 |
+
response_format: Optional[str],
|
| 215 |
+
image_method: str,
|
| 216 |
+
) -> str:
|
| 217 |
+
"""
|
| 218 |
+
Keep legacy behavior, but for experimental imagine path:
|
| 219 |
+
if caller does not explicitly provide response_format and global default is `url`,
|
| 220 |
+
prefer `b64_json` to avoid loopback URL rendering issues in local deployments.
|
| 221 |
+
"""
|
| 222 |
+
raw = response_format if not isinstance(response_format, str) else response_format.strip()
|
| 223 |
+
if not raw and image_method == IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL:
|
| 224 |
+
default_format = str(get_config("app.image_format", "url") or "url").strip().lower()
|
| 225 |
+
if default_format == "url":
|
| 226 |
+
return "b64_json"
|
| 227 |
+
return resolve_response_format(response_format)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def response_field_name(response_format: str) -> str:
|
| 231 |
+
if response_format == "url":
|
| 232 |
+
return "url"
|
| 233 |
+
if response_format == "base64":
|
| 234 |
+
return "base64"
|
| 235 |
+
return "b64_json"
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def _image_generation_method() -> str:
|
| 239 |
+
return resolve_image_generation_method(
|
| 240 |
+
get_config("grok.image_generation_method", IMAGE_METHOD_LEGACY)
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def resolve_aspect_ratio(size: Optional[str]) -> str:
|
| 245 |
+
return resolve_imagine_aspect_ratio(size)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def _is_valid_image_value(value: Any) -> bool:
|
| 249 |
+
return is_valid_imagine_image_value(value)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def _dedupe_images(images: List[str]) -> List[str]:
|
| 253 |
+
return dedupe_imagine_images(images)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
async def _gather_limited(
|
| 257 |
+
task_factories: List[Callable[[], Awaitable[List[str]]]],
|
| 258 |
+
max_concurrency: int,
|
| 259 |
+
) -> List[Any]:
|
| 260 |
+
sem = asyncio.Semaphore(max(1, int(max_concurrency or 1)))
|
| 261 |
+
|
| 262 |
+
async def _run(factory: Callable[[], Awaitable[List[str]]]) -> Any:
|
| 263 |
+
async with sem:
|
| 264 |
+
return await factory()
|
| 265 |
+
|
| 266 |
+
return await asyncio.gather(*[_run(factory) for factory in task_factories], return_exceptions=True)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
async def call_grok_legacy(
|
| 270 |
+
token: str,
|
| 271 |
+
prompt: str,
|
| 272 |
+
model_info,
|
| 273 |
+
file_attachments: Optional[List[str]] = None,
|
| 274 |
+
response_format: str = "b64_json",
|
| 275 |
+
) -> List[str]:
|
| 276 |
+
"""
|
| 277 |
+
调用 Grok 获取图片,返回图片列表
|
| 278 |
+
"""
|
| 279 |
+
chat_service = GrokChatService()
|
| 280 |
+
|
| 281 |
+
try:
|
| 282 |
+
response = await chat_service.chat(
|
| 283 |
+
token=token,
|
| 284 |
+
message=prompt,
|
| 285 |
+
model=model_info.grok_model,
|
| 286 |
+
mode=model_info.model_mode,
|
| 287 |
+
think=False,
|
| 288 |
+
stream=True,
|
| 289 |
+
file_attachments=file_attachments,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
processor = ImageCollectProcessor(
|
| 293 |
+
model_info.model_id,
|
| 294 |
+
token,
|
| 295 |
+
response_format=response_format,
|
| 296 |
+
)
|
| 297 |
+
return await processor.process(response)
|
| 298 |
+
except Exception as e:
|
| 299 |
+
logger.error(f"Grok image call failed: {e}")
|
| 300 |
+
return []
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
async def call_grok_experimental_ws(
|
| 304 |
+
token: str,
|
| 305 |
+
prompt: str,
|
| 306 |
+
response_format: str = "b64_json",
|
| 307 |
+
n: int = 4,
|
| 308 |
+
aspect_ratio: str = "2:3",
|
| 309 |
+
) -> List[str]:
|
| 310 |
+
return await call_experimental_generation_once(
|
| 311 |
+
token=token,
|
| 312 |
+
prompt=prompt,
|
| 313 |
+
response_format=response_format,
|
| 314 |
+
n=n,
|
| 315 |
+
aspect_ratio=aspect_ratio,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
async def call_grok_experimental_edit(
|
| 320 |
+
token: str,
|
| 321 |
+
prompt: str,
|
| 322 |
+
model_id: str,
|
| 323 |
+
file_uris: List[str],
|
| 324 |
+
response_format: str = "b64_json",
|
| 325 |
+
) -> List[str]:
|
| 326 |
+
service = ImagineExperimentalService()
|
| 327 |
+
response = await service.chat_edit(token=token, prompt=prompt, file_uris=file_uris)
|
| 328 |
+
processor = ImageCollectProcessor(
|
| 329 |
+
model_id,
|
| 330 |
+
token,
|
| 331 |
+
response_format=response_format,
|
| 332 |
+
)
|
| 333 |
+
return await processor.process(response)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
async def _collect_experimental_generation_images(
|
| 337 |
+
token: str,
|
| 338 |
+
prompt: str,
|
| 339 |
+
n: int,
|
| 340 |
+
response_format: str,
|
| 341 |
+
aspect_ratio: str,
|
| 342 |
+
concurrency: int,
|
| 343 |
+
) -> List[str]:
|
| 344 |
+
return await collect_experimental_generation_images(
|
| 345 |
+
token=token,
|
| 346 |
+
prompt=prompt,
|
| 347 |
+
n=n,
|
| 348 |
+
response_format=response_format,
|
| 349 |
+
aspect_ratio=aspect_ratio,
|
| 350 |
+
concurrency=concurrency,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
async def _experimental_stream_generation(
|
| 355 |
+
token: str,
|
| 356 |
+
prompt: str,
|
| 357 |
+
n: int,
|
| 358 |
+
response_format: str,
|
| 359 |
+
response_field: str,
|
| 360 |
+
aspect_ratio: str,
|
| 361 |
+
state: dict[str, Any],
|
| 362 |
+
):
|
| 363 |
+
service = ImagineExperimentalService()
|
| 364 |
+
queue: asyncio.Queue[Optional[str]] = asyncio.Queue()
|
| 365 |
+
index_map: Dict[int, int] = {}
|
| 366 |
+
map_lock = asyncio.Lock()
|
| 367 |
+
next_output_index = 0
|
| 368 |
+
|
| 369 |
+
async def _resolve_output_index(raw_index: int) -> int:
|
| 370 |
+
nonlocal next_output_index
|
| 371 |
+
async with map_lock:
|
| 372 |
+
if raw_index not in index_map:
|
| 373 |
+
index_map[raw_index] = min(next_output_index, max(0, n - 1))
|
| 374 |
+
next_output_index += 1
|
| 375 |
+
return index_map[raw_index]
|
| 376 |
+
|
| 377 |
+
async def _progress_cb(raw_index: int, progress: float):
|
| 378 |
+
idx = await _resolve_output_index(raw_index)
|
| 379 |
+
await queue.put(
|
| 380 |
+
_sse_event(
|
| 381 |
+
"image_generation.partial_image",
|
| 382 |
+
{
|
| 383 |
+
"type": "image_generation.partial_image",
|
| 384 |
+
response_field: "",
|
| 385 |
+
"index": idx,
|
| 386 |
+
"progress": max(0, min(100, int(progress))),
|
| 387 |
+
},
|
| 388 |
+
)
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
async def _completed_cb(raw_index: int, raw_url: str):
|
| 392 |
+
idx = await _resolve_output_index(raw_index)
|
| 393 |
+
converted = await service.convert_url(
|
| 394 |
+
token=token,
|
| 395 |
+
url=raw_url,
|
| 396 |
+
response_format=response_format,
|
| 397 |
+
)
|
| 398 |
+
if not _is_valid_image_value(converted):
|
| 399 |
+
return
|
| 400 |
+
|
| 401 |
+
state["success"] = True
|
| 402 |
+
await queue.put(
|
| 403 |
+
_sse_event(
|
| 404 |
+
"image_generation.completed",
|
| 405 |
+
{
|
| 406 |
+
"type": "image_generation.completed",
|
| 407 |
+
response_field: converted,
|
| 408 |
+
"index": idx,
|
| 409 |
+
"usage": {
|
| 410 |
+
"total_tokens": 50,
|
| 411 |
+
"input_tokens": 25,
|
| 412 |
+
"output_tokens": 25,
|
| 413 |
+
"input_tokens_details": {"text_tokens": 5, "image_tokens": 20},
|
| 414 |
+
},
|
| 415 |
+
},
|
| 416 |
+
)
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
producer_error: Optional[Exception] = None
|
| 420 |
+
|
| 421 |
+
async def _producer():
|
| 422 |
+
nonlocal producer_error
|
| 423 |
+
try:
|
| 424 |
+
await service.generate_ws(
|
| 425 |
+
token=token,
|
| 426 |
+
prompt=prompt,
|
| 427 |
+
n=n,
|
| 428 |
+
aspect_ratio=aspect_ratio,
|
| 429 |
+
progress_cb=_progress_cb,
|
| 430 |
+
completed_cb=_completed_cb,
|
| 431 |
+
)
|
| 432 |
+
except Exception as exc:
|
| 433 |
+
producer_error = exc
|
| 434 |
+
finally:
|
| 435 |
+
await queue.put(None)
|
| 436 |
+
|
| 437 |
+
producer_task = asyncio.create_task(_producer())
|
| 438 |
+
try:
|
| 439 |
+
while True:
|
| 440 |
+
chunk = await queue.get()
|
| 441 |
+
if chunk is None:
|
| 442 |
+
break
|
| 443 |
+
yield chunk
|
| 444 |
+
finally:
|
| 445 |
+
await producer_task
|
| 446 |
+
|
| 447 |
+
if not state.get("success", False):
|
| 448 |
+
if isinstance(producer_error, Exception):
|
| 449 |
+
raise producer_error
|
| 450 |
+
raise UpstreamException("Experimental imagine websocket returned no images")
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def _sse_event(event: str, data: dict) -> str:
|
| 454 |
+
return f"event: {event}\ndata: {orjson.dumps(data).decode()}\n\n"
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
async def _synthetic_image_stream(
|
| 458 |
+
selected_images: List[str],
|
| 459 |
+
response_field: str,
|
| 460 |
+
):
|
| 461 |
+
emitted = False
|
| 462 |
+
for idx, image in enumerate(selected_images):
|
| 463 |
+
if not isinstance(image, str) or not image or image == "error":
|
| 464 |
+
continue
|
| 465 |
+
emitted = True
|
| 466 |
+
yield _sse_event(
|
| 467 |
+
"image_generation.partial_image",
|
| 468 |
+
{
|
| 469 |
+
"type": "image_generation.partial_image",
|
| 470 |
+
response_field: "",
|
| 471 |
+
"index": idx,
|
| 472 |
+
"progress": 100,
|
| 473 |
+
},
|
| 474 |
+
)
|
| 475 |
+
yield _sse_event(
|
| 476 |
+
"image_generation.completed",
|
| 477 |
+
{
|
| 478 |
+
"type": "image_generation.completed",
|
| 479 |
+
response_field: image,
|
| 480 |
+
"index": idx,
|
| 481 |
+
"usage": {
|
| 482 |
+
"total_tokens": 50,
|
| 483 |
+
"input_tokens": 25,
|
| 484 |
+
"output_tokens": 25,
|
| 485 |
+
"input_tokens_details": {"text_tokens": 5, "image_tokens": 20},
|
| 486 |
+
},
|
| 487 |
+
},
|
| 488 |
+
)
|
| 489 |
+
if not emitted:
|
| 490 |
+
yield _sse_event(
|
| 491 |
+
"image_generation.completed",
|
| 492 |
+
{
|
| 493 |
+
"type": "image_generation.completed",
|
| 494 |
+
response_field: "error",
|
| 495 |
+
"index": 0,
|
| 496 |
+
"usage": {
|
| 497 |
+
"total_tokens": 0,
|
| 498 |
+
"input_tokens": 0,
|
| 499 |
+
"output_tokens": 0,
|
| 500 |
+
"input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
|
| 501 |
+
},
|
| 502 |
+
},
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
async def _record_request(model_id: str, success: bool):
|
| 507 |
+
try:
|
| 508 |
+
await request_stats.record_request(model_id, success=success)
|
| 509 |
+
except Exception:
|
| 510 |
+
pass
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
async def _get_token_for_model(model_id: str):
|
| 514 |
+
"""获取指定模型可用 token,失败时抛出统一异常"""
|
| 515 |
+
try:
|
| 516 |
+
token_mgr = await get_token_manager()
|
| 517 |
+
await token_mgr.reload_if_stale()
|
| 518 |
+
token = token_mgr.get_token_for_model(model_id)
|
| 519 |
+
except Exception as e:
|
| 520 |
+
logger.error(f"Failed to get token: {e}")
|
| 521 |
+
await _record_request(model_id or "image", False)
|
| 522 |
+
raise AppException(
|
| 523 |
+
message="Internal service error obtaining token",
|
| 524 |
+
error_type=ErrorType.SERVER.value,
|
| 525 |
+
code="internal_error",
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
if not token:
|
| 529 |
+
await _record_request(model_id or "image", False)
|
| 530 |
+
raise AppException(
|
| 531 |
+
message="No available tokens. Please try again later.",
|
| 532 |
+
error_type=ErrorType.RATE_LIMIT.value,
|
| 533 |
+
code="rate_limit_exceeded",
|
| 534 |
+
status_code=429,
|
| 535 |
+
)
|
| 536 |
+
return token_mgr, token
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def _pick_images(all_images: List[str], n: int) -> List[str]:
|
| 540 |
+
if len(all_images) >= n:
|
| 541 |
+
return random.sample(all_images, n)
|
| 542 |
+
selected = all_images.copy()
|
| 543 |
+
while len(selected) < n:
|
| 544 |
+
selected.append("error")
|
| 545 |
+
return selected
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
def _build_image_response(selected_images: List[str], response_field: str) -> JSONResponse:
|
| 549 |
+
import time
|
| 550 |
+
|
| 551 |
+
return JSONResponse(
|
| 552 |
+
content={
|
| 553 |
+
"created": int(time.time()),
|
| 554 |
+
"data": [{response_field: img} for img in selected_images],
|
| 555 |
+
"usage": {
|
| 556 |
+
"total_tokens": 0 * len([img for img in selected_images if img != "error"]),
|
| 557 |
+
"input_tokens": 0,
|
| 558 |
+
"output_tokens": 0 * len([img for img in selected_images if img != "error"]),
|
| 559 |
+
"input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
|
| 560 |
+
},
|
| 561 |
+
}
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
@router.get("/images/method")
|
| 566 |
+
async def get_image_generation_method():
|
| 567 |
+
return {"image_generation_method": _image_generation_method()}
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
@router.post("/images/generations")
|
| 571 |
+
async def create_image(
|
| 572 |
+
request: ImageGenerationRequest,
|
| 573 |
+
api_key: Optional[str] = Depends(verify_api_key),
|
| 574 |
+
):
|
| 575 |
+
"""Image Generation API."""
|
| 576 |
+
if request.stream is None:
|
| 577 |
+
request.stream = False
|
| 578 |
+
|
| 579 |
+
validate_generation_request(request)
|
| 580 |
+
model_id = request.model or "grok-imagine-1.0"
|
| 581 |
+
n = int(request.n or 1)
|
| 582 |
+
concurrency = max(1, min(3, int(request.concurrency or 1)))
|
| 583 |
+
image_method = _image_generation_method()
|
| 584 |
+
response_format = resolve_image_response_format(request.response_format, image_method)
|
| 585 |
+
request.response_format = response_format
|
| 586 |
+
response_field = response_field_name(response_format)
|
| 587 |
+
aspect_ratio = resolve_aspect_ratio(request.size)
|
| 588 |
+
|
| 589 |
+
await enforce_daily_quota(api_key, model_id, image_count=n)
|
| 590 |
+
token_mgr, token = await _get_token_for_model(model_id)
|
| 591 |
+
model_info = ModelService.get(model_id)
|
| 592 |
+
|
| 593 |
+
if request.stream:
|
| 594 |
+
if image_method == IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL:
|
| 595 |
+
stream_state: Dict[str, Any] = {"success": False}
|
| 596 |
+
|
| 597 |
+
async def _wrapped_experimental_stream():
|
| 598 |
+
try:
|
| 599 |
+
try:
|
| 600 |
+
async for chunk in _experimental_stream_generation(
|
| 601 |
+
token=token,
|
| 602 |
+
prompt=request.prompt,
|
| 603 |
+
n=n,
|
| 604 |
+
response_format=response_format,
|
| 605 |
+
response_field=response_field,
|
| 606 |
+
aspect_ratio=aspect_ratio,
|
| 607 |
+
state=stream_state,
|
| 608 |
+
):
|
| 609 |
+
yield chunk
|
| 610 |
+
except Exception as stream_err:
|
| 611 |
+
logger.warning(
|
| 612 |
+
f"Experimental image generation realtime stream failed: {stream_err}. "
|
| 613 |
+
"Fallback to synthetic stream."
|
| 614 |
+
)
|
| 615 |
+
try:
|
| 616 |
+
all_images = await _collect_experimental_generation_images(
|
| 617 |
+
token=token,
|
| 618 |
+
prompt=request.prompt,
|
| 619 |
+
n=n,
|
| 620 |
+
response_format=response_format,
|
| 621 |
+
aspect_ratio=aspect_ratio,
|
| 622 |
+
concurrency=concurrency,
|
| 623 |
+
)
|
| 624 |
+
selected_images = _pick_images(_dedupe_images(all_images), n)
|
| 625 |
+
stream_state["success"] = any(
|
| 626 |
+
_is_valid_image_value(item) for item in selected_images
|
| 627 |
+
)
|
| 628 |
+
async for chunk in _synthetic_image_stream(selected_images, response_field):
|
| 629 |
+
yield chunk
|
| 630 |
+
except Exception as synthetic_err:
|
| 631 |
+
logger.warning(
|
| 632 |
+
f"Experimental synthetic stream failed: {synthetic_err}. "
|
| 633 |
+
"Fallback to legacy stream."
|
| 634 |
+
)
|
| 635 |
+
chat_service = GrokChatService()
|
| 636 |
+
response = await chat_service.chat(
|
| 637 |
+
token=token,
|
| 638 |
+
message=f"Image Generation: {request.prompt}",
|
| 639 |
+
model=model_info.grok_model,
|
| 640 |
+
mode=model_info.model_mode,
|
| 641 |
+
think=False,
|
| 642 |
+
stream=True,
|
| 643 |
+
)
|
| 644 |
+
processor = ImageStreamProcessor(
|
| 645 |
+
model_info.model_id,
|
| 646 |
+
token,
|
| 647 |
+
n=n,
|
| 648 |
+
response_format=response_format,
|
| 649 |
+
)
|
| 650 |
+
async for chunk in processor.process(response):
|
| 651 |
+
yield chunk
|
| 652 |
+
stream_state["success"] = True
|
| 653 |
+
finally:
|
| 654 |
+
try:
|
| 655 |
+
if stream_state.get("success"):
|
| 656 |
+
await token_mgr.sync_usage(
|
| 657 |
+
token,
|
| 658 |
+
model_info.model_id,
|
| 659 |
+
consume_on_fail=True,
|
| 660 |
+
is_usage=True,
|
| 661 |
+
)
|
| 662 |
+
await _record_request(model_info.model_id, True)
|
| 663 |
+
else:
|
| 664 |
+
await _record_request(model_info.model_id, False)
|
| 665 |
+
except Exception:
|
| 666 |
+
pass
|
| 667 |
+
|
| 668 |
+
return StreamingResponse(
|
| 669 |
+
_wrapped_experimental_stream(),
|
| 670 |
+
media_type="text/event-stream",
|
| 671 |
+
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
chat_service = GrokChatService()
|
| 675 |
+
try:
|
| 676 |
+
response = await chat_service.chat(
|
| 677 |
+
token=token,
|
| 678 |
+
message=f"Image Generation: {request.prompt}",
|
| 679 |
+
model=model_info.grok_model,
|
| 680 |
+
mode=model_info.model_mode,
|
| 681 |
+
think=False,
|
| 682 |
+
stream=True,
|
| 683 |
+
)
|
| 684 |
+
except Exception:
|
| 685 |
+
await _record_request(model_info.model_id, False)
|
| 686 |
+
raise
|
| 687 |
+
|
| 688 |
+
processor = ImageStreamProcessor(
|
| 689 |
+
model_info.model_id,
|
| 690 |
+
token,
|
| 691 |
+
n=n,
|
| 692 |
+
response_format=response_format,
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
async def _wrapped_stream():
|
| 696 |
+
completed = False
|
| 697 |
+
try:
|
| 698 |
+
async for chunk in processor.process(response):
|
| 699 |
+
yield chunk
|
| 700 |
+
completed = True
|
| 701 |
+
finally:
|
| 702 |
+
try:
|
| 703 |
+
if completed:
|
| 704 |
+
await token_mgr.sync_usage(
|
| 705 |
+
token,
|
| 706 |
+
model_info.model_id,
|
| 707 |
+
consume_on_fail=True,
|
| 708 |
+
is_usage=True,
|
| 709 |
+
)
|
| 710 |
+
await _record_request(model_info.model_id, True)
|
| 711 |
+
else:
|
| 712 |
+
await _record_request(model_info.model_id, False)
|
| 713 |
+
except Exception:
|
| 714 |
+
pass
|
| 715 |
+
|
| 716 |
+
return StreamingResponse(
|
| 717 |
+
_wrapped_stream(),
|
| 718 |
+
media_type="text/event-stream",
|
| 719 |
+
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
all_images: List[str] = []
|
| 723 |
+
if image_method == IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL:
|
| 724 |
+
try:
|
| 725 |
+
all_images = await _collect_experimental_generation_images(
|
| 726 |
+
token=token,
|
| 727 |
+
prompt=request.prompt,
|
| 728 |
+
n=n,
|
| 729 |
+
response_format=response_format,
|
| 730 |
+
aspect_ratio=aspect_ratio,
|
| 731 |
+
concurrency=concurrency,
|
| 732 |
+
)
|
| 733 |
+
except Exception as e:
|
| 734 |
+
logger.warning(f"Experimental image generation failed, fallback to legacy: {e}")
|
| 735 |
+
|
| 736 |
+
if not all_images:
|
| 737 |
+
calls_needed = (n + 1) // 2
|
| 738 |
+
task_factories: List[Callable[[], Awaitable[List[str]]]] = [
|
| 739 |
+
lambda: call_grok_legacy(
|
| 740 |
+
token,
|
| 741 |
+
f"Image Generation: {request.prompt}",
|
| 742 |
+
model_info,
|
| 743 |
+
response_format=response_format,
|
| 744 |
+
)
|
| 745 |
+
for _ in range(calls_needed)
|
| 746 |
+
]
|
| 747 |
+
results = await _gather_limited(
|
| 748 |
+
task_factories,
|
| 749 |
+
max_concurrency=min(calls_needed, concurrency),
|
| 750 |
+
)
|
| 751 |
+
|
| 752 |
+
all_images = []
|
| 753 |
+
for result in results:
|
| 754 |
+
if isinstance(result, Exception):
|
| 755 |
+
logger.error(f"Concurrent call failed: {result}")
|
| 756 |
+
elif isinstance(result, list):
|
| 757 |
+
all_images.extend(result)
|
| 758 |
+
|
| 759 |
+
selected_images = _pick_images(_dedupe_images(all_images), n)
|
| 760 |
+
success = any(_is_valid_image_value(img) for img in selected_images)
|
| 761 |
+
try:
|
| 762 |
+
if success:
|
| 763 |
+
await token_mgr.sync_usage(
|
| 764 |
+
token,
|
| 765 |
+
model_info.model_id,
|
| 766 |
+
consume_on_fail=True,
|
| 767 |
+
is_usage=True,
|
| 768 |
+
)
|
| 769 |
+
await _record_request(model_info.model_id, bool(success))
|
| 770 |
+
except Exception:
|
| 771 |
+
pass
|
| 772 |
+
|
| 773 |
+
return _build_image_response(selected_images, response_field)
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
@router.post("/images/edits")
|
| 777 |
+
async def edit_image(
|
| 778 |
+
prompt: str = Form(...),
|
| 779 |
+
image: Optional[List[UploadFile]] = File(None),
|
| 780 |
+
image_alias: Optional[List[UploadFile]] = File(None, alias="image[]"),
|
| 781 |
+
model: Optional[str] = Form("grok-imagine-1.0-edit"),
|
| 782 |
+
n: int = Form(1),
|
| 783 |
+
size: str = Form("1024x1024"),
|
| 784 |
+
quality: str = Form("standard"),
|
| 785 |
+
response_format: Optional[str] = Form(None),
|
| 786 |
+
style: Optional[str] = Form(None),
|
| 787 |
+
stream: Optional[bool] = Form(False),
|
| 788 |
+
api_key: Optional[str] = Depends(verify_api_key),
|
| 789 |
+
):
|
| 790 |
+
"""
|
| 791 |
+
Image Edits API
|
| 792 |
+
|
| 793 |
+
同官方 API 格式,仅支持 multipart/form-data 文件上传
|
| 794 |
+
"""
|
| 795 |
+
try:
|
| 796 |
+
edit_request = ImageEditRequest(
|
| 797 |
+
prompt=prompt,
|
| 798 |
+
model=model,
|
| 799 |
+
n=n,
|
| 800 |
+
size=size,
|
| 801 |
+
quality=quality,
|
| 802 |
+
response_format=response_format,
|
| 803 |
+
style=style,
|
| 804 |
+
stream=stream,
|
| 805 |
+
)
|
| 806 |
+
except ValidationError as exc:
|
| 807 |
+
errors = exc.errors()
|
| 808 |
+
if errors:
|
| 809 |
+
first = errors[0]
|
| 810 |
+
loc = first.get("loc", [])
|
| 811 |
+
msg = first.get("msg", "Invalid request")
|
| 812 |
+
code = first.get("type", "invalid_value")
|
| 813 |
+
param_parts = [str(x) for x in loc if not (isinstance(x, int) or str(x).isdigit())]
|
| 814 |
+
param = ".".join(param_parts) if param_parts else None
|
| 815 |
+
raise ValidationException(message=msg, param=param, code=code)
|
| 816 |
+
raise ValidationException(message="Invalid request", code="invalid_value")
|
| 817 |
+
|
| 818 |
+
if edit_request.stream is None:
|
| 819 |
+
edit_request.stream = False
|
| 820 |
+
if edit_request.n is None:
|
| 821 |
+
edit_request.n = 1
|
| 822 |
+
|
| 823 |
+
image_method = _image_generation_method()
|
| 824 |
+
response_format = resolve_image_response_format(edit_request.response_format, image_method)
|
| 825 |
+
edit_request.response_format = response_format
|
| 826 |
+
response_field = response_field_name(response_format)
|
| 827 |
+
images = (image or []) + (image_alias or [])
|
| 828 |
+
validate_edit_request(edit_request, images)
|
| 829 |
+
|
| 830 |
+
model_id = edit_request.model or "grok-imagine-1.0-edit"
|
| 831 |
+
n = int(edit_request.n or 1)
|
| 832 |
+
|
| 833 |
+
await enforce_daily_quota(api_key, model_id, image_count=n)
|
| 834 |
+
|
| 835 |
+
max_image_bytes = 50 * 1024 * 1024
|
| 836 |
+
allowed_types = {"image/png", "image/jpeg", "image/webp", "image/jpg"}
|
| 837 |
+
image_payloads: List[str] = []
|
| 838 |
+
|
| 839 |
+
for item in images:
|
| 840 |
+
content = await item.read()
|
| 841 |
+
await item.close()
|
| 842 |
+
if not content:
|
| 843 |
+
raise ValidationException(
|
| 844 |
+
message="File content is empty",
|
| 845 |
+
param="image",
|
| 846 |
+
code="empty_file",
|
| 847 |
+
)
|
| 848 |
+
if len(content) > max_image_bytes:
|
| 849 |
+
raise ValidationException(
|
| 850 |
+
message="Image file too large. Maximum is 50MB.",
|
| 851 |
+
param="image",
|
| 852 |
+
code="file_too_large",
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
mime = (item.content_type or "").lower()
|
| 856 |
+
if mime == "image/jpg":
|
| 857 |
+
mime = "image/jpeg"
|
| 858 |
+
ext = Path(item.filename or "").suffix.lower()
|
| 859 |
+
if mime not in allowed_types:
|
| 860 |
+
if ext in (".jpg", ".jpeg"):
|
| 861 |
+
mime = "image/jpeg"
|
| 862 |
+
elif ext == ".png":
|
| 863 |
+
mime = "image/png"
|
| 864 |
+
elif ext == ".webp":
|
| 865 |
+
mime = "image/webp"
|
| 866 |
+
else:
|
| 867 |
+
raise ValidationException(
|
| 868 |
+
message="Unsupported image type. Supported: png, jpg, webp.",
|
| 869 |
+
param="image",
|
| 870 |
+
code="invalid_image_type",
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
image_payloads.append(f"data:{mime};base64,{base64.b64encode(content).decode()}")
|
| 874 |
+
|
| 875 |
+
token_mgr, token = await _get_token_for_model(model_id)
|
| 876 |
+
model_info = ModelService.get(model_id)
|
| 877 |
+
|
| 878 |
+
file_ids: List[str] = []
|
| 879 |
+
file_uris: List[str] = []
|
| 880 |
+
upload_service = UploadService()
|
| 881 |
+
try:
|
| 882 |
+
for payload in image_payloads:
|
| 883 |
+
file_id, file_uri = await upload_service.upload(payload, token)
|
| 884 |
+
if file_id:
|
| 885 |
+
file_ids.append(file_id)
|
| 886 |
+
if file_uri:
|
| 887 |
+
file_uris.append(file_uri)
|
| 888 |
+
finally:
|
| 889 |
+
await upload_service.close()
|
| 890 |
+
|
| 891 |
+
if edit_request.stream:
|
| 892 |
+
if image_method == IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL:
|
| 893 |
+
try:
|
| 894 |
+
service = ImagineExperimentalService()
|
| 895 |
+
response = await service.chat_edit(
|
| 896 |
+
token=token,
|
| 897 |
+
prompt=edit_request.prompt,
|
| 898 |
+
file_uris=file_uris,
|
| 899 |
+
)
|
| 900 |
+
processor = ImageStreamProcessor(
|
| 901 |
+
model_info.model_id,
|
| 902 |
+
token,
|
| 903 |
+
n=n,
|
| 904 |
+
response_format=response_format,
|
| 905 |
+
)
|
| 906 |
+
|
| 907 |
+
async def _wrapped_experimental_stream():
|
| 908 |
+
completed = False
|
| 909 |
+
try:
|
| 910 |
+
async for chunk in processor.process(response):
|
| 911 |
+
yield chunk
|
| 912 |
+
completed = True
|
| 913 |
+
finally:
|
| 914 |
+
try:
|
| 915 |
+
if completed:
|
| 916 |
+
await token_mgr.sync_usage(
|
| 917 |
+
token,
|
| 918 |
+
model_info.model_id,
|
| 919 |
+
consume_on_fail=True,
|
| 920 |
+
is_usage=True,
|
| 921 |
+
)
|
| 922 |
+
await _record_request(model_info.model_id, True)
|
| 923 |
+
else:
|
| 924 |
+
await _record_request(model_info.model_id, False)
|
| 925 |
+
except Exception:
|
| 926 |
+
pass
|
| 927 |
+
|
| 928 |
+
return StreamingResponse(
|
| 929 |
+
_wrapped_experimental_stream(),
|
| 930 |
+
media_type="text/event-stream",
|
| 931 |
+
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
| 932 |
+
)
|
| 933 |
+
except Exception as e:
|
| 934 |
+
logger.warning(f"Experimental image edit stream failed, fallback to legacy: {e}")
|
| 935 |
+
|
| 936 |
+
chat_service = GrokChatService()
|
| 937 |
+
try:
|
| 938 |
+
response = await chat_service.chat(
|
| 939 |
+
token=token,
|
| 940 |
+
message=f"Image Edit: {edit_request.prompt}",
|
| 941 |
+
model=model_info.grok_model,
|
| 942 |
+
mode=model_info.model_mode,
|
| 943 |
+
think=False,
|
| 944 |
+
stream=True,
|
| 945 |
+
file_attachments=file_ids,
|
| 946 |
+
)
|
| 947 |
+
except Exception:
|
| 948 |
+
await _record_request(model_info.model_id, False)
|
| 949 |
+
raise
|
| 950 |
+
|
| 951 |
+
processor = ImageStreamProcessor(
|
| 952 |
+
model_info.model_id,
|
| 953 |
+
token,
|
| 954 |
+
n=n,
|
| 955 |
+
response_format=response_format,
|
| 956 |
+
)
|
| 957 |
+
|
| 958 |
+
async def _wrapped_stream():
|
| 959 |
+
completed = False
|
| 960 |
+
try:
|
| 961 |
+
async for chunk in processor.process(response):
|
| 962 |
+
yield chunk
|
| 963 |
+
completed = True
|
| 964 |
+
finally:
|
| 965 |
+
try:
|
| 966 |
+
if completed:
|
| 967 |
+
await token_mgr.sync_usage(
|
| 968 |
+
token,
|
| 969 |
+
model_info.model_id,
|
| 970 |
+
consume_on_fail=True,
|
| 971 |
+
is_usage=True,
|
| 972 |
+
)
|
| 973 |
+
await _record_request(model_info.model_id, True)
|
| 974 |
+
else:
|
| 975 |
+
await _record_request(model_info.model_id, False)
|
| 976 |
+
except Exception:
|
| 977 |
+
pass
|
| 978 |
+
|
| 979 |
+
return StreamingResponse(
|
| 980 |
+
_wrapped_stream(),
|
| 981 |
+
media_type="text/event-stream",
|
| 982 |
+
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
| 983 |
+
)
|
| 984 |
+
|
| 985 |
+
all_images: List[str] = []
|
| 986 |
+
if image_method == IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL:
|
| 987 |
+
try:
|
| 988 |
+
calls_needed = (n + 1) // 2
|
| 989 |
+
if calls_needed == 1:
|
| 990 |
+
all_images = await call_grok_experimental_edit(
|
| 991 |
+
token=token,
|
| 992 |
+
prompt=edit_request.prompt,
|
| 993 |
+
model_id=model_info.model_id,
|
| 994 |
+
file_uris=file_uris,
|
| 995 |
+
response_format=response_format,
|
| 996 |
+
)
|
| 997 |
+
else:
|
| 998 |
+
tasks = [
|
| 999 |
+
call_grok_experimental_edit(
|
| 1000 |
+
token=token,
|
| 1001 |
+
prompt=edit_request.prompt,
|
| 1002 |
+
model_id=model_info.model_id,
|
| 1003 |
+
file_uris=file_uris,
|
| 1004 |
+
response_format=response_format,
|
| 1005 |
+
)
|
| 1006 |
+
for _ in range(calls_needed)
|
| 1007 |
+
]
|
| 1008 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 1009 |
+
for result in results:
|
| 1010 |
+
if isinstance(result, Exception):
|
| 1011 |
+
logger.warning(f"Experimental image edit call failed: {result}")
|
| 1012 |
+
elif isinstance(result, list):
|
| 1013 |
+
all_images.extend(result)
|
| 1014 |
+
if not all_images:
|
| 1015 |
+
raise UpstreamException("Experimental image edit returned no images")
|
| 1016 |
+
except Exception as e:
|
| 1017 |
+
logger.warning(f"Experimental image edit failed, fallback to legacy: {e}")
|
| 1018 |
+
|
| 1019 |
+
if not all_images:
|
| 1020 |
+
calls_needed = (n + 1) // 2
|
| 1021 |
+
if calls_needed == 1:
|
| 1022 |
+
all_images = await call_grok_legacy(
|
| 1023 |
+
token,
|
| 1024 |
+
f"Image Edit: {edit_request.prompt}",
|
| 1025 |
+
model_info,
|
| 1026 |
+
file_attachments=file_ids,
|
| 1027 |
+
response_format=response_format,
|
| 1028 |
+
)
|
| 1029 |
+
else:
|
| 1030 |
+
tasks = [
|
| 1031 |
+
call_grok_legacy(
|
| 1032 |
+
token,
|
| 1033 |
+
f"Image Edit: {edit_request.prompt}",
|
| 1034 |
+
model_info,
|
| 1035 |
+
file_attachments=file_ids,
|
| 1036 |
+
response_format=response_format,
|
| 1037 |
+
)
|
| 1038 |
+
for _ in range(calls_needed)
|
| 1039 |
+
]
|
| 1040 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 1041 |
+
all_images = []
|
| 1042 |
+
for result in results:
|
| 1043 |
+
if isinstance(result, Exception):
|
| 1044 |
+
logger.error(f"Concurrent call failed: {result}")
|
| 1045 |
+
elif isinstance(result, list):
|
| 1046 |
+
all_images.extend(result)
|
| 1047 |
+
|
| 1048 |
+
selected_images = _pick_images(all_images, n)
|
| 1049 |
+
success = any(isinstance(img, str) and img and img != "error" for img in selected_images)
|
| 1050 |
+
try:
|
| 1051 |
+
if success:
|
| 1052 |
+
await token_mgr.sync_usage(
|
| 1053 |
+
token,
|
| 1054 |
+
model_info.model_id,
|
| 1055 |
+
consume_on_fail=True,
|
| 1056 |
+
is_usage=True,
|
| 1057 |
+
)
|
| 1058 |
+
await _record_request(model_info.model_id, bool(success))
|
| 1059 |
+
except Exception:
|
| 1060 |
+
pass
|
| 1061 |
+
|
| 1062 |
+
return _build_image_response(selected_images, response_field)
|
| 1063 |
+
|
| 1064 |
+
|
| 1065 |
+
__all__ = ["router"]
|
app/api/v1/models.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Models API 路由
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
from fastapi import APIRouter, HTTPException
|
| 8 |
+
|
| 9 |
+
from app.services.grok.model import ModelService
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
router = APIRouter(tags=["Models"])
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@router.get("/models")
|
| 16 |
+
async def list_models():
|
| 17 |
+
"""OpenAI 兼容 models 列表接口"""
|
| 18 |
+
ts = int(time.time())
|
| 19 |
+
data = [
|
| 20 |
+
{
|
| 21 |
+
"id": m.model_id,
|
| 22 |
+
"object": "model",
|
| 23 |
+
"created": ts,
|
| 24 |
+
"owned_by": "grok2api",
|
| 25 |
+
"display_name": m.display_name,
|
| 26 |
+
"description": m.description,
|
| 27 |
+
}
|
| 28 |
+
for m in ModelService.list()
|
| 29 |
+
]
|
| 30 |
+
return {"object": "list", "data": data}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@router.get("/models/{model_id}")
|
| 34 |
+
async def get_model(model_id: str):
|
| 35 |
+
"""OpenAI compatible: single model detail."""
|
| 36 |
+
m = ModelService.get(model_id)
|
| 37 |
+
if not m:
|
| 38 |
+
raise HTTPException(status_code=404, detail=f"Model '{model_id}' not found")
|
| 39 |
+
|
| 40 |
+
ts = int(time.time())
|
| 41 |
+
return {
|
| 42 |
+
"id": m.model_id,
|
| 43 |
+
"object": "model",
|
| 44 |
+
"created": ts,
|
| 45 |
+
"owned_by": "grok2api",
|
| 46 |
+
"display_name": m.display_name,
|
| 47 |
+
"description": m.description,
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
__all__ = ["router"]
|
app/api/v1/uploads.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Uploads API (used by the web chat UI)
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import uuid
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import aiofiles
|
| 9 |
+
from fastapi import APIRouter, UploadFile, File, HTTPException
|
| 10 |
+
|
| 11 |
+
from app.services.grok.assets import DownloadService
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
router = APIRouter(tags=["Uploads"])
|
| 15 |
+
|
| 16 |
+
BASE_DIR = Path(__file__).parent.parent.parent.parent / "data" / "tmp"
|
| 17 |
+
IMAGE_DIR = BASE_DIR / "image"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _ext_from_mime(mime: str) -> str:
|
| 21 |
+
m = (mime or "").lower()
|
| 22 |
+
if m == "image/png":
|
| 23 |
+
return "png"
|
| 24 |
+
if m == "image/webp":
|
| 25 |
+
return "webp"
|
| 26 |
+
if m == "image/gif":
|
| 27 |
+
return "gif"
|
| 28 |
+
if m in ("image/jpeg", "image/jpg"):
|
| 29 |
+
return "jpg"
|
| 30 |
+
return "jpg"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@router.post("/uploads/image")
|
| 34 |
+
async def upload_image(file: UploadFile = File(...)):
|
| 35 |
+
content_type = (file.content_type or "").lower()
|
| 36 |
+
if not content_type.startswith("image/"):
|
| 37 |
+
raise HTTPException(status_code=400, detail=f"Unsupported file type: {file.content_type}")
|
| 38 |
+
|
| 39 |
+
IMAGE_DIR.mkdir(parents=True, exist_ok=True)
|
| 40 |
+
name = f"upload-{uuid.uuid4().hex}.{_ext_from_mime(content_type)}"
|
| 41 |
+
path = IMAGE_DIR / name
|
| 42 |
+
|
| 43 |
+
size = 0
|
| 44 |
+
async with aiofiles.open(path, "wb") as f:
|
| 45 |
+
while True:
|
| 46 |
+
chunk = await file.read(1024 * 1024)
|
| 47 |
+
if not chunk:
|
| 48 |
+
break
|
| 49 |
+
size += len(chunk)
|
| 50 |
+
await f.write(chunk)
|
| 51 |
+
|
| 52 |
+
# Best-effort: reuse existing cache cleanup policy (size-based).
|
| 53 |
+
try:
|
| 54 |
+
dl = DownloadService()
|
| 55 |
+
await dl.check_limit()
|
| 56 |
+
await dl.close()
|
| 57 |
+
except Exception:
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
return {"url": f"/v1/files/image/{name}", "name": name, "size_bytes": size}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
__all__ = ["router"]
|
| 64 |
+
|
app/api/v1/video.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TODO:Video Generation API 路由
|
| 3 |
+
"""
|
app/core/auth.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API 认证模块
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import asyncio
|
| 8 |
+
import json
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Optional, Set
|
| 11 |
+
|
| 12 |
+
from fastapi import HTTPException, Security, status
|
| 13 |
+
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
| 14 |
+
|
| 15 |
+
from app.core.config import get_config
|
| 16 |
+
|
| 17 |
+
# 定义 Bearer Scheme
|
| 18 |
+
security = HTTPBearer(
|
| 19 |
+
auto_error=False,
|
| 20 |
+
scheme_name="API Key",
|
| 21 |
+
description="Enter your API Key in the format: Bearer <key>",
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
LEGACY_API_KEYS_FILE = Path(__file__).parent.parent.parent / "data" / "api_keys.json"
|
| 25 |
+
_legacy_api_keys_cache: Set[str] | None = None
|
| 26 |
+
_legacy_api_keys_mtime: float | None = None
|
| 27 |
+
_legacy_api_keys_lock = asyncio.Lock()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
async def _load_legacy_api_keys() -> Set[str]:
|
| 31 |
+
"""
|
| 32 |
+
Backward-compatible API keys loader.
|
| 33 |
+
|
| 34 |
+
Older versions stored multiple API keys in `data/api_keys.json` with a shape like:
|
| 35 |
+
[{"key": "...", "is_active": true, ...}, ...]
|
| 36 |
+
"""
|
| 37 |
+
global _legacy_api_keys_cache, _legacy_api_keys_mtime
|
| 38 |
+
|
| 39 |
+
if not LEGACY_API_KEYS_FILE.exists():
|
| 40 |
+
_legacy_api_keys_cache = set()
|
| 41 |
+
_legacy_api_keys_mtime = None
|
| 42 |
+
return set()
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
stat = LEGACY_API_KEYS_FILE.stat()
|
| 46 |
+
mtime = stat.st_mtime
|
| 47 |
+
except Exception:
|
| 48 |
+
mtime = None
|
| 49 |
+
|
| 50 |
+
if _legacy_api_keys_cache is not None and mtime is not None and _legacy_api_keys_mtime == mtime:
|
| 51 |
+
return _legacy_api_keys_cache
|
| 52 |
+
|
| 53 |
+
async with _legacy_api_keys_lock:
|
| 54 |
+
# Re-check in lock
|
| 55 |
+
if not LEGACY_API_KEYS_FILE.exists():
|
| 56 |
+
_legacy_api_keys_cache = set()
|
| 57 |
+
_legacy_api_keys_mtime = None
|
| 58 |
+
return set()
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
stat = LEGACY_API_KEYS_FILE.stat()
|
| 62 |
+
mtime = stat.st_mtime
|
| 63 |
+
except Exception:
|
| 64 |
+
mtime = None
|
| 65 |
+
|
| 66 |
+
if _legacy_api_keys_cache is not None and mtime is not None and _legacy_api_keys_mtime == mtime:
|
| 67 |
+
return _legacy_api_keys_cache
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
raw = await asyncio.to_thread(LEGACY_API_KEYS_FILE.read_text, "utf-8")
|
| 71 |
+
data = json.loads(raw) if raw.strip() else []
|
| 72 |
+
except Exception:
|
| 73 |
+
data = []
|
| 74 |
+
|
| 75 |
+
keys: Set[str] = set()
|
| 76 |
+
if isinstance(data, list):
|
| 77 |
+
for item in data:
|
| 78 |
+
if not isinstance(item, dict):
|
| 79 |
+
continue
|
| 80 |
+
key = item.get("key")
|
| 81 |
+
is_active = item.get("is_active", True)
|
| 82 |
+
if isinstance(key, str) and key.strip() and is_active is not False:
|
| 83 |
+
keys.add(key.strip())
|
| 84 |
+
|
| 85 |
+
_legacy_api_keys_cache = keys
|
| 86 |
+
_legacy_api_keys_mtime = mtime
|
| 87 |
+
return keys
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
async def verify_api_key(
|
| 91 |
+
auth: Optional[HTTPAuthorizationCredentials] = Security(security),
|
| 92 |
+
) -> Optional[str]:
|
| 93 |
+
"""
|
| 94 |
+
验证 Bearer Token
|
| 95 |
+
|
| 96 |
+
- 若 `app.api_key` 未配置且不存在 legacy keys,则跳过验证。
|
| 97 |
+
- 若配置了 `app.api_key` 或存在 legacy keys,则必须提供 Authorization: Bearer <key>。
|
| 98 |
+
"""
|
| 99 |
+
api_key = str(get_config("app.api_key", "") or "").strip()
|
| 100 |
+
legacy_keys = await _load_legacy_api_keys()
|
| 101 |
+
|
| 102 |
+
# 如果未配置 API Key 且没有 legacy keys,直接放行
|
| 103 |
+
if not api_key and not legacy_keys:
|
| 104 |
+
return None
|
| 105 |
+
|
| 106 |
+
if not auth:
|
| 107 |
+
raise HTTPException(
|
| 108 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 109 |
+
detail="Missing authentication token",
|
| 110 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
token = auth.credentials
|
| 114 |
+
if (api_key and token == api_key) or token in legacy_keys:
|
| 115 |
+
return token
|
| 116 |
+
|
| 117 |
+
raise HTTPException(
|
| 118 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 119 |
+
detail="Invalid authentication token",
|
| 120 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
async def verify_app_key(
|
| 125 |
+
auth: Optional[HTTPAuthorizationCredentials] = Security(security),
|
| 126 |
+
) -> Optional[str]:
|
| 127 |
+
"""
|
| 128 |
+
验证后台登录密钥(app_key)。
|
| 129 |
+
|
| 130 |
+
如果未配置 app_key,则跳过验证。
|
| 131 |
+
"""
|
| 132 |
+
app_key = str(get_config("app.app_key", "") or "").strip()
|
| 133 |
+
|
| 134 |
+
if not app_key:
|
| 135 |
+
raise HTTPException(
|
| 136 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 137 |
+
detail="App key is not configured",
|
| 138 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
if not auth:
|
| 142 |
+
raise HTTPException(
|
| 143 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 144 |
+
detail="Missing authentication token",
|
| 145 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
if auth.credentials != app_key:
|
| 149 |
+
raise HTTPException(
|
| 150 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 151 |
+
detail="Invalid authentication token",
|
| 152 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
return auth.credentials
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
__all__ = ["verify_api_key", "verify_app_key"]
|
| 159 |
+
|
app/core/config.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
配置管理
|
| 3 |
+
|
| 4 |
+
- config.toml: 运行时配置
|
| 5 |
+
- config.defaults.toml: 默认配置基线
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from copy import deepcopy
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any, Dict
|
| 11 |
+
import tomllib
|
| 12 |
+
|
| 13 |
+
from app.core.logger import logger
|
| 14 |
+
|
| 15 |
+
DEFAULT_CONFIG_FILE = Path(__file__).parent.parent.parent / "config.defaults.toml"
|
| 16 |
+
LEGACY_CONFIG_FILE = Path(__file__).parent.parent.parent / "data" / "setting.toml"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _as_str(v: Any) -> str:
|
| 20 |
+
if isinstance(v, str):
|
| 21 |
+
return v
|
| 22 |
+
return ""
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _as_int(v: Any) -> int | None:
|
| 26 |
+
try:
|
| 27 |
+
if v is None:
|
| 28 |
+
return None
|
| 29 |
+
return int(v)
|
| 30 |
+
except Exception:
|
| 31 |
+
return None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _as_bool(v: Any) -> bool | None:
|
| 35 |
+
if isinstance(v, bool):
|
| 36 |
+
return v
|
| 37 |
+
return None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _split_csv_tags(v: Any) -> list[str] | None:
|
| 41 |
+
if not isinstance(v, str):
|
| 42 |
+
return None
|
| 43 |
+
parts = [x.strip() for x in v.split(",")]
|
| 44 |
+
tags = [x for x in parts if x]
|
| 45 |
+
return tags or None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _legacy_setting_to_config(legacy: Dict[str, Any]) -> Dict[str, Any]:
|
| 49 |
+
"""
|
| 50 |
+
Migrate legacy `data/setting.toml` format (grok/global) to the new config schema.
|
| 51 |
+
|
| 52 |
+
Best-effort mapping only for stable fields. It does not delete or rename the legacy file.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
grok = legacy.get("grok") if isinstance(legacy.get("grok"), dict) else {}
|
| 56 |
+
global_ = legacy.get("global") if isinstance(legacy.get("global"), dict) else {}
|
| 57 |
+
|
| 58 |
+
out: Dict[str, Any] = {}
|
| 59 |
+
|
| 60 |
+
# === app ===
|
| 61 |
+
app_url = _as_str(global_.get("base_url")).strip()
|
| 62 |
+
admin_username = _as_str(global_.get("admin_username")).strip()
|
| 63 |
+
app_key = _as_str(global_.get("admin_password")).strip()
|
| 64 |
+
api_key = _as_str(grok.get("api_key")).strip()
|
| 65 |
+
image_format = _as_str(global_.get("image_mode")).strip()
|
| 66 |
+
|
| 67 |
+
if app_url or admin_username or app_key or api_key or image_format:
|
| 68 |
+
out["app"] = {}
|
| 69 |
+
if app_url:
|
| 70 |
+
out["app"]["app_url"] = app_url
|
| 71 |
+
if admin_username:
|
| 72 |
+
out["app"]["admin_username"] = admin_username
|
| 73 |
+
if app_key:
|
| 74 |
+
out["app"]["app_key"] = app_key
|
| 75 |
+
if api_key:
|
| 76 |
+
out["app"]["api_key"] = api_key
|
| 77 |
+
if image_format:
|
| 78 |
+
out["app"]["image_format"] = image_format
|
| 79 |
+
|
| 80 |
+
# === grok ===
|
| 81 |
+
base_proxy_url = _as_str(grok.get("proxy_url")).strip()
|
| 82 |
+
asset_proxy_url = _as_str(grok.get("cache_proxy_url")).strip()
|
| 83 |
+
cf_clearance = _as_str(grok.get("cf_clearance")).strip()
|
| 84 |
+
|
| 85 |
+
temporary = _as_bool(grok.get("temporary"))
|
| 86 |
+
thinking = _as_bool(grok.get("show_thinking"))
|
| 87 |
+
dynamic_statsig = _as_bool(grok.get("dynamic_statsig"))
|
| 88 |
+
filter_tags = _split_csv_tags(grok.get("filtered_tags"))
|
| 89 |
+
|
| 90 |
+
retry_status_codes = grok.get("retry_status_codes")
|
| 91 |
+
|
| 92 |
+
timeout = None
|
| 93 |
+
total_timeout = _as_int(grok.get("stream_total_timeout"))
|
| 94 |
+
if total_timeout and total_timeout > 0:
|
| 95 |
+
timeout = total_timeout
|
| 96 |
+
else:
|
| 97 |
+
chunk_timeout = _as_int(grok.get("stream_chunk_timeout"))
|
| 98 |
+
if chunk_timeout and chunk_timeout > 0:
|
| 99 |
+
timeout = chunk_timeout
|
| 100 |
+
|
| 101 |
+
if (
|
| 102 |
+
base_proxy_url
|
| 103 |
+
or asset_proxy_url
|
| 104 |
+
or cf_clearance
|
| 105 |
+
or temporary is not None
|
| 106 |
+
or thinking is not None
|
| 107 |
+
or dynamic_statsig is not None
|
| 108 |
+
or filter_tags is not None
|
| 109 |
+
or timeout is not None
|
| 110 |
+
or isinstance(retry_status_codes, list)
|
| 111 |
+
):
|
| 112 |
+
out["grok"] = {}
|
| 113 |
+
if base_proxy_url:
|
| 114 |
+
out["grok"]["base_proxy_url"] = base_proxy_url
|
| 115 |
+
if asset_proxy_url:
|
| 116 |
+
out["grok"]["asset_proxy_url"] = asset_proxy_url
|
| 117 |
+
if cf_clearance:
|
| 118 |
+
out["grok"]["cf_clearance"] = cf_clearance
|
| 119 |
+
if temporary is not None:
|
| 120 |
+
out["grok"]["temporary"] = temporary
|
| 121 |
+
if thinking is not None:
|
| 122 |
+
out["grok"]["thinking"] = thinking
|
| 123 |
+
if dynamic_statsig is not None:
|
| 124 |
+
out["grok"]["dynamic_statsig"] = dynamic_statsig
|
| 125 |
+
if filter_tags is not None:
|
| 126 |
+
out["grok"]["filter_tags"] = filter_tags
|
| 127 |
+
if timeout is not None:
|
| 128 |
+
out["grok"]["timeout"] = timeout
|
| 129 |
+
if isinstance(retry_status_codes, list) and retry_status_codes:
|
| 130 |
+
out["grok"]["retry_status_codes"] = retry_status_codes
|
| 131 |
+
|
| 132 |
+
# === cache ===
|
| 133 |
+
# Legacy had separate limits; new uses a single total limit_mb.
|
| 134 |
+
image_mb = _as_int(global_.get("image_cache_max_size_mb")) or 0
|
| 135 |
+
video_mb = _as_int(global_.get("video_cache_max_size_mb")) or 0
|
| 136 |
+
if image_mb > 0 or video_mb > 0:
|
| 137 |
+
out["cache"] = {"limit_mb": max(1, image_mb + video_mb)}
|
| 138 |
+
|
| 139 |
+
return out
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _apply_legacy_config(
|
| 143 |
+
config_data: Dict[str, Any],
|
| 144 |
+
legacy_cfg: Dict[str, Any],
|
| 145 |
+
defaults: Dict[str, Any],
|
| 146 |
+
) -> bool:
|
| 147 |
+
"""
|
| 148 |
+
Merge legacy settings into current config:
|
| 149 |
+
- fill missing keys
|
| 150 |
+
- override keys that are still default values
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
changed = False
|
| 154 |
+
for section, items in legacy_cfg.items():
|
| 155 |
+
if not isinstance(items, dict):
|
| 156 |
+
continue
|
| 157 |
+
|
| 158 |
+
current_section = config_data.get(section)
|
| 159 |
+
if not isinstance(current_section, dict):
|
| 160 |
+
current_section = {}
|
| 161 |
+
config_data[section] = current_section
|
| 162 |
+
changed = True
|
| 163 |
+
|
| 164 |
+
default_section = defaults.get(section) if isinstance(defaults.get(section), dict) else {}
|
| 165 |
+
|
| 166 |
+
for key, val in items.items():
|
| 167 |
+
if val is None:
|
| 168 |
+
continue
|
| 169 |
+
if key not in current_section:
|
| 170 |
+
current_section[key] = val
|
| 171 |
+
changed = True
|
| 172 |
+
continue
|
| 173 |
+
|
| 174 |
+
default_val = default_section.get(key) if isinstance(default_section, dict) else None
|
| 175 |
+
current_val = current_section.get(key)
|
| 176 |
+
|
| 177 |
+
# NOTE: The admin panel password default used to be `grok2api` in older versions.
|
| 178 |
+
# Treat it as "still default" so legacy `data/setting.toml` can override it during migration.
|
| 179 |
+
is_effective_default = current_val == default_val
|
| 180 |
+
if section == "app" and key == "app_key" and current_val == "grok2api":
|
| 181 |
+
is_effective_default = True
|
| 182 |
+
|
| 183 |
+
if is_effective_default and val != default_val:
|
| 184 |
+
current_section[key] = val
|
| 185 |
+
changed = True
|
| 186 |
+
|
| 187 |
+
return changed
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def _deep_merge(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
|
| 191 |
+
"""深度合并字典:override 覆盖 base。"""
|
| 192 |
+
if not isinstance(base, dict):
|
| 193 |
+
return deepcopy(override) if isinstance(override, dict) else deepcopy(base)
|
| 194 |
+
|
| 195 |
+
result = deepcopy(base)
|
| 196 |
+
if not isinstance(override, dict):
|
| 197 |
+
return result
|
| 198 |
+
|
| 199 |
+
for key, val in override.items():
|
| 200 |
+
if isinstance(val, dict) and isinstance(result.get(key), dict):
|
| 201 |
+
result[key] = _deep_merge(result[key], val)
|
| 202 |
+
else:
|
| 203 |
+
result[key] = val
|
| 204 |
+
return result
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def _load_defaults() -> Dict[str, Any]:
|
| 208 |
+
"""加载默认配置文件"""
|
| 209 |
+
if not DEFAULT_CONFIG_FILE.exists():
|
| 210 |
+
return {}
|
| 211 |
+
try:
|
| 212 |
+
with DEFAULT_CONFIG_FILE.open("rb") as f:
|
| 213 |
+
return tomllib.load(f)
|
| 214 |
+
except Exception as e:
|
| 215 |
+
logger.warning(f"Failed to load defaults from {DEFAULT_CONFIG_FILE}: {e}")
|
| 216 |
+
return {}
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class Config:
|
| 220 |
+
"""配置管理器"""
|
| 221 |
+
|
| 222 |
+
_instance = None
|
| 223 |
+
_config = {}
|
| 224 |
+
|
| 225 |
+
def __init__(self):
|
| 226 |
+
self._config = {}
|
| 227 |
+
self._defaults = {}
|
| 228 |
+
self._defaults_loaded = False
|
| 229 |
+
|
| 230 |
+
def _ensure_defaults(self):
|
| 231 |
+
if self._defaults_loaded:
|
| 232 |
+
return
|
| 233 |
+
self._defaults = _load_defaults()
|
| 234 |
+
self._defaults_loaded = True
|
| 235 |
+
|
| 236 |
+
async def load(self):
|
| 237 |
+
"""显式加载配置"""
|
| 238 |
+
try:
|
| 239 |
+
from app.core.storage import get_storage, LocalStorage
|
| 240 |
+
|
| 241 |
+
self._ensure_defaults()
|
| 242 |
+
|
| 243 |
+
storage = get_storage()
|
| 244 |
+
config_data = await storage.load_config()
|
| 245 |
+
from_remote = True
|
| 246 |
+
|
| 247 |
+
# 从本地 data/config.toml 初始化后端
|
| 248 |
+
if config_data is None:
|
| 249 |
+
local_storage = LocalStorage()
|
| 250 |
+
from_remote = False
|
| 251 |
+
try:
|
| 252 |
+
config_data = await local_storage.load_config()
|
| 253 |
+
except Exception as e:
|
| 254 |
+
logger.info(f"Failed to auto-init config from local: {e}")
|
| 255 |
+
config_data = {}
|
| 256 |
+
|
| 257 |
+
config_data = config_data or {}
|
| 258 |
+
before_legacy = deepcopy(config_data)
|
| 259 |
+
|
| 260 |
+
# Legacy migration: data/setting.toml -> config schema
|
| 261 |
+
if LEGACY_CONFIG_FILE.exists():
|
| 262 |
+
try:
|
| 263 |
+
with LEGACY_CONFIG_FILE.open("rb") as f:
|
| 264 |
+
legacy_raw = tomllib.load(f) or {}
|
| 265 |
+
legacy_cfg = _legacy_setting_to_config(legacy_raw)
|
| 266 |
+
if legacy_cfg and _apply_legacy_config(config_data, legacy_cfg, self._defaults):
|
| 267 |
+
logger.info(
|
| 268 |
+
"Detected legacy data/setting.toml, migrated into config (missing/default keys)."
|
| 269 |
+
)
|
| 270 |
+
except Exception as e:
|
| 271 |
+
logger.warning(f"Failed to migrate legacy config from {LEGACY_CONFIG_FILE}: {e}")
|
| 272 |
+
|
| 273 |
+
merged = _deep_merge(self._defaults, config_data)
|
| 274 |
+
|
| 275 |
+
# 自动回填缺失配置到存储
|
| 276 |
+
should_persist = (not from_remote) or (merged != before_legacy)
|
| 277 |
+
if should_persist:
|
| 278 |
+
async with storage.acquire_lock("config_save", timeout=10):
|
| 279 |
+
await storage.save_config(merged)
|
| 280 |
+
if not from_remote:
|
| 281 |
+
logger.info(
|
| 282 |
+
f"Initialized remote storage ({storage.__class__.__name__}) with config baseline."
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
self._config = merged
|
| 286 |
+
except Exception as e:
|
| 287 |
+
logger.error(f"Error loading config: {e}")
|
| 288 |
+
self._config = {}
|
| 289 |
+
|
| 290 |
+
def get(self, key: str, default: Any = None) -> Any:
|
| 291 |
+
"""
|
| 292 |
+
获取配置值
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
key: 配置键,格式 "section.key"
|
| 296 |
+
default: 默认值
|
| 297 |
+
"""
|
| 298 |
+
if "." in key:
|
| 299 |
+
try:
|
| 300 |
+
section, attr = key.split(".", 1)
|
| 301 |
+
return self._config.get(section, {}).get(attr, default)
|
| 302 |
+
except (ValueError, AttributeError):
|
| 303 |
+
return default
|
| 304 |
+
|
| 305 |
+
return self._config.get(key, default)
|
| 306 |
+
|
| 307 |
+
async def update(self, new_config: dict):
|
| 308 |
+
"""更新配置"""
|
| 309 |
+
from app.core.storage import get_storage
|
| 310 |
+
|
| 311 |
+
storage = get_storage()
|
| 312 |
+
async with storage.acquire_lock("config_save", timeout=10):
|
| 313 |
+
self._ensure_defaults()
|
| 314 |
+
base = _deep_merge(self._defaults, self._config or {})
|
| 315 |
+
merged = _deep_merge(base, new_config or {})
|
| 316 |
+
await storage.save_config(merged)
|
| 317 |
+
self._config = merged
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
# 全局配置实例
|
| 321 |
+
config = Config()
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def get_config(key: str, default: Any = None) -> Any:
|
| 325 |
+
"""获取配置"""
|
| 326 |
+
return config.get(key, default)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
__all__ = ["Config", "config", "get_config"]
|
app/core/exceptions.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
全局异常处理 - OpenAI 兼容错误格式
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Any, Optional
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from fastapi import Request, HTTPException
|
| 8 |
+
from fastapi.responses import JSONResponse
|
| 9 |
+
from fastapi.exceptions import RequestValidationError
|
| 10 |
+
|
| 11 |
+
from app.core.logger import logger
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# ============= 错误类型 =============
|
| 15 |
+
|
| 16 |
+
class ErrorType(str, Enum):
|
| 17 |
+
"""OpenAI 错误类型"""
|
| 18 |
+
INVALID_REQUEST = "invalid_request_error"
|
| 19 |
+
AUTHENTICATION = "authentication_error"
|
| 20 |
+
PERMISSION = "permission_error"
|
| 21 |
+
NOT_FOUND = "not_found_error"
|
| 22 |
+
RATE_LIMIT = "rate_limit_error"
|
| 23 |
+
SERVER = "server_error"
|
| 24 |
+
SERVICE_UNAVAILABLE = "service_unavailable_error"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# ============= 辅助函数 =============
|
| 28 |
+
|
| 29 |
+
def error_response(
|
| 30 |
+
message: str,
|
| 31 |
+
error_type: str = ErrorType.INVALID_REQUEST.value,
|
| 32 |
+
param: str = None,
|
| 33 |
+
code: str = None
|
| 34 |
+
) -> dict:
|
| 35 |
+
"""构建 OpenAI 错误响应"""
|
| 36 |
+
return {
|
| 37 |
+
"error": {
|
| 38 |
+
"message": message,
|
| 39 |
+
"type": error_type,
|
| 40 |
+
"param": param,
|
| 41 |
+
"code": code
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ============= 异常类 =============
|
| 47 |
+
|
| 48 |
+
class AppException(Exception):
|
| 49 |
+
"""应用基础异常"""
|
| 50 |
+
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
message: str,
|
| 54 |
+
error_type: str = ErrorType.SERVER.value,
|
| 55 |
+
code: str = None,
|
| 56 |
+
param: str = None,
|
| 57 |
+
status_code: int = 500
|
| 58 |
+
):
|
| 59 |
+
self.message = message
|
| 60 |
+
self.error_type = error_type
|
| 61 |
+
self.code = code
|
| 62 |
+
self.param = param
|
| 63 |
+
self.status_code = status_code
|
| 64 |
+
super().__init__(message)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class ValidationException(AppException):
|
| 68 |
+
"""验证错误"""
|
| 69 |
+
|
| 70 |
+
def __init__(self, message: str, param: str = None, code: str = None):
|
| 71 |
+
super().__init__(
|
| 72 |
+
message=message,
|
| 73 |
+
error_type=ErrorType.INVALID_REQUEST.value,
|
| 74 |
+
code=code or "invalid_value",
|
| 75 |
+
param=param,
|
| 76 |
+
status_code=400
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class AuthenticationException(AppException):
|
| 81 |
+
"""认证错误"""
|
| 82 |
+
|
| 83 |
+
def __init__(self, message: str = "Invalid API key"):
|
| 84 |
+
super().__init__(
|
| 85 |
+
message=message,
|
| 86 |
+
error_type=ErrorType.AUTHENTICATION.value,
|
| 87 |
+
code="invalid_api_key",
|
| 88 |
+
status_code=401
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class UpstreamException(AppException):
|
| 93 |
+
"""上游服务错误"""
|
| 94 |
+
|
| 95 |
+
def __init__(self, message: str, details: Any = None):
|
| 96 |
+
super().__init__(
|
| 97 |
+
message=message,
|
| 98 |
+
error_type=ErrorType.SERVER.value,
|
| 99 |
+
code="upstream_error",
|
| 100 |
+
status_code=502
|
| 101 |
+
)
|
| 102 |
+
self.details = details
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# ============= 异常处理器 =============
|
| 106 |
+
|
| 107 |
+
async def app_exception_handler(request: Request, exc: AppException) -> JSONResponse:
|
| 108 |
+
"""处理应用异常"""
|
| 109 |
+
logger.warning(f"AppException: {exc.error_type} - {exc.message}")
|
| 110 |
+
|
| 111 |
+
return JSONResponse(
|
| 112 |
+
status_code=exc.status_code,
|
| 113 |
+
content=error_response(
|
| 114 |
+
message=exc.message,
|
| 115 |
+
error_type=exc.error_type,
|
| 116 |
+
param=exc.param,
|
| 117 |
+
code=exc.code
|
| 118 |
+
)
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
| 123 |
+
"""处理 HTTP 异常"""
|
| 124 |
+
type_map = {
|
| 125 |
+
400: ErrorType.INVALID_REQUEST.value,
|
| 126 |
+
401: ErrorType.AUTHENTICATION.value,
|
| 127 |
+
403: ErrorType.PERMISSION.value,
|
| 128 |
+
404: ErrorType.NOT_FOUND.value,
|
| 129 |
+
429: ErrorType.RATE_LIMIT.value,
|
| 130 |
+
}
|
| 131 |
+
error_type = type_map.get(exc.status_code, ErrorType.SERVER.value)
|
| 132 |
+
|
| 133 |
+
# 默认 code 映射
|
| 134 |
+
code_map = {
|
| 135 |
+
401: "invalid_api_key",
|
| 136 |
+
403: "insufficient_quota",
|
| 137 |
+
404: "model_not_found",
|
| 138 |
+
429: "rate_limit_exceeded",
|
| 139 |
+
}
|
| 140 |
+
code = code_map.get(exc.status_code, None)
|
| 141 |
+
|
| 142 |
+
logger.warning(f"HTTPException: {exc.status_code} - {exc.detail}")
|
| 143 |
+
|
| 144 |
+
return JSONResponse(
|
| 145 |
+
status_code=exc.status_code,
|
| 146 |
+
content=error_response(
|
| 147 |
+
message=str(exc.detail),
|
| 148 |
+
error_type=error_type,
|
| 149 |
+
code=code
|
| 150 |
+
)
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
|
| 155 |
+
"""处理验证错误"""
|
| 156 |
+
errors = exc.errors()
|
| 157 |
+
|
| 158 |
+
if errors:
|
| 159 |
+
first = errors[0]
|
| 160 |
+
loc = first.get("loc", [])
|
| 161 |
+
msg = first.get("msg", "Invalid request")
|
| 162 |
+
code = first.get("type", "invalid_value")
|
| 163 |
+
|
| 164 |
+
# JSON 解析错误
|
| 165 |
+
if code == "json_invalid" or "JSON" in msg:
|
| 166 |
+
message = "Invalid JSON in request body. Please check for trailing commas or syntax errors."
|
| 167 |
+
param = "body"
|
| 168 |
+
else:
|
| 169 |
+
param_parts = [str(x) for x in loc if not (isinstance(x, int) or str(x).isdigit())]
|
| 170 |
+
param = ".".join(param_parts) if param_parts else None
|
| 171 |
+
message = msg
|
| 172 |
+
else:
|
| 173 |
+
param, message, code = None, "Invalid request", "invalid_value"
|
| 174 |
+
|
| 175 |
+
logger.warning(f"ValidationError: {param} - {message}")
|
| 176 |
+
|
| 177 |
+
return JSONResponse(
|
| 178 |
+
status_code=400,
|
| 179 |
+
content=error_response(
|
| 180 |
+
message=message,
|
| 181 |
+
error_type=ErrorType.INVALID_REQUEST.value,
|
| 182 |
+
param=param,
|
| 183 |
+
code=code
|
| 184 |
+
)
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
async def generic_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
| 189 |
+
"""处理未捕获异常"""
|
| 190 |
+
logger.exception(f"Unhandled: {type(exc).__name__}: {str(exc)}")
|
| 191 |
+
|
| 192 |
+
return JSONResponse(
|
| 193 |
+
status_code=500,
|
| 194 |
+
content=error_response(
|
| 195 |
+
message="Internal server error",
|
| 196 |
+
error_type=ErrorType.SERVER.value,
|
| 197 |
+
code="internal_error"
|
| 198 |
+
)
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# ============= 注册 =============
|
| 203 |
+
|
| 204 |
+
def register_exception_handlers(app):
|
| 205 |
+
"""注册异常处理器"""
|
| 206 |
+
app.add_exception_handler(AppException, app_exception_handler)
|
| 207 |
+
app.add_exception_handler(HTTPException, http_exception_handler)
|
| 208 |
+
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
| 209 |
+
app.add_exception_handler(Exception, generic_exception_handler)
|
| 210 |
+
app.add_exception_handler(Exception, generic_exception_handler)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
__all__ = [
|
| 214 |
+
"ErrorType",
|
| 215 |
+
"AppException",
|
| 216 |
+
"ValidationException",
|
| 217 |
+
"AuthenticationException",
|
| 218 |
+
"UpstreamException",
|
| 219 |
+
"error_response",
|
| 220 |
+
"register_exception_handlers",
|
| 221 |
+
]
|
app/core/legacy_migration.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Legacy data migrations for local deployments (python/docker).
|
| 3 |
+
|
| 4 |
+
Goal: when upgrading the project, old on-disk data should still be readable and not lost.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import asyncio
|
| 10 |
+
import os
|
| 11 |
+
import shutil
|
| 12 |
+
import time
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Any, Dict
|
| 15 |
+
|
| 16 |
+
from app.core.logger import logger
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def migrate_legacy_cache_dirs(data_dir: Path | None = None) -> Dict[str, Any]:
|
| 20 |
+
"""
|
| 21 |
+
Migrate old cache directory layout:
|
| 22 |
+
|
| 23 |
+
- legacy: data/temp/{image,video}
|
| 24 |
+
- current: data/tmp/{image,video}
|
| 25 |
+
|
| 26 |
+
This keeps existing cached files (not yet cleaned) available after upgrades.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
data_root = data_dir or (Path(__file__).parent.parent.parent / "data")
|
| 30 |
+
legacy_root = data_root / "temp"
|
| 31 |
+
current_root = data_root / "tmp"
|
| 32 |
+
|
| 33 |
+
if not legacy_root.exists() or not legacy_root.is_dir():
|
| 34 |
+
return {"migrated": False, "reason": "no_legacy_dir"}
|
| 35 |
+
|
| 36 |
+
lock_dir = data_root / ".locks"
|
| 37 |
+
lock_dir.mkdir(parents=True, exist_ok=True)
|
| 38 |
+
|
| 39 |
+
done_marker = lock_dir / "legacy_cache_dirs_v1.done"
|
| 40 |
+
if done_marker.exists():
|
| 41 |
+
return {"migrated": False, "reason": "already_done"}
|
| 42 |
+
|
| 43 |
+
lock_file = lock_dir / "legacy_cache_dirs_v1.lock"
|
| 44 |
+
|
| 45 |
+
# Best-effort cross-process lock (works on Windows/Linux).
|
| 46 |
+
fd: int | None = None
|
| 47 |
+
try:
|
| 48 |
+
try:
|
| 49 |
+
fd = os.open(str(lock_file), os.O_CREAT | os.O_EXCL | os.O_WRONLY)
|
| 50 |
+
except FileExistsError:
|
| 51 |
+
# Another worker/process is migrating. Wait briefly for completion.
|
| 52 |
+
deadline = time.monotonic() + 30.0
|
| 53 |
+
while time.monotonic() < deadline:
|
| 54 |
+
if done_marker.exists():
|
| 55 |
+
return {"migrated": False, "reason": "waited_for_other_process"}
|
| 56 |
+
time.sleep(0.2)
|
| 57 |
+
return {"migrated": False, "reason": "lock_timeout"}
|
| 58 |
+
|
| 59 |
+
current_root.mkdir(parents=True, exist_ok=True)
|
| 60 |
+
|
| 61 |
+
moved = 0
|
| 62 |
+
skipped = 0
|
| 63 |
+
errors = 0
|
| 64 |
+
|
| 65 |
+
for sub in ("image", "video"):
|
| 66 |
+
src_dir = legacy_root / sub
|
| 67 |
+
if not src_dir.exists() or not src_dir.is_dir():
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
dst_dir = current_root / sub
|
| 71 |
+
dst_dir.mkdir(parents=True, exist_ok=True)
|
| 72 |
+
|
| 73 |
+
for item in src_dir.iterdir():
|
| 74 |
+
if not item.is_file():
|
| 75 |
+
continue
|
| 76 |
+
target = dst_dir / item.name
|
| 77 |
+
if target.exists():
|
| 78 |
+
skipped += 1
|
| 79 |
+
continue
|
| 80 |
+
try:
|
| 81 |
+
shutil.move(str(item), str(target))
|
| 82 |
+
moved += 1
|
| 83 |
+
except Exception:
|
| 84 |
+
errors += 1
|
| 85 |
+
|
| 86 |
+
# Cleanup empty legacy dirs (best-effort).
|
| 87 |
+
for sub in ("image", "video"):
|
| 88 |
+
p = legacy_root / sub
|
| 89 |
+
try:
|
| 90 |
+
if p.exists() and p.is_dir() and not any(p.iterdir()):
|
| 91 |
+
p.rmdir()
|
| 92 |
+
except Exception:
|
| 93 |
+
pass
|
| 94 |
+
try:
|
| 95 |
+
if legacy_root.exists() and legacy_root.is_dir() and not any(legacy_root.iterdir()):
|
| 96 |
+
legacy_root.rmdir()
|
| 97 |
+
except Exception:
|
| 98 |
+
pass
|
| 99 |
+
|
| 100 |
+
if errors == 0:
|
| 101 |
+
done_marker.write_text(str(int(time.time())), encoding="utf-8")
|
| 102 |
+
if moved or skipped or errors:
|
| 103 |
+
logger.info(
|
| 104 |
+
f"Legacy cache migration complete: moved={moved}, skipped={skipped}, errors={errors}"
|
| 105 |
+
)
|
| 106 |
+
return {"migrated": True, "moved": moved, "skipped": skipped, "errors": errors}
|
| 107 |
+
finally:
|
| 108 |
+
try:
|
| 109 |
+
if fd is not None:
|
| 110 |
+
os.close(fd)
|
| 111 |
+
except Exception:
|
| 112 |
+
pass
|
| 113 |
+
try:
|
| 114 |
+
if lock_file.exists():
|
| 115 |
+
lock_file.unlink()
|
| 116 |
+
except Exception:
|
| 117 |
+
pass
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
__all__ = ["migrate_legacy_cache_dirs", "migrate_legacy_account_settings"]
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
async def migrate_legacy_account_settings(
|
| 124 |
+
concurrency: int = 10,
|
| 125 |
+
data_dir: Path | None = None,
|
| 126 |
+
) -> Dict[str, Any]:
|
| 127 |
+
"""
|
| 128 |
+
After legacy data migration, run a one-time TOS + BirthDate + NSFW pass for existing accounts.
|
| 129 |
+
|
| 130 |
+
This is best-effort and guarded by a cross-process lock + done marker.
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
data_root = data_dir or (Path(__file__).parent.parent.parent / "data")
|
| 134 |
+
lock_dir = data_root / ".locks"
|
| 135 |
+
lock_dir.mkdir(parents=True, exist_ok=True)
|
| 136 |
+
|
| 137 |
+
done_marker = lock_dir / "legacy_accounts_tos_birth_nsfw_v2.done"
|
| 138 |
+
if done_marker.exists():
|
| 139 |
+
return {"migrated": False, "reason": "already_done"}
|
| 140 |
+
|
| 141 |
+
lock_file = lock_dir / "legacy_accounts_tos_birth_nsfw_v2.lock"
|
| 142 |
+
fd: int | None = None
|
| 143 |
+
|
| 144 |
+
try:
|
| 145 |
+
try:
|
| 146 |
+
fd = os.open(str(lock_file), os.O_CREAT | os.O_EXCL | os.O_WRONLY)
|
| 147 |
+
except FileExistsError:
|
| 148 |
+
deadline = time.monotonic() + 30.0
|
| 149 |
+
while time.monotonic() < deadline:
|
| 150 |
+
if done_marker.exists():
|
| 151 |
+
return {"migrated": False, "reason": "waited_for_other_process"}
|
| 152 |
+
await asyncio.sleep(0.2)
|
| 153 |
+
return {"migrated": False, "reason": "lock_timeout"}
|
| 154 |
+
|
| 155 |
+
from app.core.config import get_config
|
| 156 |
+
from app.core.storage import get_storage
|
| 157 |
+
from app.services.register.services import (
|
| 158 |
+
UserAgreementService,
|
| 159 |
+
BirthDateService,
|
| 160 |
+
NsfwSettingsService,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
storage = get_storage()
|
| 164 |
+
try:
|
| 165 |
+
token_data = await storage.load_tokens()
|
| 166 |
+
except Exception as exc:
|
| 167 |
+
logger.warning("Legacy account migration: failed to load tokens: {}", exc)
|
| 168 |
+
return {"migrated": False, "reason": "load_tokens_failed"}
|
| 169 |
+
|
| 170 |
+
token_data = token_data or {}
|
| 171 |
+
tokens: list[str] = []
|
| 172 |
+
for items in token_data.values():
|
| 173 |
+
if not isinstance(items, list):
|
| 174 |
+
continue
|
| 175 |
+
for item in items:
|
| 176 |
+
if isinstance(item, str):
|
| 177 |
+
tokens.append(item)
|
| 178 |
+
elif isinstance(item, dict):
|
| 179 |
+
token_val = item.get("token")
|
| 180 |
+
if isinstance(token_val, str):
|
| 181 |
+
tokens.append(token_val)
|
| 182 |
+
|
| 183 |
+
# De-duplicate while preserving order.
|
| 184 |
+
tokens = list(dict.fromkeys([t.strip() for t in tokens if isinstance(t, str) and t.strip()]))
|
| 185 |
+
if not tokens:
|
| 186 |
+
done_marker.write_text(str(int(time.time())), encoding="utf-8")
|
| 187 |
+
return {"migrated": True, "total": 0, "ok": 0, "failed": 0}
|
| 188 |
+
|
| 189 |
+
try:
|
| 190 |
+
concurrency = max(1, int(concurrency))
|
| 191 |
+
except Exception:
|
| 192 |
+
concurrency = 10
|
| 193 |
+
|
| 194 |
+
cf_clearance = str(get_config("grok.cf_clearance", "") or "").strip()
|
| 195 |
+
|
| 196 |
+
def _extract_cookie_value(cookie_str: str, name: str) -> str | None:
|
| 197 |
+
needle = f"{name}="
|
| 198 |
+
if needle not in cookie_str:
|
| 199 |
+
return None
|
| 200 |
+
for part in cookie_str.split(";"):
|
| 201 |
+
part = part.strip()
|
| 202 |
+
if part.startswith(needle):
|
| 203 |
+
return part[len(needle):].strip()
|
| 204 |
+
return None
|
| 205 |
+
|
| 206 |
+
def _normalize_tokens(raw_token: str) -> tuple[str, str]:
|
| 207 |
+
raw_token = raw_token.strip()
|
| 208 |
+
if ";" in raw_token:
|
| 209 |
+
sso_val = _extract_cookie_value(raw_token, "sso") or ""
|
| 210 |
+
sso_rw_val = _extract_cookie_value(raw_token, "sso-rw") or sso_val
|
| 211 |
+
else:
|
| 212 |
+
sso_val = raw_token[4:] if raw_token.startswith("sso=") else raw_token
|
| 213 |
+
sso_rw_val = sso_val
|
| 214 |
+
return sso_val, sso_rw_val
|
| 215 |
+
|
| 216 |
+
def _apply_settings(raw_token: str) -> bool:
|
| 217 |
+
sso_val, sso_rw_val = _normalize_tokens(raw_token)
|
| 218 |
+
if not sso_val:
|
| 219 |
+
return False
|
| 220 |
+
|
| 221 |
+
user_service = UserAgreementService(cf_clearance=cf_clearance)
|
| 222 |
+
birth_service = BirthDateService(cf_clearance=cf_clearance)
|
| 223 |
+
nsfw_service = NsfwSettingsService(cf_clearance=cf_clearance)
|
| 224 |
+
|
| 225 |
+
tos_result = user_service.accept_tos_version(
|
| 226 |
+
sso=sso_val,
|
| 227 |
+
sso_rw=sso_rw_val or sso_val,
|
| 228 |
+
impersonate="chrome120",
|
| 229 |
+
)
|
| 230 |
+
if not tos_result.get("ok"):
|
| 231 |
+
return False
|
| 232 |
+
|
| 233 |
+
birth_result = birth_service.set_birth_date(
|
| 234 |
+
sso=sso_val,
|
| 235 |
+
sso_rw=sso_rw_val or sso_val,
|
| 236 |
+
impersonate="chrome120",
|
| 237 |
+
)
|
| 238 |
+
if not birth_result.get("ok"):
|
| 239 |
+
return False
|
| 240 |
+
|
| 241 |
+
nsfw_result = nsfw_service.enable_nsfw(
|
| 242 |
+
sso=sso_val,
|
| 243 |
+
sso_rw=sso_rw_val or sso_val,
|
| 244 |
+
impersonate="chrome120",
|
| 245 |
+
)
|
| 246 |
+
return bool(nsfw_result.get("ok"))
|
| 247 |
+
|
| 248 |
+
sem = asyncio.Semaphore(concurrency)
|
| 249 |
+
|
| 250 |
+
async def _run_one(token: str) -> bool:
|
| 251 |
+
async with sem:
|
| 252 |
+
return await asyncio.to_thread(_apply_settings, token)
|
| 253 |
+
|
| 254 |
+
tasks = [_run_one(token) for token in tokens]
|
| 255 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 256 |
+
|
| 257 |
+
ok = 0
|
| 258 |
+
failed = 0
|
| 259 |
+
for res in results:
|
| 260 |
+
if isinstance(res, Exception):
|
| 261 |
+
failed += 1
|
| 262 |
+
elif res:
|
| 263 |
+
ok += 1
|
| 264 |
+
else:
|
| 265 |
+
failed += 1
|
| 266 |
+
|
| 267 |
+
done_marker.write_text(str(int(time.time())), encoding="utf-8")
|
| 268 |
+
logger.info(
|
| 269 |
+
"Legacy account migration complete: total=%d, ok=%d, failed=%d",
|
| 270 |
+
len(tokens),
|
| 271 |
+
ok,
|
| 272 |
+
failed,
|
| 273 |
+
)
|
| 274 |
+
return {"migrated": True, "total": len(tokens), "ok": ok, "failed": failed}
|
| 275 |
+
finally:
|
| 276 |
+
try:
|
| 277 |
+
if fd is not None:
|
| 278 |
+
os.close(fd)
|
| 279 |
+
except Exception:
|
| 280 |
+
pass
|
| 281 |
+
try:
|
| 282 |
+
if lock_file.exists():
|
| 283 |
+
lock_file.unlink()
|
| 284 |
+
except Exception:
|
| 285 |
+
pass
|
app/core/logger.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
结构化 JSON 日志 - 极简格式
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import sys
|
| 6 |
+
import json
|
| 7 |
+
import traceback
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from loguru import logger
|
| 10 |
+
|
| 11 |
+
# 日志目录
|
| 12 |
+
LOG_DIR = Path(__file__).parent.parent.parent / "logs"
|
| 13 |
+
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _format_json(record) -> str:
|
| 17 |
+
"""格式化日志"""
|
| 18 |
+
# ISO8601 时间
|
| 19 |
+
time_str = record["time"].strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
| 20 |
+
tz = record["time"].strftime("%z")
|
| 21 |
+
if tz:
|
| 22 |
+
time_str += tz[:3] + ":" + tz[3:]
|
| 23 |
+
|
| 24 |
+
log_entry = {
|
| 25 |
+
"time": time_str,
|
| 26 |
+
"level": record["level"].name.lower(),
|
| 27 |
+
"msg": record["message"],
|
| 28 |
+
"caller": f"{record['file'].name}:{record['line']}",
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
# trace 上下文
|
| 32 |
+
extra = record["extra"]
|
| 33 |
+
if extra.get("traceID"):
|
| 34 |
+
log_entry["traceID"] = extra["traceID"]
|
| 35 |
+
if extra.get("spanID"):
|
| 36 |
+
log_entry["spanID"] = extra["spanID"]
|
| 37 |
+
|
| 38 |
+
# 其他 extra 字段
|
| 39 |
+
for key, value in extra.items():
|
| 40 |
+
if key not in ("traceID", "spanID") and not key.startswith("_"):
|
| 41 |
+
log_entry[key] = value
|
| 42 |
+
|
| 43 |
+
# 错误及以上级别添加堆栈跟踪
|
| 44 |
+
if record["level"].no >= 40 and record["exception"]:
|
| 45 |
+
log_entry["stacktrace"] = "".join(traceback.format_exception(
|
| 46 |
+
record["exception"].type,
|
| 47 |
+
record["exception"].value,
|
| 48 |
+
record["exception"].traceback
|
| 49 |
+
))
|
| 50 |
+
|
| 51 |
+
return json.dumps(log_entry, ensure_ascii=False)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _make_json_sink(output):
|
| 55 |
+
"""创建 JSON sink"""
|
| 56 |
+
def sink(message):
|
| 57 |
+
json_str = _format_json(message.record)
|
| 58 |
+
print(json_str, file=output, flush=True)
|
| 59 |
+
return sink
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _file_json_sink(message):
|
| 63 |
+
"""写入日志文件"""
|
| 64 |
+
record = message.record
|
| 65 |
+
json_str = _format_json(record)
|
| 66 |
+
log_file = LOG_DIR / f"app_{record['time'].strftime('%Y-%m-%d')}.log"
|
| 67 |
+
with open(log_file, "a", encoding="utf-8") as f:
|
| 68 |
+
f.write(json_str + "\n")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def setup_logging(
|
| 72 |
+
level: str = "DEBUG",
|
| 73 |
+
json_console: bool = True,
|
| 74 |
+
file_logging: bool = True,
|
| 75 |
+
):
|
| 76 |
+
"""设置日志配置"""
|
| 77 |
+
logger.remove()
|
| 78 |
+
|
| 79 |
+
# 控制台输出
|
| 80 |
+
if json_console:
|
| 81 |
+
logger.add(
|
| 82 |
+
_make_json_sink(sys.stdout),
|
| 83 |
+
level=level,
|
| 84 |
+
format="{message}",
|
| 85 |
+
colorize=False,
|
| 86 |
+
)
|
| 87 |
+
else:
|
| 88 |
+
logger.add(
|
| 89 |
+
sys.stdout,
|
| 90 |
+
level=level,
|
| 91 |
+
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{file.name}:{line}</cyan> - <level>{message}</level>",
|
| 92 |
+
colorize=True,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# 文件输出
|
| 96 |
+
if file_logging:
|
| 97 |
+
logger.add(
|
| 98 |
+
_file_json_sink,
|
| 99 |
+
level=level,
|
| 100 |
+
format="{message}",
|
| 101 |
+
enqueue=True,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
return logger
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def get_logger(trace_id: str = "", span_id: str = ""):
|
| 108 |
+
"""获取绑定了 trace 上下文的 logger"""
|
| 109 |
+
bound = {}
|
| 110 |
+
if trace_id:
|
| 111 |
+
bound["traceID"] = trace_id
|
| 112 |
+
if span_id:
|
| 113 |
+
bound["spanID"] = span_id
|
| 114 |
+
return logger.bind(**bound) if bound else logger
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
__all__ = ["logger", "setup_logging", "get_logger", "LOG_DIR"]
|
app/core/response_middleware.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
响应中间件
|
| 3 |
+
Response Middleware
|
| 4 |
+
|
| 5 |
+
用于记录请求日志、生成 TraceID 和计算请求耗时
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import time
|
| 9 |
+
import uuid
|
| 10 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 11 |
+
from starlette.requests import Request
|
| 12 |
+
from starlette.types import ASGIApp
|
| 13 |
+
|
| 14 |
+
from app.core.logger import logger
|
| 15 |
+
|
| 16 |
+
class ResponseLoggerMiddleware(BaseHTTPMiddleware):
|
| 17 |
+
"""
|
| 18 |
+
请求日志/响应追踪中间件
|
| 19 |
+
Request Logging and Response Tracking Middleware
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
async def dispatch(self, request: Request, call_next):
|
| 23 |
+
# 生成请求 ID
|
| 24 |
+
trace_id = str(uuid.uuid4())
|
| 25 |
+
request.state.trace_id = trace_id
|
| 26 |
+
|
| 27 |
+
start_time = time.time()
|
| 28 |
+
|
| 29 |
+
# 记录请求信息
|
| 30 |
+
logger.info(
|
| 31 |
+
f"Request: {request.method} {request.url.path}",
|
| 32 |
+
extra={
|
| 33 |
+
"traceID": trace_id,
|
| 34 |
+
"method": request.method,
|
| 35 |
+
"path": request.url.path
|
| 36 |
+
}
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
response = await call_next(request)
|
| 41 |
+
|
| 42 |
+
# 计算耗时
|
| 43 |
+
duration = (time.time() - start_time) * 1000
|
| 44 |
+
|
| 45 |
+
# 记录响应信息
|
| 46 |
+
logger.info(
|
| 47 |
+
f"Response: {request.method} {request.url.path} - {response.status_code} ({duration:.2f}ms)",
|
| 48 |
+
extra={
|
| 49 |
+
"traceID": trace_id,
|
| 50 |
+
"method": request.method,
|
| 51 |
+
"path": request.url.path,
|
| 52 |
+
"status": response.status_code,
|
| 53 |
+
"duration_ms": round(duration, 2)
|
| 54 |
+
}
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
return response
|
| 58 |
+
|
| 59 |
+
except Exception as e:
|
| 60 |
+
duration = (time.time() - start_time) * 1000
|
| 61 |
+
logger.error(
|
| 62 |
+
f"Response Error: {request.method} {request.url.path} - {str(e)} ({duration:.2f}ms)",
|
| 63 |
+
extra={
|
| 64 |
+
"traceID": trace_id,
|
| 65 |
+
"method": request.method,
|
| 66 |
+
"path": request.url.path,
|
| 67 |
+
"duration_ms": round(duration, 2),
|
| 68 |
+
"error": str(e)
|
| 69 |
+
}
|
| 70 |
+
)
|
| 71 |
+
raise e
|
app/core/storage.py
ADDED
|
@@ -0,0 +1,720 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
统一存储服务 (Professional Storage Service)
|
| 3 |
+
支持 Local (TOML), Redis, MySQL, PostgreSQL
|
| 4 |
+
|
| 5 |
+
特性:
|
| 6 |
+
- 全异步 I/O (Async I/O)
|
| 7 |
+
- 连接池管理 (Connection Pooling)
|
| 8 |
+
- 分布式/本地锁 (Distributed/Local Locking)
|
| 9 |
+
- 内存优化 (序列化性能优化)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import abc
|
| 13 |
+
import os
|
| 14 |
+
import asyncio
|
| 15 |
+
import os
|
| 16 |
+
import hashlib
|
| 17 |
+
import time
|
| 18 |
+
import tomllib
|
| 19 |
+
from typing import Any, Dict, Optional
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from enum import Enum
|
| 22 |
+
try:
|
| 23 |
+
import fcntl
|
| 24 |
+
except ImportError: # pragma: no cover - non-posix platforms
|
| 25 |
+
fcntl = None
|
| 26 |
+
from contextlib import asynccontextmanager
|
| 27 |
+
|
| 28 |
+
import orjson
|
| 29 |
+
import aiofiles
|
| 30 |
+
from app.core.logger import logger
|
| 31 |
+
|
| 32 |
+
# 配置文件路径
|
| 33 |
+
CONFIG_FILE = Path(__file__).parent.parent.parent / "data" / "config.toml"
|
| 34 |
+
TOKEN_FILE = Path(__file__).parent.parent.parent / "data" / "token.json"
|
| 35 |
+
LOCK_DIR = Path(__file__).parent.parent.parent / "data" / ".locks"
|
| 36 |
+
|
| 37 |
+
# JSON 序列化优化助手函数
|
| 38 |
+
def json_dumps(obj: Any) -> str:
|
| 39 |
+
return orjson.dumps(obj).decode("utf-8")
|
| 40 |
+
|
| 41 |
+
def json_loads(obj: str | bytes) -> Any:
|
| 42 |
+
return orjson.loads(obj)
|
| 43 |
+
|
| 44 |
+
class StorageError(Exception):
|
| 45 |
+
"""存储服务基础异常"""
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
class BaseStorage(abc.ABC):
|
| 49 |
+
"""存储基类"""
|
| 50 |
+
|
| 51 |
+
@abc.abstractmethod
|
| 52 |
+
async def load_config(self) -> Dict[str, Any]:
|
| 53 |
+
"""加载配置"""
|
| 54 |
+
pass
|
| 55 |
+
|
| 56 |
+
@abc.abstractmethod
|
| 57 |
+
async def save_config(self, data: Dict[str, Any]):
|
| 58 |
+
"""保存配置"""
|
| 59 |
+
pass
|
| 60 |
+
|
| 61 |
+
@abc.abstractmethod
|
| 62 |
+
async def load_tokens(self) -> Dict[str, Any]:
|
| 63 |
+
"""加载所有 Token"""
|
| 64 |
+
pass
|
| 65 |
+
|
| 66 |
+
@abc.abstractmethod
|
| 67 |
+
async def save_tokens(self, data: Dict[str, Any]):
|
| 68 |
+
"""保存所有 Token"""
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
@abc.abstractmethod
|
| 72 |
+
async def close(self):
|
| 73 |
+
"""关闭资源"""
|
| 74 |
+
pass
|
| 75 |
+
|
| 76 |
+
@asynccontextmanager
|
| 77 |
+
async def acquire_lock(self, name: str, timeout: int = 10):
|
| 78 |
+
"""
|
| 79 |
+
获取锁 (互斥访问)
|
| 80 |
+
用于读写操作的临界区保护
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
name: 锁名称
|
| 84 |
+
timeout: 超时时间 (秒)
|
| 85 |
+
"""
|
| 86 |
+
# 默认空实现,用于 fallback
|
| 87 |
+
yield
|
| 88 |
+
|
| 89 |
+
async def verify_connection(self) -> bool:
|
| 90 |
+
"""健康检查"""
|
| 91 |
+
return True
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class LocalStorage(BaseStorage):
|
| 95 |
+
"""
|
| 96 |
+
本地文件存储
|
| 97 |
+
- 使用 aiofiles 进行异步 I/O
|
| 98 |
+
- 使用 asyncio.Lock 进行进程内并发控制
|
| 99 |
+
- 如果需要多进程安全,需要系统级文件锁 (fcntl)
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
def __init__(self):
|
| 103 |
+
self._lock = asyncio.Lock()
|
| 104 |
+
|
| 105 |
+
@asynccontextmanager
|
| 106 |
+
async def acquire_lock(self, name: str, timeout: int = 10):
|
| 107 |
+
if fcntl is None:
|
| 108 |
+
try:
|
| 109 |
+
async with asyncio.timeout(timeout):
|
| 110 |
+
async with self._lock:
|
| 111 |
+
yield
|
| 112 |
+
except asyncio.TimeoutError:
|
| 113 |
+
logger.warning(f"LocalStorage: 获取锁 '{name}' 超时 ({timeout}s)")
|
| 114 |
+
raise StorageError(f"无法获取锁 '{name}'")
|
| 115 |
+
return
|
| 116 |
+
|
| 117 |
+
lock_path = LOCK_DIR / f"{name}.lock"
|
| 118 |
+
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
| 119 |
+
fd = None
|
| 120 |
+
locked = False
|
| 121 |
+
start = time.monotonic()
|
| 122 |
+
|
| 123 |
+
async with self._lock:
|
| 124 |
+
try:
|
| 125 |
+
fd = open(lock_path, "a+")
|
| 126 |
+
while True:
|
| 127 |
+
try:
|
| 128 |
+
fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
| 129 |
+
locked = True
|
| 130 |
+
break
|
| 131 |
+
except BlockingIOError:
|
| 132 |
+
if time.monotonic() - start >= timeout:
|
| 133 |
+
raise StorageError(f"无法获取锁 '{name}'")
|
| 134 |
+
await asyncio.sleep(0.05)
|
| 135 |
+
yield
|
| 136 |
+
except StorageError:
|
| 137 |
+
logger.warning(f"LocalStorage: 获取锁 '{name}' 超时 ({timeout}s)")
|
| 138 |
+
raise
|
| 139 |
+
finally:
|
| 140 |
+
if fd:
|
| 141 |
+
if locked:
|
| 142 |
+
try:
|
| 143 |
+
fcntl.flock(fd, fcntl.LOCK_UN)
|
| 144 |
+
except Exception:
|
| 145 |
+
pass
|
| 146 |
+
try:
|
| 147 |
+
fd.close()
|
| 148 |
+
except Exception:
|
| 149 |
+
pass
|
| 150 |
+
|
| 151 |
+
async def load_config(self) -> Dict[str, Any]:
|
| 152 |
+
if not CONFIG_FILE.exists():
|
| 153 |
+
return {}
|
| 154 |
+
try:
|
| 155 |
+
async with aiofiles.open(CONFIG_FILE, "rb") as f:
|
| 156 |
+
content = await f.read()
|
| 157 |
+
return tomllib.loads(content.decode("utf-8"))
|
| 158 |
+
except Exception as e:
|
| 159 |
+
logger.error(f"LocalStorage: 加载配置失败: {e}")
|
| 160 |
+
return {}
|
| 161 |
+
|
| 162 |
+
async def save_config(self, data: Dict[str, Any]):
|
| 163 |
+
try:
|
| 164 |
+
lines = []
|
| 165 |
+
for section, items in data.items():
|
| 166 |
+
if not isinstance(items, dict): continue
|
| 167 |
+
lines.append(f"[{section}]")
|
| 168 |
+
for key, val in items.items():
|
| 169 |
+
if isinstance(val, bool):
|
| 170 |
+
val_str = "true" if val else "false"
|
| 171 |
+
elif isinstance(val, str):
|
| 172 |
+
escaped = val.replace('"', '\\"')
|
| 173 |
+
val_str = f'"{escaped}"'
|
| 174 |
+
elif isinstance(val, (int, float)):
|
| 175 |
+
val_str = str(val)
|
| 176 |
+
elif isinstance(val, (list, dict)):
|
| 177 |
+
val_str = json_dumps(val)
|
| 178 |
+
else:
|
| 179 |
+
val_str = f'"{str(val)}"'
|
| 180 |
+
lines.append(f"{key} = {val_str}")
|
| 181 |
+
lines.append("")
|
| 182 |
+
|
| 183 |
+
content = "\n".join(lines)
|
| 184 |
+
|
| 185 |
+
CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True)
|
| 186 |
+
async with aiofiles.open(CONFIG_FILE, "w", encoding="utf-8") as f:
|
| 187 |
+
await f.write(content)
|
| 188 |
+
except Exception as e:
|
| 189 |
+
logger.error(f"LocalStorage: 保存配置失败: {e}")
|
| 190 |
+
raise StorageError(f"保存配置失败: {e}")
|
| 191 |
+
|
| 192 |
+
async def load_tokens(self) -> Dict[str, Any]:
|
| 193 |
+
if not TOKEN_FILE.exists():
|
| 194 |
+
return {}
|
| 195 |
+
try:
|
| 196 |
+
async with aiofiles.open(TOKEN_FILE, "rb") as f:
|
| 197 |
+
content = await f.read()
|
| 198 |
+
return json_loads(content)
|
| 199 |
+
except Exception as e:
|
| 200 |
+
logger.error(f"LocalStorage: 加载 Token 失败: {e}")
|
| 201 |
+
return {}
|
| 202 |
+
|
| 203 |
+
async def save_tokens(self, data: Dict[str, Any]):
|
| 204 |
+
try:
|
| 205 |
+
TOKEN_FILE.parent.mkdir(parents=True, exist_ok=True)
|
| 206 |
+
temp_path = TOKEN_FILE.with_suffix('.tmp')
|
| 207 |
+
|
| 208 |
+
# 原子写操作: 写入临时文件 -> 重命名
|
| 209 |
+
async with aiofiles.open(temp_path, "wb") as f:
|
| 210 |
+
await f.write(orjson.dumps(data, option=orjson.OPT_INDENT_2))
|
| 211 |
+
|
| 212 |
+
# 使用 os.replace 保证原子性
|
| 213 |
+
os.replace(temp_path, TOKEN_FILE)
|
| 214 |
+
|
| 215 |
+
except Exception as e:
|
| 216 |
+
logger.error(f"LocalStorage: 保存 Token 失败: {e}")
|
| 217 |
+
raise StorageError(f"保存 Token 失败: {e}")
|
| 218 |
+
|
| 219 |
+
async def close(self):
|
| 220 |
+
pass
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class RedisStorage(BaseStorage):
|
| 224 |
+
"""
|
| 225 |
+
Redis 存储
|
| 226 |
+
- 使用 redis-py 异步客户端 (自带连接池)
|
| 227 |
+
- 支持分布式锁 (redis.lock)
|
| 228 |
+
- 扁平化数据结构优化性能
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
def __init__(self, url: str):
|
| 232 |
+
try:
|
| 233 |
+
from redis import asyncio as aioredis
|
| 234 |
+
from redis.asyncio.lock import Lock
|
| 235 |
+
except ImportError:
|
| 236 |
+
raise ImportError("需要安装 redis 包: pip install redis")
|
| 237 |
+
|
| 238 |
+
# 显式配置连接池
|
| 239 |
+
# 使用 decode_responses=True 简化字符串处理,但在处理复杂对象时使用 orjson
|
| 240 |
+
self.redis = aioredis.from_url(
|
| 241 |
+
url,
|
| 242 |
+
decode_responses=True,
|
| 243 |
+
health_check_interval=30
|
| 244 |
+
)
|
| 245 |
+
self.config_key = "grok2api:config" # Hash: section.key -> value_json
|
| 246 |
+
self.key_pools = "grok2api:pools" # Set: pool_names
|
| 247 |
+
self.prefix_pool_set = "grok2api:pool:" # Set: pool -> token_ids
|
| 248 |
+
self.prefix_token_hash = "grok2api:token:"# Hash: token_id -> token_data
|
| 249 |
+
self.lock_prefix = "grok2api:lock:"
|
| 250 |
+
|
| 251 |
+
@asynccontextmanager
|
| 252 |
+
async def acquire_lock(self, name: str, timeout: int = 10):
|
| 253 |
+
# 使用 Redis 分布式锁
|
| 254 |
+
lock_key = f"{self.lock_prefix}{name}"
|
| 255 |
+
lock = self.redis.lock(lock_key, timeout=timeout, blocking_timeout=5)
|
| 256 |
+
acquired = False
|
| 257 |
+
try:
|
| 258 |
+
acquired = await lock.acquire()
|
| 259 |
+
if not acquired:
|
| 260 |
+
raise StorageError(f"RedisStorage: 无法获取锁 '{name}'")
|
| 261 |
+
yield
|
| 262 |
+
finally:
|
| 263 |
+
if acquired:
|
| 264 |
+
try:
|
| 265 |
+
await lock.release()
|
| 266 |
+
except Exception:
|
| 267 |
+
# 锁可能已过期或被意外释放,忽略异常
|
| 268 |
+
pass
|
| 269 |
+
|
| 270 |
+
async def verify_connection(self) -> bool:
|
| 271 |
+
try:
|
| 272 |
+
return await self.redis.ping()
|
| 273 |
+
except Exception:
|
| 274 |
+
return False
|
| 275 |
+
|
| 276 |
+
async def load_config(self) -> Dict[str, Any]:
|
| 277 |
+
"""从 Redis Hash 加载配置"""
|
| 278 |
+
try:
|
| 279 |
+
raw_data = await self.redis.hgetall(self.config_key)
|
| 280 |
+
if not raw_data:
|
| 281 |
+
return None
|
| 282 |
+
|
| 283 |
+
config = {}
|
| 284 |
+
for composite_key, val_str in raw_data.items():
|
| 285 |
+
if "." not in composite_key: continue
|
| 286 |
+
section, key = composite_key.split(".", 1)
|
| 287 |
+
|
| 288 |
+
if section not in config: config[section] = {}
|
| 289 |
+
|
| 290 |
+
try:
|
| 291 |
+
val = json_loads(val_str)
|
| 292 |
+
except:
|
| 293 |
+
val = val_str
|
| 294 |
+
config[section][key] = val
|
| 295 |
+
return config
|
| 296 |
+
except Exception as e:
|
| 297 |
+
logger.error(f"RedisStorage: 加载配置失败: {e}")
|
| 298 |
+
return None
|
| 299 |
+
|
| 300 |
+
async def save_config(self, data: Dict[str, Any]):
|
| 301 |
+
"""保存配置到 Redis Hash"""
|
| 302 |
+
if not data: return
|
| 303 |
+
try:
|
| 304 |
+
mapping = {}
|
| 305 |
+
for section, items in data.items():
|
| 306 |
+
if not isinstance(items, dict): continue
|
| 307 |
+
for key, val in items.items():
|
| 308 |
+
composite_key = f"{section}.{key}"
|
| 309 |
+
mapping[composite_key] = json_dumps(val)
|
| 310 |
+
|
| 311 |
+
if mapping:
|
| 312 |
+
await self.redis.hset(self.config_key, mapping=mapping)
|
| 313 |
+
except Exception as e:
|
| 314 |
+
logger.error(f"RedisStorage: 保存配置失败: {e}")
|
| 315 |
+
raise
|
| 316 |
+
|
| 317 |
+
async def load_tokens(self) -> Dict[str, Any]:
|
| 318 |
+
"""加载所有 Token"""
|
| 319 |
+
try:
|
| 320 |
+
pool_names = await self.redis.smembers(self.key_pools)
|
| 321 |
+
if not pool_names: return None
|
| 322 |
+
|
| 323 |
+
pools = {}
|
| 324 |
+
async with self.redis.pipeline() as pipe:
|
| 325 |
+
for pool_name in pool_names:
|
| 326 |
+
# 获取该池下所有 Token ID
|
| 327 |
+
pipe.smembers(f"{self.prefix_pool_set}{pool_name}")
|
| 328 |
+
pool_tokens_res = await pipe.execute()
|
| 329 |
+
|
| 330 |
+
# 收集所有 Token ID 以便批量查询
|
| 331 |
+
all_token_ids = []
|
| 332 |
+
pool_map = {} # pool_name -> list[token_id]
|
| 333 |
+
|
| 334 |
+
for i, pool_name in enumerate(pool_names):
|
| 335 |
+
tids = list(pool_tokens_res[i])
|
| 336 |
+
pool_map[pool_name] = tids
|
| 337 |
+
all_token_ids.extend(tids)
|
| 338 |
+
|
| 339 |
+
if not all_token_ids:
|
| 340 |
+
return {name: [] for name in pool_names}
|
| 341 |
+
|
| 342 |
+
# 批量获取 Token 详情 (Hash)
|
| 343 |
+
async with self.redis.pipeline() as pipe:
|
| 344 |
+
for tid in all_token_ids:
|
| 345 |
+
pipe.hgetall(f"{self.prefix_token_hash}{tid}")
|
| 346 |
+
token_data_list = await pipe.execute()
|
| 347 |
+
|
| 348 |
+
# 重组数据结构
|
| 349 |
+
token_lookup = {}
|
| 350 |
+
for i, tid in enumerate(all_token_ids):
|
| 351 |
+
t_data = token_data_list[i]
|
| 352 |
+
if not t_data: continue
|
| 353 |
+
|
| 354 |
+
# 恢复 tags (JSON -> List)
|
| 355 |
+
if "tags" in t_data:
|
| 356 |
+
try: t_data["tags"] = json_loads(t_data["tags"])
|
| 357 |
+
except: t_data["tags"] = []
|
| 358 |
+
|
| 359 |
+
# 类型转换 (Redis 返回全 string)
|
| 360 |
+
for int_field in ["quota", "created_at", "use_count", "fail_count", "last_used_at", "last_fail_at", "last_sync_at"]:
|
| 361 |
+
if t_data.get(int_field) and t_data[int_field] != "None":
|
| 362 |
+
try: t_data[int_field] = int(t_data[int_field])
|
| 363 |
+
except: pass
|
| 364 |
+
|
| 365 |
+
token_lookup[tid] = t_data
|
| 366 |
+
|
| 367 |
+
# 按 Pool 分组返回
|
| 368 |
+
for pool_name in pool_names:
|
| 369 |
+
pools[pool_name] = []
|
| 370 |
+
for tid in pool_map[pool_name]:
|
| 371 |
+
if tid in token_lookup:
|
| 372 |
+
pools[pool_name].append(token_lookup[tid])
|
| 373 |
+
|
| 374 |
+
return pools
|
| 375 |
+
|
| 376 |
+
except Exception as e:
|
| 377 |
+
logger.error(f"RedisStorage: 加载 Token 失败: {e}")
|
| 378 |
+
return None
|
| 379 |
+
|
| 380 |
+
async def save_tokens(self, data: Dict[str, Any]):
|
| 381 |
+
"""保存所有 Token"""
|
| 382 |
+
if data is None:
|
| 383 |
+
return
|
| 384 |
+
try:
|
| 385 |
+
new_pools = set(data.keys()) if isinstance(data, dict) else set()
|
| 386 |
+
pool_tokens_map = {}
|
| 387 |
+
new_token_ids = set()
|
| 388 |
+
|
| 389 |
+
for pool_name, tokens in (data or {}).items():
|
| 390 |
+
tids_in_pool = []
|
| 391 |
+
for t in tokens:
|
| 392 |
+
token_str = t.get("token")
|
| 393 |
+
if not token_str:
|
| 394 |
+
continue
|
| 395 |
+
tids_in_pool.append(token_str)
|
| 396 |
+
new_token_ids.add(token_str)
|
| 397 |
+
pool_tokens_map[pool_name] = tids_in_pool
|
| 398 |
+
|
| 399 |
+
existing_pools = await self.redis.smembers(self.key_pools)
|
| 400 |
+
existing_pools = set(existing_pools) if existing_pools else set()
|
| 401 |
+
|
| 402 |
+
existing_token_ids = set()
|
| 403 |
+
if existing_pools:
|
| 404 |
+
async with self.redis.pipeline() as pipe:
|
| 405 |
+
for pool_name in existing_pools:
|
| 406 |
+
pipe.smembers(f"{self.prefix_pool_set}{pool_name}")
|
| 407 |
+
pool_tokens_res = await pipe.execute()
|
| 408 |
+
for tokens in pool_tokens_res:
|
| 409 |
+
existing_token_ids.update(list(tokens or []))
|
| 410 |
+
|
| 411 |
+
tokens_to_delete = existing_token_ids - new_token_ids
|
| 412 |
+
all_pools = existing_pools.union(new_pools)
|
| 413 |
+
|
| 414 |
+
async with self.redis.pipeline() as pipe:
|
| 415 |
+
# Reset pool index
|
| 416 |
+
pipe.delete(self.key_pools)
|
| 417 |
+
if new_pools:
|
| 418 |
+
pipe.sadd(self.key_pools, *new_pools)
|
| 419 |
+
|
| 420 |
+
# Reset pool sets
|
| 421 |
+
for pool_name in all_pools:
|
| 422 |
+
pipe.delete(f"{self.prefix_pool_set}{pool_name}")
|
| 423 |
+
for pool_name, tids_in_pool in pool_tokens_map.items():
|
| 424 |
+
if tids_in_pool:
|
| 425 |
+
pipe.sadd(f"{self.prefix_pool_set}{pool_name}", *tids_in_pool)
|
| 426 |
+
|
| 427 |
+
# Remove deleted token hashes
|
| 428 |
+
for token_str in tokens_to_delete:
|
| 429 |
+
pipe.delete(f"{self.prefix_token_hash}{token_str}")
|
| 430 |
+
|
| 431 |
+
# Upsert token hashes
|
| 432 |
+
for pool_name, tokens in (data or {}).items():
|
| 433 |
+
for t in tokens:
|
| 434 |
+
token_str = t.get("token")
|
| 435 |
+
if not token_str:
|
| 436 |
+
continue
|
| 437 |
+
t_flat = t.copy()
|
| 438 |
+
if "tags" in t_flat:
|
| 439 |
+
t_flat["tags"] = json_dumps(t_flat["tags"])
|
| 440 |
+
status = t_flat.get("status")
|
| 441 |
+
if isinstance(status, str) and status.startswith("TokenStatus."):
|
| 442 |
+
t_flat["status"] = status.split(".", 1)[1].lower()
|
| 443 |
+
elif isinstance(status, Enum):
|
| 444 |
+
t_flat["status"] = status.value
|
| 445 |
+
t_flat = {k: str(v) for k, v in t_flat.items() if v is not None}
|
| 446 |
+
pipe.hset(f"{self.prefix_token_hash}{token_str}", mapping=t_flat)
|
| 447 |
+
|
| 448 |
+
await pipe.execute()
|
| 449 |
+
|
| 450 |
+
except Exception as e:
|
| 451 |
+
logger.error(f"RedisStorage: 保存 Token 失败: {e}")
|
| 452 |
+
raise
|
| 453 |
+
|
| 454 |
+
async def close(self):
|
| 455 |
+
try:
|
| 456 |
+
await self.redis.close()
|
| 457 |
+
except (RuntimeError, asyncio.CancelledError, Exception):
|
| 458 |
+
# 忽略关闭时的 Event loop is closed 错误
|
| 459 |
+
pass
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
class SQLStorage(BaseStorage):
|
| 463 |
+
"""
|
| 464 |
+
SQL 数据库存储 (MySQL/PgSQL)
|
| 465 |
+
- 使用 SQLAlchemy 异步引擎
|
| 466 |
+
- 自动 Schema 初始化
|
| 467 |
+
- 内置连接池 (QueuePool)
|
| 468 |
+
"""
|
| 469 |
+
|
| 470 |
+
def __init__(self, url: str):
|
| 471 |
+
try:
|
| 472 |
+
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
| 473 |
+
from sqlalchemy import text, MetaData
|
| 474 |
+
except ImportError:
|
| 475 |
+
raise ImportError("需要安装 sqlalchemy 和 async 驱动: pip install sqlalchemy[asyncio]")
|
| 476 |
+
|
| 477 |
+
self.dialect = url.split(":", 1)[0].split("+", 1)[0].lower()
|
| 478 |
+
|
| 479 |
+
# 配置 robust 的连接池
|
| 480 |
+
self.engine = create_async_engine(
|
| 481 |
+
url,
|
| 482 |
+
echo=False,
|
| 483 |
+
pool_size=20,
|
| 484 |
+
max_overflow=10,
|
| 485 |
+
pool_recycle=3600,
|
| 486 |
+
pool_pre_ping=True
|
| 487 |
+
)
|
| 488 |
+
self.async_session = async_sessionmaker(self.engine, expire_on_commit=False)
|
| 489 |
+
self._initialized = False
|
| 490 |
+
|
| 491 |
+
async def _ensure_schema(self):
|
| 492 |
+
"""确保数据库表存在"""
|
| 493 |
+
if self._initialized: return
|
| 494 |
+
try:
|
| 495 |
+
async with self.engine.begin() as conn:
|
| 496 |
+
from sqlalchemy import text
|
| 497 |
+
|
| 498 |
+
# Tokens 表 (通用 SQL)
|
| 499 |
+
await conn.execute(text("""
|
| 500 |
+
CREATE TABLE IF NOT EXISTS tokens (
|
| 501 |
+
token VARCHAR(512) PRIMARY KEY,
|
| 502 |
+
pool_name VARCHAR(64) NOT NULL,
|
| 503 |
+
data TEXT,
|
| 504 |
+
updated_at BIGINT
|
| 505 |
+
)
|
| 506 |
+
"""))
|
| 507 |
+
|
| 508 |
+
# 配置表
|
| 509 |
+
await conn.execute(text("""
|
| 510 |
+
CREATE TABLE IF NOT EXISTS app_config (
|
| 511 |
+
section VARCHAR(64) NOT NULL,
|
| 512 |
+
key_name VARCHAR(64) NOT NULL,
|
| 513 |
+
value TEXT,
|
| 514 |
+
PRIMARY KEY (section, key_name)
|
| 515 |
+
)
|
| 516 |
+
"""))
|
| 517 |
+
|
| 518 |
+
# 索引
|
| 519 |
+
try:
|
| 520 |
+
await conn.execute(text("CREATE INDEX idx_tokens_pool ON tokens (pool_name)"))
|
| 521 |
+
except Exception:
|
| 522 |
+
pass
|
| 523 |
+
|
| 524 |
+
# 尝试兼容旧表结构
|
| 525 |
+
try:
|
| 526 |
+
if self.dialect in ("mysql", "mariadb"):
|
| 527 |
+
await conn.execute(text("ALTER TABLE tokens MODIFY token VARCHAR(512)"))
|
| 528 |
+
await conn.execute(text("ALTER TABLE tokens MODIFY data TEXT"))
|
| 529 |
+
elif self.dialect in ("postgres", "postgresql", "pgsql"):
|
| 530 |
+
await conn.execute(text("ALTER TABLE tokens ALTER COLUMN token TYPE VARCHAR(512)"))
|
| 531 |
+
await conn.execute(text("ALTER TABLE tokens ALTER COLUMN data TYPE TEXT"))
|
| 532 |
+
except Exception:
|
| 533 |
+
pass
|
| 534 |
+
|
| 535 |
+
self._initialized = True
|
| 536 |
+
except Exception as e:
|
| 537 |
+
logger.error(f"SQLStorage: Schema 初始化失败: {e}")
|
| 538 |
+
raise
|
| 539 |
+
|
| 540 |
+
@asynccontextmanager
|
| 541 |
+
async def acquire_lock(self, name: str, timeout: int = 10):
|
| 542 |
+
# SQL 分布式锁: MySQL GET_LOCK / PG advisory_lock
|
| 543 |
+
from sqlalchemy import text
|
| 544 |
+
lock_name = f"g2a:{hashlib.sha1(name.encode('utf-8')).hexdigest()[:24]}"
|
| 545 |
+
if self.dialect in ("mysql", "mariadb"):
|
| 546 |
+
async with self.async_session() as session:
|
| 547 |
+
res = await session.execute(
|
| 548 |
+
text("SELECT GET_LOCK(:name, :timeout)"),
|
| 549 |
+
{"name": lock_name, "timeout": timeout}
|
| 550 |
+
)
|
| 551 |
+
got = res.scalar()
|
| 552 |
+
if got != 1:
|
| 553 |
+
raise StorageError(f"SQLStorage: 无法获取锁 '{name}'")
|
| 554 |
+
try:
|
| 555 |
+
yield
|
| 556 |
+
finally:
|
| 557 |
+
try:
|
| 558 |
+
await session.execute(text("SELECT RELEASE_LOCK(:name)"), {"name": lock_name})
|
| 559 |
+
await session.commit()
|
| 560 |
+
except Exception:
|
| 561 |
+
pass
|
| 562 |
+
elif self.dialect in ("postgres", "postgresql", "pgsql"):
|
| 563 |
+
lock_key = int.from_bytes(hashlib.sha256(name.encode("utf-8")).digest()[:8], "big", signed=False)
|
| 564 |
+
async with self.async_session() as session:
|
| 565 |
+
start = time.monotonic()
|
| 566 |
+
while True:
|
| 567 |
+
res = await session.execute(
|
| 568 |
+
text("SELECT pg_try_advisory_lock(:key)"),
|
| 569 |
+
{"key": lock_key}
|
| 570 |
+
)
|
| 571 |
+
if res.scalar():
|
| 572 |
+
break
|
| 573 |
+
if time.monotonic() - start >= timeout:
|
| 574 |
+
raise StorageError(f"SQLStorage: 无法获取锁 '{name}'")
|
| 575 |
+
await asyncio.sleep(0.1)
|
| 576 |
+
try:
|
| 577 |
+
yield
|
| 578 |
+
finally:
|
| 579 |
+
try:
|
| 580 |
+
await session.execute(text("SELECT pg_advisory_unlock(:key)"), {"key": lock_key})
|
| 581 |
+
await session.commit()
|
| 582 |
+
except Exception:
|
| 583 |
+
pass
|
| 584 |
+
else:
|
| 585 |
+
yield
|
| 586 |
+
|
| 587 |
+
async def load_config(self) -> Dict[str, Any]:
|
| 588 |
+
await self._ensure_schema()
|
| 589 |
+
from sqlalchemy import text
|
| 590 |
+
try:
|
| 591 |
+
async with self.async_session() as session:
|
| 592 |
+
res = await session.execute(text("SELECT section, key_name, value FROM app_config"))
|
| 593 |
+
rows = res.fetchall()
|
| 594 |
+
if not rows: return None
|
| 595 |
+
|
| 596 |
+
config = {}
|
| 597 |
+
for section, key, val_str in rows:
|
| 598 |
+
if section not in config: config[section] = {}
|
| 599 |
+
try:
|
| 600 |
+
val = json_loads(val_str)
|
| 601 |
+
except:
|
| 602 |
+
val = val_str
|
| 603 |
+
config[section][key] = val
|
| 604 |
+
return config
|
| 605 |
+
except Exception as e:
|
| 606 |
+
logger.error(f"SQLStorage: 加载配置失败: {e}")
|
| 607 |
+
return None
|
| 608 |
+
|
| 609 |
+
async def save_config(self, data: Dict[str, Any]):
|
| 610 |
+
await self._ensure_schema()
|
| 611 |
+
from sqlalchemy import text
|
| 612 |
+
try:
|
| 613 |
+
async with self.async_session() as session:
|
| 614 |
+
for section, items in data.items():
|
| 615 |
+
if not isinstance(items, dict): continue
|
| 616 |
+
for key, val in items.items():
|
| 617 |
+
val_str = json_dumps(val)
|
| 618 |
+
|
| 619 |
+
# Upsert 逻辑 (简单实现: Delete + Insert)
|
| 620 |
+
await session.execute(
|
| 621 |
+
text("DELETE FROM app_config WHERE section=:s AND key_name=:k"),
|
| 622 |
+
{"s": section, "k": key}
|
| 623 |
+
)
|
| 624 |
+
await session.execute(
|
| 625 |
+
text("INSERT INTO app_config (section, key_name, value) VALUES (:s, :k, :v)"),
|
| 626 |
+
{"s": section, "k": key, "v": val_str}
|
| 627 |
+
)
|
| 628 |
+
await session.commit()
|
| 629 |
+
except Exception as e:
|
| 630 |
+
logger.error(f"SQLStorage: 保存配置失败: {e}")
|
| 631 |
+
raise
|
| 632 |
+
|
| 633 |
+
async def load_tokens(self) -> Dict[str, Any]:
|
| 634 |
+
await self._ensure_schema()
|
| 635 |
+
from sqlalchemy import text
|
| 636 |
+
try:
|
| 637 |
+
async with self.async_session() as session:
|
| 638 |
+
res = await session.execute(text("SELECT pool_name, data FROM tokens"))
|
| 639 |
+
rows = res.fetchall()
|
| 640 |
+
if not rows: return None
|
| 641 |
+
|
| 642 |
+
pools = {}
|
| 643 |
+
for pool_name, data_json in rows:
|
| 644 |
+
if pool_name not in pools: pools[pool_name] = []
|
| 645 |
+
|
| 646 |
+
try:
|
| 647 |
+
if isinstance(data_json, str):
|
| 648 |
+
t_data = json_loads(data_json)
|
| 649 |
+
else:
|
| 650 |
+
t_data = data_json
|
| 651 |
+
pools[pool_name].append(t_data)
|
| 652 |
+
except:
|
| 653 |
+
pass
|
| 654 |
+
return pools
|
| 655 |
+
except Exception as e:
|
| 656 |
+
logger.error(f"SQLStorage: 加载 Token 失败: {e}")
|
| 657 |
+
return None
|
| 658 |
+
|
| 659 |
+
async def save_tokens(self, data: Dict[str, Any]):
|
| 660 |
+
await self._ensure_schema()
|
| 661 |
+
from sqlalchemy import text
|
| 662 |
+
try:
|
| 663 |
+
async with self.async_session() as session:
|
| 664 |
+
await session.execute(text("DELETE FROM tokens"))
|
| 665 |
+
|
| 666 |
+
params = []
|
| 667 |
+
for pool_name, tokens in data.items():
|
| 668 |
+
for t in tokens:
|
| 669 |
+
params.append({
|
| 670 |
+
"token": t.get("token"),
|
| 671 |
+
"pool_name": pool_name,
|
| 672 |
+
"data": json_dumps(t),
|
| 673 |
+
"updated_at": 0
|
| 674 |
+
})
|
| 675 |
+
|
| 676 |
+
if params:
|
| 677 |
+
# 批量插入
|
| 678 |
+
await session.execute(
|
| 679 |
+
text("INSERT INTO tokens (token, pool_name, data, updated_at) VALUES (:token, :pool_name, :data, :updated_at)"),
|
| 680 |
+
params
|
| 681 |
+
)
|
| 682 |
+
await session.commit()
|
| 683 |
+
except Exception as e:
|
| 684 |
+
logger.error(f"SQLStorage: 保存 Token 失败: {e}")
|
| 685 |
+
raise
|
| 686 |
+
|
| 687 |
+
async def close(self):
|
| 688 |
+
await self.engine.dispose()
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
class StorageFactory:
|
| 692 |
+
"""存储后端工厂"""
|
| 693 |
+
_instance: Optional[BaseStorage] = None
|
| 694 |
+
|
| 695 |
+
@classmethod
|
| 696 |
+
def get_storage(cls) -> BaseStorage:
|
| 697 |
+
"""获取全局存储实例 (单例)"""
|
| 698 |
+
if cls._instance:
|
| 699 |
+
return cls._instance
|
| 700 |
+
|
| 701 |
+
storage_type = os.getenv("SERVER_STORAGE_TYPE", "local").lower()
|
| 702 |
+
storage_url = os.getenv("SERVER_STORAGE_URL", "")
|
| 703 |
+
|
| 704 |
+
logger.info(f"StorageFactory: 初始化存储后端: {storage_type}")
|
| 705 |
+
|
| 706 |
+
if storage_type == "redis":
|
| 707 |
+
if not storage_url: raise ValueError("Redis 存储需要设置 SERVER_STORAGE_URL")
|
| 708 |
+
cls._instance = RedisStorage(storage_url)
|
| 709 |
+
|
| 710 |
+
elif storage_type in ("mysql", "pgsql"):
|
| 711 |
+
if not storage_url: raise ValueError("SQL 存储需要设置 SERVER_STORAGE_URL")
|
| 712 |
+
cls._instance = SQLStorage(storage_url)
|
| 713 |
+
|
| 714 |
+
else:
|
| 715 |
+
cls._instance = LocalStorage()
|
| 716 |
+
|
| 717 |
+
return cls._instance
|
| 718 |
+
|
| 719 |
+
def get_storage() -> BaseStorage:
|
| 720 |
+
return StorageFactory.get_storage()
|
app/services/api_keys.py
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""API Key 管理器 - 多用户密钥管理"""
|
| 2 |
+
|
| 3 |
+
import orjson
|
| 4 |
+
import time
|
| 5 |
+
import os
|
| 6 |
+
import secrets
|
| 7 |
+
import asyncio
|
| 8 |
+
from datetime import datetime, timezone, timedelta
|
| 9 |
+
from typing import List, Dict, Optional, Any, Tuple
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
from app.core.logger import logger
|
| 13 |
+
from app.core.config import get_config
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ApiKeyManager:
|
| 17 |
+
"""API Key 管理服务"""
|
| 18 |
+
|
| 19 |
+
_instance = None
|
| 20 |
+
|
| 21 |
+
def __new__(cls):
|
| 22 |
+
if cls._instance is None:
|
| 23 |
+
cls._instance = super().__new__(cls)
|
| 24 |
+
return cls._instance
|
| 25 |
+
|
| 26 |
+
def __init__(self):
|
| 27 |
+
if hasattr(self, '_initialized'):
|
| 28 |
+
return
|
| 29 |
+
|
| 30 |
+
self.file_path = Path(__file__).parents[2] / "data" / "api_keys.json"
|
| 31 |
+
self.usage_path = Path(__file__).parents[2] / "data" / "api_key_usage.json"
|
| 32 |
+
self._keys: List[Dict] = []
|
| 33 |
+
self._lock = asyncio.Lock()
|
| 34 |
+
self._loaded = False
|
| 35 |
+
|
| 36 |
+
self._usage: Dict[str, Dict[str, Dict[str, int]]] = {}
|
| 37 |
+
self._usage_lock = asyncio.Lock()
|
| 38 |
+
self._usage_loaded = False
|
| 39 |
+
|
| 40 |
+
self._initialized = True
|
| 41 |
+
logger.debug(f"[ApiKey] 初始化完成: {self.file_path}")
|
| 42 |
+
|
| 43 |
+
async def init(self):
|
| 44 |
+
"""初始化加载数据"""
|
| 45 |
+
if not self._loaded:
|
| 46 |
+
await self._load_data()
|
| 47 |
+
if not self._usage_loaded:
|
| 48 |
+
await self._load_usage_data()
|
| 49 |
+
|
| 50 |
+
async def _load_data(self):
|
| 51 |
+
"""加载 API Keys"""
|
| 52 |
+
if self._loaded:
|
| 53 |
+
return
|
| 54 |
+
|
| 55 |
+
if not self.file_path.exists():
|
| 56 |
+
self._keys = []
|
| 57 |
+
self._loaded = True
|
| 58 |
+
return
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
async with self._lock:
|
| 62 |
+
content = await asyncio.to_thread(self.file_path.read_bytes)
|
| 63 |
+
if content:
|
| 64 |
+
data = orjson.loads(content)
|
| 65 |
+
if isinstance(data, list):
|
| 66 |
+
out: List[Dict[str, Any]] = []
|
| 67 |
+
for item in data:
|
| 68 |
+
if not isinstance(item, dict):
|
| 69 |
+
continue
|
| 70 |
+
row = self._normalize_key_row(item)
|
| 71 |
+
if row.get("key"):
|
| 72 |
+
out.append(row)
|
| 73 |
+
self._keys = out
|
| 74 |
+
else:
|
| 75 |
+
self._keys = []
|
| 76 |
+
else:
|
| 77 |
+
self._keys = []
|
| 78 |
+
self._loaded = True
|
| 79 |
+
logger.debug(f"[ApiKey] 加载了 {len(self._keys)} 个 API Key")
|
| 80 |
+
except Exception as e:
|
| 81 |
+
logger.error(f"[ApiKey] 加载失败: {e}")
|
| 82 |
+
self._keys = []
|
| 83 |
+
self._loaded = True # 即使加载失败也认为已尝试加载,防止后续保存清空数据(或者抛出异常)
|
| 84 |
+
|
| 85 |
+
async def _save_data(self):
|
| 86 |
+
"""保存 API Keys"""
|
| 87 |
+
if not self._loaded:
|
| 88 |
+
logger.warning("[ApiKey] 尝试在数据未加载时保存,已取消操作以防覆盖数据")
|
| 89 |
+
return
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
# 确保目录存在
|
| 93 |
+
self.file_path.parent.mkdir(parents=True, exist_ok=True)
|
| 94 |
+
|
| 95 |
+
async with self._lock:
|
| 96 |
+
content = orjson.dumps(self._keys, option=orjson.OPT_INDENT_2)
|
| 97 |
+
await asyncio.to_thread(self.file_path.write_bytes, content)
|
| 98 |
+
except Exception as e:
|
| 99 |
+
logger.error(f"[ApiKey] 保存失败: {e}")
|
| 100 |
+
|
| 101 |
+
def _normalize_limit(self, v: Any) -> int:
|
| 102 |
+
"""Normalize a daily limit value. -1 means unlimited."""
|
| 103 |
+
if v is None or v == "":
|
| 104 |
+
return -1
|
| 105 |
+
try:
|
| 106 |
+
n = int(v)
|
| 107 |
+
except Exception:
|
| 108 |
+
return -1
|
| 109 |
+
return max(-1, n)
|
| 110 |
+
|
| 111 |
+
def _normalize_key_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
|
| 112 |
+
out = dict(row or {})
|
| 113 |
+
out["key"] = str(out.get("key") or "").strip()
|
| 114 |
+
out["name"] = str(out.get("name") or "").strip()
|
| 115 |
+
try:
|
| 116 |
+
out["created_at"] = int(out.get("created_at") or int(time.time()))
|
| 117 |
+
except Exception:
|
| 118 |
+
out["created_at"] = int(time.time())
|
| 119 |
+
out["is_active"] = bool(out.get("is_active", True))
|
| 120 |
+
|
| 121 |
+
# Daily limits (-1 = unlimited)
|
| 122 |
+
out["chat_limit"] = self._normalize_limit(out.get("chat_limit", -1))
|
| 123 |
+
out["heavy_limit"] = self._normalize_limit(out.get("heavy_limit", -1))
|
| 124 |
+
out["image_limit"] = self._normalize_limit(out.get("image_limit", -1))
|
| 125 |
+
out["video_limit"] = self._normalize_limit(out.get("video_limit", -1))
|
| 126 |
+
return out
|
| 127 |
+
|
| 128 |
+
def _tz_offset_minutes(self) -> int:
|
| 129 |
+
raw = (os.getenv("CACHE_RESET_TZ_OFFSET_MINUTES", "") or "").strip()
|
| 130 |
+
try:
|
| 131 |
+
n = int(raw)
|
| 132 |
+
except Exception:
|
| 133 |
+
n = 480
|
| 134 |
+
return max(-720, min(840, n))
|
| 135 |
+
|
| 136 |
+
def _day_str(self, at_ms: Optional[int] = None, tz_offset_minutes: Optional[int] = None) -> str:
|
| 137 |
+
now_ms = int(at_ms if at_ms is not None else int(time.time() * 1000))
|
| 138 |
+
offset = self._tz_offset_minutes() if tz_offset_minutes is None else int(tz_offset_minutes)
|
| 139 |
+
dt = datetime.fromtimestamp(now_ms / 1000, tz=timezone.utc) + timedelta(minutes=offset)
|
| 140 |
+
return dt.strftime("%Y-%m-%d")
|
| 141 |
+
|
| 142 |
+
async def _load_usage_data(self):
|
| 143 |
+
"""Load per-day per-key usage counters."""
|
| 144 |
+
if self._usage_loaded:
|
| 145 |
+
return
|
| 146 |
+
|
| 147 |
+
if not self.usage_path.exists():
|
| 148 |
+
self._usage = {}
|
| 149 |
+
self._usage_loaded = True
|
| 150 |
+
return
|
| 151 |
+
|
| 152 |
+
try:
|
| 153 |
+
async with self._usage_lock:
|
| 154 |
+
if self.usage_path.exists():
|
| 155 |
+
content = await asyncio.to_thread(self.usage_path.read_bytes)
|
| 156 |
+
if content:
|
| 157 |
+
data = orjson.loads(content)
|
| 158 |
+
if isinstance(data, dict):
|
| 159 |
+
# { day: { key: { chat_used, ... } } }
|
| 160 |
+
self._usage = data # type: ignore[assignment]
|
| 161 |
+
else:
|
| 162 |
+
self._usage = {}
|
| 163 |
+
else:
|
| 164 |
+
self._usage = {}
|
| 165 |
+
self._usage_loaded = True
|
| 166 |
+
except Exception as e:
|
| 167 |
+
logger.error(f"[ApiKey] Usage 加载失败: {e}")
|
| 168 |
+
self._usage = {}
|
| 169 |
+
self._usage_loaded = True
|
| 170 |
+
|
| 171 |
+
async def _save_usage_data(self):
|
| 172 |
+
if not self._usage_loaded:
|
| 173 |
+
return
|
| 174 |
+
try:
|
| 175 |
+
self.usage_path.parent.mkdir(parents=True, exist_ok=True)
|
| 176 |
+
async with self._usage_lock:
|
| 177 |
+
content = orjson.dumps(self._usage, option=orjson.OPT_INDENT_2)
|
| 178 |
+
await asyncio.to_thread(self.usage_path.write_bytes, content)
|
| 179 |
+
except Exception as e:
|
| 180 |
+
logger.error(f"[ApiKey] Usage 保存失败: {e}")
|
| 181 |
+
|
| 182 |
+
def generate_key(self) -> str:
|
| 183 |
+
"""生成一个新的 sk- 开头的 key"""
|
| 184 |
+
return f"sk-{secrets.token_urlsafe(24)}"
|
| 185 |
+
|
| 186 |
+
def generate_name(self) -> str:
|
| 187 |
+
"""生成一个随机 key 名称"""
|
| 188 |
+
return f"key-{secrets.token_urlsafe(6)}"
|
| 189 |
+
|
| 190 |
+
async def add_key(
|
| 191 |
+
self,
|
| 192 |
+
name: str | None = None,
|
| 193 |
+
key: str | None = None,
|
| 194 |
+
limits: Optional[Dict[str, Any]] = None,
|
| 195 |
+
is_active: bool = True,
|
| 196 |
+
) -> Dict[str, Any]:
|
| 197 |
+
"""添加 API Key(支持自定义 key 与每日额度)"""
|
| 198 |
+
await self.init()
|
| 199 |
+
|
| 200 |
+
name_val = str(name or "").strip() or self.generate_name()
|
| 201 |
+
key_val = str(key or "").strip() or self.generate_key()
|
| 202 |
+
|
| 203 |
+
limits = limits or {}
|
| 204 |
+
new_key: Dict[str, Any] = {
|
| 205 |
+
"key": key_val,
|
| 206 |
+
"name": name_val,
|
| 207 |
+
"created_at": int(time.time()),
|
| 208 |
+
"is_active": bool(is_active),
|
| 209 |
+
"chat_limit": self._normalize_limit(limits.get("chat_limit", limits.get("chat_per_day", -1))),
|
| 210 |
+
"heavy_limit": self._normalize_limit(limits.get("heavy_limit", limits.get("heavy_per_day", -1))),
|
| 211 |
+
"image_limit": self._normalize_limit(limits.get("image_limit", limits.get("image_per_day", -1))),
|
| 212 |
+
"video_limit": self._normalize_limit(limits.get("video_limit", limits.get("video_per_day", -1))),
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
# Ensure uniqueness
|
| 216 |
+
if any(k.get("key") == key_val for k in self._keys):
|
| 217 |
+
raise ValueError("Key already exists")
|
| 218 |
+
|
| 219 |
+
self._keys.append(new_key)
|
| 220 |
+
await self._save_data()
|
| 221 |
+
logger.info(f"[ApiKey] 添加新Key: {name_val}")
|
| 222 |
+
return new_key
|
| 223 |
+
|
| 224 |
+
async def batch_add_keys(self, name_prefix: str, count: int) -> List[Dict]:
|
| 225 |
+
"""批量添加 API Key"""
|
| 226 |
+
new_keys = []
|
| 227 |
+
for i in range(1, count + 1):
|
| 228 |
+
name = f"{name_prefix}-{i}" if count > 1 else name_prefix
|
| 229 |
+
new_keys.append({
|
| 230 |
+
"key": self.generate_key(),
|
| 231 |
+
"name": name,
|
| 232 |
+
"created_at": int(time.time()),
|
| 233 |
+
"is_active": True,
|
| 234 |
+
"chat_limit": -1,
|
| 235 |
+
"heavy_limit": -1,
|
| 236 |
+
"image_limit": -1,
|
| 237 |
+
"video_limit": -1,
|
| 238 |
+
})
|
| 239 |
+
|
| 240 |
+
self._keys.extend(new_keys)
|
| 241 |
+
await self._save_data()
|
| 242 |
+
logger.info(f"[ApiKey] 批量添加 {count} 个 Key, 前缀: {name_prefix}")
|
| 243 |
+
return new_keys
|
| 244 |
+
|
| 245 |
+
async def delete_key(self, key: str) -> bool:
|
| 246 |
+
"""删除 API Key"""
|
| 247 |
+
initial_len = len(self._keys)
|
| 248 |
+
self._keys = [k for k in self._keys if k["key"] != key]
|
| 249 |
+
|
| 250 |
+
if len(self._keys) != initial_len:
|
| 251 |
+
await self._save_data()
|
| 252 |
+
logger.info(f"[ApiKey] 删除Key: {key[:10]}...")
|
| 253 |
+
return True
|
| 254 |
+
return False
|
| 255 |
+
|
| 256 |
+
async def batch_delete_keys(self, keys: List[str]) -> int:
|
| 257 |
+
"""批量删除 API Key"""
|
| 258 |
+
initial_len = len(self._keys)
|
| 259 |
+
self._keys = [k for k in self._keys if k["key"] not in keys]
|
| 260 |
+
|
| 261 |
+
deleted_count = initial_len - len(self._keys)
|
| 262 |
+
if deleted_count > 0:
|
| 263 |
+
await self._save_data()
|
| 264 |
+
logger.info(f"[ApiKey] 批量删除 {deleted_count} 个 Key")
|
| 265 |
+
return deleted_count
|
| 266 |
+
|
| 267 |
+
async def update_key_status(self, key: str, is_active: bool) -> bool:
|
| 268 |
+
"""更新 Key 状态"""
|
| 269 |
+
for k in self._keys:
|
| 270 |
+
if k["key"] == key:
|
| 271 |
+
k["is_active"] = is_active
|
| 272 |
+
await self._save_data()
|
| 273 |
+
return True
|
| 274 |
+
return False
|
| 275 |
+
|
| 276 |
+
async def batch_update_keys_status(self, keys: List[str], is_active: bool) -> int:
|
| 277 |
+
"""批量更新 Key 状态"""
|
| 278 |
+
updated_count = 0
|
| 279 |
+
for k in self._keys:
|
| 280 |
+
if k["key"] in keys:
|
| 281 |
+
if k["is_active"] != is_active:
|
| 282 |
+
k["is_active"] = is_active
|
| 283 |
+
updated_count += 1
|
| 284 |
+
|
| 285 |
+
if updated_count > 0:
|
| 286 |
+
await self._save_data()
|
| 287 |
+
logger.info(f"[ApiKey] 批量更新 {updated_count} 个 Key 状态为: {is_active}")
|
| 288 |
+
return updated_count
|
| 289 |
+
|
| 290 |
+
async def update_key_name(self, key: str, name: str) -> bool:
|
| 291 |
+
"""更新 Key 备注"""
|
| 292 |
+
for k in self._keys:
|
| 293 |
+
if k["key"] == key:
|
| 294 |
+
k["name"] = name
|
| 295 |
+
await self._save_data()
|
| 296 |
+
return True
|
| 297 |
+
return False
|
| 298 |
+
|
| 299 |
+
async def update_key_limits(self, key: str, limits: Dict[str, Any]) -> bool:
|
| 300 |
+
"""更新 Key 每日额度(-1 表示不限)"""
|
| 301 |
+
limits = limits or {}
|
| 302 |
+
for k in self._keys:
|
| 303 |
+
if k.get("key") != key:
|
| 304 |
+
continue
|
| 305 |
+
if "chat_limit" in limits or "chat_per_day" in limits:
|
| 306 |
+
k["chat_limit"] = self._normalize_limit(limits.get("chat_limit", limits.get("chat_per_day")))
|
| 307 |
+
if "heavy_limit" in limits or "heavy_per_day" in limits:
|
| 308 |
+
k["heavy_limit"] = self._normalize_limit(limits.get("heavy_limit", limits.get("heavy_per_day")))
|
| 309 |
+
if "image_limit" in limits or "image_per_day" in limits:
|
| 310 |
+
k["image_limit"] = self._normalize_limit(limits.get("image_limit", limits.get("image_per_day")))
|
| 311 |
+
if "video_limit" in limits or "video_per_day" in limits:
|
| 312 |
+
k["video_limit"] = self._normalize_limit(limits.get("video_limit", limits.get("video_per_day")))
|
| 313 |
+
await self._save_data()
|
| 314 |
+
return True
|
| 315 |
+
return False
|
| 316 |
+
|
| 317 |
+
def get_key_row(self, key: str) -> Optional[Dict[str, Any]]:
|
| 318 |
+
"""获取 Key 原始记录(不要求 active)"""
|
| 319 |
+
for k in self._keys:
|
| 320 |
+
if k.get("key") == key:
|
| 321 |
+
return self._normalize_key_row(k)
|
| 322 |
+
return None
|
| 323 |
+
|
| 324 |
+
async def usage_for_day(self, day: str) -> Dict[str, Dict[str, int]]:
|
| 325 |
+
"""返回指定 day 的 usage map: { key: {chat_used,...} }"""
|
| 326 |
+
await self.init()
|
| 327 |
+
if not self._usage_loaded:
|
| 328 |
+
await self._load_usage_data()
|
| 329 |
+
day_map = self._usage.get(day)
|
| 330 |
+
return day_map if isinstance(day_map, dict) else {}
|
| 331 |
+
|
| 332 |
+
async def usage_today(self) -> Tuple[str, Dict[str, Dict[str, int]]]:
|
| 333 |
+
day = self._day_str()
|
| 334 |
+
return day, await self.usage_for_day(day)
|
| 335 |
+
|
| 336 |
+
async def consume_daily_usage(
|
| 337 |
+
self,
|
| 338 |
+
key: str,
|
| 339 |
+
incs: Dict[str, int],
|
| 340 |
+
tz_offset_minutes: Optional[int] = None,
|
| 341 |
+
) -> bool:
|
| 342 |
+
"""
|
| 343 |
+
Consume per-day quota for the given API key.
|
| 344 |
+
|
| 345 |
+
incs keys: chat_used/heavy_used/image_used/video_used
|
| 346 |
+
"""
|
| 347 |
+
await self.init()
|
| 348 |
+
row = self.get_key_row(key)
|
| 349 |
+
if not row or not row.get("is_active"):
|
| 350 |
+
# Unknown/disabled keys are already rejected by auth; keep best-effort safe here.
|
| 351 |
+
return True
|
| 352 |
+
|
| 353 |
+
if not self._usage_loaded:
|
| 354 |
+
await self._load_usage_data()
|
| 355 |
+
|
| 356 |
+
day = self._day_str(tz_offset_minutes=tz_offset_minutes)
|
| 357 |
+
at_ms = int(time.time() * 1000)
|
| 358 |
+
|
| 359 |
+
# Normalize incs
|
| 360 |
+
normalized: Dict[str, int] = {}
|
| 361 |
+
for k, v in (incs or {}).items():
|
| 362 |
+
try:
|
| 363 |
+
inc = int(v)
|
| 364 |
+
except Exception:
|
| 365 |
+
continue
|
| 366 |
+
if inc <= 0:
|
| 367 |
+
continue
|
| 368 |
+
normalized[k] = inc
|
| 369 |
+
if not normalized:
|
| 370 |
+
return True
|
| 371 |
+
|
| 372 |
+
limits = {
|
| 373 |
+
"chat_used": int(row.get("chat_limit", -1)),
|
| 374 |
+
"heavy_used": int(row.get("heavy_limit", -1)),
|
| 375 |
+
"image_used": int(row.get("image_limit", -1)),
|
| 376 |
+
"video_used": int(row.get("video_limit", -1)),
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
async with self._usage_lock:
|
| 380 |
+
day_map = self._usage.get(day)
|
| 381 |
+
if not isinstance(day_map, dict):
|
| 382 |
+
day_map = {}
|
| 383 |
+
self._usage[day] = day_map # type: ignore[assignment]
|
| 384 |
+
|
| 385 |
+
usage = day_map.get(key)
|
| 386 |
+
if not isinstance(usage, dict):
|
| 387 |
+
usage = {"chat_used": 0, "heavy_used": 0, "image_used": 0, "video_used": 0, "updated_at": at_ms}
|
| 388 |
+
day_map[key] = usage # type: ignore[assignment]
|
| 389 |
+
|
| 390 |
+
# Check all limits first (atomic for multi-bucket)
|
| 391 |
+
for bucket, inc in normalized.items():
|
| 392 |
+
lim = int(limits.get(bucket, -1))
|
| 393 |
+
used = int(usage.get(bucket, 0) or 0)
|
| 394 |
+
if lim >= 0 and used + inc > lim:
|
| 395 |
+
return False
|
| 396 |
+
|
| 397 |
+
# Apply
|
| 398 |
+
for bucket, inc in normalized.items():
|
| 399 |
+
usage[bucket] = int(usage.get(bucket, 0) or 0) + inc
|
| 400 |
+
usage["updated_at"] = at_ms
|
| 401 |
+
|
| 402 |
+
await self._save_usage_data()
|
| 403 |
+
return True
|
| 404 |
+
|
| 405 |
+
def validate_key(self, key: str) -> Optional[Dict]:
|
| 406 |
+
"""验证 Key,返回 Key 信息"""
|
| 407 |
+
# 1. 检查全局配置的 Key (作为默认 admin key)
|
| 408 |
+
global_key = str(get_config("app.api_key", "") or "").strip()
|
| 409 |
+
if global_key and key == global_key:
|
| 410 |
+
return {
|
| 411 |
+
"key": global_key,
|
| 412 |
+
"name": "默认管理员",
|
| 413 |
+
"is_active": True,
|
| 414 |
+
"is_admin": True
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
# 2. 检查多 Key 列表
|
| 418 |
+
for k in self._keys:
|
| 419 |
+
if k["key"] == key:
|
| 420 |
+
if k["is_active"]:
|
| 421 |
+
return {**k, "is_admin": False} # 普通 Key 也可以视为非管理员? 暂不区分权限,只做身份识别
|
| 422 |
+
return None
|
| 423 |
+
|
| 424 |
+
return None
|
| 425 |
+
|
| 426 |
+
def get_all_keys(self) -> List[Dict]:
|
| 427 |
+
"""获取所有 Keys"""
|
| 428 |
+
return [self._normalize_key_row(k) for k in self._keys]
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
# 全局实例
|
| 432 |
+
api_key_manager = ApiKeyManager()
|
app/services/base.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Base service interface will be defined here
|
| 2 |
+
# Placeholder for service abstraction with concurrency control
|
app/services/grok/assets.py
ADDED
|
@@ -0,0 +1,875 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Grok 文件资产服务
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import base64
|
| 7 |
+
import os
|
| 8 |
+
import time
|
| 9 |
+
import hashlib
|
| 10 |
+
import re
|
| 11 |
+
import uuid
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from contextlib import asynccontextmanager
|
| 14 |
+
try:
|
| 15 |
+
import fcntl
|
| 16 |
+
except ImportError: # pragma: no cover - non-posix platforms
|
| 17 |
+
fcntl = None
|
| 18 |
+
from typing import Tuple, List, Dict, Optional, Any
|
| 19 |
+
from urllib.parse import urlparse
|
| 20 |
+
|
| 21 |
+
import aiofiles
|
| 22 |
+
from curl_cffi.requests import AsyncSession
|
| 23 |
+
|
| 24 |
+
from app.core.logger import logger
|
| 25 |
+
from app.core.config import get_config
|
| 26 |
+
from app.core.exceptions import (
|
| 27 |
+
AppException,
|
| 28 |
+
UpstreamException,
|
| 29 |
+
ValidationException
|
| 30 |
+
)
|
| 31 |
+
from app.services.grok.statsig import StatsigService
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# ==================== 常量 ====================
|
| 35 |
+
|
| 36 |
+
UPLOAD_API = "https://grok.com/rest/app-chat/upload-file"
|
| 37 |
+
LIST_API = "https://grok.com/rest/assets"
|
| 38 |
+
DELETE_API = "https://grok.com/rest/assets-metadata"
|
| 39 |
+
DOWNLOAD_API = "https://assets.grok.com"
|
| 40 |
+
LOCK_DIR = Path(__file__).parent.parent.parent.parent / "data" / ".locks"
|
| 41 |
+
|
| 42 |
+
TIMEOUT = 120
|
| 43 |
+
BROWSER = "chrome136"
|
| 44 |
+
DEFAULT_MIME = "application/octet-stream"
|
| 45 |
+
|
| 46 |
+
# 并发控制
|
| 47 |
+
DEFAULT_MAX_CONCURRENT = 25
|
| 48 |
+
DEFAULT_DELETE_BATCH_SIZE = 10
|
| 49 |
+
_ASSETS_SEMAPHORE = asyncio.Semaphore(DEFAULT_MAX_CONCURRENT)
|
| 50 |
+
_ASSETS_SEM_VALUE = DEFAULT_MAX_CONCURRENT
|
| 51 |
+
|
| 52 |
+
def _get_assets_semaphore() -> asyncio.Semaphore:
|
| 53 |
+
global _ASSETS_SEMAPHORE, _ASSETS_SEM_VALUE
|
| 54 |
+
value = get_config("performance.assets_max_concurrent", DEFAULT_MAX_CONCURRENT)
|
| 55 |
+
try:
|
| 56 |
+
value = int(value)
|
| 57 |
+
except Exception:
|
| 58 |
+
value = DEFAULT_MAX_CONCURRENT
|
| 59 |
+
value = max(1, value)
|
| 60 |
+
if value != _ASSETS_SEM_VALUE:
|
| 61 |
+
_ASSETS_SEM_VALUE = value
|
| 62 |
+
_ASSETS_SEMAPHORE = asyncio.Semaphore(value)
|
| 63 |
+
return _ASSETS_SEMAPHORE
|
| 64 |
+
|
| 65 |
+
def _get_delete_batch_size() -> int:
|
| 66 |
+
value = get_config("performance.assets_delete_batch_size", DEFAULT_DELETE_BATCH_SIZE)
|
| 67 |
+
try:
|
| 68 |
+
value = int(value)
|
| 69 |
+
except Exception:
|
| 70 |
+
value = DEFAULT_DELETE_BATCH_SIZE
|
| 71 |
+
return max(1, value)
|
| 72 |
+
|
| 73 |
+
@asynccontextmanager
|
| 74 |
+
async def _file_lock(name: str, timeout: int = 10):
|
| 75 |
+
if fcntl is None:
|
| 76 |
+
yield
|
| 77 |
+
return
|
| 78 |
+
LOCK_DIR.mkdir(parents=True, exist_ok=True)
|
| 79 |
+
lock_path = LOCK_DIR / f"{name}.lock"
|
| 80 |
+
fd = None
|
| 81 |
+
locked = False
|
| 82 |
+
start = time.monotonic()
|
| 83 |
+
try:
|
| 84 |
+
fd = open(lock_path, "a+")
|
| 85 |
+
while True:
|
| 86 |
+
try:
|
| 87 |
+
fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
| 88 |
+
locked = True
|
| 89 |
+
break
|
| 90 |
+
except BlockingIOError:
|
| 91 |
+
if time.monotonic() - start >= timeout:
|
| 92 |
+
break
|
| 93 |
+
await asyncio.sleep(0.05)
|
| 94 |
+
yield
|
| 95 |
+
finally:
|
| 96 |
+
if fd:
|
| 97 |
+
if locked:
|
| 98 |
+
try:
|
| 99 |
+
fcntl.flock(fd, fcntl.LOCK_UN)
|
| 100 |
+
except Exception:
|
| 101 |
+
pass
|
| 102 |
+
try:
|
| 103 |
+
fd.close()
|
| 104 |
+
except Exception:
|
| 105 |
+
pass
|
| 106 |
+
|
| 107 |
+
MIME_TYPES = {
|
| 108 |
+
# 图片
|
| 109 |
+
'.jpg': 'image/jpeg', '.jpeg': 'image/jpeg', '.png': 'image/png',
|
| 110 |
+
'.gif': 'image/gif', '.webp': 'image/webp', '.bmp': 'image/bmp',
|
| 111 |
+
|
| 112 |
+
# 文档
|
| 113 |
+
'.pdf': 'application/pdf', '.txt': 'text/plain', '.md': 'text/markdown',
|
| 114 |
+
'.doc': 'application/msword',
|
| 115 |
+
'.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
| 116 |
+
'.rtf': 'application/rtf',
|
| 117 |
+
|
| 118 |
+
# 表格
|
| 119 |
+
'.csv': 'text/csv',
|
| 120 |
+
'.xls': 'application/vnd.ms-excel',
|
| 121 |
+
'.xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
| 122 |
+
|
| 123 |
+
# 代码
|
| 124 |
+
'.py': 'text/x-python-script', '.js': 'application/javascript', '.ts': 'application/typescript',
|
| 125 |
+
'.java': 'text/x-java', '.cpp': 'text/x-c++', '.c': 'text/x-c',
|
| 126 |
+
'.go': 'text/x-go', '.rs': 'text/x-rust', '.rb': 'text/x-ruby',
|
| 127 |
+
'.php': 'text/x-php', '.sh': 'application/x-sh', '.html': 'text/html',
|
| 128 |
+
'.css': 'text/css', '.sql': 'application/sql',
|
| 129 |
+
|
| 130 |
+
# 数据
|
| 131 |
+
'.json': 'application/json', '.xml': 'application/xml', '.yaml': 'application/x-yaml',
|
| 132 |
+
'.yml': 'application/x-yaml', '.toml': 'application/toml', '.ini': 'text/plain',
|
| 133 |
+
'.log': 'text/plain', '.tmp': 'application/octet-stream',
|
| 134 |
+
|
| 135 |
+
# 其他
|
| 136 |
+
'.graphql': 'application/graphql', '.proto': 'application/x-protobuf',
|
| 137 |
+
'.latex': 'application/x-latex', '.wiki': 'text/plain', '.rst': 'text/x-rst',
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
IMAGE_EXTS = {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp'}
|
| 141 |
+
VIDEO_EXTS = {'.mp4', '.mov', '.m4v', '.webm', '.avi', '.mkv'}
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# ==================== 基础服务 ====================
|
| 145 |
+
|
| 146 |
+
class BaseService:
|
| 147 |
+
"""基础服务类"""
|
| 148 |
+
|
| 149 |
+
def __init__(self, proxy: str = None):
|
| 150 |
+
self.proxy = proxy or get_config("grok.asset_proxy_url") or get_config("grok.base_proxy_url", "")
|
| 151 |
+
self.timeout = get_config("grok.timeout", TIMEOUT)
|
| 152 |
+
self._session: Optional[AsyncSession] = None
|
| 153 |
+
|
| 154 |
+
def _headers(self, token: str, referer: str = "https://grok.com/") -> dict:
|
| 155 |
+
"""构建请求头"""
|
| 156 |
+
headers = {
|
| 157 |
+
"Accept": "*/*",
|
| 158 |
+
"Accept-Encoding": "gzip, deflate, br, zstd",
|
| 159 |
+
"Accept-Language": "zh-CN,zh;q=0.9",
|
| 160 |
+
"Baggage": "sentry-environment=production,sentry-release=d6add6fb0460641fd482d767a335ef72b9b6abb8,sentry-public_key=b311e0f2690c81f25e2c4cf6d4f7ce1c",
|
| 161 |
+
"Cache-Control": "no-cache",
|
| 162 |
+
"Content-Type": "application/json",
|
| 163 |
+
"Origin": "https://grok.com",
|
| 164 |
+
"Pragma": "no-cache",
|
| 165 |
+
"Priority": "u=1, i",
|
| 166 |
+
"Referer": referer,
|
| 167 |
+
"Sec-Ch-Ua": '"Google Chrome";v="136", "Chromium";v="136", "Not(A:Brand";v="24"',
|
| 168 |
+
"Sec-Ch-Ua-Arch": "arm",
|
| 169 |
+
"Sec-Ch-Ua-Bitness": "64",
|
| 170 |
+
"Sec-Ch-Ua-Mobile": "?0",
|
| 171 |
+
"Sec-Ch-Ua-Model": "",
|
| 172 |
+
"Sec-Ch-Ua-Platform": '"macOS"',
|
| 173 |
+
"Sec-Fetch-Dest": "empty",
|
| 174 |
+
"Sec-Fetch-Mode": "cors",
|
| 175 |
+
"Sec-Fetch-Site": "same-origin",
|
| 176 |
+
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36",
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
# Statsig ID
|
| 180 |
+
headers["x-statsig-id"] = StatsigService.gen_id()
|
| 181 |
+
headers["x-xai-request-id"] = str(uuid.uuid4())
|
| 182 |
+
|
| 183 |
+
# Cookie
|
| 184 |
+
token = token[4:] if token.startswith("sso=") else token
|
| 185 |
+
cf = get_config("grok.cf_clearance", "")
|
| 186 |
+
headers["Cookie"] = f"sso={token};cf_clearance={cf}" if cf else f"sso={token}"
|
| 187 |
+
|
| 188 |
+
return headers
|
| 189 |
+
|
| 190 |
+
def _proxies(self) -> Optional[dict]:
|
| 191 |
+
"""构建代理配置"""
|
| 192 |
+
return {"http": self.proxy, "https": self.proxy} if self.proxy else None
|
| 193 |
+
|
| 194 |
+
def _dl_headers(self, token: str, file_path: str) -> dict:
|
| 195 |
+
"""构建下载请求头"""
|
| 196 |
+
headers = {
|
| 197 |
+
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8",
|
| 198 |
+
"Sec-Fetch-Dest": "document",
|
| 199 |
+
"Sec-Fetch-Mode": "navigate",
|
| 200 |
+
"Sec-Fetch-Site": "same-site",
|
| 201 |
+
"Sec-Fetch-User": "?1",
|
| 202 |
+
"Upgrade-Insecure-Requests": "1",
|
| 203 |
+
"Referer": "https://grok.com/",
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
# Cookie
|
| 207 |
+
token = token[4:] if token.startswith("sso=") else token
|
| 208 |
+
cf = get_config("grok.cf_clearance", "")
|
| 209 |
+
headers["Cookie"] = f"sso={token};cf_clearance={cf}" if cf else f"sso={token}"
|
| 210 |
+
|
| 211 |
+
return headers
|
| 212 |
+
|
| 213 |
+
async def _get_session(self) -> AsyncSession:
|
| 214 |
+
"""获取复用 Session"""
|
| 215 |
+
if self._session is None:
|
| 216 |
+
self._session = AsyncSession()
|
| 217 |
+
return self._session
|
| 218 |
+
|
| 219 |
+
async def close(self):
|
| 220 |
+
"""关闭 Session"""
|
| 221 |
+
if self._session:
|
| 222 |
+
await self._session.close()
|
| 223 |
+
self._session = None
|
| 224 |
+
|
| 225 |
+
@staticmethod
|
| 226 |
+
def is_url(input_str: str) -> bool:
|
| 227 |
+
"""检查是否为 URL"""
|
| 228 |
+
try:
|
| 229 |
+
result = urlparse(input_str)
|
| 230 |
+
return all([result.scheme, result.netloc]) and result.scheme in ['http', 'https']
|
| 231 |
+
except:
|
| 232 |
+
return False
|
| 233 |
+
|
| 234 |
+
@staticmethod
|
| 235 |
+
async def fetch(url: str) -> Tuple[str, str, str]:
|
| 236 |
+
"""
|
| 237 |
+
获取远程资源并转 Base64
|
| 238 |
+
|
| 239 |
+
Raises:
|
| 240 |
+
UpstreamException: 当获取失败时
|
| 241 |
+
"""
|
| 242 |
+
try:
|
| 243 |
+
async with AsyncSession() as session:
|
| 244 |
+
response = await session.get(url, timeout=10)
|
| 245 |
+
if response.status_code >= 400:
|
| 246 |
+
raise UpstreamException(
|
| 247 |
+
message=f"Failed to fetch resource: {response.status_code}",
|
| 248 |
+
details={"url": url, "status": response.status_code}
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
filename = url.split('/')[-1].split('?')[0] or 'download'
|
| 252 |
+
content_type = response.headers.get('content-type', DEFAULT_MIME).split(';')[0]
|
| 253 |
+
b64 = base64.b64encode(response.content).decode()
|
| 254 |
+
|
| 255 |
+
logger.debug(f"Fetched: {url} -> {filename}")
|
| 256 |
+
return filename, b64, content_type
|
| 257 |
+
except Exception as e:
|
| 258 |
+
logger.error(f"Fetch failed: {url} - {e}")
|
| 259 |
+
if isinstance(e, AppException):
|
| 260 |
+
raise e
|
| 261 |
+
raise UpstreamException(f"Resource fetch failed: {str(e)}", details={"url": url})
|
| 262 |
+
|
| 263 |
+
@staticmethod
|
| 264 |
+
def parse_b64(data_uri: str) -> Tuple[str, str, str]:
|
| 265 |
+
"""解析 Base64 数据"""
|
| 266 |
+
if data_uri.startswith("data:"):
|
| 267 |
+
match = re.match(r"data:([^;]+);base64,(.+)", data_uri)
|
| 268 |
+
if match:
|
| 269 |
+
mime = match.group(1)
|
| 270 |
+
b64 = match.group(2)
|
| 271 |
+
ext = mime.split('/')[-1] if '/' in mime else 'bin'
|
| 272 |
+
return f"file.{ext}", b64, mime
|
| 273 |
+
return "file.bin", data_uri, DEFAULT_MIME
|
| 274 |
+
|
| 275 |
+
@staticmethod
|
| 276 |
+
def to_b64(file_path: Path, mime_type: str) -> str:
|
| 277 |
+
"""将本地文件转为 base64 data URI"""
|
| 278 |
+
try:
|
| 279 |
+
b64_data = base64.b64encode(file_path.read_bytes()).decode()
|
| 280 |
+
return f"data:{mime_type};base64,{b64_data}"
|
| 281 |
+
except Exception as e:
|
| 282 |
+
logger.error(f"File to base64 failed: {file_path} - {e}")
|
| 283 |
+
raise AppException(f"Failed to read file: {file_path}", code="file_read_error")
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
# ==================== 上传服务 ====================
|
| 287 |
+
|
| 288 |
+
class UploadService(BaseService):
|
| 289 |
+
"""文件上传服务"""
|
| 290 |
+
|
| 291 |
+
async def upload(self, file_input: str, token: str) -> Tuple[str, str]:
|
| 292 |
+
"""
|
| 293 |
+
上传文件到 Grok
|
| 294 |
+
|
| 295 |
+
Returns:
|
| 296 |
+
(file_id, file_uri)
|
| 297 |
+
|
| 298 |
+
Raises:
|
| 299 |
+
ValidationException: 输入无效
|
| 300 |
+
UpstreamException: 上传失败
|
| 301 |
+
"""
|
| 302 |
+
async with _get_assets_semaphore():
|
| 303 |
+
try:
|
| 304 |
+
# 处理输入
|
| 305 |
+
if self.is_url(file_input):
|
| 306 |
+
filename, b64, mime = await self.fetch(file_input)
|
| 307 |
+
else:
|
| 308 |
+
filename, b64, mime = self.parse_b64(file_input)
|
| 309 |
+
|
| 310 |
+
if not b64:
|
| 311 |
+
raise ValidationException("Invalid file input: empty content")
|
| 312 |
+
|
| 313 |
+
# 构建请求
|
| 314 |
+
headers = self._headers(token)
|
| 315 |
+
payload = {
|
| 316 |
+
"fileName": filename,
|
| 317 |
+
"fileMimeType": mime,
|
| 318 |
+
"content": b64,
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
# 执行上传
|
| 322 |
+
session = await self._get_session()
|
| 323 |
+
response = await session.post(
|
| 324 |
+
UPLOAD_API,
|
| 325 |
+
headers=headers,
|
| 326 |
+
json=payload,
|
| 327 |
+
impersonate=BROWSER,
|
| 328 |
+
timeout=self.timeout,
|
| 329 |
+
proxies=self._proxies(),
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
if response.status_code == 200:
|
| 333 |
+
result = response.json()
|
| 334 |
+
file_id = result.get("fileMetadataId", "")
|
| 335 |
+
file_uri = result.get("fileUri", "")
|
| 336 |
+
logger.info(f"Upload success: {filename} -> {file_id}", extra={"file_id": file_id})
|
| 337 |
+
return file_id, file_uri
|
| 338 |
+
|
| 339 |
+
logger.error(
|
| 340 |
+
f"Upload failed: {filename} - {response.status_code}",
|
| 341 |
+
extra={"response": response.text[:200]}
|
| 342 |
+
)
|
| 343 |
+
raise UpstreamException(
|
| 344 |
+
message=f"Upload failed with status {response.status_code}",
|
| 345 |
+
details={"status": response.status_code, "response": response.text[:200]}
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
except Exception as e:
|
| 349 |
+
logger.error(f"Upload error: {e}")
|
| 350 |
+
if isinstance(e, AppException):
|
| 351 |
+
raise e
|
| 352 |
+
raise UpstreamException(f"Upload process error: {str(e)}")
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
# ==================== 列表服务 ====================
|
| 356 |
+
|
| 357 |
+
class ListService(BaseService):
|
| 358 |
+
"""文件列表查询服务"""
|
| 359 |
+
|
| 360 |
+
async def iter_assets(self, token: str):
|
| 361 |
+
"""
|
| 362 |
+
分页迭代资产列表
|
| 363 |
+
"""
|
| 364 |
+
headers = self._headers(token, referer="https://grok.com/files")
|
| 365 |
+
base_params = {
|
| 366 |
+
"pageSize": 50,
|
| 367 |
+
"orderBy": "ORDER_BY_LAST_USE_TIME",
|
| 368 |
+
"source": "SOURCE_ANY",
|
| 369 |
+
"isLatest": "true",
|
| 370 |
+
}
|
| 371 |
+
page_token = None
|
| 372 |
+
seen_tokens = set()
|
| 373 |
+
|
| 374 |
+
async with AsyncSession() as session:
|
| 375 |
+
while True:
|
| 376 |
+
params = dict(base_params)
|
| 377 |
+
if page_token:
|
| 378 |
+
if page_token in seen_tokens:
|
| 379 |
+
logger.warning("List pagination stopped due to repeated page token")
|
| 380 |
+
break
|
| 381 |
+
seen_tokens.add(page_token)
|
| 382 |
+
params["pageToken"] = page_token
|
| 383 |
+
|
| 384 |
+
response = await session.get(
|
| 385 |
+
LIST_API,
|
| 386 |
+
headers=headers,
|
| 387 |
+
params=params,
|
| 388 |
+
impersonate=BROWSER,
|
| 389 |
+
timeout=self.timeout,
|
| 390 |
+
proxies=self._proxies(),
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
if response.status_code != 200:
|
| 394 |
+
logger.error(f"List failed: {response.status_code}")
|
| 395 |
+
raise UpstreamException(
|
| 396 |
+
message=f"List assets failed: {response.status_code}",
|
| 397 |
+
details={"status": response.status_code}
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
result = response.json()
|
| 401 |
+
page_assets = result.get("assets", [])
|
| 402 |
+
yield page_assets
|
| 403 |
+
|
| 404 |
+
page_token = result.get("nextPageToken")
|
| 405 |
+
if not page_token:
|
| 406 |
+
break
|
| 407 |
+
|
| 408 |
+
async def list(self, token: str) -> List[Dict]:
|
| 409 |
+
"""
|
| 410 |
+
查询文件列表
|
| 411 |
+
|
| 412 |
+
Raises:
|
| 413 |
+
UpstreamException: 查询失败
|
| 414 |
+
"""
|
| 415 |
+
try:
|
| 416 |
+
assets: List[Dict] = []
|
| 417 |
+
async for page_assets in self.iter_assets(token):
|
| 418 |
+
assets.extend(page_assets)
|
| 419 |
+
|
| 420 |
+
logger.info(f"List success: {len(assets)} files")
|
| 421 |
+
return assets
|
| 422 |
+
|
| 423 |
+
except Exception as e:
|
| 424 |
+
logger.error(f"List error: {e}")
|
| 425 |
+
if isinstance(e, AppException):
|
| 426 |
+
raise e
|
| 427 |
+
raise UpstreamException(f"List assets error: {str(e)}")
|
| 428 |
+
|
| 429 |
+
async def count(self, token: str) -> int:
|
| 430 |
+
"""
|
| 431 |
+
统计资产数量(不保留明细)
|
| 432 |
+
"""
|
| 433 |
+
try:
|
| 434 |
+
total = 0
|
| 435 |
+
async for page_assets in self.iter_assets(token):
|
| 436 |
+
total += len(page_assets)
|
| 437 |
+
return total
|
| 438 |
+
except Exception as e:
|
| 439 |
+
logger.error(f"List count error: {e}")
|
| 440 |
+
if isinstance(e, AppException):
|
| 441 |
+
raise e
|
| 442 |
+
raise UpstreamException(f"List assets error: {str(e)}")
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
# ==================== 删除服务 ====================
|
| 446 |
+
|
| 447 |
+
class DeleteService(BaseService):
|
| 448 |
+
"""文件删除服务"""
|
| 449 |
+
|
| 450 |
+
async def delete(self, token: str, asset_id: str) -> bool:
|
| 451 |
+
"""
|
| 452 |
+
删除单个文件
|
| 453 |
+
|
| 454 |
+
Raises:
|
| 455 |
+
UpstreamException: 删除失败
|
| 456 |
+
"""
|
| 457 |
+
async with _get_assets_semaphore():
|
| 458 |
+
try:
|
| 459 |
+
headers = self._headers(token, referer="https://grok.com/files")
|
| 460 |
+
url = f"{DELETE_API}/{asset_id}"
|
| 461 |
+
|
| 462 |
+
session = await self._get_session()
|
| 463 |
+
response = await session.delete(
|
| 464 |
+
url,
|
| 465 |
+
headers=headers,
|
| 466 |
+
impersonate=BROWSER,
|
| 467 |
+
timeout=self.timeout,
|
| 468 |
+
proxies=self._proxies(),
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
if response.status_code == 200:
|
| 472 |
+
logger.debug(f"Delete success: {asset_id}")
|
| 473 |
+
return True
|
| 474 |
+
|
| 475 |
+
logger.error(f"Delete failed: {asset_id} - {response.status_code}")
|
| 476 |
+
#: Note: Returning False or raising Exception?
|
| 477 |
+
#: Assuming caller handles Exception for stricter control, or False for loose control.
|
| 478 |
+
#: Given "optimization" and "standardization", raising exceptions is better for API feedback.
|
| 479 |
+
raise UpstreamException(
|
| 480 |
+
message=f"Delete failed: {asset_id}",
|
| 481 |
+
details={"status": response.status_code}
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
except Exception as e:
|
| 485 |
+
logger.error(f"Delete error: {asset_id} - {e}")
|
| 486 |
+
if isinstance(e, AppException):
|
| 487 |
+
raise e
|
| 488 |
+
raise UpstreamException(f"Delete error: {str(e)}")
|
| 489 |
+
|
| 490 |
+
async def delete_all(self, token: str) -> Dict[str, int]:
|
| 491 |
+
"""
|
| 492 |
+
删除所有文件
|
| 493 |
+
"""
|
| 494 |
+
total = 0
|
| 495 |
+
success = 0
|
| 496 |
+
failed = 0
|
| 497 |
+
list_service = ListService(self.proxy)
|
| 498 |
+
try:
|
| 499 |
+
async for assets in list_service.iter_assets(token):
|
| 500 |
+
if not assets:
|
| 501 |
+
continue
|
| 502 |
+
total += len(assets)
|
| 503 |
+
|
| 504 |
+
# 批量并发删除
|
| 505 |
+
async def _delete_one(asset: Dict, index: int) -> bool:
|
| 506 |
+
await asyncio.sleep(0.01 * index)
|
| 507 |
+
asset_id = asset.get("assetId", "")
|
| 508 |
+
if asset_id:
|
| 509 |
+
try:
|
| 510 |
+
return await self.delete(token, asset_id)
|
| 511 |
+
except:
|
| 512 |
+
return False
|
| 513 |
+
return False
|
| 514 |
+
|
| 515 |
+
batch_size = _get_delete_batch_size()
|
| 516 |
+
for i in range(0, len(assets), batch_size):
|
| 517 |
+
batch = assets[i:i + batch_size]
|
| 518 |
+
results = await asyncio.gather(*[
|
| 519 |
+
_delete_one(asset, idx) for idx, asset in enumerate(batch)
|
| 520 |
+
])
|
| 521 |
+
success += sum(results)
|
| 522 |
+
failed += len(batch) - sum(results)
|
| 523 |
+
|
| 524 |
+
if total == 0:
|
| 525 |
+
logger.info("No assets to delete")
|
| 526 |
+
return {"total": 0, "success": 0, "failed": 0, "skipped": True}
|
| 527 |
+
except Exception as e:
|
| 528 |
+
logger.error(f"Delete all failed during list: {e}")
|
| 529 |
+
return {"total": total, "success": success, "failed": failed}
|
| 530 |
+
finally:
|
| 531 |
+
await list_service.close()
|
| 532 |
+
|
| 533 |
+
logger.info(f"Delete all: total={total}, success={success}, failed={failed}")
|
| 534 |
+
return {"total": total, "success": success, "failed": failed}
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
# ==================== 下载服务 ====================
|
| 538 |
+
|
| 539 |
+
class DownloadService(BaseService):
|
| 540 |
+
"""文件下载服务"""
|
| 541 |
+
|
| 542 |
+
def __init__(self, proxy: str = None):
|
| 543 |
+
super().__init__(proxy)
|
| 544 |
+
# 创建缓存目录
|
| 545 |
+
self.base_dir = Path(__file__).parent.parent.parent.parent / "data" / "tmp"
|
| 546 |
+
self.legacy_base_dir = Path(__file__).parent.parent.parent.parent / "data" / "temp"
|
| 547 |
+
self.legacy_image_dir = self.legacy_base_dir / "image"
|
| 548 |
+
self.legacy_video_dir = self.legacy_base_dir / "video"
|
| 549 |
+
self.image_dir = self.base_dir / "image"
|
| 550 |
+
self.video_dir = self.base_dir / "video"
|
| 551 |
+
self.image_dir.mkdir(parents=True, exist_ok=True)
|
| 552 |
+
self.video_dir.mkdir(parents=True, exist_ok=True)
|
| 553 |
+
self._cleanup_running = False
|
| 554 |
+
|
| 555 |
+
def _cache_path(self, file_path: str, media_type: str) -> Path:
|
| 556 |
+
"""获取缓存路径"""
|
| 557 |
+
cache_dir = self.image_dir if media_type == "image" else self.video_dir
|
| 558 |
+
filename = file_path.lstrip('/').replace('/', '-')
|
| 559 |
+
return cache_dir / filename
|
| 560 |
+
|
| 561 |
+
def _legacy_cache_path(self, file_path: str, media_type: str) -> Path:
|
| 562 |
+
"""Legacy cache path (data/temp)."""
|
| 563 |
+
cache_dir = self.legacy_image_dir if media_type == "image" else self.legacy_video_dir
|
| 564 |
+
filename = file_path.lstrip("/").replace("/", "-")
|
| 565 |
+
return cache_dir / filename
|
| 566 |
+
|
| 567 |
+
async def download(self, file_path: str, token: str, media_type: str = "image") -> Tuple[Optional[Path], str]:
|
| 568 |
+
"""
|
| 569 |
+
下载文件到本地
|
| 570 |
+
|
| 571 |
+
Raises:
|
| 572 |
+
UpstreamException: 下载失败
|
| 573 |
+
"""
|
| 574 |
+
async with _get_assets_semaphore():
|
| 575 |
+
try:
|
| 576 |
+
# Be forgiving: callers may pass absolute URLs.
|
| 577 |
+
if isinstance(file_path, str) and file_path.startswith("http"):
|
| 578 |
+
try:
|
| 579 |
+
file_path = urlparse(file_path).path
|
| 580 |
+
except Exception:
|
| 581 |
+
pass
|
| 582 |
+
|
| 583 |
+
cache_path = self._cache_path(file_path, media_type)
|
| 584 |
+
|
| 585 |
+
# 如果已缓存
|
| 586 |
+
if cache_path.exists():
|
| 587 |
+
logger.debug(f"Cache hit: {cache_path}")
|
| 588 |
+
mime_type = MIME_TYPES.get(cache_path.suffix.lower(), DEFAULT_MIME)
|
| 589 |
+
return cache_path, mime_type
|
| 590 |
+
|
| 591 |
+
legacy_path = self._legacy_cache_path(file_path, media_type)
|
| 592 |
+
if legacy_path.exists():
|
| 593 |
+
logger.debug(f"Legacy cache hit: {legacy_path}")
|
| 594 |
+
mime_type = MIME_TYPES.get(legacy_path.suffix.lower(), DEFAULT_MIME)
|
| 595 |
+
return legacy_path, mime_type
|
| 596 |
+
|
| 597 |
+
lock_name = f"download_{media_type}_{hashlib.sha1(str(cache_path).encode('utf-8')).hexdigest()[:16]}"
|
| 598 |
+
async with _file_lock(lock_name, timeout=10):
|
| 599 |
+
# Double-check after lock
|
| 600 |
+
if cache_path.exists():
|
| 601 |
+
logger.debug(f"Cache hit after lock: {cache_path}")
|
| 602 |
+
mime_type = MIME_TYPES.get(cache_path.suffix.lower(), DEFAULT_MIME)
|
| 603 |
+
return cache_path, mime_type
|
| 604 |
+
|
| 605 |
+
# 下载文件
|
| 606 |
+
if not file_path.startswith("/"):
|
| 607 |
+
file_path = f"/{file_path}"
|
| 608 |
+
|
| 609 |
+
url = f"{DOWNLOAD_API}{file_path}"
|
| 610 |
+
headers = self._dl_headers(token, file_path)
|
| 611 |
+
|
| 612 |
+
session = await self._get_session()
|
| 613 |
+
response = await session.get(
|
| 614 |
+
url,
|
| 615 |
+
headers=headers,
|
| 616 |
+
proxies=self._proxies(),
|
| 617 |
+
timeout=self.timeout,
|
| 618 |
+
allow_redirects=True,
|
| 619 |
+
impersonate=BROWSER,
|
| 620 |
+
stream=True,
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
if response.status_code != 200:
|
| 624 |
+
raise UpstreamException(
|
| 625 |
+
message=f"Download failed: {response.status_code}",
|
| 626 |
+
details={"path": file_path, "status": response.status_code}
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
# 保存文件(分块写入,避免大文件占用内存)
|
| 630 |
+
tmp_path = cache_path.with_suffix(cache_path.suffix + ".tmp")
|
| 631 |
+
try:
|
| 632 |
+
async with aiofiles.open(tmp_path, "wb") as f:
|
| 633 |
+
if hasattr(response, "aiter_content"):
|
| 634 |
+
async for chunk in response.aiter_content():
|
| 635 |
+
if chunk:
|
| 636 |
+
await f.write(chunk)
|
| 637 |
+
elif hasattr(response, "aiter_bytes"):
|
| 638 |
+
async for chunk in response.aiter_bytes():
|
| 639 |
+
if chunk:
|
| 640 |
+
await f.write(chunk)
|
| 641 |
+
elif hasattr(response, "aiter_raw"):
|
| 642 |
+
async for chunk in response.aiter_raw():
|
| 643 |
+
if chunk:
|
| 644 |
+
await f.write(chunk)
|
| 645 |
+
else:
|
| 646 |
+
await f.write(response.content)
|
| 647 |
+
os.replace(tmp_path, cache_path)
|
| 648 |
+
finally:
|
| 649 |
+
if tmp_path.exists() and not cache_path.exists():
|
| 650 |
+
try:
|
| 651 |
+
tmp_path.unlink()
|
| 652 |
+
except Exception:
|
| 653 |
+
pass
|
| 654 |
+
mime_type = response.headers.get('content-type', DEFAULT_MIME).split(';')[0]
|
| 655 |
+
|
| 656 |
+
logger.info(f"Download success: {file_path}")
|
| 657 |
+
|
| 658 |
+
# 检查缓存限制
|
| 659 |
+
asyncio.create_task(self.check_limit())
|
| 660 |
+
|
| 661 |
+
return cache_path, mime_type
|
| 662 |
+
|
| 663 |
+
except Exception as e:
|
| 664 |
+
logger.error(f"Download failed: {file_path} - {e}")
|
| 665 |
+
if isinstance(e, AppException):
|
| 666 |
+
raise e
|
| 667 |
+
raise UpstreamException(f"Download error: {str(e)}")
|
| 668 |
+
|
| 669 |
+
async def to_base64(
|
| 670 |
+
self,
|
| 671 |
+
file_path: str,
|
| 672 |
+
token: str,
|
| 673 |
+
media_type: str = "image"
|
| 674 |
+
) -> str:
|
| 675 |
+
"""
|
| 676 |
+
下载文件并转为 base64
|
| 677 |
+
"""
|
| 678 |
+
try:
|
| 679 |
+
cache_path, mime_type = await self.download(file_path, token, media_type)
|
| 680 |
+
|
| 681 |
+
if not cache_path or not cache_path.exists():
|
| 682 |
+
raise AppException("File download returned invalid path")
|
| 683 |
+
|
| 684 |
+
# 使用基础服务的工具方法转换
|
| 685 |
+
data_uri = self.to_b64(cache_path, mime_type)
|
| 686 |
+
|
| 687 |
+
# 默认保留文件到本地缓存,便于后台“缓存管理”统计与复用;
|
| 688 |
+
# 如需转为临时模式,可通过 cache.keep_base64_cache=false 关闭保留。
|
| 689 |
+
keep_cache = get_config("cache.keep_base64_cache", True)
|
| 690 |
+
if data_uri and not keep_cache:
|
| 691 |
+
try:
|
| 692 |
+
cache_path.unlink()
|
| 693 |
+
except Exception as e:
|
| 694 |
+
logger.warning(f"Delete temp file failed: {e}")
|
| 695 |
+
|
| 696 |
+
return data_uri
|
| 697 |
+
|
| 698 |
+
except Exception as e:
|
| 699 |
+
logger.error(f"To base64 failed: {file_path} - {e}")
|
| 700 |
+
if isinstance(e, AppException):
|
| 701 |
+
raise e
|
| 702 |
+
raise AppException(f"Base64 conversion failed: {str(e)}")
|
| 703 |
+
|
| 704 |
+
def get_stats(self, media_type: str = "image") -> Dict[str, Any]:
|
| 705 |
+
"""获取缓存统计"""
|
| 706 |
+
cache_dir = self.image_dir if media_type == "image" else self.video_dir
|
| 707 |
+
if not cache_dir.exists():
|
| 708 |
+
return {"count": 0, "size_mb": 0.0}
|
| 709 |
+
|
| 710 |
+
# 统计目录下所有文件(有些资产路径可能不带标准后缀名)
|
| 711 |
+
files = [f for f in cache_dir.glob("*") if f.is_file()]
|
| 712 |
+
total_size = sum(f.stat().st_size for f in files)
|
| 713 |
+
|
| 714 |
+
return {
|
| 715 |
+
"count": len(files),
|
| 716 |
+
"size_mb": round(total_size / 1024 / 1024, 2)
|
| 717 |
+
}
|
| 718 |
+
|
| 719 |
+
def list_files(self, media_type: str = "image", page: int = 1, page_size: int = 1000) -> Dict[str, Any]:
|
| 720 |
+
"""列出本地缓存文件"""
|
| 721 |
+
cache_dir = self.image_dir if media_type == "image" else self.video_dir
|
| 722 |
+
if not cache_dir.exists():
|
| 723 |
+
return {"total": 0, "page": page, "page_size": page_size, "items": []}
|
| 724 |
+
|
| 725 |
+
files = [f for f in cache_dir.glob("*") if f.is_file()]
|
| 726 |
+
items = []
|
| 727 |
+
for f in files:
|
| 728 |
+
try:
|
| 729 |
+
stat = f.stat()
|
| 730 |
+
items.append({
|
| 731 |
+
"name": f.name,
|
| 732 |
+
"size_bytes": stat.st_size,
|
| 733 |
+
"mtime_ms": int(stat.st_mtime * 1000),
|
| 734 |
+
})
|
| 735 |
+
except Exception:
|
| 736 |
+
continue
|
| 737 |
+
|
| 738 |
+
items.sort(key=lambda x: x["mtime_ms"], reverse=True)
|
| 739 |
+
total = len(items)
|
| 740 |
+
start = max(0, (page - 1) * page_size)
|
| 741 |
+
end = start + page_size
|
| 742 |
+
paged = items[start:end]
|
| 743 |
+
|
| 744 |
+
if media_type == "image":
|
| 745 |
+
for item in paged:
|
| 746 |
+
item["view_url"] = f"/v1/files/image/{item['name']}"
|
| 747 |
+
else:
|
| 748 |
+
preview_map = {}
|
| 749 |
+
if self.image_dir.exists():
|
| 750 |
+
for img in self.image_dir.glob("*"):
|
| 751 |
+
if img.is_file() and img.suffix.lower() in IMAGE_EXTS:
|
| 752 |
+
preview_map.setdefault(img.stem, img.name)
|
| 753 |
+
for item in paged:
|
| 754 |
+
item["view_url"] = f"/v1/files/video/{item['name']}"
|
| 755 |
+
preview_name = preview_map.get(Path(item["name"]).stem)
|
| 756 |
+
if preview_name:
|
| 757 |
+
item["preview_url"] = f"/v1/files/image/{preview_name}"
|
| 758 |
+
|
| 759 |
+
return {"total": total, "page": page, "page_size": page_size, "items": paged}
|
| 760 |
+
|
| 761 |
+
def delete_file(self, media_type: str, name: str) -> Dict[str, Any]:
|
| 762 |
+
"""删除单个缓存文件"""
|
| 763 |
+
cache_dir = self.image_dir if media_type == "image" else self.video_dir
|
| 764 |
+
safe_name = name.replace("/", "-")
|
| 765 |
+
file_path = cache_dir / safe_name
|
| 766 |
+
if not file_path.exists():
|
| 767 |
+
return {"deleted": False}
|
| 768 |
+
try:
|
| 769 |
+
file_path.unlink()
|
| 770 |
+
return {"deleted": True}
|
| 771 |
+
except Exception:
|
| 772 |
+
return {"deleted": False}
|
| 773 |
+
|
| 774 |
+
def clear(self, media_type: str = "image") -> Dict[str, Any]:
|
| 775 |
+
"""清空��存"""
|
| 776 |
+
cache_dir = self.image_dir if media_type == "image" else self.video_dir
|
| 777 |
+
if not cache_dir.exists():
|
| 778 |
+
return {"count": 0, "size_mb": 0.0}
|
| 779 |
+
|
| 780 |
+
files = list(cache_dir.glob("*"))
|
| 781 |
+
total_size = sum(f.stat().st_size for f in files)
|
| 782 |
+
count = 0
|
| 783 |
+
|
| 784 |
+
for f in files:
|
| 785 |
+
try:
|
| 786 |
+
f.unlink()
|
| 787 |
+
count += 1
|
| 788 |
+
except Exception as e:
|
| 789 |
+
logger.error(f"Failed to delete {f}: {e}")
|
| 790 |
+
|
| 791 |
+
return {
|
| 792 |
+
"count": count,
|
| 793 |
+
"size_mb": round(total_size / 1024 / 1024, 2)
|
| 794 |
+
}
|
| 795 |
+
|
| 796 |
+
async def check_limit(self):
|
| 797 |
+
"""检查并清理缓存限制"""
|
| 798 |
+
if self._cleanup_running:
|
| 799 |
+
return
|
| 800 |
+
self._cleanup_running = True
|
| 801 |
+
try:
|
| 802 |
+
async with _file_lock("cache_cleanup", timeout=5):
|
| 803 |
+
if not get_config("cache.enable_auto_clean", True):
|
| 804 |
+
return
|
| 805 |
+
|
| 806 |
+
limit_mb = get_config("cache.limit_mb", 1024)
|
| 807 |
+
|
| 808 |
+
# 统计总大小
|
| 809 |
+
total_size = 0
|
| 810 |
+
all_files = []
|
| 811 |
+
|
| 812 |
+
for d in [self.image_dir, self.video_dir]:
|
| 813 |
+
if d.exists():
|
| 814 |
+
for f in d.glob("*"):
|
| 815 |
+
try:
|
| 816 |
+
stat = f.stat()
|
| 817 |
+
total_size += stat.st_size
|
| 818 |
+
all_files.append((f, stat.st_mtime, stat.st_size))
|
| 819 |
+
except:
|
| 820 |
+
pass
|
| 821 |
+
|
| 822 |
+
current_mb = total_size / 1024 / 1024
|
| 823 |
+
if current_mb <= limit_mb:
|
| 824 |
+
return
|
| 825 |
+
|
| 826 |
+
# 需要清理
|
| 827 |
+
logger.info(f"Cache limit exceeded ({current_mb:.2f}MB > {limit_mb}MB), cleaning up...")
|
| 828 |
+
|
| 829 |
+
# 按时间排序
|
| 830 |
+
all_files.sort(key=lambda x: x[1])
|
| 831 |
+
|
| 832 |
+
deleted_count = 0
|
| 833 |
+
deleted_size = 0
|
| 834 |
+
target_mb = limit_mb * 0.8 # 清理到 80%
|
| 835 |
+
|
| 836 |
+
for f, _, size in all_files:
|
| 837 |
+
try:
|
| 838 |
+
f.unlink()
|
| 839 |
+
deleted_count += 1
|
| 840 |
+
deleted_size += size
|
| 841 |
+
total_size -= size
|
| 842 |
+
|
| 843 |
+
if (total_size / 1024 / 1024) <= target_mb:
|
| 844 |
+
break
|
| 845 |
+
except Exception as e:
|
| 846 |
+
logger.error(f"Cleanup failed for {f}: {e}")
|
| 847 |
+
|
| 848 |
+
logger.info(f"Cache cleanup: deleted {deleted_count} files ({deleted_size/1024/1024:.2f}MB)")
|
| 849 |
+
finally:
|
| 850 |
+
self._cleanup_running = False
|
| 851 |
+
|
| 852 |
+
def get_public_url(self, file_path: str) -> str:
|
| 853 |
+
"""
|
| 854 |
+
获取文件的公共访问 URL
|
| 855 |
+
|
| 856 |
+
如果配置了 app_url,则返回自托管 URL,否则返回 Grok 原始 URL
|
| 857 |
+
"""
|
| 858 |
+
app_url = get_config("app.app_url", "")
|
| 859 |
+
if not app_url:
|
| 860 |
+
return f"{DOWNLOAD_API}{file_path if file_path.startswith('/') else '/' + file_path}"
|
| 861 |
+
|
| 862 |
+
if not file_path.startswith("/"):
|
| 863 |
+
file_path = f"/{file_path}"
|
| 864 |
+
|
| 865 |
+
# 自动添加 /v1/files 前缀
|
| 866 |
+
return f"{app_url.rstrip('/')}/v1/files{file_path}"
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
__all__ = [
|
| 870 |
+
"BaseService",
|
| 871 |
+
"UploadService",
|
| 872 |
+
"ListService",
|
| 873 |
+
"DeleteService",
|
| 874 |
+
"DownloadService",
|
| 875 |
+
]
|
app/services/grok/chat.py
ADDED
|
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Grok Chat 服务
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import uuid
|
| 7 |
+
import orjson
|
| 8 |
+
from typing import Dict, List, Any
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
|
| 11 |
+
from curl_cffi.requests import AsyncSession
|
| 12 |
+
|
| 13 |
+
from app.core.logger import logger
|
| 14 |
+
from app.core.config import get_config
|
| 15 |
+
from app.core.exceptions import (
|
| 16 |
+
AppException,
|
| 17 |
+
UpstreamException,
|
| 18 |
+
ValidationException,
|
| 19 |
+
ErrorType
|
| 20 |
+
)
|
| 21 |
+
from app.services.grok.statsig import StatsigService
|
| 22 |
+
from app.services.grok.model import ModelService
|
| 23 |
+
from app.services.grok.assets import UploadService
|
| 24 |
+
from app.services.grok.processor import StreamProcessor, CollectProcessor
|
| 25 |
+
from app.services.grok.retry import retry_on_status
|
| 26 |
+
from app.services.token import get_token_manager
|
| 27 |
+
from app.services.request_stats import request_stats
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
CHAT_API = "https://grok.com/rest/app-chat/conversations/new"
|
| 31 |
+
TIMEOUT = 120
|
| 32 |
+
BROWSER = "chrome136"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class ChatRequest:
|
| 37 |
+
"""聊天请求数据"""
|
| 38 |
+
model: str
|
| 39 |
+
messages: List[Dict[str, Any]]
|
| 40 |
+
stream: bool = None
|
| 41 |
+
think: bool = None
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class MessageExtractor:
|
| 45 |
+
"""消息内容提取器"""
|
| 46 |
+
|
| 47 |
+
# 需要上传的类型
|
| 48 |
+
UPLOAD_TYPES = {"image_url", "input_audio", "file"}
|
| 49 |
+
# 视频模式不支持的类型
|
| 50 |
+
VIDEO_UNSUPPORTED = {"input_audio", "file"}
|
| 51 |
+
|
| 52 |
+
@staticmethod
|
| 53 |
+
def extract(messages: List[Dict[str, Any]], is_video: bool = False) -> tuple[str, List[str]]:
|
| 54 |
+
"""
|
| 55 |
+
从 OpenAI 消息格式提取内容
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
messages: OpenAI 格式消息列表
|
| 59 |
+
is_video: 是否为视频模型
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
(text, attachments): 拼接后的文本和需要上传的附件列表
|
| 63 |
+
|
| 64 |
+
Raises:
|
| 65 |
+
ValueError: 视频模型遇到不支持的内容类型
|
| 66 |
+
"""
|
| 67 |
+
texts = []
|
| 68 |
+
attachments = [] # 需要上传的附件 (URL 或 base64)
|
| 69 |
+
|
| 70 |
+
# 先抽取每条消息的文本,保留角色信息用于合并
|
| 71 |
+
extracted: List[Dict[str, str]] = []
|
| 72 |
+
|
| 73 |
+
for msg in messages:
|
| 74 |
+
role = msg.get("role", "")
|
| 75 |
+
content = msg.get("content", "")
|
| 76 |
+
parts = []
|
| 77 |
+
|
| 78 |
+
# 简单字符串内容
|
| 79 |
+
if isinstance(content, str):
|
| 80 |
+
if content.strip():
|
| 81 |
+
parts.append(content)
|
| 82 |
+
|
| 83 |
+
# 列表格式内容
|
| 84 |
+
elif isinstance(content, list):
|
| 85 |
+
for item in content:
|
| 86 |
+
item_type = item.get("type", "")
|
| 87 |
+
|
| 88 |
+
# 文本类型
|
| 89 |
+
if item_type == "text":
|
| 90 |
+
text = item.get("text", "")
|
| 91 |
+
if text.strip():
|
| 92 |
+
parts.append(text)
|
| 93 |
+
|
| 94 |
+
# 图片类型
|
| 95 |
+
elif item_type == "image_url":
|
| 96 |
+
image_data = item.get("image_url", {})
|
| 97 |
+
url = image_data.get("url", "") if isinstance(image_data, dict) else str(image_data)
|
| 98 |
+
if url:
|
| 99 |
+
attachments.append(("image", url))
|
| 100 |
+
|
| 101 |
+
# 音频类型
|
| 102 |
+
elif item_type == "input_audio":
|
| 103 |
+
if is_video:
|
| 104 |
+
raise ValueError("视频模型不支持 input_audio 类型")
|
| 105 |
+
audio_data = item.get("input_audio", {})
|
| 106 |
+
data = audio_data.get("data", "") if isinstance(audio_data, dict) else str(audio_data)
|
| 107 |
+
if data:
|
| 108 |
+
attachments.append(("audio", data))
|
| 109 |
+
|
| 110 |
+
# 文件类型
|
| 111 |
+
elif item_type == "file":
|
| 112 |
+
if is_video:
|
| 113 |
+
raise ValueError("视频模型不支持 file 类型")
|
| 114 |
+
file_data = item.get("file", {})
|
| 115 |
+
# file 可能是 URL 或 base64
|
| 116 |
+
url = file_data.get("url", "") or file_data.get("data", "")
|
| 117 |
+
if isinstance(file_data, str):
|
| 118 |
+
url = file_data
|
| 119 |
+
if url:
|
| 120 |
+
attachments.append(("file", url))
|
| 121 |
+
|
| 122 |
+
if parts:
|
| 123 |
+
extracted.append({"role": role, "text": "\n".join(parts)})
|
| 124 |
+
|
| 125 |
+
# 合并文本
|
| 126 |
+
last_user_index = None
|
| 127 |
+
for i in range(len(extracted) - 1, -1, -1):
|
| 128 |
+
if extracted[i]["role"] == "user":
|
| 129 |
+
last_user_index = i
|
| 130 |
+
break
|
| 131 |
+
|
| 132 |
+
for i, item in enumerate(extracted):
|
| 133 |
+
role = item["role"] or "user"
|
| 134 |
+
text = item["text"]
|
| 135 |
+
if i == last_user_index:
|
| 136 |
+
texts.append(text)
|
| 137 |
+
else:
|
| 138 |
+
texts.append(f"{role}: {text}")
|
| 139 |
+
|
| 140 |
+
# 换行拼接文本
|
| 141 |
+
message = "\n\n".join(texts)
|
| 142 |
+
return message, attachments
|
| 143 |
+
|
| 144 |
+
@staticmethod
|
| 145 |
+
def extract_text_only(messages: List[Dict[str, Any]]) -> str:
|
| 146 |
+
"""仅提取文本内容"""
|
| 147 |
+
text, _ = MessageExtractor.extract(messages, is_video=True)
|
| 148 |
+
return text
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class ChatRequestBuilder:
|
| 152 |
+
"""请求构造器"""
|
| 153 |
+
|
| 154 |
+
@staticmethod
|
| 155 |
+
def build_headers(token: str) -> Dict[str, str]:
|
| 156 |
+
"""构造请求头"""
|
| 157 |
+
headers = {
|
| 158 |
+
"Accept": "*/*",
|
| 159 |
+
"Accept-Encoding": "gzip, deflate, br, zstd",
|
| 160 |
+
"Accept-Language": "zh-CN,zh;q=0.9",
|
| 161 |
+
"Baggage": "sentry-environment=production,sentry-release=d6add6fb0460641fd482d767a335ef72b9b6abb8,sentry-public_key=b311e0f2690c81f25e2c4cf6d4f7ce1c",
|
| 162 |
+
"Cache-Control": "no-cache",
|
| 163 |
+
"Content-Type": "application/json",
|
| 164 |
+
"Origin": "https://grok.com",
|
| 165 |
+
"Pragma": "no-cache",
|
| 166 |
+
"Priority": "u=1, i",
|
| 167 |
+
"Referer": "https://grok.com/",
|
| 168 |
+
"Sec-Ch-Ua": '"Google Chrome";v="136", "Chromium";v="136", "Not(A:Brand";v="24"',
|
| 169 |
+
"Sec-Ch-Ua-Arch": "arm",
|
| 170 |
+
"Sec-Ch-Ua-Bitness": "64",
|
| 171 |
+
"Sec-Ch-Ua-Mobile": "?0",
|
| 172 |
+
"Sec-Ch-Ua-Model": "",
|
| 173 |
+
"Sec-Ch-Ua-Platform": '"macOS"',
|
| 174 |
+
"Sec-Fetch-Dest": "empty",
|
| 175 |
+
"Sec-Fetch-Mode": "cors",
|
| 176 |
+
"Sec-Fetch-Site": "same-origin",
|
| 177 |
+
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36",
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
# Statsig ID
|
| 181 |
+
headers["x-statsig-id"] = StatsigService.gen_id()
|
| 182 |
+
headers["x-xai-request-id"] = str(uuid.uuid4())
|
| 183 |
+
|
| 184 |
+
# Cookie
|
| 185 |
+
token = token[4:] if token.startswith("sso=") else token
|
| 186 |
+
cf = get_config("grok.cf_clearance", "")
|
| 187 |
+
headers["Cookie"] = f"sso={token};cf_clearance={cf}" if cf else f"sso={token}"
|
| 188 |
+
|
| 189 |
+
return headers
|
| 190 |
+
|
| 191 |
+
@staticmethod
|
| 192 |
+
def build_payload(
|
| 193 |
+
message: str,
|
| 194 |
+
model: str,
|
| 195 |
+
mode: str,
|
| 196 |
+
think: bool = None,
|
| 197 |
+
file_attachments: List[str] = None,
|
| 198 |
+
image_attachments: List[str] = None
|
| 199 |
+
) -> Dict[str, Any]:
|
| 200 |
+
"""
|
| 201 |
+
构造请求体
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
message: 消息文本
|
| 205 |
+
model: 模型名称
|
| 206 |
+
mode: 模型模式
|
| 207 |
+
think: 是否开启思考
|
| 208 |
+
file_attachments: 文件附件 ID 列表
|
| 209 |
+
image_attachments: 图片附件 URL 列表
|
| 210 |
+
"""
|
| 211 |
+
temporary = get_config("grok.temporary", True)
|
| 212 |
+
if think is None:
|
| 213 |
+
think = get_config("grok.thinking", False)
|
| 214 |
+
|
| 215 |
+
# Upstream payload expects image attachments merged into fileAttachments.
|
| 216 |
+
merged_attachments: List[str] = []
|
| 217 |
+
if file_attachments:
|
| 218 |
+
merged_attachments.extend(file_attachments)
|
| 219 |
+
if image_attachments:
|
| 220 |
+
merged_attachments.extend(image_attachments)
|
| 221 |
+
|
| 222 |
+
return {
|
| 223 |
+
"temporary": temporary,
|
| 224 |
+
"modelName": model,
|
| 225 |
+
"modelMode": mode,
|
| 226 |
+
"message": message,
|
| 227 |
+
"fileAttachments": merged_attachments,
|
| 228 |
+
"imageAttachments": [],
|
| 229 |
+
"disableSearch": False,
|
| 230 |
+
"enableImageGeneration": True,
|
| 231 |
+
"returnImageBytes": False,
|
| 232 |
+
"returnRawGrokInXaiRequest": False,
|
| 233 |
+
"enableImageStreaming": True,
|
| 234 |
+
"imageGenerationCount": 2,
|
| 235 |
+
"forceConcise": False,
|
| 236 |
+
"toolOverrides": {},
|
| 237 |
+
"enableSideBySide": True,
|
| 238 |
+
"sendFinalMetadata": True,
|
| 239 |
+
"isReasoning": False,
|
| 240 |
+
"disableTextFollowUps": False,
|
| 241 |
+
"responseMetadata": {
|
| 242 |
+
"modelConfigOverride": {"modelMap": {}},
|
| 243 |
+
"requestModelDetails": {"modelId": model}
|
| 244 |
+
},
|
| 245 |
+
"disableMemory": False,
|
| 246 |
+
"forceSideBySide": False,
|
| 247 |
+
"isAsyncChat": False,
|
| 248 |
+
"disableSelfHarmShortCircuit": False,
|
| 249 |
+
"deviceEnvInfo": {
|
| 250 |
+
"darkModeEnabled": False,
|
| 251 |
+
"devicePixelRatio": 2,
|
| 252 |
+
"screenWidth": 2056,
|
| 253 |
+
"screenHeight": 1329,
|
| 254 |
+
"viewportWidth": 2056,
|
| 255 |
+
"viewportHeight": 1083
|
| 256 |
+
}
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# ==================== Grok 服务 ====================
|
| 261 |
+
|
| 262 |
+
class GrokChatService:
|
| 263 |
+
"""Grok API 调用服务"""
|
| 264 |
+
|
| 265 |
+
def __init__(self, proxy: str = None):
|
| 266 |
+
self.proxy = proxy or get_config("grok.base_proxy_url", "")
|
| 267 |
+
|
| 268 |
+
async def chat(
|
| 269 |
+
self,
|
| 270 |
+
token: str,
|
| 271 |
+
message: str,
|
| 272 |
+
model: str = "grok-3",
|
| 273 |
+
mode: str = "MODEL_MODE_FAST",
|
| 274 |
+
think: bool = None,
|
| 275 |
+
stream: bool = None,
|
| 276 |
+
file_attachments: List[str] = None,
|
| 277 |
+
image_attachments: List[str] = None
|
| 278 |
+
):
|
| 279 |
+
"""
|
| 280 |
+
发送聊天请求
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
token: 认证 Token
|
| 284 |
+
message: 消息文本
|
| 285 |
+
model: Grok 模型名称
|
| 286 |
+
mode: 模型模式
|
| 287 |
+
think: 是否开启思考
|
| 288 |
+
stream: 是否流式
|
| 289 |
+
file_attachments: 文件附件 ID 列表
|
| 290 |
+
image_attachments: 图片附件 URL 列表
|
| 291 |
+
|
| 292 |
+
Raises:
|
| 293 |
+
UpstreamException: 当 Grok API 返回错误且重试耗尽时
|
| 294 |
+
"""
|
| 295 |
+
if stream is None:
|
| 296 |
+
stream = get_config("grok.stream", True)
|
| 297 |
+
|
| 298 |
+
headers = ChatRequestBuilder.build_headers(token)
|
| 299 |
+
payload = ChatRequestBuilder.build_payload(
|
| 300 |
+
message, model, mode, think,
|
| 301 |
+
file_attachments, image_attachments
|
| 302 |
+
)
|
| 303 |
+
proxies = {"http": self.proxy, "https": self.proxy} if self.proxy else None
|
| 304 |
+
timeout = get_config("grok.timeout", TIMEOUT)
|
| 305 |
+
|
| 306 |
+
# 状态码提取器
|
| 307 |
+
def extract_status(e: Exception) -> int | None:
|
| 308 |
+
if isinstance(e, UpstreamException) and e.details:
|
| 309 |
+
return e.details.get("status")
|
| 310 |
+
return None
|
| 311 |
+
|
| 312 |
+
# 建立连接函数
|
| 313 |
+
async def establish_connection():
|
| 314 |
+
"""建立连接并返回 response 对象"""
|
| 315 |
+
session = AsyncSession(impersonate=BROWSER)
|
| 316 |
+
try:
|
| 317 |
+
response = await session.post(
|
| 318 |
+
CHAT_API,
|
| 319 |
+
headers=headers,
|
| 320 |
+
data=orjson.dumps(payload),
|
| 321 |
+
timeout=timeout,
|
| 322 |
+
stream=True,
|
| 323 |
+
proxies=proxies
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
if response.status_code != 200:
|
| 327 |
+
try:
|
| 328 |
+
content = await response.text()
|
| 329 |
+
content = content[:1000] # 限制长度避免日志过大
|
| 330 |
+
except:
|
| 331 |
+
content = "Unable to read response content"
|
| 332 |
+
|
| 333 |
+
logger.error(
|
| 334 |
+
f"Chat failed: {response.status_code}, {content}",
|
| 335 |
+
extra={"status": response.status_code, "token": token[:10] + "..."}
|
| 336 |
+
)
|
| 337 |
+
# 关闭 session 并抛出异常
|
| 338 |
+
try:
|
| 339 |
+
await session.close()
|
| 340 |
+
except:
|
| 341 |
+
pass
|
| 342 |
+
raise UpstreamException(
|
| 343 |
+
message=f"Grok API request failed: {response.status_code}",
|
| 344 |
+
details={"status": response.status_code}
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
# 返回 session 和 response
|
| 348 |
+
return session, response
|
| 349 |
+
|
| 350 |
+
except UpstreamException:
|
| 351 |
+
# 已经处理过的异常,直接抛出
|
| 352 |
+
raise
|
| 353 |
+
except Exception as e:
|
| 354 |
+
# 其他异常,关闭 session 并包装
|
| 355 |
+
logger.error(f"Chat request error: {e}")
|
| 356 |
+
try:
|
| 357 |
+
await session.close()
|
| 358 |
+
except:
|
| 359 |
+
pass
|
| 360 |
+
raise UpstreamException(
|
| 361 |
+
message=f"Chat connection failed: {str(e)}",
|
| 362 |
+
details={"error": str(e)}
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# 建立连接
|
| 366 |
+
session = None
|
| 367 |
+
response = None
|
| 368 |
+
try:
|
| 369 |
+
session, response = await retry_on_status(
|
| 370 |
+
establish_connection,
|
| 371 |
+
extract_status=extract_status
|
| 372 |
+
)
|
| 373 |
+
except Exception as e:
|
| 374 |
+
# 记录失败
|
| 375 |
+
status_code = extract_status(e)
|
| 376 |
+
if status_code:
|
| 377 |
+
token_mgr = await get_token_manager()
|
| 378 |
+
await token_mgr.record_fail(token, status_code, str(e))
|
| 379 |
+
raise
|
| 380 |
+
|
| 381 |
+
# 流式传输
|
| 382 |
+
async def stream_response():
|
| 383 |
+
try:
|
| 384 |
+
async for line in response.aiter_lines():
|
| 385 |
+
yield line
|
| 386 |
+
finally:
|
| 387 |
+
if session:
|
| 388 |
+
await session.close()
|
| 389 |
+
|
| 390 |
+
return stream_response()
|
| 391 |
+
|
| 392 |
+
async def chat_openai(self, token: str, request: ChatRequest):
|
| 393 |
+
"""OpenAI 兼容接口"""
|
| 394 |
+
model_info = ModelService.get(request.model)
|
| 395 |
+
if not model_info:
|
| 396 |
+
raise ValidationException(f"Unknown model: {request.model}")
|
| 397 |
+
|
| 398 |
+
grok_model = model_info.grok_model
|
| 399 |
+
mode = model_info.model_mode
|
| 400 |
+
is_video = model_info.is_video
|
| 401 |
+
|
| 402 |
+
# 提取消息和附件
|
| 403 |
+
try:
|
| 404 |
+
message, attachments = MessageExtractor.extract(request.messages, is_video=is_video)
|
| 405 |
+
except ValueError as e:
|
| 406 |
+
raise ValidationException(str(e))
|
| 407 |
+
|
| 408 |
+
# 处理附件上传
|
| 409 |
+
file_ids = []
|
| 410 |
+
image_ids = []
|
| 411 |
+
|
| 412 |
+
if attachments:
|
| 413 |
+
upload_service = UploadService()
|
| 414 |
+
try:
|
| 415 |
+
for attach_type, attach_data in attachments:
|
| 416 |
+
# 获取 ID
|
| 417 |
+
file_id, _ = await upload_service.upload(attach_data, token)
|
| 418 |
+
|
| 419 |
+
if attach_type == "image":
|
| 420 |
+
# 图片 imageAttachments
|
| 421 |
+
image_ids.append(file_id)
|
| 422 |
+
logger.debug(f"Image uploaded: {file_id}")
|
| 423 |
+
else:
|
| 424 |
+
# 文件 fileAttachments
|
| 425 |
+
file_ids.append(file_id)
|
| 426 |
+
logger.debug(f"File uploaded: {file_id}")
|
| 427 |
+
finally:
|
| 428 |
+
await upload_service.close()
|
| 429 |
+
|
| 430 |
+
stream = request.stream if request.stream is not None else get_config("grok.stream", True)
|
| 431 |
+
think = request.think if request.think is not None else get_config("grok.thinking", False)
|
| 432 |
+
|
| 433 |
+
response = await self.chat(
|
| 434 |
+
token, message, grok_model, mode, think, stream,
|
| 435 |
+
file_attachments=file_ids,
|
| 436 |
+
image_attachments=image_ids
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
return response, stream, request.model
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
# ==================== Chat 业务服务 ====================
|
| 443 |
+
|
| 444 |
+
class ChatService:
|
| 445 |
+
"""Chat 业务服务"""
|
| 446 |
+
|
| 447 |
+
@staticmethod
|
| 448 |
+
async def completions(
|
| 449 |
+
model: str,
|
| 450 |
+
messages: List[Dict[str, Any]],
|
| 451 |
+
stream: bool = None,
|
| 452 |
+
thinking: str = None
|
| 453 |
+
):
|
| 454 |
+
"""
|
| 455 |
+
Chat Completions 入口
|
| 456 |
+
|
| 457 |
+
Args:
|
| 458 |
+
model: 模型名称
|
| 459 |
+
messages: 消息列表
|
| 460 |
+
stream: 是否流式
|
| 461 |
+
thinking: 思考模式
|
| 462 |
+
|
| 463 |
+
Returns:
|
| 464 |
+
AsyncGenerator 或 dict
|
| 465 |
+
"""
|
| 466 |
+
# 获取 token
|
| 467 |
+
try:
|
| 468 |
+
token_mgr = await get_token_manager()
|
| 469 |
+
await token_mgr.reload_if_stale()
|
| 470 |
+
token = token_mgr.get_token_for_model(model)
|
| 471 |
+
except Exception as e:
|
| 472 |
+
logger.error(f"Failed to get token: {e}")
|
| 473 |
+
try:
|
| 474 |
+
await request_stats.record_request(model, success=False)
|
| 475 |
+
except Exception:
|
| 476 |
+
pass
|
| 477 |
+
raise AppException(
|
| 478 |
+
message="Internal service error obtaining token",
|
| 479 |
+
error_type=ErrorType.SERVER.value,
|
| 480 |
+
code="internal_error"
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
if not token:
|
| 484 |
+
try:
|
| 485 |
+
await request_stats.record_request(model, success=False)
|
| 486 |
+
except Exception:
|
| 487 |
+
pass
|
| 488 |
+
raise AppException(
|
| 489 |
+
message="No available tokens. Please try again later.",
|
| 490 |
+
error_type=ErrorType.RATE_LIMIT.value,
|
| 491 |
+
code="rate_limit_exceeded",
|
| 492 |
+
status_code=429
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
# 解析参数
|
| 496 |
+
think = None
|
| 497 |
+
if thinking == "enabled":
|
| 498 |
+
think = True
|
| 499 |
+
elif thinking == "disabled":
|
| 500 |
+
think = False
|
| 501 |
+
|
| 502 |
+
is_stream = stream if stream is not None else get_config("grok.stream", True)
|
| 503 |
+
|
| 504 |
+
# 构造请求
|
| 505 |
+
chat_request = ChatRequest(
|
| 506 |
+
model=model,
|
| 507 |
+
messages=messages,
|
| 508 |
+
stream=is_stream,
|
| 509 |
+
think=think
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
# 请求 Grok
|
| 513 |
+
service = GrokChatService()
|
| 514 |
+
try:
|
| 515 |
+
response, _, model_name = await service.chat_openai(token, chat_request)
|
| 516 |
+
except AppException:
|
| 517 |
+
try:
|
| 518 |
+
await request_stats.record_request(model, success=False)
|
| 519 |
+
except Exception:
|
| 520 |
+
pass
|
| 521 |
+
raise
|
| 522 |
+
except Exception as e:
|
| 523 |
+
logger.error(f"Chat service error: {e}")
|
| 524 |
+
try:
|
| 525 |
+
await request_stats.record_request(model, success=False)
|
| 526 |
+
except Exception:
|
| 527 |
+
pass
|
| 528 |
+
raise UpstreamException(
|
| 529 |
+
message=f"Service processing failed: {str(e)}",
|
| 530 |
+
details={"error": str(e)}
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
# 处理响应
|
| 534 |
+
if is_stream:
|
| 535 |
+
processor = StreamProcessor(model_name, token, think).process(response)
|
| 536 |
+
|
| 537 |
+
async def _wrapped_stream():
|
| 538 |
+
completed = False
|
| 539 |
+
try:
|
| 540 |
+
async for chunk in processor:
|
| 541 |
+
yield chunk
|
| 542 |
+
completed = True
|
| 543 |
+
finally:
|
| 544 |
+
# Only count as "success" when the stream ends naturally.
|
| 545 |
+
try:
|
| 546 |
+
if completed:
|
| 547 |
+
await token_mgr.sync_usage(token, model_name, consume_on_fail=True, is_usage=True)
|
| 548 |
+
await request_stats.record_request(model_name, success=True)
|
| 549 |
+
else:
|
| 550 |
+
await request_stats.record_request(model_name, success=False)
|
| 551 |
+
except Exception:
|
| 552 |
+
pass
|
| 553 |
+
|
| 554 |
+
return _wrapped_stream()
|
| 555 |
+
|
| 556 |
+
result = await CollectProcessor(model_name, token).process(response)
|
| 557 |
+
try:
|
| 558 |
+
await token_mgr.sync_usage(token, model_name, consume_on_fail=True, is_usage=True)
|
| 559 |
+
await request_stats.record_request(model_name, success=True)
|
| 560 |
+
except Exception:
|
| 561 |
+
pass
|
| 562 |
+
return result
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
__all__ = [
|
| 566 |
+
"GrokChatService",
|
| 567 |
+
"ChatRequest",
|
| 568 |
+
"ChatRequestBuilder",
|
| 569 |
+
"MessageExtractor",
|
| 570 |
+
"ChatService",
|
| 571 |
+
]
|
app/services/grok/imagine_experimental.py
ADDED
|
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Experimental imagine/image-edit upstream calls.
|
| 3 |
+
|
| 4 |
+
This module provides:
|
| 5 |
+
- WebSocket imagine generation (ws/imagine/listen)
|
| 6 |
+
- Experimental image-edit payloads via conversations/new
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import asyncio
|
| 12 |
+
import time
|
| 13 |
+
import uuid
|
| 14 |
+
from typing import Any, Awaitable, Callable, Dict, Iterable, List, Optional
|
| 15 |
+
from urllib.parse import urlparse
|
| 16 |
+
|
| 17 |
+
import orjson
|
| 18 |
+
from curl_cffi.requests import AsyncSession
|
| 19 |
+
|
| 20 |
+
from app.core.config import get_config
|
| 21 |
+
from app.core.exceptions import UpstreamException
|
| 22 |
+
from app.core.logger import logger
|
| 23 |
+
from app.services.grok.assets import DownloadService
|
| 24 |
+
from app.services.grok.chat import BROWSER, CHAT_API, ChatRequestBuilder
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
IMAGE_METHOD_LEGACY = "legacy"
|
| 28 |
+
IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL = "imagine_ws_experimental"
|
| 29 |
+
IMAGE_METHODS = {IMAGE_METHOD_LEGACY, IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL}
|
| 30 |
+
IMAGE_METHOD_ALIASES = {
|
| 31 |
+
"imagine_ws": IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL,
|
| 32 |
+
"experimental": IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL,
|
| 33 |
+
"new": IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL,
|
| 34 |
+
"new_method": IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL,
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
IMAGINE_WS_API = "wss://grok.com/ws/imagine/listen"
|
| 38 |
+
ASSET_API = "https://assets.grok.com"
|
| 39 |
+
TIMEOUT = 120
|
| 40 |
+
|
| 41 |
+
ProgressCallback = Callable[[int, float], Optional[Awaitable[None] | None]]
|
| 42 |
+
CompletedCallback = Callable[[int, str], Optional[Awaitable[None] | None]]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def resolve_image_generation_method(raw: Any) -> str:
|
| 46 |
+
candidate = str(raw or "").strip().lower()
|
| 47 |
+
if candidate in IMAGE_METHODS:
|
| 48 |
+
return candidate
|
| 49 |
+
mapped = IMAGE_METHOD_ALIASES.get(candidate)
|
| 50 |
+
if mapped:
|
| 51 |
+
return mapped
|
| 52 |
+
return IMAGE_METHOD_LEGACY
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _normalize_asset_path(raw_url: str) -> str:
|
| 56 |
+
raw = str(raw_url or "").strip()
|
| 57 |
+
if not raw:
|
| 58 |
+
return "/"
|
| 59 |
+
if raw.startswith("http://") or raw.startswith("https://"):
|
| 60 |
+
try:
|
| 61 |
+
path = urlparse(raw).path or "/"
|
| 62 |
+
except Exception:
|
| 63 |
+
path = "/"
|
| 64 |
+
else:
|
| 65 |
+
path = raw
|
| 66 |
+
if not path.startswith("/"):
|
| 67 |
+
path = f"/{path}"
|
| 68 |
+
return path
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class ImagineExperimentalService:
|
| 72 |
+
def __init__(self, proxy: str | None = None):
|
| 73 |
+
self.proxy = proxy or get_config("grok.base_proxy_url", "")
|
| 74 |
+
self.timeout = int(get_config("grok.timeout", TIMEOUT) or TIMEOUT)
|
| 75 |
+
|
| 76 |
+
def _proxies(self) -> Optional[dict]:
|
| 77 |
+
return {"http": self.proxy, "https": self.proxy} if self.proxy else None
|
| 78 |
+
|
| 79 |
+
def _headers(self, token: str, referer: str = "https://grok.com/imagine") -> Dict[str, str]:
|
| 80 |
+
headers = ChatRequestBuilder.build_headers(token)
|
| 81 |
+
headers["Referer"] = referer
|
| 82 |
+
headers["Origin"] = "https://grok.com"
|
| 83 |
+
return headers
|
| 84 |
+
|
| 85 |
+
@staticmethod
|
| 86 |
+
def _build_ws_payload(
|
| 87 |
+
prompt: str,
|
| 88 |
+
request_id: str,
|
| 89 |
+
aspect_ratio: str = "2:3",
|
| 90 |
+
) -> Dict[str, Any]:
|
| 91 |
+
return {
|
| 92 |
+
"type": "conversation.item.create",
|
| 93 |
+
"timestamp": int(time.time() * 1000),
|
| 94 |
+
"item": {
|
| 95 |
+
"type": "message",
|
| 96 |
+
"content": [
|
| 97 |
+
{
|
| 98 |
+
"requestId": request_id,
|
| 99 |
+
"text": prompt,
|
| 100 |
+
"type": "input_scroll",
|
| 101 |
+
"properties": {
|
| 102 |
+
"section_count": 0,
|
| 103 |
+
"is_kids_mode": False,
|
| 104 |
+
"enable_nsfw": True,
|
| 105 |
+
"skip_upsampler": False,
|
| 106 |
+
"is_initial": False,
|
| 107 |
+
"aspect_ratio": aspect_ratio,
|
| 108 |
+
},
|
| 109 |
+
}
|
| 110 |
+
],
|
| 111 |
+
},
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
@staticmethod
|
| 115 |
+
def _extract_url(msg: Dict[str, Any]) -> str:
|
| 116 |
+
for key in ("url", "imageUrl", "image_url"):
|
| 117 |
+
value = msg.get(key)
|
| 118 |
+
if isinstance(value, str) and value.strip():
|
| 119 |
+
return value.strip()
|
| 120 |
+
return ""
|
| 121 |
+
|
| 122 |
+
@staticmethod
|
| 123 |
+
def _extract_progress(msg: Dict[str, Any]) -> Optional[float]:
|
| 124 |
+
for key in ("progress", "percentage_complete", "percentageComplete"):
|
| 125 |
+
value = msg.get(key)
|
| 126 |
+
if value is None:
|
| 127 |
+
continue
|
| 128 |
+
try:
|
| 129 |
+
pct = float(value)
|
| 130 |
+
if pct < 0:
|
| 131 |
+
pct = 0
|
| 132 |
+
if pct > 100:
|
| 133 |
+
pct = 100
|
| 134 |
+
return pct
|
| 135 |
+
except Exception:
|
| 136 |
+
continue
|
| 137 |
+
return None
|
| 138 |
+
|
| 139 |
+
@staticmethod
|
| 140 |
+
def _is_completed(msg: Dict[str, Any], progress: Optional[float]) -> bool:
|
| 141 |
+
status = str(msg.get("current_status") or msg.get("currentStatus") or "").strip().lower()
|
| 142 |
+
if status in {"completed", "done", "success"}:
|
| 143 |
+
return True
|
| 144 |
+
if progress is not None and progress >= 100:
|
| 145 |
+
return True
|
| 146 |
+
return False
|
| 147 |
+
|
| 148 |
+
async def generate_ws(
|
| 149 |
+
self,
|
| 150 |
+
token: str,
|
| 151 |
+
prompt: str,
|
| 152 |
+
n: int = 2,
|
| 153 |
+
aspect_ratio: str = "2:3",
|
| 154 |
+
progress_cb: Optional[ProgressCallback] = None,
|
| 155 |
+
completed_cb: Optional[CompletedCallback] = None,
|
| 156 |
+
timeout: Optional[int] = None,
|
| 157 |
+
) -> List[str]:
|
| 158 |
+
request_id = str(uuid.uuid4())
|
| 159 |
+
target_count = max(1, int(n or 1))
|
| 160 |
+
effective_timeout = max(10, int(timeout or self.timeout))
|
| 161 |
+
payload = self._build_ws_payload(
|
| 162 |
+
prompt=prompt,
|
| 163 |
+
request_id=request_id,
|
| 164 |
+
aspect_ratio=aspect_ratio,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
session = AsyncSession(impersonate=BROWSER)
|
| 168 |
+
ws = None
|
| 169 |
+
started_at = time.monotonic()
|
| 170 |
+
image_indices: Dict[str, int] = {}
|
| 171 |
+
final_urls: Dict[str, str] = {}
|
| 172 |
+
|
| 173 |
+
try:
|
| 174 |
+
ws = await session.ws_connect(
|
| 175 |
+
IMAGINE_WS_API,
|
| 176 |
+
headers=self._headers(token),
|
| 177 |
+
timeout=effective_timeout,
|
| 178 |
+
proxies=self._proxies(),
|
| 179 |
+
impersonate=BROWSER,
|
| 180 |
+
)
|
| 181 |
+
await ws.send_json(payload)
|
| 182 |
+
|
| 183 |
+
while time.monotonic() - started_at < effective_timeout:
|
| 184 |
+
remain = max(1.0, effective_timeout - (time.monotonic() - started_at))
|
| 185 |
+
try:
|
| 186 |
+
msg = await ws.recv_json(timeout=min(5.0, remain))
|
| 187 |
+
except asyncio.TimeoutError:
|
| 188 |
+
continue
|
| 189 |
+
except Exception as e:
|
| 190 |
+
raise UpstreamException(f"Imagine websocket receive failed: {e}") from e
|
| 191 |
+
|
| 192 |
+
if not isinstance(msg, dict):
|
| 193 |
+
continue
|
| 194 |
+
|
| 195 |
+
msg_request_id = str(msg.get("request_id") or msg.get("requestId") or "")
|
| 196 |
+
if msg_request_id and msg_request_id != request_id:
|
| 197 |
+
continue
|
| 198 |
+
|
| 199 |
+
msg_type = str(msg.get("type") or "").lower()
|
| 200 |
+
status = str(msg.get("current_status") or msg.get("currentStatus") or "").lower()
|
| 201 |
+
if msg_type == "error" or status == "error":
|
| 202 |
+
err_code = str(msg.get("err_code") or msg.get("errCode") or "unknown")
|
| 203 |
+
err_msg = str(
|
| 204 |
+
msg.get("err_message") or msg.get("err_msg") or msg.get("error") or "unknown error"
|
| 205 |
+
)
|
| 206 |
+
raise UpstreamException(
|
| 207 |
+
message=f"Imagine websocket error ({err_code}): {err_msg}",
|
| 208 |
+
details={"code": err_code, "message": err_msg},
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
image_id = str(msg.get("id") or msg.get("imageId") or msg.get("image_id") or "")
|
| 212 |
+
if not image_id:
|
| 213 |
+
image_id = f"image-{len(image_indices)}"
|
| 214 |
+
if image_id not in image_indices:
|
| 215 |
+
image_indices[image_id] = len(image_indices)
|
| 216 |
+
|
| 217 |
+
progress = self._extract_progress(msg)
|
| 218 |
+
if progress is not None and progress_cb is not None:
|
| 219 |
+
try:
|
| 220 |
+
maybe_coro = progress_cb(image_indices[image_id], progress)
|
| 221 |
+
if asyncio.iscoroutine(maybe_coro):
|
| 222 |
+
await maybe_coro
|
| 223 |
+
except Exception as e:
|
| 224 |
+
logger.debug(f"Imagine progress callback failed: {e}")
|
| 225 |
+
|
| 226 |
+
image_url = self._extract_url(msg)
|
| 227 |
+
if image_url and self._is_completed(msg, progress):
|
| 228 |
+
is_new = image_id not in final_urls
|
| 229 |
+
final_urls.setdefault(image_id, image_url)
|
| 230 |
+
if is_new and completed_cb is not None:
|
| 231 |
+
try:
|
| 232 |
+
maybe_coro = completed_cb(image_indices[image_id], image_url)
|
| 233 |
+
if asyncio.iscoroutine(maybe_coro):
|
| 234 |
+
await maybe_coro
|
| 235 |
+
except Exception as e:
|
| 236 |
+
logger.debug(f"Imagine completion callback failed: {e}")
|
| 237 |
+
if len(final_urls) >= target_count:
|
| 238 |
+
break
|
| 239 |
+
|
| 240 |
+
if not final_urls:
|
| 241 |
+
raise UpstreamException("Imagine websocket returned no completed images")
|
| 242 |
+
|
| 243 |
+
return list(final_urls.values())
|
| 244 |
+
finally:
|
| 245 |
+
if ws is not None:
|
| 246 |
+
try:
|
| 247 |
+
await ws.close()
|
| 248 |
+
except Exception:
|
| 249 |
+
pass
|
| 250 |
+
try:
|
| 251 |
+
await session.close()
|
| 252 |
+
except Exception:
|
| 253 |
+
pass
|
| 254 |
+
|
| 255 |
+
async def convert_urls(self, token: str, urls: Iterable[str], response_format: str = "b64_json") -> List[str]:
|
| 256 |
+
mode = str(response_format or "b64_json").strip().lower()
|
| 257 |
+
out: List[str] = []
|
| 258 |
+
dl = DownloadService(self.proxy)
|
| 259 |
+
try:
|
| 260 |
+
for raw in urls:
|
| 261 |
+
raw = str(raw or "").strip()
|
| 262 |
+
if not raw:
|
| 263 |
+
continue
|
| 264 |
+
if mode == "url":
|
| 265 |
+
path = _normalize_asset_path(raw)
|
| 266 |
+
if path in {"", "/"}:
|
| 267 |
+
continue
|
| 268 |
+
await dl.download(path, token, "image")
|
| 269 |
+
app_url = str(get_config("app.app_url", "") or "").strip()
|
| 270 |
+
local_path = f"/v1/files/image{path}"
|
| 271 |
+
if app_url:
|
| 272 |
+
out.append(f"{app_url.rstrip('/')}{local_path}")
|
| 273 |
+
else:
|
| 274 |
+
out.append(local_path)
|
| 275 |
+
continue
|
| 276 |
+
|
| 277 |
+
data_uri = await dl.to_base64(raw, token, "image")
|
| 278 |
+
if not data_uri:
|
| 279 |
+
continue
|
| 280 |
+
if "," in data_uri:
|
| 281 |
+
out.append(data_uri.split(",", 1)[1])
|
| 282 |
+
else:
|
| 283 |
+
out.append(data_uri)
|
| 284 |
+
return out
|
| 285 |
+
finally:
|
| 286 |
+
await dl.close()
|
| 287 |
+
|
| 288 |
+
async def convert_url(self, token: str, url: str, response_format: str = "b64_json") -> str:
|
| 289 |
+
items = await self.convert_urls(token=token, urls=[url], response_format=response_format)
|
| 290 |
+
return items[0] if items else ""
|
| 291 |
+
|
| 292 |
+
@staticmethod
|
| 293 |
+
def _to_asset_urls(file_uris: List[str]) -> List[str]:
|
| 294 |
+
out = []
|
| 295 |
+
for uri in file_uris:
|
| 296 |
+
value = str(uri or "").strip()
|
| 297 |
+
if not value:
|
| 298 |
+
continue
|
| 299 |
+
if value.startswith("http://") or value.startswith("https://"):
|
| 300 |
+
out.append(value)
|
| 301 |
+
else:
|
| 302 |
+
out.append(f"{ASSET_API}/{value.lstrip('/')}")
|
| 303 |
+
return out
|
| 304 |
+
|
| 305 |
+
@staticmethod
|
| 306 |
+
def _build_edit_payload(prompt: str, image_urls: List[str], model_name: str) -> Dict[str, Any]:
|
| 307 |
+
model_map = {
|
| 308 |
+
"imageEditModel": "imagine",
|
| 309 |
+
"imageEditModelConfig": {
|
| 310 |
+
"imageReferences": image_urls,
|
| 311 |
+
},
|
| 312 |
+
}
|
| 313 |
+
payload: Dict[str, Any] = {
|
| 314 |
+
"temporary": True,
|
| 315 |
+
"modelName": model_name,
|
| 316 |
+
"message": prompt,
|
| 317 |
+
"fileAttachments": [],
|
| 318 |
+
"imageAttachments": [],
|
| 319 |
+
"disableSearch": False,
|
| 320 |
+
"enableImageGeneration": True,
|
| 321 |
+
"returnImageBytes": False,
|
| 322 |
+
"returnRawGrokInXaiRequest": False,
|
| 323 |
+
"enableImageStreaming": True,
|
| 324 |
+
"imageGenerationCount": 2,
|
| 325 |
+
"forceConcise": False,
|
| 326 |
+
"toolOverrides": {"imageGen": True},
|
| 327 |
+
"enableSideBySide": True,
|
| 328 |
+
"sendFinalMetadata": True,
|
| 329 |
+
"isReasoning": False,
|
| 330 |
+
"disableTextFollowUps": False,
|
| 331 |
+
"disableMemory": False,
|
| 332 |
+
"forceSideBySide": False,
|
| 333 |
+
"isAsyncChat": False,
|
| 334 |
+
"responseMetadata": {
|
| 335 |
+
"modelConfigOverride": {
|
| 336 |
+
"modelMap": model_map,
|
| 337 |
+
},
|
| 338 |
+
"requestModelDetails": {
|
| 339 |
+
"modelId": model_name,
|
| 340 |
+
},
|
| 341 |
+
},
|
| 342 |
+
}
|
| 343 |
+
if model_name == "grok-3":
|
| 344 |
+
payload["modelMode"] = "MODEL_MODE_FAST"
|
| 345 |
+
return payload
|
| 346 |
+
|
| 347 |
+
async def chat_edit(
|
| 348 |
+
self,
|
| 349 |
+
token: str,
|
| 350 |
+
prompt: str,
|
| 351 |
+
file_uris: List[str],
|
| 352 |
+
):
|
| 353 |
+
image_urls = self._to_asset_urls(file_uris)
|
| 354 |
+
if not image_urls:
|
| 355 |
+
raise UpstreamException("Experimental image edit requires at least one uploaded image")
|
| 356 |
+
|
| 357 |
+
headers = self._headers(token, referer="https://grok.com/imagine")
|
| 358 |
+
proxies = self._proxies()
|
| 359 |
+
timeout = self.timeout
|
| 360 |
+
|
| 361 |
+
payloads = [
|
| 362 |
+
self._build_edit_payload(prompt, image_urls, "imagine-image-edit"),
|
| 363 |
+
self._build_edit_payload(prompt, image_urls, "grok-3"),
|
| 364 |
+
]
|
| 365 |
+
|
| 366 |
+
last_error: Optional[Exception] = None
|
| 367 |
+
for payload in payloads:
|
| 368 |
+
session = AsyncSession(impersonate=BROWSER)
|
| 369 |
+
response = None
|
| 370 |
+
try:
|
| 371 |
+
response = await session.post(
|
| 372 |
+
CHAT_API,
|
| 373 |
+
headers=headers,
|
| 374 |
+
data=orjson.dumps(payload),
|
| 375 |
+
timeout=timeout,
|
| 376 |
+
stream=True,
|
| 377 |
+
proxies=proxies,
|
| 378 |
+
)
|
| 379 |
+
if response.status_code != 200:
|
| 380 |
+
try:
|
| 381 |
+
body = await response.text()
|
| 382 |
+
except Exception:
|
| 383 |
+
body = ""
|
| 384 |
+
raise UpstreamException(
|
| 385 |
+
message=f"Experimental image edit request failed: {response.status_code}",
|
| 386 |
+
details={"status": response.status_code, "body": body[:500]},
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
async def _stream_response():
|
| 390 |
+
try:
|
| 391 |
+
async for line in response.aiter_lines():
|
| 392 |
+
yield line
|
| 393 |
+
finally:
|
| 394 |
+
await session.close()
|
| 395 |
+
|
| 396 |
+
return _stream_response()
|
| 397 |
+
except Exception as e:
|
| 398 |
+
last_error = e
|
| 399 |
+
try:
|
| 400 |
+
await session.close()
|
| 401 |
+
except Exception:
|
| 402 |
+
pass
|
| 403 |
+
continue
|
| 404 |
+
|
| 405 |
+
if isinstance(last_error, Exception):
|
| 406 |
+
raise last_error
|
| 407 |
+
raise UpstreamException("Experimental image edit request failed")
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
__all__ = [
|
| 411 |
+
"ImagineExperimentalService",
|
| 412 |
+
"IMAGE_METHOD_LEGACY",
|
| 413 |
+
"IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL",
|
| 414 |
+
"IMAGE_METHODS",
|
| 415 |
+
"resolve_image_generation_method",
|
| 416 |
+
]
|
app/services/grok/imagine_generation.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Shared helpers for experimental imagine generation flows.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import asyncio
|
| 8 |
+
from typing import Any, Awaitable, Callable, List, Optional
|
| 9 |
+
|
| 10 |
+
from app.core.exceptions import UpstreamException
|
| 11 |
+
from app.core.logger import logger
|
| 12 |
+
from app.services.grok.imagine_experimental import ImagineExperimentalService
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def resolve_aspect_ratio(size: Optional[str]) -> str:
|
| 16 |
+
value = str(size or "").strip().lower()
|
| 17 |
+
if value in {"16:9", "9:16", "1:1", "2:3", "3:2"}:
|
| 18 |
+
return value
|
| 19 |
+
|
| 20 |
+
mapping = {
|
| 21 |
+
"1024x1024": "1:1",
|
| 22 |
+
"512x512": "1:1",
|
| 23 |
+
"1024x576": "16:9",
|
| 24 |
+
"1280x720": "16:9",
|
| 25 |
+
"1536x864": "16:9",
|
| 26 |
+
"576x1024": "9:16",
|
| 27 |
+
"720x1280": "9:16",
|
| 28 |
+
"864x1536": "9:16",
|
| 29 |
+
"1024x1536": "2:3",
|
| 30 |
+
"1024x1792": "2:3",
|
| 31 |
+
"512x768": "2:3",
|
| 32 |
+
"768x1024": "2:3",
|
| 33 |
+
"1536x1024": "3:2",
|
| 34 |
+
"1792x1024": "3:2",
|
| 35 |
+
"768x512": "3:2",
|
| 36 |
+
"1024x768": "3:2",
|
| 37 |
+
}
|
| 38 |
+
return mapping.get(value, "2:3")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def is_valid_image_value(value: Any) -> bool:
|
| 42 |
+
return isinstance(value, str) and bool(value) and value != "error"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def dedupe_images(images: List[str]) -> List[str]:
|
| 46 |
+
out: List[str] = []
|
| 47 |
+
seen: set[str] = set()
|
| 48 |
+
for image in images:
|
| 49 |
+
if not isinstance(image, str):
|
| 50 |
+
continue
|
| 51 |
+
if image in seen:
|
| 52 |
+
continue
|
| 53 |
+
seen.add(image)
|
| 54 |
+
out.append(image)
|
| 55 |
+
return out
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
async def gather_limited(
|
| 59 |
+
task_factories: List[Callable[[], Awaitable[List[str]]]],
|
| 60 |
+
max_concurrency: int,
|
| 61 |
+
) -> List[Any]:
|
| 62 |
+
sem = asyncio.Semaphore(max(1, int(max_concurrency or 1)))
|
| 63 |
+
|
| 64 |
+
async def _run(factory: Callable[[], Awaitable[List[str]]]) -> Any:
|
| 65 |
+
async with sem:
|
| 66 |
+
return await factory()
|
| 67 |
+
|
| 68 |
+
return await asyncio.gather(*[_run(factory) for factory in task_factories], return_exceptions=True)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
async def call_experimental_generation_once(
|
| 72 |
+
token: str,
|
| 73 |
+
prompt: str,
|
| 74 |
+
response_format: str = "b64_json",
|
| 75 |
+
n: int = 4,
|
| 76 |
+
aspect_ratio: str = "2:3",
|
| 77 |
+
) -> List[str]:
|
| 78 |
+
service = ImagineExperimentalService()
|
| 79 |
+
raw_urls = await service.generate_ws(
|
| 80 |
+
token=token,
|
| 81 |
+
prompt=prompt,
|
| 82 |
+
n=n,
|
| 83 |
+
aspect_ratio=aspect_ratio,
|
| 84 |
+
)
|
| 85 |
+
return await service.convert_urls(token=token, urls=raw_urls, response_format=response_format)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
async def collect_experimental_generation_images(
|
| 89 |
+
token: str,
|
| 90 |
+
prompt: str,
|
| 91 |
+
n: int,
|
| 92 |
+
response_format: str,
|
| 93 |
+
aspect_ratio: str,
|
| 94 |
+
concurrency: int,
|
| 95 |
+
) -> List[str]:
|
| 96 |
+
calls_needed = max(1, (n + 3) // 4)
|
| 97 |
+
task_factories: List[Callable[[], Awaitable[List[str]]]] = []
|
| 98 |
+
remain = n
|
| 99 |
+
for _ in range(calls_needed):
|
| 100 |
+
target_n = max(1, min(4, remain))
|
| 101 |
+
remain -= target_n
|
| 102 |
+
task_factories.append(
|
| 103 |
+
lambda target_n=target_n: call_experimental_generation_once(
|
| 104 |
+
token,
|
| 105 |
+
prompt,
|
| 106 |
+
response_format=response_format,
|
| 107 |
+
n=target_n,
|
| 108 |
+
aspect_ratio=aspect_ratio,
|
| 109 |
+
)
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
results = await gather_limited(
|
| 113 |
+
task_factories,
|
| 114 |
+
max_concurrency=min(calls_needed, max(1, int(concurrency or 1))),
|
| 115 |
+
)
|
| 116 |
+
all_images: List[str] = []
|
| 117 |
+
for result in results:
|
| 118 |
+
if isinstance(result, Exception):
|
| 119 |
+
logger.warning(f"Experimental imagine websocket call failed: {result}")
|
| 120 |
+
continue
|
| 121 |
+
if isinstance(result, list):
|
| 122 |
+
all_images.extend(result)
|
| 123 |
+
|
| 124 |
+
all_images = dedupe_images(all_images)
|
| 125 |
+
if not any(is_valid_image_value(item) for item in all_images):
|
| 126 |
+
raise UpstreamException("Experimental imagine websocket returned no images")
|
| 127 |
+
return all_images
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
__all__ = [
|
| 131 |
+
"resolve_aspect_ratio",
|
| 132 |
+
"is_valid_image_value",
|
| 133 |
+
"dedupe_images",
|
| 134 |
+
"gather_limited",
|
| 135 |
+
"call_experimental_generation_once",
|
| 136 |
+
"collect_experimental_generation_images",
|
| 137 |
+
]
|
app/services/grok/media.py
ADDED
|
@@ -0,0 +1,512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Grok 视频生成服务
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import uuid
|
| 7 |
+
from typing import AsyncGenerator, Optional
|
| 8 |
+
|
| 9 |
+
import orjson
|
| 10 |
+
from curl_cffi.requests import AsyncSession
|
| 11 |
+
|
| 12 |
+
from app.core.logger import logger
|
| 13 |
+
from app.core.config import get_config
|
| 14 |
+
from app.core.exceptions import UpstreamException, AppException, ValidationException, ErrorType
|
| 15 |
+
from app.services.grok.statsig import StatsigService
|
| 16 |
+
from app.services.grok.model import ModelService
|
| 17 |
+
from app.services.token import get_token_manager
|
| 18 |
+
from app.services.grok.processor import VideoStreamProcessor, VideoCollectProcessor
|
| 19 |
+
from app.services.request_stats import request_stats
|
| 20 |
+
|
| 21 |
+
# API 端点
|
| 22 |
+
CREATE_POST_API = "https://grok.com/rest/media/post/create"
|
| 23 |
+
CHAT_API = "https://grok.com/rest/app-chat/conversations/new"
|
| 24 |
+
|
| 25 |
+
# 常量
|
| 26 |
+
BROWSER = "chrome136"
|
| 27 |
+
TIMEOUT = 300
|
| 28 |
+
DEFAULT_MAX_CONCURRENT = 50
|
| 29 |
+
_MEDIA_SEMAPHORE = asyncio.Semaphore(DEFAULT_MAX_CONCURRENT)
|
| 30 |
+
_MEDIA_SEM_VALUE = DEFAULT_MAX_CONCURRENT
|
| 31 |
+
|
| 32 |
+
def _get_media_semaphore() -> asyncio.Semaphore:
|
| 33 |
+
global _MEDIA_SEMAPHORE, _MEDIA_SEM_VALUE
|
| 34 |
+
value = get_config("performance.media_max_concurrent", DEFAULT_MAX_CONCURRENT)
|
| 35 |
+
try:
|
| 36 |
+
value = int(value)
|
| 37 |
+
except Exception:
|
| 38 |
+
value = DEFAULT_MAX_CONCURRENT
|
| 39 |
+
value = max(1, value)
|
| 40 |
+
if value != _MEDIA_SEM_VALUE:
|
| 41 |
+
_MEDIA_SEM_VALUE = value
|
| 42 |
+
_MEDIA_SEMAPHORE = asyncio.Semaphore(value)
|
| 43 |
+
return _MEDIA_SEMAPHORE
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class VideoService:
|
| 47 |
+
"""视频生成服务"""
|
| 48 |
+
|
| 49 |
+
def __init__(self, proxy: str = None):
|
| 50 |
+
self.proxy = proxy or get_config("grok.base_proxy_url", "")
|
| 51 |
+
self.timeout = get_config("grok.timeout", TIMEOUT)
|
| 52 |
+
|
| 53 |
+
def _build_headers(self, token: str, referer: str = "https://grok.com/imagine") -> dict:
|
| 54 |
+
"""构建请求头"""
|
| 55 |
+
headers = {
|
| 56 |
+
"Accept": "*/*",
|
| 57 |
+
"Accept-Encoding": "gzip, deflate, br, zstd",
|
| 58 |
+
"Accept-Language": "zh-CN,zh;q=0.9",
|
| 59 |
+
"Baggage": "sentry-environment=production,sentry-release=d6add6fb0460641fd482d767a335ef72b9b6abb8,sentry-public_key=b311e0f2690c81f25e2c4cf6d4f7ce1c",
|
| 60 |
+
"Cache-Control": "no-cache",
|
| 61 |
+
"Content-Type": "application/json",
|
| 62 |
+
"Origin": "https://grok.com",
|
| 63 |
+
"Pragma": "no-cache",
|
| 64 |
+
"Priority": "u=1, i",
|
| 65 |
+
"Referer": referer,
|
| 66 |
+
"Sec-Ch-Ua": '"Google Chrome";v="136", "Chromium";v="136", "Not(A:Brand";v="24"',
|
| 67 |
+
"Sec-Ch-Ua-Arch": "arm",
|
| 68 |
+
"Sec-Ch-Ua-Bitness": "64",
|
| 69 |
+
"Sec-Ch-Ua-Mobile": "?0",
|
| 70 |
+
"Sec-Ch-Ua-Model": "",
|
| 71 |
+
"Sec-Ch-Ua-Platform": '"macOS"',
|
| 72 |
+
"Sec-Fetch-Dest": "empty",
|
| 73 |
+
"Sec-Fetch-Mode": "cors",
|
| 74 |
+
"Sec-Fetch-Site": "same-origin",
|
| 75 |
+
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36",
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
# Statsig ID
|
| 79 |
+
headers["x-statsig-id"] = StatsigService.gen_id()
|
| 80 |
+
headers["x-xai-request-id"] = str(uuid.uuid4())
|
| 81 |
+
|
| 82 |
+
# Cookie
|
| 83 |
+
token = token[4:] if token.startswith("sso=") else token
|
| 84 |
+
cf = get_config("grok.cf_clearance", "")
|
| 85 |
+
headers["Cookie"] = f"sso={token};cf_clearance={cf}" if cf else f"sso={token}"
|
| 86 |
+
|
| 87 |
+
return headers
|
| 88 |
+
|
| 89 |
+
def _build_proxies(self) -> Optional[dict]:
|
| 90 |
+
"""构建代理"""
|
| 91 |
+
return {"http": self.proxy, "https": self.proxy} if self.proxy else None
|
| 92 |
+
|
| 93 |
+
async def create_post(self, token: str, prompt: str, media_type: str = "MEDIA_POST_TYPE_VIDEO", media_url: str = None) -> str:
|
| 94 |
+
"""
|
| 95 |
+
创建媒体帖子
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
token: 认证 Token
|
| 99 |
+
prompt: 提示词(视频生成用)
|
| 100 |
+
media_type: 媒体类型 (MEDIA_POST_TYPE_VIDEO 或 MEDIA_POST_TYPE_IMAGE)
|
| 101 |
+
media_url: 媒体 URL(图片模式用)
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
post ID
|
| 105 |
+
"""
|
| 106 |
+
try:
|
| 107 |
+
headers = self._build_headers(token)
|
| 108 |
+
|
| 109 |
+
# 根据类型构建不同的载荷
|
| 110 |
+
if media_type == "MEDIA_POST_TYPE_IMAGE" and media_url:
|
| 111 |
+
payload = {
|
| 112 |
+
"mediaType": media_type,
|
| 113 |
+
"mediaUrl": media_url
|
| 114 |
+
}
|
| 115 |
+
else:
|
| 116 |
+
payload = {
|
| 117 |
+
"mediaType": media_type,
|
| 118 |
+
"prompt": prompt
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
async with AsyncSession() as session:
|
| 122 |
+
response = await session.post(
|
| 123 |
+
CREATE_POST_API,
|
| 124 |
+
headers=headers,
|
| 125 |
+
json=payload,
|
| 126 |
+
impersonate=BROWSER,
|
| 127 |
+
timeout=30,
|
| 128 |
+
proxies=self._build_proxies()
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
if response.status_code != 200:
|
| 132 |
+
logger.error(f"Create post failed: {response.status_code}")
|
| 133 |
+
raise UpstreamException(f"Failed to create post: {response.status_code}")
|
| 134 |
+
|
| 135 |
+
data = response.json()
|
| 136 |
+
post_id = data.get("post", {}).get("id", "")
|
| 137 |
+
|
| 138 |
+
if not post_id:
|
| 139 |
+
raise UpstreamException("No post ID in response")
|
| 140 |
+
|
| 141 |
+
logger.info(f"Media post created: {post_id} (type={media_type})")
|
| 142 |
+
return post_id
|
| 143 |
+
|
| 144 |
+
except Exception as e:
|
| 145 |
+
logger.error(f"Create post error: {e}")
|
| 146 |
+
if isinstance(e, AppException):
|
| 147 |
+
raise e
|
| 148 |
+
raise UpstreamException(f"Create post error: {str(e)}")
|
| 149 |
+
|
| 150 |
+
async def create_image_post(self, token: str, image_url: str) -> str:
|
| 151 |
+
"""
|
| 152 |
+
创建图片帖子
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
token: 认证 Token
|
| 156 |
+
image_url: 完整的图片 URL (https://assets.grok.com/...)
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
post ID
|
| 160 |
+
"""
|
| 161 |
+
return await self.create_post(
|
| 162 |
+
token,
|
| 163 |
+
prompt="",
|
| 164 |
+
media_type="MEDIA_POST_TYPE_IMAGE",
|
| 165 |
+
media_url=image_url
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
def _build_payload(
|
| 169 |
+
self,
|
| 170 |
+
prompt: str,
|
| 171 |
+
post_id: str,
|
| 172 |
+
aspect_ratio: str = "3:2",
|
| 173 |
+
video_length: int = 6,
|
| 174 |
+
resolution: str = "SD",
|
| 175 |
+
preset: str = "normal"
|
| 176 |
+
) -> dict:
|
| 177 |
+
"""构建视频生成载荷"""
|
| 178 |
+
mode_flag = "--mode=custom"
|
| 179 |
+
if preset == "fun":
|
| 180 |
+
mode_flag = "--mode=extremely-crazy"
|
| 181 |
+
elif preset == "normal":
|
| 182 |
+
mode_flag = "--mode=normal"
|
| 183 |
+
elif preset == "spicy":
|
| 184 |
+
mode_flag = "--mode=extremely-spicy-or-crazy"
|
| 185 |
+
|
| 186 |
+
full_prompt = f"{prompt} {mode_flag}"
|
| 187 |
+
|
| 188 |
+
return {
|
| 189 |
+
"temporary": True,
|
| 190 |
+
"modelName": "grok-3",
|
| 191 |
+
"message": full_prompt,
|
| 192 |
+
"toolOverrides": {"videoGen": True},
|
| 193 |
+
"enableSideBySide": True,
|
| 194 |
+
"responseMetadata": {
|
| 195 |
+
"experiments": [],
|
| 196 |
+
"modelConfigOverride": {
|
| 197 |
+
"modelMap": {
|
| 198 |
+
"videoGenModelConfig": {
|
| 199 |
+
"parentPostId": post_id,
|
| 200 |
+
"aspectRatio": aspect_ratio,
|
| 201 |
+
"videoLength": video_length,
|
| 202 |
+
"videoResolution": resolution
|
| 203 |
+
}
|
| 204 |
+
}
|
| 205 |
+
}
|
| 206 |
+
}
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
async def generate(
|
| 210 |
+
self,
|
| 211 |
+
token: str,
|
| 212 |
+
prompt: str,
|
| 213 |
+
aspect_ratio: str = "3:2",
|
| 214 |
+
video_length: int = 6,
|
| 215 |
+
resolution: str = "SD",
|
| 216 |
+
stream: bool = True,
|
| 217 |
+
preset: str = "normal"
|
| 218 |
+
) -> AsyncGenerator[bytes, None]:
|
| 219 |
+
"""
|
| 220 |
+
生成视频
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
token: 认证 Token
|
| 224 |
+
prompt: 视频描述
|
| 225 |
+
aspect_ratio: 宽高比
|
| 226 |
+
video_length: 视频时长
|
| 227 |
+
resolution: 分辨率
|
| 228 |
+
stream: 是否流式
|
| 229 |
+
preset: 预设
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
AsyncGenerator,流式传输
|
| 233 |
+
|
| 234 |
+
Raises:
|
| 235 |
+
UpstreamException: 连接失败时
|
| 236 |
+
"""
|
| 237 |
+
async with _get_media_semaphore():
|
| 238 |
+
session = None
|
| 239 |
+
try:
|
| 240 |
+
# Step 1: 创建帖子
|
| 241 |
+
post_id = await self.create_post(token, prompt)
|
| 242 |
+
|
| 243 |
+
# Step 2: 建立连接
|
| 244 |
+
headers = self._build_headers(token)
|
| 245 |
+
payload = self._build_payload(prompt, post_id, aspect_ratio, video_length, resolution, preset)
|
| 246 |
+
|
| 247 |
+
session = AsyncSession(impersonate=BROWSER)
|
| 248 |
+
response = await session.post(
|
| 249 |
+
CHAT_API,
|
| 250 |
+
headers=headers,
|
| 251 |
+
data=orjson.dumps(payload),
|
| 252 |
+
timeout=self.timeout,
|
| 253 |
+
stream=True,
|
| 254 |
+
proxies=self._build_proxies()
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
if response.status_code != 200:
|
| 258 |
+
logger.error(f"Video generation failed: {response.status_code}")
|
| 259 |
+
try:
|
| 260 |
+
await session.close()
|
| 261 |
+
except:
|
| 262 |
+
pass
|
| 263 |
+
raise UpstreamException(
|
| 264 |
+
message=f"Video generation failed: {response.status_code}",
|
| 265 |
+
details={"status": response.status_code}
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# Step 3: 流式传输
|
| 269 |
+
async def stream_response():
|
| 270 |
+
try:
|
| 271 |
+
async for line in response.aiter_lines():
|
| 272 |
+
yield line
|
| 273 |
+
finally:
|
| 274 |
+
if session:
|
| 275 |
+
await session.close()
|
| 276 |
+
|
| 277 |
+
return stream_response()
|
| 278 |
+
|
| 279 |
+
except Exception as e:
|
| 280 |
+
if session:
|
| 281 |
+
try:
|
| 282 |
+
await session.close()
|
| 283 |
+
except:
|
| 284 |
+
pass
|
| 285 |
+
logger.error(f"Video generation error: {e}")
|
| 286 |
+
if isinstance(e, AppException):
|
| 287 |
+
raise e
|
| 288 |
+
raise UpstreamException(f"Video generation error: {str(e)}")
|
| 289 |
+
|
| 290 |
+
async def generate_from_image(
|
| 291 |
+
self,
|
| 292 |
+
token: str,
|
| 293 |
+
prompt: str,
|
| 294 |
+
image_url: str,
|
| 295 |
+
aspect_ratio: str = "3:2",
|
| 296 |
+
video_length: int = 6,
|
| 297 |
+
resolution: str = "SD",
|
| 298 |
+
stream: bool = True,
|
| 299 |
+
preset: str = "normal"
|
| 300 |
+
) -> AsyncGenerator[bytes, None]:
|
| 301 |
+
"""
|
| 302 |
+
从图片生成视频
|
| 303 |
+
|
| 304 |
+
Args:
|
| 305 |
+
token: 认证 Token
|
| 306 |
+
prompt: 视频描述
|
| 307 |
+
image_url: 图片 URL
|
| 308 |
+
aspect_ratio: 宽高比
|
| 309 |
+
video_length: 视频时长
|
| 310 |
+
resolution: 分辨率
|
| 311 |
+
stream: 是否流式
|
| 312 |
+
preset: 预设
|
| 313 |
+
|
| 314 |
+
Returns:
|
| 315 |
+
AsyncGenerator,流式传输
|
| 316 |
+
"""
|
| 317 |
+
async with _get_media_semaphore():
|
| 318 |
+
session = None
|
| 319 |
+
try:
|
| 320 |
+
# Step 1: 创建帖子
|
| 321 |
+
post_id = await self.create_image_post(token, image_url)
|
| 322 |
+
|
| 323 |
+
# Step 2: 建立连接
|
| 324 |
+
headers = self._build_headers(token)
|
| 325 |
+
payload = self._build_payload(prompt, post_id, aspect_ratio, video_length, resolution, preset)
|
| 326 |
+
|
| 327 |
+
session = AsyncSession(impersonate=BROWSER)
|
| 328 |
+
response = await session.post(
|
| 329 |
+
CHAT_API,
|
| 330 |
+
headers=headers,
|
| 331 |
+
data=orjson.dumps(payload),
|
| 332 |
+
timeout=self.timeout,
|
| 333 |
+
stream=True,
|
| 334 |
+
proxies=self._build_proxies()
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
if response.status_code != 200:
|
| 338 |
+
logger.error(f"Video from image failed: {response.status_code}")
|
| 339 |
+
try:
|
| 340 |
+
await session.close()
|
| 341 |
+
except:
|
| 342 |
+
pass
|
| 343 |
+
raise UpstreamException(
|
| 344 |
+
message=f"Video from image failed: {response.status_code}",
|
| 345 |
+
details={"status": response.status_code}
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
# Step 3: 流式传输
|
| 349 |
+
async def stream_response():
|
| 350 |
+
try:
|
| 351 |
+
async for line in response.aiter_lines():
|
| 352 |
+
yield line
|
| 353 |
+
finally:
|
| 354 |
+
if session:
|
| 355 |
+
await session.close()
|
| 356 |
+
|
| 357 |
+
return stream_response()
|
| 358 |
+
|
| 359 |
+
except Exception as e:
|
| 360 |
+
if session:
|
| 361 |
+
try:
|
| 362 |
+
await session.close()
|
| 363 |
+
except:
|
| 364 |
+
pass
|
| 365 |
+
logger.error(f"Video from image error: {e}")
|
| 366 |
+
if isinstance(e, AppException):
|
| 367 |
+
raise e
|
| 368 |
+
raise UpstreamException(f"Video from image error: {str(e)}")
|
| 369 |
+
|
| 370 |
+
@staticmethod
|
| 371 |
+
async def completions(
|
| 372 |
+
model: str,
|
| 373 |
+
messages: list,
|
| 374 |
+
stream: bool = None,
|
| 375 |
+
thinking: str = None,
|
| 376 |
+
aspect_ratio: str = "3:2",
|
| 377 |
+
video_length: int = 6,
|
| 378 |
+
resolution: str = "SD",
|
| 379 |
+
preset: str = "normal"
|
| 380 |
+
):
|
| 381 |
+
"""
|
| 382 |
+
视频生成入口
|
| 383 |
+
|
| 384 |
+
Args:
|
| 385 |
+
model: 模型名称
|
| 386 |
+
messages: 消息列表
|
| 387 |
+
stream: 是否流式
|
| 388 |
+
thinking: 思考模式
|
| 389 |
+
aspect_ratio: 宽高比
|
| 390 |
+
video_length: 视频时长
|
| 391 |
+
resolution: 分辨率
|
| 392 |
+
preset: 预设模式
|
| 393 |
+
|
| 394 |
+
Returns:
|
| 395 |
+
AsyncGenerator (流式) 或 dict (非流式)
|
| 396 |
+
"""
|
| 397 |
+
# 获取 token
|
| 398 |
+
try:
|
| 399 |
+
token_mgr = await get_token_manager()
|
| 400 |
+
await token_mgr.reload_if_stale()
|
| 401 |
+
token = token_mgr.get_token_for_model(model)
|
| 402 |
+
except Exception as e:
|
| 403 |
+
logger.error(f"Failed to get token: {e}")
|
| 404 |
+
try:
|
| 405 |
+
await request_stats.record_request(model, success=False)
|
| 406 |
+
except Exception:
|
| 407 |
+
pass
|
| 408 |
+
raise AppException(
|
| 409 |
+
message="Internal service error obtaining token",
|
| 410 |
+
error_type=ErrorType.SERVER.value,
|
| 411 |
+
code="internal_error"
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
if not token:
|
| 415 |
+
try:
|
| 416 |
+
await request_stats.record_request(model, success=False)
|
| 417 |
+
except Exception:
|
| 418 |
+
pass
|
| 419 |
+
raise AppException(
|
| 420 |
+
message="No available tokens. Please try again later.",
|
| 421 |
+
error_type=ErrorType.RATE_LIMIT.value,
|
| 422 |
+
code="rate_limit_exceeded",
|
| 423 |
+
status_code=429
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
# 解析参数
|
| 427 |
+
think = None
|
| 428 |
+
if thinking == "enabled":
|
| 429 |
+
think = True
|
| 430 |
+
elif thinking == "disabled":
|
| 431 |
+
think = False
|
| 432 |
+
|
| 433 |
+
is_stream = stream if stream is not None else get_config("grok.stream", True)
|
| 434 |
+
|
| 435 |
+
# 提取内容
|
| 436 |
+
from app.services.grok.chat import MessageExtractor
|
| 437 |
+
from app.services.grok.assets import UploadService
|
| 438 |
+
|
| 439 |
+
try:
|
| 440 |
+
prompt, attachments = MessageExtractor.extract(messages, is_video=True)
|
| 441 |
+
except ValueError as e:
|
| 442 |
+
raise ValidationException(str(e))
|
| 443 |
+
|
| 444 |
+
# 处理图片附件
|
| 445 |
+
image_url = None
|
| 446 |
+
if attachments:
|
| 447 |
+
upload_service = UploadService()
|
| 448 |
+
try:
|
| 449 |
+
for attach_type, attach_data in attachments:
|
| 450 |
+
if attach_type == "image":
|
| 451 |
+
# 上传图片
|
| 452 |
+
_, file_uri = await upload_service.upload(attach_data, token)
|
| 453 |
+
image_url = f"https://assets.grok.com/{file_uri}"
|
| 454 |
+
logger.info(f"Image uploaded for video: {image_url}")
|
| 455 |
+
break # 视频模型只使用第一张图片
|
| 456 |
+
finally:
|
| 457 |
+
await upload_service.close()
|
| 458 |
+
|
| 459 |
+
# 生成视频
|
| 460 |
+
service = VideoService()
|
| 461 |
+
|
| 462 |
+
try:
|
| 463 |
+
# 图片转视频
|
| 464 |
+
if image_url:
|
| 465 |
+
response = await service.generate_from_image(
|
| 466 |
+
token, prompt, image_url,
|
| 467 |
+
aspect_ratio, video_length, resolution, stream, preset
|
| 468 |
+
)
|
| 469 |
+
else:
|
| 470 |
+
response = await service.generate(
|
| 471 |
+
token, prompt,
|
| 472 |
+
aspect_ratio, video_length, resolution, stream, preset
|
| 473 |
+
)
|
| 474 |
+
except Exception:
|
| 475 |
+
try:
|
| 476 |
+
await request_stats.record_request(model, success=False)
|
| 477 |
+
except Exception:
|
| 478 |
+
pass
|
| 479 |
+
raise
|
| 480 |
+
|
| 481 |
+
# 处理响应
|
| 482 |
+
if is_stream:
|
| 483 |
+
processor = VideoStreamProcessor(model, token, think).process(response)
|
| 484 |
+
|
| 485 |
+
async def _wrapped_stream():
|
| 486 |
+
completed = False
|
| 487 |
+
try:
|
| 488 |
+
async for chunk in processor:
|
| 489 |
+
yield chunk
|
| 490 |
+
completed = True
|
| 491 |
+
finally:
|
| 492 |
+
try:
|
| 493 |
+
if completed:
|
| 494 |
+
await token_mgr.sync_usage(token, model, consume_on_fail=True, is_usage=True)
|
| 495 |
+
await request_stats.record_request(model, success=True)
|
| 496 |
+
else:
|
| 497 |
+
await request_stats.record_request(model, success=False)
|
| 498 |
+
except Exception:
|
| 499 |
+
pass
|
| 500 |
+
|
| 501 |
+
return _wrapped_stream()
|
| 502 |
+
|
| 503 |
+
result = await VideoCollectProcessor(model, token).process(response)
|
| 504 |
+
try:
|
| 505 |
+
await token_mgr.sync_usage(token, model, consume_on_fail=True, is_usage=True)
|
| 506 |
+
await request_stats.record_request(model, success=True)
|
| 507 |
+
except Exception:
|
| 508 |
+
pass
|
| 509 |
+
return result
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
__all__ = ["VideoService"]
|
app/services/grok/model.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Grok 模型管理服务
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
from enum import Enum
|
| 8 |
+
from typing import Optional, Tuple
|
| 9 |
+
from pydantic import BaseModel, Field
|
| 10 |
+
|
| 11 |
+
from app.core.exceptions import ValidationException
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Tier(str, Enum):
|
| 15 |
+
"""模型档位"""
|
| 16 |
+
BASIC = "basic"
|
| 17 |
+
SUPER = "super"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Cost(str, Enum):
|
| 21 |
+
"""计费类型"""
|
| 22 |
+
LOW = "low"
|
| 23 |
+
HIGH = "high"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ModelInfo(BaseModel):
|
| 27 |
+
"""模型信息"""
|
| 28 |
+
model_id: str
|
| 29 |
+
grok_model: str
|
| 30 |
+
rate_limit_model: str
|
| 31 |
+
model_mode: str
|
| 32 |
+
tier: Tier = Field(default=Tier.BASIC)
|
| 33 |
+
cost: Cost = Field(default=Cost.LOW)
|
| 34 |
+
display_name: str
|
| 35 |
+
description: str = ""
|
| 36 |
+
is_video: bool = False
|
| 37 |
+
is_image: bool = False
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ModelService:
|
| 41 |
+
"""模型管理服务"""
|
| 42 |
+
|
| 43 |
+
MODELS = [
|
| 44 |
+
ModelInfo(
|
| 45 |
+
model_id="grok-3",
|
| 46 |
+
grok_model="grok-3",
|
| 47 |
+
rate_limit_model="grok-3",
|
| 48 |
+
model_mode="MODEL_MODE_GROK_3",
|
| 49 |
+
cost=Cost.LOW,
|
| 50 |
+
display_name="Grok 3"
|
| 51 |
+
),
|
| 52 |
+
ModelInfo(
|
| 53 |
+
model_id="grok-3-mini",
|
| 54 |
+
grok_model="grok-3",
|
| 55 |
+
rate_limit_model="grok-3",
|
| 56 |
+
model_mode="MODEL_MODE_GROK_3_MINI_THINKING",
|
| 57 |
+
cost=Cost.LOW,
|
| 58 |
+
display_name="Grok 3 Mini"
|
| 59 |
+
),
|
| 60 |
+
ModelInfo(
|
| 61 |
+
model_id="grok-3-thinking",
|
| 62 |
+
grok_model="grok-3",
|
| 63 |
+
rate_limit_model="grok-3",
|
| 64 |
+
model_mode="MODEL_MODE_GROK_3_THINKING",
|
| 65 |
+
cost=Cost.LOW,
|
| 66 |
+
display_name="Grok 3 Thinking"
|
| 67 |
+
),
|
| 68 |
+
ModelInfo(
|
| 69 |
+
model_id="grok-4",
|
| 70 |
+
grok_model="grok-4",
|
| 71 |
+
rate_limit_model="grok-4",
|
| 72 |
+
model_mode="MODEL_MODE_GROK_4",
|
| 73 |
+
cost=Cost.LOW,
|
| 74 |
+
display_name="Grok 4"
|
| 75 |
+
),
|
| 76 |
+
ModelInfo(
|
| 77 |
+
model_id="grok-4-mini",
|
| 78 |
+
grok_model="grok-4-mini",
|
| 79 |
+
rate_limit_model="grok-4-mini",
|
| 80 |
+
model_mode="MODEL_MODE_GROK_4_MINI_THINKING",
|
| 81 |
+
cost=Cost.LOW,
|
| 82 |
+
display_name="Grok 4 Mini"
|
| 83 |
+
),
|
| 84 |
+
ModelInfo(
|
| 85 |
+
model_id="grok-4-thinking",
|
| 86 |
+
grok_model="grok-4",
|
| 87 |
+
rate_limit_model="grok-4",
|
| 88 |
+
model_mode="MODEL_MODE_GROK_4_THINKING",
|
| 89 |
+
cost=Cost.LOW,
|
| 90 |
+
display_name="Grok 4 Thinking"
|
| 91 |
+
),
|
| 92 |
+
ModelInfo(
|
| 93 |
+
model_id="grok-4-heavy",
|
| 94 |
+
grok_model="grok-4",
|
| 95 |
+
rate_limit_model="grok-4-heavy",
|
| 96 |
+
model_mode="MODEL_MODE_HEAVY",
|
| 97 |
+
cost=Cost.HIGH,
|
| 98 |
+
tier=Tier.SUPER,
|
| 99 |
+
display_name="Grok 4 Heavy"
|
| 100 |
+
),
|
| 101 |
+
ModelInfo(
|
| 102 |
+
model_id="grok-4.1-mini",
|
| 103 |
+
grok_model="grok-4-1-thinking-1129",
|
| 104 |
+
rate_limit_model="grok-4-1-thinking-1129",
|
| 105 |
+
model_mode="MODEL_MODE_GROK_4_1_MINI_THINKING",
|
| 106 |
+
cost=Cost.LOW,
|
| 107 |
+
display_name="Grok 4.1 Mini"
|
| 108 |
+
),
|
| 109 |
+
ModelInfo(
|
| 110 |
+
model_id="grok-4.1-fast",
|
| 111 |
+
grok_model="grok-4-1-thinking-1129",
|
| 112 |
+
rate_limit_model="grok-4-1-thinking-1129",
|
| 113 |
+
model_mode="MODEL_MODE_FAST",
|
| 114 |
+
cost=Cost.LOW,
|
| 115 |
+
display_name="Grok 4.1 Fast"
|
| 116 |
+
),
|
| 117 |
+
ModelInfo(
|
| 118 |
+
model_id="grok-4.1-expert",
|
| 119 |
+
grok_model="grok-4-1-thinking-1129",
|
| 120 |
+
rate_limit_model="grok-4-1-thinking-1129",
|
| 121 |
+
model_mode="MODEL_MODE_EXPERT",
|
| 122 |
+
cost=Cost.HIGH,
|
| 123 |
+
display_name="Grok 4.1 Expert"
|
| 124 |
+
),
|
| 125 |
+
ModelInfo(
|
| 126 |
+
model_id="grok-4.1-thinking",
|
| 127 |
+
grok_model="grok-4-1-thinking-1129",
|
| 128 |
+
rate_limit_model="grok-4-1-thinking-1129",
|
| 129 |
+
model_mode="MODEL_MODE_GROK_4_1_THINKING",
|
| 130 |
+
cost=Cost.HIGH,
|
| 131 |
+
display_name="Grok 4.1 Thinking"
|
| 132 |
+
),
|
| 133 |
+
ModelInfo(
|
| 134 |
+
model_id="grok-4.20-beta",
|
| 135 |
+
grok_model="grok-420",
|
| 136 |
+
rate_limit_model="grok-420",
|
| 137 |
+
model_mode="MODEL_MODE_GROK_420",
|
| 138 |
+
cost=Cost.LOW,
|
| 139 |
+
display_name="Grok 4.20 Beta"
|
| 140 |
+
),
|
| 141 |
+
ModelInfo(
|
| 142 |
+
model_id="grok-imagine-1.0",
|
| 143 |
+
grok_model="grok-3",
|
| 144 |
+
rate_limit_model="grok-3",
|
| 145 |
+
model_mode="MODEL_MODE_FAST",
|
| 146 |
+
cost=Cost.HIGH,
|
| 147 |
+
display_name="Grok Image",
|
| 148 |
+
description="Image generation model",
|
| 149 |
+
is_image=True
|
| 150 |
+
),
|
| 151 |
+
ModelInfo(
|
| 152 |
+
model_id="grok-imagine-1.0-edit",
|
| 153 |
+
grok_model="imagine-image-edit",
|
| 154 |
+
rate_limit_model="grok-3",
|
| 155 |
+
model_mode="MODEL_MODE_FAST",
|
| 156 |
+
cost=Cost.HIGH,
|
| 157 |
+
display_name="Grok Image Edit",
|
| 158 |
+
description="Image edit model",
|
| 159 |
+
is_image=True
|
| 160 |
+
),
|
| 161 |
+
ModelInfo(
|
| 162 |
+
model_id="grok-imagine-1.0-video",
|
| 163 |
+
grok_model="grok-3",
|
| 164 |
+
rate_limit_model="grok-3",
|
| 165 |
+
model_mode="MODEL_MODE_FAST",
|
| 166 |
+
cost=Cost.HIGH,
|
| 167 |
+
display_name="Grok Video",
|
| 168 |
+
description="Video generation model",
|
| 169 |
+
is_video=True
|
| 170 |
+
),
|
| 171 |
+
]
|
| 172 |
+
|
| 173 |
+
_map = {m.model_id: m for m in MODELS}
|
| 174 |
+
|
| 175 |
+
@classmethod
|
| 176 |
+
def get(cls, model_id: str) -> Optional[ModelInfo]:
|
| 177 |
+
"""获取模型信息"""
|
| 178 |
+
return cls._map.get(model_id)
|
| 179 |
+
|
| 180 |
+
@classmethod
|
| 181 |
+
def list(cls) -> list[ModelInfo]:
|
| 182 |
+
"""获取所有模型"""
|
| 183 |
+
return list(cls._map.values())
|
| 184 |
+
|
| 185 |
+
@classmethod
|
| 186 |
+
def valid(cls, model_id: str) -> bool:
|
| 187 |
+
"""模型是否有效"""
|
| 188 |
+
return model_id in cls._map
|
| 189 |
+
|
| 190 |
+
@classmethod
|
| 191 |
+
def to_grok(cls, model_id: str) -> Tuple[str, str]:
|
| 192 |
+
"""转换为 Grok 参数"""
|
| 193 |
+
model = cls.get(model_id)
|
| 194 |
+
if not model:
|
| 195 |
+
raise ValidationException(f"Invalid model ID: {model_id}")
|
| 196 |
+
return model.grok_model, model.model_mode
|
| 197 |
+
|
| 198 |
+
@classmethod
|
| 199 |
+
def rate_limit_model_for(cls, model_id: str) -> str:
|
| 200 |
+
"""用于 /rest/rate-limits 的 modelName 映射。"""
|
| 201 |
+
model = cls.get(model_id)
|
| 202 |
+
return model.rate_limit_model if model else model_id
|
| 203 |
+
|
| 204 |
+
@classmethod
|
| 205 |
+
def is_heavy_bucket_model(cls, model_id: str) -> bool:
|
| 206 |
+
"""是否使用 heavy 配额桶(目前仅 grok-4-heavy)。"""
|
| 207 |
+
return model_id == "grok-4-heavy"
|
| 208 |
+
|
| 209 |
+
@classmethod
|
| 210 |
+
def pool_for_model(cls, model_id: str) -> str:
|
| 211 |
+
"""根据模型选择 Token 池"""
|
| 212 |
+
model = cls.get(model_id)
|
| 213 |
+
if model and model.tier == Tier.SUPER:
|
| 214 |
+
return "ssoSuper"
|
| 215 |
+
return "ssoBasic"
|
| 216 |
+
|
| 217 |
+
@classmethod
|
| 218 |
+
def pool_candidates_for_model(cls, model_id: str) -> list[str]:
|
| 219 |
+
"""按优先级返回可用 Token 池列表。"""
|
| 220 |
+
model = cls.get(model_id)
|
| 221 |
+
if model and model.tier == Tier.SUPER:
|
| 222 |
+
return ["ssoSuper"]
|
| 223 |
+
return ["ssoBasic", "ssoSuper"]
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
__all__ = ["ModelService"]
|
app/services/grok/processor.py
ADDED
|
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenAI 响应格式处理器
|
| 3 |
+
"""
|
| 4 |
+
import time
|
| 5 |
+
import uuid
|
| 6 |
+
import random
|
| 7 |
+
import html
|
| 8 |
+
import orjson
|
| 9 |
+
from typing import Any, AsyncGenerator, Optional, AsyncIterable, List
|
| 10 |
+
|
| 11 |
+
from app.core.config import get_config
|
| 12 |
+
from app.core.logger import logger
|
| 13 |
+
from app.services.grok.assets import DownloadService
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
ASSET_URL = "https://assets.grok.com/"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _build_video_poster_preview(video_url: str, thumbnail_url: str = "") -> str:
|
| 20 |
+
"""将 <video> 替换为可点击的 Poster 预览图(用于前端展示)"""
|
| 21 |
+
safe_video = html.escape(video_url or "", quote=True)
|
| 22 |
+
safe_thumb = html.escape(thumbnail_url or "", quote=True)
|
| 23 |
+
|
| 24 |
+
if not safe_video:
|
| 25 |
+
return ""
|
| 26 |
+
|
| 27 |
+
if not safe_thumb:
|
| 28 |
+
return f'<a href="{safe_video}" target="_blank" rel="noopener noreferrer">{safe_video}</a>'
|
| 29 |
+
|
| 30 |
+
return f'''<a href="{safe_video}" target="_blank" rel="noopener noreferrer" style="display:inline-block;position:relative;max-width:100%;text-decoration:none;">
|
| 31 |
+
<img src="{safe_thumb}" alt="video" style="max-width:100%;height:auto;border-radius:12px;display:block;" />
|
| 32 |
+
<span style="position:absolute;inset:0;display:flex;align-items:center;justify-content:center;">
|
| 33 |
+
<span style="width:64px;height:64px;border-radius:9999px;background:rgba(0,0,0,.55);display:flex;align-items:center;justify-content:center;">
|
| 34 |
+
<span style="width:0;height:0;border-top:12px solid transparent;border-bottom:12px solid transparent;border-left:18px solid #fff;margin-left:4px;"></span>
|
| 35 |
+
</span>
|
| 36 |
+
</span>
|
| 37 |
+
</a>'''
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class BaseProcessor:
|
| 41 |
+
"""基础处理器"""
|
| 42 |
+
|
| 43 |
+
def __init__(self, model: str, token: str = ""):
|
| 44 |
+
self.model = model
|
| 45 |
+
self.token = token
|
| 46 |
+
self.created = int(time.time())
|
| 47 |
+
self.app_url = get_config("app.app_url", "")
|
| 48 |
+
self._dl_service: Optional[DownloadService] = None
|
| 49 |
+
|
| 50 |
+
def _get_dl(self) -> DownloadService:
|
| 51 |
+
"""获取下载服务实例(复用)"""
|
| 52 |
+
if self._dl_service is None:
|
| 53 |
+
self._dl_service = DownloadService()
|
| 54 |
+
return self._dl_service
|
| 55 |
+
|
| 56 |
+
async def close(self):
|
| 57 |
+
"""释放下载服务资源"""
|
| 58 |
+
if self._dl_service:
|
| 59 |
+
await self._dl_service.close()
|
| 60 |
+
self._dl_service = None
|
| 61 |
+
|
| 62 |
+
async def process_url(self, path: str, media_type: str = "image") -> str:
|
| 63 |
+
"""处理资产 URL"""
|
| 64 |
+
# 处理可能的绝对路径
|
| 65 |
+
if path.startswith("http"):
|
| 66 |
+
from urllib.parse import urlparse
|
| 67 |
+
path = urlparse(path).path
|
| 68 |
+
|
| 69 |
+
if not path.startswith("/"):
|
| 70 |
+
path = f"/{path}"
|
| 71 |
+
|
| 72 |
+
# Invalid root path is not a displayable image URL.
|
| 73 |
+
if path in {"", "/"}:
|
| 74 |
+
return ""
|
| 75 |
+
|
| 76 |
+
# Always materialize to local cache endpoint so callers don't rely on
|
| 77 |
+
# direct assets.grok.com access (often blocked without upstream cookies).
|
| 78 |
+
dl_service = self._get_dl()
|
| 79 |
+
await dl_service.download(path, self.token, media_type)
|
| 80 |
+
local_path = f"/v1/files/{media_type}{path}"
|
| 81 |
+
if self.app_url:
|
| 82 |
+
return f"{self.app_url.rstrip('/')}{local_path}"
|
| 83 |
+
return local_path
|
| 84 |
+
|
| 85 |
+
def _sse(self, content: str = "", role: str = None, finish: str = None) -> str:
|
| 86 |
+
"""构建 SSE 响应 (StreamProcessor 通用)"""
|
| 87 |
+
if not hasattr(self, 'response_id'):
|
| 88 |
+
self.response_id = None
|
| 89 |
+
if not hasattr(self, 'fingerprint'):
|
| 90 |
+
self.fingerprint = ""
|
| 91 |
+
|
| 92 |
+
delta = {}
|
| 93 |
+
if role:
|
| 94 |
+
delta["role"] = role
|
| 95 |
+
delta["content"] = ""
|
| 96 |
+
elif content:
|
| 97 |
+
delta["content"] = content
|
| 98 |
+
|
| 99 |
+
chunk = {
|
| 100 |
+
"id": self.response_id or f"chatcmpl-{uuid.uuid4().hex[:24]}",
|
| 101 |
+
"object": "chat.completion.chunk",
|
| 102 |
+
"created": self.created,
|
| 103 |
+
"model": self.model,
|
| 104 |
+
"system_fingerprint": self.fingerprint if hasattr(self, 'fingerprint') else "",
|
| 105 |
+
"choices": [{"index": 0, "delta": delta, "logprobs": None, "finish_reason": finish}]
|
| 106 |
+
}
|
| 107 |
+
return f"data: {orjson.dumps(chunk).decode()}\n\n"
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class StreamProcessor(BaseProcessor):
|
| 111 |
+
"""流式响应处理器"""
|
| 112 |
+
|
| 113 |
+
def __init__(self, model: str, token: str = "", think: bool = None):
|
| 114 |
+
super().__init__(model, token)
|
| 115 |
+
self.response_id: Optional[str] = None
|
| 116 |
+
self.fingerprint: str = ""
|
| 117 |
+
self.think_opened: bool = False
|
| 118 |
+
self.role_sent: bool = False
|
| 119 |
+
self.filter_tags = get_config("grok.filter_tags", [])
|
| 120 |
+
self.image_format = get_config("app.image_format", "url")
|
| 121 |
+
|
| 122 |
+
if think is None:
|
| 123 |
+
self.show_think = get_config("grok.thinking", False)
|
| 124 |
+
else:
|
| 125 |
+
self.show_think = think
|
| 126 |
+
|
| 127 |
+
async def process(self, response: AsyncIterable[bytes]) -> AsyncGenerator[str, None]:
|
| 128 |
+
"""处理流式响应"""
|
| 129 |
+
try:
|
| 130 |
+
async for line in response:
|
| 131 |
+
if not line:
|
| 132 |
+
continue
|
| 133 |
+
try:
|
| 134 |
+
data = orjson.loads(line)
|
| 135 |
+
except orjson.JSONDecodeError:
|
| 136 |
+
continue
|
| 137 |
+
|
| 138 |
+
resp = data.get("result", {}).get("response", {})
|
| 139 |
+
|
| 140 |
+
# 元数据
|
| 141 |
+
if (llm := resp.get("llmInfo")) and not self.fingerprint:
|
| 142 |
+
self.fingerprint = llm.get("modelHash", "")
|
| 143 |
+
if rid := resp.get("responseId"):
|
| 144 |
+
self.response_id = rid
|
| 145 |
+
|
| 146 |
+
# 首次发送 role
|
| 147 |
+
if not self.role_sent:
|
| 148 |
+
yield self._sse(role="assistant")
|
| 149 |
+
self.role_sent = True
|
| 150 |
+
|
| 151 |
+
# 图像生成进度
|
| 152 |
+
if img := resp.get("streamingImageGenerationResponse"):
|
| 153 |
+
if self.show_think:
|
| 154 |
+
if not self.think_opened:
|
| 155 |
+
yield self._sse("<think>\n")
|
| 156 |
+
self.think_opened = True
|
| 157 |
+
idx = img.get('imageIndex', 0) + 1
|
| 158 |
+
progress = img.get('progress', 0)
|
| 159 |
+
yield self._sse(f"正在生成第{idx}张图片中,当前进度{progress}%\n")
|
| 160 |
+
continue
|
| 161 |
+
|
| 162 |
+
# modelResponse
|
| 163 |
+
if mr := resp.get("modelResponse"):
|
| 164 |
+
if self.think_opened and self.show_think:
|
| 165 |
+
if msg := mr.get("message"):
|
| 166 |
+
yield self._sse(msg + "\n")
|
| 167 |
+
yield self._sse("</think>\n")
|
| 168 |
+
self.think_opened = False
|
| 169 |
+
|
| 170 |
+
# 处理生成的图片
|
| 171 |
+
for url in mr.get("generatedImageUrls", []):
|
| 172 |
+
parts = url.split("/")
|
| 173 |
+
img_id = parts[-2] if len(parts) >= 2 else "image"
|
| 174 |
+
|
| 175 |
+
if self.image_format == "base64":
|
| 176 |
+
dl_service = self._get_dl()
|
| 177 |
+
base64_data = await dl_service.to_base64(url, self.token, "image")
|
| 178 |
+
if base64_data:
|
| 179 |
+
yield self._sse(f"\n")
|
| 180 |
+
else:
|
| 181 |
+
final_url = await self.process_url(url, "image")
|
| 182 |
+
yield self._sse(f"\n")
|
| 183 |
+
else:
|
| 184 |
+
final_url = await self.process_url(url, "image")
|
| 185 |
+
yield self._sse(f"\n")
|
| 186 |
+
|
| 187 |
+
if (meta := mr.get("metadata", {})).get("llm_info", {}).get("modelHash"):
|
| 188 |
+
self.fingerprint = meta["llm_info"]["modelHash"]
|
| 189 |
+
continue
|
| 190 |
+
|
| 191 |
+
# 普通 token
|
| 192 |
+
if (token := resp.get("token")) is not None:
|
| 193 |
+
if token and not (self.filter_tags and any(t in token for t in self.filter_tags)):
|
| 194 |
+
yield self._sse(token)
|
| 195 |
+
|
| 196 |
+
if self.think_opened:
|
| 197 |
+
yield self._sse("</think>\n")
|
| 198 |
+
yield self._sse(finish="stop")
|
| 199 |
+
yield "data: [DONE]\n\n"
|
| 200 |
+
except Exception as e:
|
| 201 |
+
logger.error(f"Stream processing error: {e}", extra={"model": self.model})
|
| 202 |
+
raise
|
| 203 |
+
finally:
|
| 204 |
+
await self.close()
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class CollectProcessor(BaseProcessor):
|
| 208 |
+
"""非流式响应处理器"""
|
| 209 |
+
|
| 210 |
+
def __init__(self, model: str, token: str = ""):
|
| 211 |
+
super().__init__(model, token)
|
| 212 |
+
self.image_format = get_config("app.image_format", "url")
|
| 213 |
+
|
| 214 |
+
async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]:
|
| 215 |
+
"""处理并收集完整响应"""
|
| 216 |
+
response_id = ""
|
| 217 |
+
fingerprint = ""
|
| 218 |
+
content = ""
|
| 219 |
+
|
| 220 |
+
try:
|
| 221 |
+
async for line in response:
|
| 222 |
+
if not line:
|
| 223 |
+
continue
|
| 224 |
+
try:
|
| 225 |
+
data = orjson.loads(line)
|
| 226 |
+
except orjson.JSONDecodeError:
|
| 227 |
+
continue
|
| 228 |
+
|
| 229 |
+
resp = data.get("result", {}).get("response", {})
|
| 230 |
+
|
| 231 |
+
if (llm := resp.get("llmInfo")) and not fingerprint:
|
| 232 |
+
fingerprint = llm.get("modelHash", "")
|
| 233 |
+
|
| 234 |
+
if mr := resp.get("modelResponse"):
|
| 235 |
+
response_id = mr.get("responseId", "")
|
| 236 |
+
content = mr.get("message", "")
|
| 237 |
+
|
| 238 |
+
if urls := mr.get("generatedImageUrls"):
|
| 239 |
+
content += "\n"
|
| 240 |
+
for url in urls:
|
| 241 |
+
parts = url.split("/")
|
| 242 |
+
img_id = parts[-2] if len(parts) >= 2 else "image"
|
| 243 |
+
|
| 244 |
+
if self.image_format == "base64":
|
| 245 |
+
dl_service = self._get_dl()
|
| 246 |
+
base64_data = await dl_service.to_base64(url, self.token, "image")
|
| 247 |
+
if base64_data:
|
| 248 |
+
content += f"\n"
|
| 249 |
+
else:
|
| 250 |
+
final_url = await self.process_url(url, "image")
|
| 251 |
+
content += f"\n"
|
| 252 |
+
else:
|
| 253 |
+
final_url = await self.process_url(url, "image")
|
| 254 |
+
content += f"\n"
|
| 255 |
+
|
| 256 |
+
if (meta := mr.get("metadata", {})).get("llm_info", {}).get("modelHash"):
|
| 257 |
+
fingerprint = meta["llm_info"]["modelHash"]
|
| 258 |
+
|
| 259 |
+
except Exception as e:
|
| 260 |
+
logger.error(f"Collect processing error: {e}", extra={"model": self.model})
|
| 261 |
+
finally:
|
| 262 |
+
await self.close()
|
| 263 |
+
|
| 264 |
+
return {
|
| 265 |
+
"id": response_id,
|
| 266 |
+
"object": "chat.completion",
|
| 267 |
+
"created": self.created,
|
| 268 |
+
"model": self.model,
|
| 269 |
+
"system_fingerprint": fingerprint,
|
| 270 |
+
"choices": [{
|
| 271 |
+
"index": 0,
|
| 272 |
+
"message": {"role": "assistant", "content": content, "refusal": None, "annotations": []},
|
| 273 |
+
"finish_reason": "stop"
|
| 274 |
+
}],
|
| 275 |
+
"usage": {
|
| 276 |
+
"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0,
|
| 277 |
+
"prompt_tokens_details": {"cached_tokens": 0, "text_tokens": 0, "audio_tokens": 0, "image_tokens": 0},
|
| 278 |
+
"completion_tokens_details": {"text_tokens": 0, "audio_tokens": 0, "reasoning_tokens": 0}
|
| 279 |
+
}
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class VideoStreamProcessor(BaseProcessor):
|
| 284 |
+
"""视频流式响应处理器"""
|
| 285 |
+
|
| 286 |
+
def __init__(self, model: str, token: str = "", think: bool = None):
|
| 287 |
+
super().__init__(model, token)
|
| 288 |
+
self.response_id: Optional[str] = None
|
| 289 |
+
self.think_opened: bool = False
|
| 290 |
+
self.role_sent: bool = False
|
| 291 |
+
self.video_format = get_config("app.video_format", "url")
|
| 292 |
+
|
| 293 |
+
if think is None:
|
| 294 |
+
self.show_think = get_config("grok.thinking", False)
|
| 295 |
+
else:
|
| 296 |
+
self.show_think = think
|
| 297 |
+
|
| 298 |
+
def _build_video_html(self, video_url: str, thumbnail_url: str = "") -> str:
|
| 299 |
+
"""构建视频 HTML 标签"""
|
| 300 |
+
if get_config("grok.video_poster_preview", False):
|
| 301 |
+
return _build_video_poster_preview(video_url, thumbnail_url)
|
| 302 |
+
poster_attr = f' poster="{thumbnail_url}"' if thumbnail_url else ""
|
| 303 |
+
return f'''<video id="video" controls="" preload="none"{poster_attr}>
|
| 304 |
+
<source id="mp4" src="{video_url}" type="video/mp4">
|
| 305 |
+
</video>'''
|
| 306 |
+
|
| 307 |
+
async def process(self, response: AsyncIterable[bytes]) -> AsyncGenerator[str, None]:
|
| 308 |
+
"""处理视频流式响应"""
|
| 309 |
+
try:
|
| 310 |
+
async for line in response:
|
| 311 |
+
if not line:
|
| 312 |
+
continue
|
| 313 |
+
try:
|
| 314 |
+
data = orjson.loads(line)
|
| 315 |
+
except orjson.JSONDecodeError:
|
| 316 |
+
continue
|
| 317 |
+
|
| 318 |
+
resp = data.get("result", {}).get("response", {})
|
| 319 |
+
|
| 320 |
+
if rid := resp.get("responseId"):
|
| 321 |
+
self.response_id = rid
|
| 322 |
+
|
| 323 |
+
# 首次发送 role
|
| 324 |
+
if not self.role_sent:
|
| 325 |
+
yield self._sse(role="assistant")
|
| 326 |
+
self.role_sent = True
|
| 327 |
+
|
| 328 |
+
# 视频生成进度
|
| 329 |
+
if video_resp := resp.get("streamingVideoGenerationResponse"):
|
| 330 |
+
progress = video_resp.get("progress", 0)
|
| 331 |
+
|
| 332 |
+
if self.show_think:
|
| 333 |
+
if not self.think_opened:
|
| 334 |
+
yield self._sse("<think>\n")
|
| 335 |
+
self.think_opened = True
|
| 336 |
+
yield self._sse(f"正在生成视频中,当前进度{progress}%\n")
|
| 337 |
+
|
| 338 |
+
if progress == 100:
|
| 339 |
+
video_url = video_resp.get("videoUrl", "")
|
| 340 |
+
thumbnail_url = video_resp.get("thumbnailImageUrl", "")
|
| 341 |
+
|
| 342 |
+
if self.think_opened and self.show_think:
|
| 343 |
+
yield self._sse("</think>\n")
|
| 344 |
+
self.think_opened = False
|
| 345 |
+
|
| 346 |
+
if video_url:
|
| 347 |
+
final_video_url = await self.process_url(video_url, "video")
|
| 348 |
+
final_thumbnail_url = ""
|
| 349 |
+
if thumbnail_url:
|
| 350 |
+
final_thumbnail_url = await self.process_url(thumbnail_url, "image")
|
| 351 |
+
|
| 352 |
+
video_html = self._build_video_html(final_video_url, final_thumbnail_url)
|
| 353 |
+
yield self._sse(video_html)
|
| 354 |
+
|
| 355 |
+
logger.info(f"Video generated: {video_url}")
|
| 356 |
+
continue
|
| 357 |
+
|
| 358 |
+
if self.think_opened:
|
| 359 |
+
yield self._sse("</think>\n")
|
| 360 |
+
yield self._sse(finish="stop")
|
| 361 |
+
yield "data: [DONE]\n\n"
|
| 362 |
+
except Exception as e:
|
| 363 |
+
logger.error(f"Video stream processing error: {e}", extra={"model": self.model})
|
| 364 |
+
finally:
|
| 365 |
+
await self.close()
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
class VideoCollectProcessor(BaseProcessor):
|
| 369 |
+
"""视频非流式响应处理器"""
|
| 370 |
+
|
| 371 |
+
def __init__(self, model: str, token: str = ""):
|
| 372 |
+
super().__init__(model, token)
|
| 373 |
+
self.video_format = get_config("app.video_format", "url")
|
| 374 |
+
|
| 375 |
+
def _build_video_html(self, video_url: str, thumbnail_url: str = "") -> str:
|
| 376 |
+
if get_config("grok.video_poster_preview", False):
|
| 377 |
+
return _build_video_poster_preview(video_url, thumbnail_url)
|
| 378 |
+
poster_attr = f' poster="{thumbnail_url}"' if thumbnail_url else ""
|
| 379 |
+
return f'''<video id="video" controls="" preload="none"{poster_attr}>
|
| 380 |
+
<source id="mp4" src="{video_url}" type="video/mp4">
|
| 381 |
+
</video>'''
|
| 382 |
+
|
| 383 |
+
async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]:
|
| 384 |
+
"""处理并收集视频响应"""
|
| 385 |
+
response_id = ""
|
| 386 |
+
content = ""
|
| 387 |
+
|
| 388 |
+
try:
|
| 389 |
+
async for line in response:
|
| 390 |
+
if not line:
|
| 391 |
+
continue
|
| 392 |
+
try:
|
| 393 |
+
data = orjson.loads(line)
|
| 394 |
+
except orjson.JSONDecodeError:
|
| 395 |
+
continue
|
| 396 |
+
|
| 397 |
+
resp = data.get("result", {}).get("response", {})
|
| 398 |
+
|
| 399 |
+
if video_resp := resp.get("streamingVideoGenerationResponse"):
|
| 400 |
+
if video_resp.get("progress") == 100:
|
| 401 |
+
response_id = resp.get("responseId", "")
|
| 402 |
+
video_url = video_resp.get("videoUrl", "")
|
| 403 |
+
thumbnail_url = video_resp.get("thumbnailImageUrl", "")
|
| 404 |
+
|
| 405 |
+
if video_url:
|
| 406 |
+
final_video_url = await self.process_url(video_url, "video")
|
| 407 |
+
final_thumbnail_url = ""
|
| 408 |
+
if thumbnail_url:
|
| 409 |
+
final_thumbnail_url = await self.process_url(thumbnail_url, "image")
|
| 410 |
+
|
| 411 |
+
content = self._build_video_html(final_video_url, final_thumbnail_url)
|
| 412 |
+
logger.info(f"Video generated: {video_url}")
|
| 413 |
+
|
| 414 |
+
except Exception as e:
|
| 415 |
+
logger.error(f"Video collect processing error: {e}", extra={"model": self.model})
|
| 416 |
+
finally:
|
| 417 |
+
await self.close()
|
| 418 |
+
|
| 419 |
+
return {
|
| 420 |
+
"id": response_id,
|
| 421 |
+
"object": "chat.completion",
|
| 422 |
+
"created": self.created,
|
| 423 |
+
"model": self.model,
|
| 424 |
+
"choices": [{
|
| 425 |
+
"index": 0,
|
| 426 |
+
"message": {"role": "assistant", "content": content, "refusal": None},
|
| 427 |
+
"finish_reason": "stop"
|
| 428 |
+
}],
|
| 429 |
+
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
class ImageStreamProcessor(BaseProcessor):
|
| 434 |
+
"""图片生成流式响应处理器"""
|
| 435 |
+
|
| 436 |
+
def __init__(
|
| 437 |
+
self,
|
| 438 |
+
model: str,
|
| 439 |
+
token: str = "",
|
| 440 |
+
n: int = 1,
|
| 441 |
+
response_format: str = "b64_json",
|
| 442 |
+
):
|
| 443 |
+
super().__init__(model, token)
|
| 444 |
+
self.partial_index = 0
|
| 445 |
+
self.n = n
|
| 446 |
+
self.target_index = random.randint(0, 1) if n == 1 else None
|
| 447 |
+
self.response_format = (response_format or "b64_json").lower()
|
| 448 |
+
if self.response_format == "url":
|
| 449 |
+
self.response_field = "url"
|
| 450 |
+
elif self.response_format == "base64":
|
| 451 |
+
self.response_field = "base64"
|
| 452 |
+
else:
|
| 453 |
+
self.response_field = "b64_json"
|
| 454 |
+
|
| 455 |
+
def _sse(self, event: str, data: dict) -> str:
|
| 456 |
+
"""构建 SSE 响应 (覆盖基类)"""
|
| 457 |
+
return f"event: {event}\ndata: {orjson.dumps(data).decode()}\n\n"
|
| 458 |
+
|
| 459 |
+
async def process(self, response: AsyncIterable[bytes]) -> AsyncGenerator[str, None]:
|
| 460 |
+
"""处理流式响应"""
|
| 461 |
+
final_images = []
|
| 462 |
+
|
| 463 |
+
try:
|
| 464 |
+
async for line in response:
|
| 465 |
+
if not line:
|
| 466 |
+
continue
|
| 467 |
+
try:
|
| 468 |
+
data = orjson.loads(line)
|
| 469 |
+
except orjson.JSONDecodeError:
|
| 470 |
+
continue
|
| 471 |
+
|
| 472 |
+
resp = data.get("result", {}).get("response", {})
|
| 473 |
+
|
| 474 |
+
# 图片生成进度
|
| 475 |
+
if img := resp.get("streamingImageGenerationResponse"):
|
| 476 |
+
image_index = img.get("imageIndex", 0)
|
| 477 |
+
progress = img.get("progress", 0)
|
| 478 |
+
|
| 479 |
+
if self.n == 1 and image_index != self.target_index:
|
| 480 |
+
continue
|
| 481 |
+
|
| 482 |
+
out_index = 0 if self.n == 1 else image_index
|
| 483 |
+
|
| 484 |
+
yield self._sse("image_generation.partial_image", {
|
| 485 |
+
"type": "image_generation.partial_image",
|
| 486 |
+
self.response_field: "",
|
| 487 |
+
"index": out_index,
|
| 488 |
+
"progress": progress
|
| 489 |
+
})
|
| 490 |
+
continue
|
| 491 |
+
|
| 492 |
+
# modelResponse
|
| 493 |
+
if mr := resp.get("modelResponse"):
|
| 494 |
+
if urls := mr.get("generatedImageUrls"):
|
| 495 |
+
for url in urls:
|
| 496 |
+
if self.response_format == "url":
|
| 497 |
+
processed = await self.process_url(url, "image")
|
| 498 |
+
if processed:
|
| 499 |
+
final_images.append(processed)
|
| 500 |
+
continue
|
| 501 |
+
dl_service = self._get_dl()
|
| 502 |
+
base64_data = await dl_service.to_base64(url, self.token, "image")
|
| 503 |
+
if base64_data:
|
| 504 |
+
if "," in base64_data:
|
| 505 |
+
b64 = base64_data.split(",", 1)[1]
|
| 506 |
+
else:
|
| 507 |
+
b64 = base64_data
|
| 508 |
+
final_images.append(b64)
|
| 509 |
+
continue
|
| 510 |
+
|
| 511 |
+
for index, b64 in enumerate(final_images):
|
| 512 |
+
if self.n == 1:
|
| 513 |
+
if index != self.target_index:
|
| 514 |
+
continue
|
| 515 |
+
out_index = 0
|
| 516 |
+
else:
|
| 517 |
+
out_index = index
|
| 518 |
+
|
| 519 |
+
yield self._sse("image_generation.completed", {
|
| 520 |
+
"type": "image_generation.completed",
|
| 521 |
+
self.response_field: b64,
|
| 522 |
+
"index": out_index,
|
| 523 |
+
"usage": {
|
| 524 |
+
"total_tokens": 50,
|
| 525 |
+
"input_tokens": 25,
|
| 526 |
+
"output_tokens": 25,
|
| 527 |
+
"input_tokens_details": {"text_tokens": 5, "image_tokens": 20}
|
| 528 |
+
}
|
| 529 |
+
})
|
| 530 |
+
except Exception as e:
|
| 531 |
+
logger.error(f"Image stream processing error: {e}")
|
| 532 |
+
raise
|
| 533 |
+
finally:
|
| 534 |
+
await self.close()
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
class ImageCollectProcessor(BaseProcessor):
|
| 538 |
+
"""图片生成非流式响应处理器"""
|
| 539 |
+
|
| 540 |
+
def __init__(
|
| 541 |
+
self,
|
| 542 |
+
model: str,
|
| 543 |
+
token: str = "",
|
| 544 |
+
response_format: str = "b64_json",
|
| 545 |
+
):
|
| 546 |
+
super().__init__(model, token)
|
| 547 |
+
self.response_format = (response_format or "b64_json").lower()
|
| 548 |
+
|
| 549 |
+
async def process(self, response: AsyncIterable[bytes]) -> List[str]:
|
| 550 |
+
"""处理并收集图片"""
|
| 551 |
+
images = []
|
| 552 |
+
|
| 553 |
+
try:
|
| 554 |
+
async for line in response:
|
| 555 |
+
if not line:
|
| 556 |
+
continue
|
| 557 |
+
try:
|
| 558 |
+
data = orjson.loads(line)
|
| 559 |
+
except orjson.JSONDecodeError:
|
| 560 |
+
continue
|
| 561 |
+
|
| 562 |
+
resp = data.get("result", {}).get("response", {})
|
| 563 |
+
|
| 564 |
+
if mr := resp.get("modelResponse"):
|
| 565 |
+
if urls := mr.get("generatedImageUrls"):
|
| 566 |
+
for url in urls:
|
| 567 |
+
if self.response_format == "url":
|
| 568 |
+
processed = await self.process_url(url, "image")
|
| 569 |
+
if processed:
|
| 570 |
+
images.append(processed)
|
| 571 |
+
continue
|
| 572 |
+
dl_service = self._get_dl()
|
| 573 |
+
base64_data = await dl_service.to_base64(url, self.token, "image")
|
| 574 |
+
if base64_data:
|
| 575 |
+
if "," in base64_data:
|
| 576 |
+
b64 = base64_data.split(",", 1)[1]
|
| 577 |
+
else:
|
| 578 |
+
b64 = base64_data
|
| 579 |
+
images.append(b64)
|
| 580 |
+
|
| 581 |
+
except Exception as e:
|
| 582 |
+
logger.error(f"Image collect processing error: {e}")
|
| 583 |
+
finally:
|
| 584 |
+
await self.close()
|
| 585 |
+
|
| 586 |
+
return images
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
__all__ = [
|
| 590 |
+
"StreamProcessor",
|
| 591 |
+
"CollectProcessor",
|
| 592 |
+
"VideoStreamProcessor",
|
| 593 |
+
"VideoCollectProcessor",
|
| 594 |
+
"ImageStreamProcessor",
|
| 595 |
+
"ImageCollectProcessor",
|
| 596 |
+
]
|
app/services/grok/retry.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Grok API 重试工具
|
| 3 |
+
|
| 4 |
+
提供可配置的重试机制,支持:
|
| 5 |
+
- 可配置的重试次数
|
| 6 |
+
- 可配置的重试状态码
|
| 7 |
+
- 仅记录最后一次失败
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import asyncio
|
| 11 |
+
from typing import Callable, Any, Optional, List
|
| 12 |
+
from functools import wraps
|
| 13 |
+
|
| 14 |
+
from app.core.logger import logger
|
| 15 |
+
from app.core.config import get_config
|
| 16 |
+
from app.core.exceptions import UpstreamException
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class RetryConfig:
|
| 20 |
+
"""重试配置"""
|
| 21 |
+
|
| 22 |
+
@staticmethod
|
| 23 |
+
def get_max_retry() -> int:
|
| 24 |
+
"""获取最大重试次数"""
|
| 25 |
+
return get_config("grok.max_retry", 1)
|
| 26 |
+
|
| 27 |
+
@staticmethod
|
| 28 |
+
def get_retry_codes() -> List[int]:
|
| 29 |
+
"""获取可重试的状态码"""
|
| 30 |
+
return get_config("grok.retry_status_codes", [401, 429, 403])
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class RetryContext:
|
| 34 |
+
"""重试上下文"""
|
| 35 |
+
|
| 36 |
+
def __init__(self):
|
| 37 |
+
self.attempt = 0
|
| 38 |
+
self.max_retry = RetryConfig.get_max_retry()
|
| 39 |
+
self.retry_codes = RetryConfig.get_retry_codes()
|
| 40 |
+
self.last_error = None
|
| 41 |
+
self.last_status = None
|
| 42 |
+
|
| 43 |
+
def should_retry(self, status_code: int) -> bool:
|
| 44 |
+
"""判断是否重试"""
|
| 45 |
+
return (
|
| 46 |
+
self.attempt < self.max_retry and
|
| 47 |
+
status_code in self.retry_codes
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def record_error(self, status_code: int, error: Exception):
|
| 51 |
+
"""记录错误信息"""
|
| 52 |
+
self.last_status = status_code
|
| 53 |
+
self.last_error = error
|
| 54 |
+
self.attempt += 1
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
async def retry_on_status(
|
| 58 |
+
func: Callable,
|
| 59 |
+
*args,
|
| 60 |
+
extract_status: Callable[[Exception], Optional[int]] = None,
|
| 61 |
+
on_retry: Callable[[int, int, Exception], None] = None,
|
| 62 |
+
**kwargs
|
| 63 |
+
) -> Any:
|
| 64 |
+
"""
|
| 65 |
+
通用重试函数
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
func: 重试的异步函数
|
| 69 |
+
*args: 函数参数
|
| 70 |
+
extract_status: 异常提取状态码的函数
|
| 71 |
+
on_retry: 重试时的回调函数
|
| 72 |
+
**kwargs: 函数关键字参数
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
函数执行结果
|
| 76 |
+
|
| 77 |
+
Raises:
|
| 78 |
+
最后一次失败的异常
|
| 79 |
+
"""
|
| 80 |
+
ctx = RetryContext()
|
| 81 |
+
|
| 82 |
+
# 状态码提取器
|
| 83 |
+
if extract_status is None:
|
| 84 |
+
def extract_status(e: Exception) -> Optional[int]:
|
| 85 |
+
if isinstance(e, UpstreamException):
|
| 86 |
+
return e.details.get("status") if e.details else None
|
| 87 |
+
return None
|
| 88 |
+
|
| 89 |
+
while ctx.attempt <= ctx.max_retry:
|
| 90 |
+
try:
|
| 91 |
+
result = await func(*args, **kwargs)
|
| 92 |
+
|
| 93 |
+
# 记录日志
|
| 94 |
+
if ctx.attempt > 0:
|
| 95 |
+
logger.info(
|
| 96 |
+
f"Retry succeeded after {ctx.attempt} attempts"
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
return result
|
| 100 |
+
|
| 101 |
+
except Exception as e:
|
| 102 |
+
# 提取状态码
|
| 103 |
+
status_code = extract_status(e)
|
| 104 |
+
|
| 105 |
+
if status_code is None:
|
| 106 |
+
# 错误无法识别
|
| 107 |
+
logger.error(f"Non-retryable error: {e}")
|
| 108 |
+
raise
|
| 109 |
+
|
| 110 |
+
# 记录错误
|
| 111 |
+
ctx.record_error(status_code, e)
|
| 112 |
+
|
| 113 |
+
# 判断是否重试
|
| 114 |
+
if ctx.should_retry(status_code):
|
| 115 |
+
delay = 0.5 * (ctx.attempt + 1) # 渐进延迟
|
| 116 |
+
logger.warning(
|
| 117 |
+
f"Retry {ctx.attempt}/{ctx.max_retry} for status {status_code}, "
|
| 118 |
+
f"waiting {delay}s"
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# 回调
|
| 122 |
+
if on_retry:
|
| 123 |
+
on_retry(ctx.attempt, status_code, e)
|
| 124 |
+
|
| 125 |
+
await asyncio.sleep(delay)
|
| 126 |
+
continue
|
| 127 |
+
else:
|
| 128 |
+
# 不可重试或重试次数耗尽
|
| 129 |
+
if status_code in ctx.retry_codes:
|
| 130 |
+
# 打印当前尝试次数(包括最后一次)
|
| 131 |
+
logger.warning(
|
| 132 |
+
f"Retry {ctx.attempt}/{ctx.max_retry} for status {status_code}, failed"
|
| 133 |
+
)
|
| 134 |
+
logger.error(
|
| 135 |
+
f"Retry exhausted after {ctx.max_retry} attempts, "
|
| 136 |
+
f"last status: {status_code}"
|
| 137 |
+
)
|
| 138 |
+
else:
|
| 139 |
+
logger.error(
|
| 140 |
+
f"Non-retryable status code: {status_code}"
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# 抛出最后一次的错误
|
| 144 |
+
raise
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def with_retry(
|
| 148 |
+
extract_status: Callable[[Exception], Optional[int]] = None,
|
| 149 |
+
on_retry: Callable[[int, int, Exception], None] = None
|
| 150 |
+
):
|
| 151 |
+
"""
|
| 152 |
+
重试装饰器
|
| 153 |
+
|
| 154 |
+
Usage:
|
| 155 |
+
@with_retry()
|
| 156 |
+
async def my_api_call():
|
| 157 |
+
...
|
| 158 |
+
"""
|
| 159 |
+
def decorator(func: Callable):
|
| 160 |
+
@wraps(func)
|
| 161 |
+
async def wrapper(*args, **kwargs):
|
| 162 |
+
return await retry_on_status(
|
| 163 |
+
func,
|
| 164 |
+
*args,
|
| 165 |
+
extract_status=extract_status,
|
| 166 |
+
on_retry=on_retry,
|
| 167 |
+
**kwargs
|
| 168 |
+
)
|
| 169 |
+
return wrapper
|
| 170 |
+
return decorator
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
__all__ = [
|
| 174 |
+
"RetryConfig",
|
| 175 |
+
"RetryContext",
|
| 176 |
+
"retry_on_status",
|
| 177 |
+
"with_retry",
|
| 178 |
+
]
|
app/services/grok/statsig.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Statsig ID 生成服务
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import base64
|
| 6 |
+
import random
|
| 7 |
+
import string
|
| 8 |
+
|
| 9 |
+
from app.core.config import get_config
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class StatsigService:
|
| 13 |
+
"""Statsig ID 生成服务"""
|
| 14 |
+
|
| 15 |
+
@staticmethod
|
| 16 |
+
def _rand(length: int, alphanumeric: bool = False) -> str:
|
| 17 |
+
"""生成随机字符串"""
|
| 18 |
+
chars = string.ascii_lowercase + string.digits if alphanumeric else string.ascii_lowercase
|
| 19 |
+
return "".join(random.choices(chars, k=length))
|
| 20 |
+
|
| 21 |
+
@staticmethod
|
| 22 |
+
def gen_id() -> str:
|
| 23 |
+
"""
|
| 24 |
+
生成 Statsig ID
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
Base64 编码的 ID
|
| 28 |
+
"""
|
| 29 |
+
# 读取配置
|
| 30 |
+
dynamic = get_config("grok.dynamic_statsig", True)
|
| 31 |
+
|
| 32 |
+
if not dynamic:
|
| 33 |
+
return "ZTpUeXBlRXJyb3I6IENhbm5vdCByZWFkIHByb3BlcnRpZXMgb2YgdW5kZWZpbmVkIChyZWFkaW5nICdjaGlsZE5vZGVzJyk="
|
| 34 |
+
|
| 35 |
+
# 随机格式
|
| 36 |
+
if random.choice([True, False]):
|
| 37 |
+
rand = StatsigService._rand(5, alphanumeric=True)
|
| 38 |
+
message = f"e:TypeError: Cannot read properties of null (reading 'children['{rand}']')"
|
| 39 |
+
else:
|
| 40 |
+
rand = StatsigService._rand(10)
|
| 41 |
+
message = f"e:TypeError: Cannot read properties of undefined (reading '{rand}')"
|
| 42 |
+
|
| 43 |
+
return base64.b64encode(message.encode()).decode()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
__all__ = ["StatsigService"]
|
app/services/grok/usage.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Grok 用量服务
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import uuid
|
| 7 |
+
from typing import Dict
|
| 8 |
+
|
| 9 |
+
import orjson
|
| 10 |
+
from curl_cffi.requests import AsyncSession
|
| 11 |
+
|
| 12 |
+
from app.core.logger import logger
|
| 13 |
+
from app.core.config import get_config
|
| 14 |
+
from app.core.exceptions import UpstreamException, AppException
|
| 15 |
+
from app.services.grok.statsig import StatsigService
|
| 16 |
+
from app.services.grok.retry import retry_on_status
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
LIMITS_API = "https://grok.com/rest/rate-limits"
|
| 20 |
+
BROWSER = "chrome136"
|
| 21 |
+
TIMEOUT = 10
|
| 22 |
+
DEFAULT_MAX_CONCURRENT = 25
|
| 23 |
+
_USAGE_SEMAPHORE = asyncio.Semaphore(DEFAULT_MAX_CONCURRENT)
|
| 24 |
+
_USAGE_SEM_VALUE = DEFAULT_MAX_CONCURRENT
|
| 25 |
+
|
| 26 |
+
def _get_usage_semaphore() -> asyncio.Semaphore:
|
| 27 |
+
global _USAGE_SEMAPHORE, _USAGE_SEM_VALUE
|
| 28 |
+
value = get_config("performance.usage_max_concurrent", DEFAULT_MAX_CONCURRENT)
|
| 29 |
+
try:
|
| 30 |
+
value = int(value)
|
| 31 |
+
except Exception:
|
| 32 |
+
value = DEFAULT_MAX_CONCURRENT
|
| 33 |
+
value = max(1, value)
|
| 34 |
+
if value != _USAGE_SEM_VALUE:
|
| 35 |
+
_USAGE_SEM_VALUE = value
|
| 36 |
+
_USAGE_SEMAPHORE = asyncio.Semaphore(value)
|
| 37 |
+
return _USAGE_SEMAPHORE
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class UsageService:
|
| 41 |
+
"""用量查询服务"""
|
| 42 |
+
|
| 43 |
+
def __init__(self, proxy: str = None):
|
| 44 |
+
self.proxy = proxy or get_config("grok.base_proxy_url", "")
|
| 45 |
+
self.timeout = get_config("grok.timeout", TIMEOUT)
|
| 46 |
+
|
| 47 |
+
def _build_headers(self, token: str) -> dict:
|
| 48 |
+
"""构建请求头"""
|
| 49 |
+
headers = {
|
| 50 |
+
"Accept": "*/*",
|
| 51 |
+
"Accept-Encoding": "gzip, deflate, br, zstd",
|
| 52 |
+
"Accept-Language": "zh-CN,zh;q=0.9",
|
| 53 |
+
"Baggage": "sentry-environment=production,sentry-release=d6add6fb0460641fd482d767a335ef72b9b6abb8,sentry-public_key=b311e0f2690c81f25e2c4cf6d4f7ce1c",
|
| 54 |
+
"Cache-Control": "no-cache",
|
| 55 |
+
"Content-Type": "application/json",
|
| 56 |
+
"Origin": "https://grok.com",
|
| 57 |
+
"Pragma": "no-cache",
|
| 58 |
+
"Priority": "u=1, i",
|
| 59 |
+
"Referer": "https://grok.com/",
|
| 60 |
+
"Sec-Ch-Ua": '"Google Chrome";v="136", "Chromium";v="136", "Not(A:Brand";v="24"',
|
| 61 |
+
"Sec-Ch-Ua-Arch": "arm",
|
| 62 |
+
"Sec-Ch-Ua-Bitness": "64",
|
| 63 |
+
"Sec-Ch-Ua-Mobile": "?0",
|
| 64 |
+
"Sec-Ch-Ua-Model": "",
|
| 65 |
+
"Sec-Ch-Ua-Platform": '"macOS"',
|
| 66 |
+
"Sec-Fetch-Dest": "empty",
|
| 67 |
+
"Sec-Fetch-Mode": "cors",
|
| 68 |
+
"Sec-Fetch-Site": "same-origin",
|
| 69 |
+
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36",
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
# Statsig ID
|
| 73 |
+
headers["x-statsig-id"] = StatsigService.gen_id()
|
| 74 |
+
headers["x-xai-request-id"] = str(uuid.uuid4())
|
| 75 |
+
|
| 76 |
+
# Cookie
|
| 77 |
+
token = token[4:] if token.startswith("sso=") else token
|
| 78 |
+
cf = get_config("grok.cf_clearance", "")
|
| 79 |
+
headers["Cookie"] = f"sso={token};cf_clearance={cf}" if cf else f"sso={token}"
|
| 80 |
+
|
| 81 |
+
return headers
|
| 82 |
+
|
| 83 |
+
def _build_proxies(self) -> dict:
|
| 84 |
+
"""构建代理配置"""
|
| 85 |
+
return {"http": self.proxy, "https": self.proxy} if self.proxy else None
|
| 86 |
+
|
| 87 |
+
async def get(self, token: str, model_name: str = "grok-4-1-thinking-1129") -> Dict:
|
| 88 |
+
"""
|
| 89 |
+
获取速率限制信息
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
token: 认证 Token
|
| 93 |
+
model_name: 模型名称
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
响应数据
|
| 97 |
+
|
| 98 |
+
Raises:
|
| 99 |
+
UpstreamException: 当获取失败且重试耗尽时
|
| 100 |
+
"""
|
| 101 |
+
async with _get_usage_semaphore():
|
| 102 |
+
# 定义状态码提取器
|
| 103 |
+
def extract_status(e: Exception) -> int | None:
|
| 104 |
+
if isinstance(e, UpstreamException) and e.details:
|
| 105 |
+
return e.details.get("status")
|
| 106 |
+
return None
|
| 107 |
+
|
| 108 |
+
# 定义实际的请求函数
|
| 109 |
+
async def do_request():
|
| 110 |
+
try:
|
| 111 |
+
headers = self._build_headers(token)
|
| 112 |
+
payload = {
|
| 113 |
+
"requestKind": "DEFAULT",
|
| 114 |
+
"modelName": model_name
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
async with AsyncSession() as session:
|
| 118 |
+
response = await session.post(
|
| 119 |
+
LIMITS_API,
|
| 120 |
+
headers=headers,
|
| 121 |
+
json=payload,
|
| 122 |
+
impersonate=BROWSER,
|
| 123 |
+
timeout=self.timeout,
|
| 124 |
+
proxies=self._build_proxies()
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
if response.status_code == 200:
|
| 128 |
+
data = response.json()
|
| 129 |
+
remaining = data.get('remainingTokens', 0)
|
| 130 |
+
logger.info(f"Usage: quota {remaining} remaining")
|
| 131 |
+
return data
|
| 132 |
+
|
| 133 |
+
logger.error(f"Usage failed: {response.status_code}")
|
| 134 |
+
|
| 135 |
+
raise UpstreamException(
|
| 136 |
+
message=f"Failed to get usage stats: {response.status_code}",
|
| 137 |
+
details={"status": response.status_code}
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
except Exception as e:
|
| 141 |
+
if isinstance(e, UpstreamException):
|
| 142 |
+
raise
|
| 143 |
+
logger.error(f"Usage error: {e}")
|
| 144 |
+
raise UpstreamException(
|
| 145 |
+
message=f"Usage service error: {str(e)}",
|
| 146 |
+
details={"error": str(e)}
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# 带重试的执行
|
| 150 |
+
try:
|
| 151 |
+
result = await retry_on_status(
|
| 152 |
+
do_request,
|
| 153 |
+
extract_status=extract_status
|
| 154 |
+
)
|
| 155 |
+
return result
|
| 156 |
+
|
| 157 |
+
except Exception as e:
|
| 158 |
+
# 最后一次失败已经被记录
|
| 159 |
+
raise
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
__all__ = ["UsageService"]
|
app/services/quota.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API Key daily quota enforcement (local/docker runtime)
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
from typing import Optional, Dict
|
| 8 |
+
|
| 9 |
+
from app.core.config import get_config
|
| 10 |
+
from app.core.exceptions import AppException, ErrorType
|
| 11 |
+
from app.services.api_keys import api_key_manager
|
| 12 |
+
from app.services.grok.model import ModelService
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
async def enforce_daily_quota(
|
| 16 |
+
api_key: Optional[str],
|
| 17 |
+
model: str,
|
| 18 |
+
*,
|
| 19 |
+
image_count: Optional[int] = None,
|
| 20 |
+
) -> None:
|
| 21 |
+
"""
|
| 22 |
+
Enforce per-day quotas for a non-admin API key.
|
| 23 |
+
|
| 24 |
+
- chat/heavy/video: count by request (1)
|
| 25 |
+
- image: count by generated images
|
| 26 |
+
- chat endpoint + image model: charge 2 images per request
|
| 27 |
+
- image endpoint: charge `image_count` (n)
|
| 28 |
+
- heavy: consumes both heavy + chat buckets
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
token = str(api_key or "").strip()
|
| 32 |
+
if not token:
|
| 33 |
+
return
|
| 34 |
+
|
| 35 |
+
global_key = str(get_config("app.api_key", "") or "").strip()
|
| 36 |
+
if global_key and token == global_key:
|
| 37 |
+
return
|
| 38 |
+
|
| 39 |
+
model_info = ModelService.get(model)
|
| 40 |
+
incs: Dict[str, int] = {}
|
| 41 |
+
bucket_name = "chat"
|
| 42 |
+
|
| 43 |
+
if model == "grok-4-heavy":
|
| 44 |
+
incs = {"heavy_used": 1, "chat_used": 1}
|
| 45 |
+
bucket_name = "heavy/chat"
|
| 46 |
+
elif model_info and model_info.is_video:
|
| 47 |
+
incs = {"video_used": 1}
|
| 48 |
+
bucket_name = "video"
|
| 49 |
+
elif model_info and model_info.is_image:
|
| 50 |
+
# grok image model via chat endpoint: upstream usually returns up to 2 images
|
| 51 |
+
incs = {"image_used": max(1, int(image_count or 2))}
|
| 52 |
+
bucket_name = "image"
|
| 53 |
+
else:
|
| 54 |
+
incs = {"chat_used": 1}
|
| 55 |
+
bucket_name = "chat"
|
| 56 |
+
|
| 57 |
+
ok = await api_key_manager.consume_daily_usage(token, incs)
|
| 58 |
+
if ok:
|
| 59 |
+
return
|
| 60 |
+
|
| 61 |
+
raise AppException(
|
| 62 |
+
message=f"Daily quota exceeded: {bucket_name}",
|
| 63 |
+
error_type=ErrorType.RATE_LIMIT.value,
|
| 64 |
+
code="daily_quota_exceeded",
|
| 65 |
+
status_code=429,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
__all__ = ["enforce_daily_quota"]
|
| 70 |
+
|
app/services/register/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Auto registration services."""
|
| 2 |
+
|
| 3 |
+
from app.services.register.manager import get_auto_register_manager, AutoRegisterManager
|
| 4 |
+
|
| 5 |
+
__all__ = ["AutoRegisterManager", "get_auto_register_manager"]
|
app/services/register/account_settings_refresh.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
from typing import Iterable, Any
|
| 5 |
+
|
| 6 |
+
from app.core.config import get_config
|
| 7 |
+
from app.core.logger import logger
|
| 8 |
+
from app.services.register.services import (
|
| 9 |
+
UserAgreementService,
|
| 10 |
+
BirthDateService,
|
| 11 |
+
NsfwSettingsService,
|
| 12 |
+
)
|
| 13 |
+
from app.services.token.manager import TokenManager, get_token_manager
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
DEFAULT_NSFW_REFRESH_CONCURRENCY = 10
|
| 17 |
+
DEFAULT_NSFW_REFRESH_RETRIES = 3
|
| 18 |
+
DEFAULT_IMPERSONATE = "chrome120"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _extract_cookie_value(cookie_str: str, name: str) -> str | None:
|
| 22 |
+
needle = f"{name}="
|
| 23 |
+
if needle not in cookie_str:
|
| 24 |
+
return None
|
| 25 |
+
for part in cookie_str.split(";"):
|
| 26 |
+
part = part.strip()
|
| 27 |
+
if part.startswith(needle):
|
| 28 |
+
value = part[len(needle):].strip()
|
| 29 |
+
return value or None
|
| 30 |
+
return None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def parse_sso_pair(raw_token: str) -> tuple[str, str]:
|
| 34 |
+
raw = str(raw_token or "").strip()
|
| 35 |
+
if not raw:
|
| 36 |
+
return "", ""
|
| 37 |
+
|
| 38 |
+
if ";" in raw:
|
| 39 |
+
sso = _extract_cookie_value(raw, "sso") or ""
|
| 40 |
+
sso_rw = _extract_cookie_value(raw, "sso-rw") or sso
|
| 41 |
+
return sso.strip(), sso_rw.strip()
|
| 42 |
+
|
| 43 |
+
sso = raw[4:].strip() if raw.startswith("sso=") else raw
|
| 44 |
+
sso_rw = sso
|
| 45 |
+
return sso, sso_rw
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def normalize_sso_token(raw_token: str) -> str:
|
| 49 |
+
sso, _ = parse_sso_pair(raw_token)
|
| 50 |
+
return sso
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _coerce_concurrency(value: Any, default: int = DEFAULT_NSFW_REFRESH_CONCURRENCY) -> int:
|
| 54 |
+
try:
|
| 55 |
+
n = int(value)
|
| 56 |
+
except Exception:
|
| 57 |
+
n = default
|
| 58 |
+
return max(1, n)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _coerce_retries(value: Any, default: int = DEFAULT_NSFW_REFRESH_RETRIES) -> int:
|
| 62 |
+
try:
|
| 63 |
+
n = int(value)
|
| 64 |
+
except Exception:
|
| 65 |
+
n = default
|
| 66 |
+
return max(0, n)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _format_step_error(result: dict, fallback: str = "unknown error") -> str:
|
| 70 |
+
if not isinstance(result, dict):
|
| 71 |
+
return fallback
|
| 72 |
+
|
| 73 |
+
text = str(result.get("error") or "").strip()
|
| 74 |
+
if text:
|
| 75 |
+
return text
|
| 76 |
+
|
| 77 |
+
status_code = result.get("status_code")
|
| 78 |
+
if status_code is not None:
|
| 79 |
+
return f"HTTP {status_code}"
|
| 80 |
+
|
| 81 |
+
grpc_status = result.get("grpc_status")
|
| 82 |
+
if grpc_status is not None:
|
| 83 |
+
return f"gRPC {grpc_status}"
|
| 84 |
+
|
| 85 |
+
response_text = str(result.get("response_text") or "").strip()
|
| 86 |
+
if response_text:
|
| 87 |
+
return response_text
|
| 88 |
+
|
| 89 |
+
return fallback
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class AccountSettingsRefreshService:
|
| 93 |
+
def __init__(self, token_manager: TokenManager, cf_clearance: str = "") -> None:
|
| 94 |
+
self.token_manager = token_manager
|
| 95 |
+
self.cf_clearance = (cf_clearance or "").strip()
|
| 96 |
+
|
| 97 |
+
def _apply_once(self, raw_token: str) -> tuple[bool, str, str]:
|
| 98 |
+
sso, sso_rw = parse_sso_pair(raw_token)
|
| 99 |
+
if not sso:
|
| 100 |
+
return False, "parse", "missing sso"
|
| 101 |
+
if not sso_rw:
|
| 102 |
+
sso_rw = sso
|
| 103 |
+
|
| 104 |
+
user_service = UserAgreementService(cf_clearance=self.cf_clearance)
|
| 105 |
+
birth_service = BirthDateService(cf_clearance=self.cf_clearance)
|
| 106 |
+
nsfw_service = NsfwSettingsService(cf_clearance=self.cf_clearance)
|
| 107 |
+
|
| 108 |
+
tos_result = user_service.accept_tos_version(
|
| 109 |
+
sso=sso,
|
| 110 |
+
sso_rw=sso_rw,
|
| 111 |
+
impersonate=DEFAULT_IMPERSONATE,
|
| 112 |
+
)
|
| 113 |
+
if not tos_result.get("ok"):
|
| 114 |
+
return False, "tos", _format_step_error(tos_result, "accept_tos failed")
|
| 115 |
+
|
| 116 |
+
birth_result = birth_service.set_birth_date(
|
| 117 |
+
sso=sso,
|
| 118 |
+
sso_rw=sso_rw,
|
| 119 |
+
impersonate=DEFAULT_IMPERSONATE,
|
| 120 |
+
)
|
| 121 |
+
if not birth_result.get("ok"):
|
| 122 |
+
return False, "birth", _format_step_error(birth_result, "set_birth_date failed")
|
| 123 |
+
|
| 124 |
+
nsfw_result = nsfw_service.enable_nsfw(
|
| 125 |
+
sso=sso,
|
| 126 |
+
sso_rw=sso_rw,
|
| 127 |
+
impersonate=DEFAULT_IMPERSONATE,
|
| 128 |
+
)
|
| 129 |
+
if not nsfw_result.get("ok"):
|
| 130 |
+
return False, "nsfw", _format_step_error(nsfw_result, "enable_nsfw failed")
|
| 131 |
+
|
| 132 |
+
return True, "", ""
|
| 133 |
+
|
| 134 |
+
async def refresh_tokens(
|
| 135 |
+
self,
|
| 136 |
+
tokens: Iterable[str],
|
| 137 |
+
concurrency: int = DEFAULT_NSFW_REFRESH_CONCURRENCY,
|
| 138 |
+
retries: int = DEFAULT_NSFW_REFRESH_RETRIES,
|
| 139 |
+
) -> dict[str, Any]:
|
| 140 |
+
resolved_concurrency = _coerce_concurrency(concurrency)
|
| 141 |
+
resolved_retries = _coerce_retries(retries)
|
| 142 |
+
|
| 143 |
+
unique_tokens: list[str] = []
|
| 144 |
+
seen: set[str] = set()
|
| 145 |
+
for token in tokens:
|
| 146 |
+
normalized = normalize_sso_token(str(token or "").strip())
|
| 147 |
+
if not normalized or normalized in seen:
|
| 148 |
+
continue
|
| 149 |
+
seen.add(normalized)
|
| 150 |
+
unique_tokens.append(normalized)
|
| 151 |
+
|
| 152 |
+
if not unique_tokens:
|
| 153 |
+
return {
|
| 154 |
+
"summary": {"total": 0, "success": 0, "failed": 0, "invalidated": 0},
|
| 155 |
+
"failed": [],
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
semaphore = asyncio.Semaphore(resolved_concurrency)
|
| 159 |
+
|
| 160 |
+
async def _run_one(token: str) -> dict[str, Any]:
|
| 161 |
+
max_attempts = resolved_retries + 1
|
| 162 |
+
last_step = "unknown"
|
| 163 |
+
last_error = "unknown error"
|
| 164 |
+
|
| 165 |
+
async with semaphore:
|
| 166 |
+
for attempt in range(1, max_attempts + 1):
|
| 167 |
+
try:
|
| 168 |
+
ok, step, error = await asyncio.to_thread(self._apply_once, token)
|
| 169 |
+
except Exception as exc:
|
| 170 |
+
ok, step, error = False, "exception", str(exc)
|
| 171 |
+
|
| 172 |
+
if ok:
|
| 173 |
+
updated = await self.token_manager.mark_token_account_settings_success(
|
| 174 |
+
token,
|
| 175 |
+
save=False,
|
| 176 |
+
)
|
| 177 |
+
if not updated:
|
| 178 |
+
logger.warning(
|
| 179 |
+
"Account settings refresh succeeded but token not found: {}...",
|
| 180 |
+
token[:10],
|
| 181 |
+
)
|
| 182 |
+
return {
|
| 183 |
+
"token": token,
|
| 184 |
+
"ok": True,
|
| 185 |
+
"attempts": attempt,
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
last_step = step or "unknown"
|
| 189 |
+
last_error = error or "unknown error"
|
| 190 |
+
|
| 191 |
+
reason = (
|
| 192 |
+
f"account_settings_refresh_failed step={last_step} "
|
| 193 |
+
f"attempts={max_attempts} error={last_error}"
|
| 194 |
+
)
|
| 195 |
+
invalidated = await self.token_manager.set_token_invalid(
|
| 196 |
+
token,
|
| 197 |
+
reason=reason,
|
| 198 |
+
save=False,
|
| 199 |
+
)
|
| 200 |
+
return {
|
| 201 |
+
"token": token,
|
| 202 |
+
"ok": False,
|
| 203 |
+
"attempts": max_attempts,
|
| 204 |
+
"step": last_step,
|
| 205 |
+
"error": last_error,
|
| 206 |
+
"invalidated": bool(invalidated),
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
results = await asyncio.gather(*[_run_one(token) for token in unique_tokens])
|
| 210 |
+
|
| 211 |
+
try:
|
| 212 |
+
await self.token_manager.commit()
|
| 213 |
+
except Exception as exc:
|
| 214 |
+
logger.warning("Account settings refresh commit failed: {}", exc)
|
| 215 |
+
|
| 216 |
+
success = sum(1 for item in results if item.get("ok"))
|
| 217 |
+
failed_items = [item for item in results if not item.get("ok")]
|
| 218 |
+
invalidated = sum(1 for item in failed_items if item.get("invalidated"))
|
| 219 |
+
|
| 220 |
+
summary = {
|
| 221 |
+
"total": len(unique_tokens),
|
| 222 |
+
"success": success,
|
| 223 |
+
"failed": len(failed_items),
|
| 224 |
+
"invalidated": invalidated,
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
return {"summary": summary, "failed": failed_items}
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
async def refresh_account_settings_for_tokens(
|
| 231 |
+
tokens: Iterable[str],
|
| 232 |
+
concurrency: int | None = None,
|
| 233 |
+
retries: int | None = None,
|
| 234 |
+
) -> dict[str, Any]:
|
| 235 |
+
resolved_concurrency = _coerce_concurrency(
|
| 236 |
+
concurrency if concurrency is not None else get_config(
|
| 237 |
+
"token.nsfw_refresh_concurrency",
|
| 238 |
+
DEFAULT_NSFW_REFRESH_CONCURRENCY,
|
| 239 |
+
),
|
| 240 |
+
default=DEFAULT_NSFW_REFRESH_CONCURRENCY,
|
| 241 |
+
)
|
| 242 |
+
resolved_retries = _coerce_retries(
|
| 243 |
+
retries if retries is not None else get_config(
|
| 244 |
+
"token.nsfw_refresh_retries",
|
| 245 |
+
DEFAULT_NSFW_REFRESH_RETRIES,
|
| 246 |
+
),
|
| 247 |
+
default=DEFAULT_NSFW_REFRESH_RETRIES,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
token_manager = await get_token_manager()
|
| 251 |
+
cf_clearance = str(get_config("grok.cf_clearance", "") or "").strip()
|
| 252 |
+
service = AccountSettingsRefreshService(token_manager, cf_clearance=cf_clearance)
|
| 253 |
+
return await service.refresh_tokens(
|
| 254 |
+
tokens=tokens,
|
| 255 |
+
concurrency=resolved_concurrency,
|
| 256 |
+
retries=resolved_retries,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
__all__ = [
|
| 261 |
+
"AccountSettingsRefreshService",
|
| 262 |
+
"parse_sso_pair",
|
| 263 |
+
"normalize_sso_token",
|
| 264 |
+
"refresh_account_settings_for_tokens",
|
| 265 |
+
"DEFAULT_NSFW_REFRESH_CONCURRENCY",
|
| 266 |
+
"DEFAULT_NSFW_REFRESH_RETRIES",
|
| 267 |
+
]
|
app/services/register/manager.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Auto registration manager."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import asyncio
|
| 5 |
+
import queue
|
| 6 |
+
import threading
|
| 7 |
+
import time
|
| 8 |
+
import uuid
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from typing import Dict, List, Optional
|
| 11 |
+
|
| 12 |
+
from app.core.config import get_config
|
| 13 |
+
from app.core.logger import logger
|
| 14 |
+
from app.services.token.manager import get_token_manager
|
| 15 |
+
from app.services.register.runner import RegisterRunner
|
| 16 |
+
from app.services.register.solver import SolverConfig, TurnstileSolverProcess
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class RegisterJob:
|
| 21 |
+
job_id: str
|
| 22 |
+
total: int
|
| 23 |
+
pool: str
|
| 24 |
+
register_threads: int = 10
|
| 25 |
+
status: str = "starting"
|
| 26 |
+
started_at: float = field(default_factory=time.time)
|
| 27 |
+
finished_at: Optional[float] = None
|
| 28 |
+
completed: int = 0
|
| 29 |
+
added: int = 0
|
| 30 |
+
errors: int = 0
|
| 31 |
+
error: Optional[str] = None
|
| 32 |
+
last_error: Optional[str] = None
|
| 33 |
+
tokens: List[str] = field(default_factory=list)
|
| 34 |
+
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
|
| 35 |
+
stop_event: threading.Event = field(default_factory=threading.Event, repr=False)
|
| 36 |
+
|
| 37 |
+
def record_success(self, token: str) -> None:
|
| 38 |
+
with self._lock:
|
| 39 |
+
self.completed += 1
|
| 40 |
+
self.tokens.append(token)
|
| 41 |
+
|
| 42 |
+
def record_added(self) -> None:
|
| 43 |
+
with self._lock:
|
| 44 |
+
self.added += 1
|
| 45 |
+
|
| 46 |
+
def record_error(self, message: str) -> None:
|
| 47 |
+
message = (message or "").strip()
|
| 48 |
+
if len(message) > 500:
|
| 49 |
+
message = message[:500] + "..."
|
| 50 |
+
with self._lock:
|
| 51 |
+
self.errors += 1
|
| 52 |
+
if message:
|
| 53 |
+
self.last_error = message
|
| 54 |
+
|
| 55 |
+
def to_dict(self) -> Dict[str, object]:
|
| 56 |
+
with self._lock:
|
| 57 |
+
return {
|
| 58 |
+
"job_id": self.job_id,
|
| 59 |
+
"status": self.status,
|
| 60 |
+
"pool": self.pool,
|
| 61 |
+
"total": self.total,
|
| 62 |
+
"concurrency": self.register_threads,
|
| 63 |
+
"completed": self.completed,
|
| 64 |
+
"added": self.added,
|
| 65 |
+
"errors": self.errors,
|
| 66 |
+
"error": self.error,
|
| 67 |
+
"last_error": self.last_error,
|
| 68 |
+
"started_at": int(self.started_at * 1000),
|
| 69 |
+
"finished_at": int(self.finished_at * 1000) if self.finished_at else None,
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class AutoRegisterManager:
|
| 74 |
+
"""Single job manager for auto registration."""
|
| 75 |
+
|
| 76 |
+
_instance: Optional["AutoRegisterManager"] = None
|
| 77 |
+
|
| 78 |
+
def __init__(self) -> None:
|
| 79 |
+
self._lock = asyncio.Lock()
|
| 80 |
+
self._job: Optional[RegisterJob] = None
|
| 81 |
+
self._task: Optional[asyncio.Task] = None
|
| 82 |
+
self._solver: Optional[TurnstileSolverProcess] = None
|
| 83 |
+
|
| 84 |
+
async def start_job(
|
| 85 |
+
self,
|
| 86 |
+
count: int,
|
| 87 |
+
pool: str,
|
| 88 |
+
concurrency: Optional[int] = None,
|
| 89 |
+
) -> RegisterJob:
|
| 90 |
+
async with self._lock:
|
| 91 |
+
if self._job and self._job.status in {"starting", "running", "stopping"}:
|
| 92 |
+
raise RuntimeError("Auto registration already running")
|
| 93 |
+
|
| 94 |
+
default_threads = get_config("register.register_threads", 10)
|
| 95 |
+
try:
|
| 96 |
+
default_threads = max(1, int(default_threads))
|
| 97 |
+
except Exception:
|
| 98 |
+
default_threads = 10
|
| 99 |
+
|
| 100 |
+
threads = concurrency if isinstance(concurrency, int) and concurrency > 0 else default_threads
|
| 101 |
+
|
| 102 |
+
job = RegisterJob(
|
| 103 |
+
job_id=uuid.uuid4().hex[:8],
|
| 104 |
+
total=count,
|
| 105 |
+
pool=pool,
|
| 106 |
+
register_threads=threads,
|
| 107 |
+
)
|
| 108 |
+
self._job = job
|
| 109 |
+
self._task = asyncio.create_task(self._run_job(job))
|
| 110 |
+
return job
|
| 111 |
+
|
| 112 |
+
def get_status(self, job_id: Optional[str] = None) -> Dict[str, object]:
|
| 113 |
+
if not self._job:
|
| 114 |
+
return {"status": "idle"}
|
| 115 |
+
if job_id and self._job.job_id != job_id:
|
| 116 |
+
return {"status": "not_found"}
|
| 117 |
+
return self._job.to_dict()
|
| 118 |
+
|
| 119 |
+
async def stop_job(self) -> None:
|
| 120 |
+
"""Best-effort stop for the current job (used on shutdown)."""
|
| 121 |
+
async with self._lock:
|
| 122 |
+
job = self._job
|
| 123 |
+
task = self._task
|
| 124 |
+
solver = self._solver
|
| 125 |
+
|
| 126 |
+
if not job or job.status not in {"starting", "running"}:
|
| 127 |
+
return
|
| 128 |
+
job.status = "stopping"
|
| 129 |
+
job.stop_event.set()
|
| 130 |
+
|
| 131 |
+
# Stop solver first to avoid noisy retries.
|
| 132 |
+
if solver:
|
| 133 |
+
try:
|
| 134 |
+
await asyncio.to_thread(solver.stop)
|
| 135 |
+
except Exception:
|
| 136 |
+
pass
|
| 137 |
+
|
| 138 |
+
# Give the runner a short grace period to exit.
|
| 139 |
+
if task:
|
| 140 |
+
try:
|
| 141 |
+
await asyncio.wait_for(task, timeout=5.0)
|
| 142 |
+
except Exception:
|
| 143 |
+
# Don't block shutdown; the process is exiting anyway.
|
| 144 |
+
pass
|
| 145 |
+
|
| 146 |
+
async def _run_job(self, job: RegisterJob) -> None:
|
| 147 |
+
job.status = "starting"
|
| 148 |
+
|
| 149 |
+
solver_url = get_config("register.solver_url", "http://127.0.0.1:5072")
|
| 150 |
+
solver_threads = get_config("register.solver_threads", 5)
|
| 151 |
+
try:
|
| 152 |
+
solver_threads = max(1, int(solver_threads))
|
| 153 |
+
except Exception:
|
| 154 |
+
solver_threads = 5
|
| 155 |
+
|
| 156 |
+
auto_start_solver = get_config("register.auto_start_solver", True)
|
| 157 |
+
if not isinstance(auto_start_solver, bool):
|
| 158 |
+
auto_start_solver = str(auto_start_solver).lower() in {"1", "true", "yes", "on"}
|
| 159 |
+
|
| 160 |
+
# Auto-start only for local solver endpoints.
|
| 161 |
+
try:
|
| 162 |
+
from urllib.parse import urlparse
|
| 163 |
+
|
| 164 |
+
host = urlparse(str(solver_url)).hostname or ""
|
| 165 |
+
if host and host not in {"127.0.0.1", "localhost", "::1", "0.0.0.0"}:
|
| 166 |
+
auto_start_solver = False
|
| 167 |
+
except Exception:
|
| 168 |
+
pass
|
| 169 |
+
|
| 170 |
+
solver_debug = get_config("register.solver_debug", False)
|
| 171 |
+
if not isinstance(solver_debug, bool):
|
| 172 |
+
solver_debug = str(solver_debug).lower() in {"1", "true", "yes", "on"}
|
| 173 |
+
|
| 174 |
+
browser_type = str(get_config("register.solver_browser_type", "chromium") or "chromium").strip().lower()
|
| 175 |
+
if browser_type not in {"chromium", "chrome", "msedge", "camoufox"}:
|
| 176 |
+
browser_type = "chromium"
|
| 177 |
+
|
| 178 |
+
solver_cfg = SolverConfig(
|
| 179 |
+
url=str(solver_url or "http://127.0.0.1:5072"),
|
| 180 |
+
threads=solver_threads,
|
| 181 |
+
browser_type=browser_type,
|
| 182 |
+
debug=solver_debug,
|
| 183 |
+
auto_start=auto_start_solver,
|
| 184 |
+
)
|
| 185 |
+
solver = TurnstileSolverProcess(solver_cfg)
|
| 186 |
+
self._solver = solver
|
| 187 |
+
|
| 188 |
+
use_yescaptcha = bool(str(get_config("register.yescaptcha_key", "") or "").strip())
|
| 189 |
+
if use_yescaptcha:
|
| 190 |
+
# When YesCaptcha is configured we don't need a local solver process.
|
| 191 |
+
auto_start_solver = False
|
| 192 |
+
solver.config.auto_start = False
|
| 193 |
+
|
| 194 |
+
# Safety limits to avoid endless loops when upstream is broken.
|
| 195 |
+
max_errors = get_config("register.max_errors", 0)
|
| 196 |
+
try:
|
| 197 |
+
max_errors = int(max_errors)
|
| 198 |
+
except Exception:
|
| 199 |
+
max_errors = 0
|
| 200 |
+
if max_errors <= 0:
|
| 201 |
+
# Default: allow retries, but stop instead of looping "forever".
|
| 202 |
+
max_errors = max(30, int(job.total) * 5)
|
| 203 |
+
|
| 204 |
+
max_runtime_minutes = get_config("register.max_runtime_minutes", 0)
|
| 205 |
+
try:
|
| 206 |
+
max_runtime_minutes = float(max_runtime_minutes)
|
| 207 |
+
except Exception:
|
| 208 |
+
max_runtime_minutes = 0
|
| 209 |
+
max_runtime_sec = max_runtime_minutes * 60 if max_runtime_minutes and max_runtime_minutes > 0 else 0
|
| 210 |
+
|
| 211 |
+
token_queue: queue.Queue[object] = queue.Queue()
|
| 212 |
+
sentinel = object()
|
| 213 |
+
|
| 214 |
+
async def _consume_tokens() -> None:
|
| 215 |
+
mgr = await get_token_manager()
|
| 216 |
+
while True:
|
| 217 |
+
item = await asyncio.to_thread(token_queue.get)
|
| 218 |
+
if item is sentinel:
|
| 219 |
+
break
|
| 220 |
+
token = str(item or "").strip()
|
| 221 |
+
if not token:
|
| 222 |
+
continue
|
| 223 |
+
try:
|
| 224 |
+
if await mgr.add(token, pool_name=job.pool):
|
| 225 |
+
job.record_added()
|
| 226 |
+
except Exception as exc:
|
| 227 |
+
job.record_error(f"save token failed: {exc}")
|
| 228 |
+
|
| 229 |
+
def _on_error(msg: str) -> None:
|
| 230 |
+
job.record_error(msg)
|
| 231 |
+
# Called from worker threads; keep it simple and thread-safe.
|
| 232 |
+
with job._lock:
|
| 233 |
+
if job.status in {"starting", "running"} and job.errors >= max_errors:
|
| 234 |
+
job.status = "error"
|
| 235 |
+
job.error = f"Too many failures ({job.errors}/{max_errors}). Check register config/solver."
|
| 236 |
+
job.stop_event.set()
|
| 237 |
+
|
| 238 |
+
async def _watchdog() -> None:
|
| 239 |
+
if not max_runtime_sec:
|
| 240 |
+
return
|
| 241 |
+
while True:
|
| 242 |
+
await asyncio.sleep(1.0)
|
| 243 |
+
if job.stop_event.is_set():
|
| 244 |
+
return
|
| 245 |
+
if job.status not in {"starting", "running"}:
|
| 246 |
+
return
|
| 247 |
+
if (time.time() - job.started_at) >= max_runtime_sec:
|
| 248 |
+
with job._lock:
|
| 249 |
+
if job.status in {"starting", "running"}:
|
| 250 |
+
job.status = "error"
|
| 251 |
+
job.error = f"Timeout after {max_runtime_minutes:g} minutes."
|
| 252 |
+
job.stop_event.set()
|
| 253 |
+
return
|
| 254 |
+
|
| 255 |
+
try:
|
| 256 |
+
if auto_start_solver:
|
| 257 |
+
try:
|
| 258 |
+
await asyncio.to_thread(solver.start)
|
| 259 |
+
except Exception as exc:
|
| 260 |
+
if not use_yescaptcha:
|
| 261 |
+
raise
|
| 262 |
+
logger.warning("Solver start failed, continuing with YesCaptcha: {}", exc)
|
| 263 |
+
|
| 264 |
+
job.status = "running"
|
| 265 |
+
watchdog_task = asyncio.create_task(_watchdog())
|
| 266 |
+
consumer_task = asyncio.create_task(_consume_tokens())
|
| 267 |
+
runner = RegisterRunner(
|
| 268 |
+
target_count=job.total,
|
| 269 |
+
thread_count=job.register_threads,
|
| 270 |
+
stop_event=job.stop_event,
|
| 271 |
+
on_success=lambda _email, _password, token, _done, _total: (
|
| 272 |
+
job.record_success(token),
|
| 273 |
+
token_queue.put(token),
|
| 274 |
+
),
|
| 275 |
+
on_error=_on_error,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
await asyncio.to_thread(runner.run)
|
| 279 |
+
|
| 280 |
+
# Drain token consumer.
|
| 281 |
+
token_queue.put(sentinel)
|
| 282 |
+
await consumer_task
|
| 283 |
+
if job.status == "stopping":
|
| 284 |
+
job.status = "stopped"
|
| 285 |
+
elif job.status != "error":
|
| 286 |
+
# If we returned without reaching the target, treat it as a failure.
|
| 287 |
+
# This makes issues like "TOS/BirthDate/NSFW not enabled" visible to the UI as a failed job.
|
| 288 |
+
if job.completed < job.total:
|
| 289 |
+
job.status = "error"
|
| 290 |
+
suffix = f" Last error: {job.last_error}" if job.last_error else ""
|
| 291 |
+
job.error = f"Registration ended early ({job.completed}/{job.total}).{suffix}".strip()
|
| 292 |
+
else:
|
| 293 |
+
job.status = "completed"
|
| 294 |
+
except Exception as exc:
|
| 295 |
+
job.status = "error"
|
| 296 |
+
job.error = str(exc)
|
| 297 |
+
logger.exception("Auto registration failed")
|
| 298 |
+
finally:
|
| 299 |
+
job.finished_at = time.time()
|
| 300 |
+
# Ensure consumer exits even on exceptions.
|
| 301 |
+
try:
|
| 302 |
+
token_queue.put(sentinel)
|
| 303 |
+
except Exception:
|
| 304 |
+
pass
|
| 305 |
+
try:
|
| 306 |
+
if "consumer_task" in locals():
|
| 307 |
+
await asyncio.wait_for(consumer_task, timeout=10)
|
| 308 |
+
except Exception:
|
| 309 |
+
try:
|
| 310 |
+
consumer_task.cancel()
|
| 311 |
+
except Exception:
|
| 312 |
+
pass
|
| 313 |
+
try:
|
| 314 |
+
if "watchdog_task" in locals():
|
| 315 |
+
watchdog_task.cancel()
|
| 316 |
+
except Exception:
|
| 317 |
+
pass
|
| 318 |
+
self._solver = None
|
| 319 |
+
if auto_start_solver:
|
| 320 |
+
try:
|
| 321 |
+
await asyncio.to_thread(solver.stop)
|
| 322 |
+
except Exception:
|
| 323 |
+
pass
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def get_auto_register_manager() -> AutoRegisterManager:
|
| 327 |
+
if AutoRegisterManager._instance is None:
|
| 328 |
+
AutoRegisterManager._instance = AutoRegisterManager()
|
| 329 |
+
return AutoRegisterManager._instance
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
__all__ = ["AutoRegisterManager", "get_auto_register_manager"]
|
app/services/register/runner.py
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Grok account registration runner."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import concurrent.futures
|
| 5 |
+
import random
|
| 6 |
+
import re
|
| 7 |
+
import string
|
| 8 |
+
import struct
|
| 9 |
+
import threading
|
| 10 |
+
import time
|
| 11 |
+
from typing import Callable, Dict, List, Optional, Tuple
|
| 12 |
+
from urllib.parse import urljoin
|
| 13 |
+
|
| 14 |
+
from bs4 import BeautifulSoup
|
| 15 |
+
from curl_cffi import requests as curl_requests
|
| 16 |
+
|
| 17 |
+
from app.core.logger import logger
|
| 18 |
+
from app.services.register.services import (
|
| 19 |
+
EmailService,
|
| 20 |
+
TurnstileService,
|
| 21 |
+
UserAgreementService,
|
| 22 |
+
BirthDateService,
|
| 23 |
+
NsfwSettingsService,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
SITE_URL = "https://accounts.x.ai"
|
| 28 |
+
DEFAULT_IMPERSONATE = "chrome120"
|
| 29 |
+
|
| 30 |
+
CHROME_PROFILES = [
|
| 31 |
+
{"impersonate": "chrome110", "version": "110.0.0.0", "brand": "chrome"},
|
| 32 |
+
{"impersonate": "chrome119", "version": "119.0.0.0", "brand": "chrome"},
|
| 33 |
+
{"impersonate": "chrome120", "version": "120.0.0.0", "brand": "chrome"},
|
| 34 |
+
{"impersonate": "edge99", "version": "99.0.1150.36", "brand": "edge"},
|
| 35 |
+
{"impersonate": "edge101", "version": "101.0.1210.47", "brand": "edge"},
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _random_chrome_profile() -> Tuple[str, str]:
|
| 40 |
+
profile = random.choice(CHROME_PROFILES)
|
| 41 |
+
if profile.get("brand") == "edge":
|
| 42 |
+
chrome_major = profile["version"].split(".")[0]
|
| 43 |
+
chrome_version = f"{chrome_major}.0.0.0"
|
| 44 |
+
ua = (
|
| 45 |
+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
| 46 |
+
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
| 47 |
+
f"Chrome/{chrome_version} Safari/537.36 Edg/{profile['version']}"
|
| 48 |
+
)
|
| 49 |
+
else:
|
| 50 |
+
ua = (
|
| 51 |
+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
| 52 |
+
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
| 53 |
+
f"Chrome/{profile['version']} Safari/537.36"
|
| 54 |
+
)
|
| 55 |
+
return profile["impersonate"], ua
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _generate_random_name() -> str:
|
| 59 |
+
length = random.randint(4, 6)
|
| 60 |
+
return random.choice(string.ascii_uppercase) + "".join(
|
| 61 |
+
random.choice(string.ascii_lowercase) for _ in range(length - 1)
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _generate_random_string(length: int = 15) -> str:
|
| 66 |
+
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(length))
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _encode_grpc_message(field_id: int, string_value: str) -> bytes:
|
| 70 |
+
key = (field_id << 3) | 2
|
| 71 |
+
value_bytes = string_value.encode("utf-8")
|
| 72 |
+
payload = struct.pack("B", key) + struct.pack("B", len(value_bytes)) + value_bytes
|
| 73 |
+
return b"\x00" + struct.pack(">I", len(payload)) + payload
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _encode_grpc_message_verify(email: str, code: str) -> bytes:
|
| 77 |
+
p1 = struct.pack("B", (1 << 3) | 2) + struct.pack("B", len(email)) + email.encode("utf-8")
|
| 78 |
+
p2 = struct.pack("B", (2 << 3) | 2) + struct.pack("B", len(code)) + code.encode("utf-8")
|
| 79 |
+
payload = p1 + p2
|
| 80 |
+
return b"\x00" + struct.pack(">I", len(payload)) + payload
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class RegisterRunner:
|
| 84 |
+
"""Threaded registration runner."""
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
target_count: int = 100,
|
| 89 |
+
thread_count: int = 8,
|
| 90 |
+
on_success: Optional[Callable[[str, str, str, int, int], None]] = None,
|
| 91 |
+
on_error: Optional[Callable[[str], None]] = None,
|
| 92 |
+
stop_event: Optional[threading.Event] = None,
|
| 93 |
+
) -> None:
|
| 94 |
+
self.target_count = max(1, int(target_count))
|
| 95 |
+
self.thread_count = max(1, int(thread_count))
|
| 96 |
+
self.on_success = on_success
|
| 97 |
+
self.on_error = on_error
|
| 98 |
+
self.stop_event = stop_event or threading.Event()
|
| 99 |
+
|
| 100 |
+
self._post_lock = threading.Lock()
|
| 101 |
+
self._result_lock = threading.Lock()
|
| 102 |
+
|
| 103 |
+
self._success_count = 0
|
| 104 |
+
self._start_time = 0.0
|
| 105 |
+
self._tokens: List[str] = []
|
| 106 |
+
self._accounts: List[Dict[str, str]] = []
|
| 107 |
+
|
| 108 |
+
self._config: Dict[str, Optional[str]] = {
|
| 109 |
+
"site_key": "0x4AAAAAAAhr9JGVDZbrZOo0",
|
| 110 |
+
"action_id": None,
|
| 111 |
+
"state_tree": "%5B%22%22%2C%7B%22children%22%3A%5B%22(app)%22%2C%7B%22children%22%3A%5B%22(auth)%22%2C%7B%22children%22%3A%5B%22sign-up%22%2C%7B%22children%22%3A%5B%22__PAGE__%22%2C%7B%7D%2C%22%2Fsign-up%22%2C%22refresh%22%5D%7D%5D%7D%2Cnull%2Cnull%5D%7D%2Cnull%2Cnull%5D%7D%2Cnull%2Cnull%2Ctrue%5D",
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
@property
|
| 115 |
+
def success_count(self) -> int:
|
| 116 |
+
return self._success_count
|
| 117 |
+
|
| 118 |
+
@property
|
| 119 |
+
def tokens(self) -> List[str]:
|
| 120 |
+
return list(self._tokens)
|
| 121 |
+
|
| 122 |
+
@property
|
| 123 |
+
def accounts(self) -> List[Dict[str, str]]:
|
| 124 |
+
return list(self._accounts)
|
| 125 |
+
|
| 126 |
+
def _record_success(self, email: str, password: str, token: str) -> None:
|
| 127 |
+
with self._result_lock:
|
| 128 |
+
if self._success_count >= self.target_count:
|
| 129 |
+
if not self.stop_event.is_set():
|
| 130 |
+
self.stop_event.set()
|
| 131 |
+
return
|
| 132 |
+
|
| 133 |
+
self._success_count += 1
|
| 134 |
+
self._tokens.append(token)
|
| 135 |
+
self._accounts.append({"email": email, "password": password, "token": token})
|
| 136 |
+
|
| 137 |
+
avg = (time.time() - self._start_time) / max(1, self._success_count)
|
| 138 |
+
logger.info(
|
| 139 |
+
"Register success: {} | sso={}... | avg={:.1f}s ({}/{})",
|
| 140 |
+
email,
|
| 141 |
+
token[:12],
|
| 142 |
+
avg,
|
| 143 |
+
self._success_count,
|
| 144 |
+
self.target_count,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
if self.on_success:
|
| 148 |
+
try:
|
| 149 |
+
self.on_success(email, password, token, self._success_count, self.target_count)
|
| 150 |
+
except Exception:
|
| 151 |
+
pass
|
| 152 |
+
|
| 153 |
+
if self._success_count >= self.target_count and not self.stop_event.is_set():
|
| 154 |
+
self.stop_event.set()
|
| 155 |
+
|
| 156 |
+
def _record_error(self, message: str) -> None:
|
| 157 |
+
if self.on_error:
|
| 158 |
+
try:
|
| 159 |
+
self.on_error(message)
|
| 160 |
+
except Exception:
|
| 161 |
+
pass
|
| 162 |
+
|
| 163 |
+
def _init_config(self) -> None:
|
| 164 |
+
logger.info("Register: initializing action config...")
|
| 165 |
+
start_url = f"{SITE_URL}/sign-up"
|
| 166 |
+
|
| 167 |
+
with curl_requests.Session(impersonate=DEFAULT_IMPERSONATE) as session:
|
| 168 |
+
html = session.get(start_url, timeout=15).text
|
| 169 |
+
|
| 170 |
+
key_match = re.search(r'sitekey":"(0x4[a-zA-Z0-9_-]+)"', html)
|
| 171 |
+
if key_match:
|
| 172 |
+
self._config["site_key"] = key_match.group(1)
|
| 173 |
+
|
| 174 |
+
tree_match = re.search(r'next-router-state-tree":"([^"]+)"', html)
|
| 175 |
+
if tree_match:
|
| 176 |
+
self._config["state_tree"] = tree_match.group(1)
|
| 177 |
+
|
| 178 |
+
soup = BeautifulSoup(html, "html.parser")
|
| 179 |
+
js_urls = [
|
| 180 |
+
urljoin(start_url, script["src"])
|
| 181 |
+
for script in soup.find_all("script", src=True)
|
| 182 |
+
if "_next/static" in script["src"]
|
| 183 |
+
]
|
| 184 |
+
for js_url in js_urls:
|
| 185 |
+
js_content = session.get(js_url, timeout=15).text
|
| 186 |
+
match = re.search(r"7f[a-fA-F0-9]{40}", js_content)
|
| 187 |
+
if match:
|
| 188 |
+
self._config["action_id"] = match.group(0)
|
| 189 |
+
logger.info("Register: Action ID found: {}", self._config["action_id"])
|
| 190 |
+
break
|
| 191 |
+
|
| 192 |
+
if not self._config.get("action_id"):
|
| 193 |
+
raise RuntimeError("Register init failed: missing action_id")
|
| 194 |
+
|
| 195 |
+
def _send_email_code(self, session: curl_requests.Session, email: str) -> bool:
|
| 196 |
+
url = f"{SITE_URL}/auth_mgmt.AuthManagement/CreateEmailValidationCode"
|
| 197 |
+
data = _encode_grpc_message(1, email)
|
| 198 |
+
headers = {
|
| 199 |
+
"content-type": "application/grpc-web+proto",
|
| 200 |
+
"x-grpc-web": "1",
|
| 201 |
+
"x-user-agent": "connect-es/2.1.1",
|
| 202 |
+
"origin": SITE_URL,
|
| 203 |
+
"referer": f"{SITE_URL}/sign-up?redirect=grok-com",
|
| 204 |
+
}
|
| 205 |
+
try:
|
| 206 |
+
res = session.post(url, data=data, headers=headers, timeout=15)
|
| 207 |
+
return res.status_code == 200
|
| 208 |
+
except Exception as exc:
|
| 209 |
+
self._record_error(f"send code error: {email} - {exc}")
|
| 210 |
+
return False
|
| 211 |
+
|
| 212 |
+
def _verify_email_code(self, session: curl_requests.Session, email: str, code: str) -> bool:
|
| 213 |
+
url = f"{SITE_URL}/auth_mgmt.AuthManagement/VerifyEmailValidationCode"
|
| 214 |
+
data = _encode_grpc_message_verify(email, code)
|
| 215 |
+
headers = {
|
| 216 |
+
"content-type": "application/grpc-web+proto",
|
| 217 |
+
"x-grpc-web": "1",
|
| 218 |
+
"x-user-agent": "connect-es/2.1.1",
|
| 219 |
+
"origin": SITE_URL,
|
| 220 |
+
"referer": f"{SITE_URL}/sign-up?redirect=grok-com",
|
| 221 |
+
}
|
| 222 |
+
try:
|
| 223 |
+
res = session.post(url, data=data, headers=headers, timeout=15)
|
| 224 |
+
return res.status_code == 200
|
| 225 |
+
except Exception as exc:
|
| 226 |
+
self._record_error(f"verify code error: {email} - {exc}")
|
| 227 |
+
return False
|
| 228 |
+
|
| 229 |
+
def _register_single_thread(self) -> None:
|
| 230 |
+
time.sleep(random.uniform(0, 5))
|
| 231 |
+
|
| 232 |
+
try:
|
| 233 |
+
email_service = EmailService()
|
| 234 |
+
turnstile_service = TurnstileService()
|
| 235 |
+
user_agreement_service = UserAgreementService()
|
| 236 |
+
birth_date_service = BirthDateService()
|
| 237 |
+
nsfw_service = NsfwSettingsService()
|
| 238 |
+
except Exception as exc:
|
| 239 |
+
self._record_error(f"service init failed: {exc}")
|
| 240 |
+
return
|
| 241 |
+
|
| 242 |
+
final_action_id = self._config.get("action_id")
|
| 243 |
+
if not final_action_id:
|
| 244 |
+
self._record_error("missing action id")
|
| 245 |
+
return
|
| 246 |
+
|
| 247 |
+
while not self.stop_event.is_set():
|
| 248 |
+
try:
|
| 249 |
+
impersonate_fingerprint, account_user_agent = _random_chrome_profile()
|
| 250 |
+
|
| 251 |
+
with curl_requests.Session(impersonate=impersonate_fingerprint) as session:
|
| 252 |
+
try:
|
| 253 |
+
session.get(SITE_URL, timeout=10)
|
| 254 |
+
except Exception:
|
| 255 |
+
pass
|
| 256 |
+
|
| 257 |
+
password = _generate_random_string()
|
| 258 |
+
|
| 259 |
+
jwt, email = email_service.create_email()
|
| 260 |
+
if not email:
|
| 261 |
+
self._record_error("create_email failed")
|
| 262 |
+
time.sleep(5)
|
| 263 |
+
continue
|
| 264 |
+
|
| 265 |
+
if self.stop_event.is_set():
|
| 266 |
+
return
|
| 267 |
+
|
| 268 |
+
if not self._send_email_code(session, email):
|
| 269 |
+
self._record_error(f"send_email_code failed: {email}")
|
| 270 |
+
time.sleep(5)
|
| 271 |
+
continue
|
| 272 |
+
|
| 273 |
+
verify_code = None
|
| 274 |
+
for _ in range(30):
|
| 275 |
+
time.sleep(1)
|
| 276 |
+
if self.stop_event.is_set():
|
| 277 |
+
return
|
| 278 |
+
content = email_service.fetch_first_email(jwt)
|
| 279 |
+
if content:
|
| 280 |
+
match = re.search(r">([A-Z0-9]{3}-[A-Z0-9]{3})<", content)
|
| 281 |
+
if match:
|
| 282 |
+
verify_code = match.group(1).replace("-", "")
|
| 283 |
+
break
|
| 284 |
+
|
| 285 |
+
if not verify_code:
|
| 286 |
+
self._record_error(f"verify_code not received: {email}")
|
| 287 |
+
time.sleep(3)
|
| 288 |
+
continue
|
| 289 |
+
|
| 290 |
+
if not self._verify_email_code(session, email, verify_code):
|
| 291 |
+
self._record_error(f"verify_email_code failed: {email}")
|
| 292 |
+
time.sleep(3)
|
| 293 |
+
continue
|
| 294 |
+
|
| 295 |
+
for _ in range(3):
|
| 296 |
+
if self.stop_event.is_set():
|
| 297 |
+
return
|
| 298 |
+
|
| 299 |
+
try:
|
| 300 |
+
task_id = turnstile_service.create_task(f"{SITE_URL}/sign-up", self._config["site_key"] or "")
|
| 301 |
+
except Exception as exc:
|
| 302 |
+
self._record_error(f"turnstile create_task failed: {exc}")
|
| 303 |
+
time.sleep(2)
|
| 304 |
+
continue
|
| 305 |
+
|
| 306 |
+
token = turnstile_service.get_response(task_id, stop_event=self.stop_event)
|
| 307 |
+
|
| 308 |
+
if not token:
|
| 309 |
+
self._record_error(f"turnstile failed: {turnstile_service.last_error or 'no token'}")
|
| 310 |
+
time.sleep(2)
|
| 311 |
+
continue
|
| 312 |
+
|
| 313 |
+
headers = {
|
| 314 |
+
"user-agent": account_user_agent,
|
| 315 |
+
"accept": "text/x-component",
|
| 316 |
+
"content-type": "text/plain;charset=UTF-8",
|
| 317 |
+
"origin": SITE_URL,
|
| 318 |
+
"referer": f"{SITE_URL}/sign-up",
|
| 319 |
+
"cookie": f"__cf_bm={session.cookies.get('__cf_bm','')}",
|
| 320 |
+
"next-router-state-tree": self._config["state_tree"] or "",
|
| 321 |
+
"next-action": final_action_id,
|
| 322 |
+
}
|
| 323 |
+
payload = [
|
| 324 |
+
{
|
| 325 |
+
"emailValidationCode": verify_code,
|
| 326 |
+
"createUserAndSessionRequest": {
|
| 327 |
+
"email": email,
|
| 328 |
+
"givenName": _generate_random_name(),
|
| 329 |
+
"familyName": _generate_random_name(),
|
| 330 |
+
"clearTextPassword": password,
|
| 331 |
+
"tosAcceptedVersion": "$undefined",
|
| 332 |
+
},
|
| 333 |
+
"turnstileToken": token,
|
| 334 |
+
"promptOnDuplicateEmail": True,
|
| 335 |
+
}
|
| 336 |
+
]
|
| 337 |
+
|
| 338 |
+
with self._post_lock:
|
| 339 |
+
res = session.post(
|
| 340 |
+
f"{SITE_URL}/sign-up",
|
| 341 |
+
json=payload,
|
| 342 |
+
headers=headers,
|
| 343 |
+
timeout=20,
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
if res.status_code != 200:
|
| 347 |
+
self._record_error(f"sign_up http {res.status_code}")
|
| 348 |
+
time.sleep(3)
|
| 349 |
+
continue
|
| 350 |
+
|
| 351 |
+
match = re.search(r'(https://[^" \s]+set-cookie\?q=[^:" \s]+)1:', res.text)
|
| 352 |
+
if not match:
|
| 353 |
+
self._record_error("sign_up missing set-cookie redirect")
|
| 354 |
+
break
|
| 355 |
+
|
| 356 |
+
verify_url = match.group(1)
|
| 357 |
+
session.get(verify_url, allow_redirects=True, timeout=15)
|
| 358 |
+
|
| 359 |
+
sso = session.cookies.get("sso")
|
| 360 |
+
sso_rw = session.cookies.get("sso-rw")
|
| 361 |
+
if not sso:
|
| 362 |
+
self._record_error("sign_up missing sso cookie")
|
| 363 |
+
break
|
| 364 |
+
|
| 365 |
+
tos_result = user_agreement_service.accept_tos_version(
|
| 366 |
+
sso=sso,
|
| 367 |
+
sso_rw=sso_rw or "",
|
| 368 |
+
impersonate=impersonate_fingerprint,
|
| 369 |
+
user_agent=account_user_agent,
|
| 370 |
+
)
|
| 371 |
+
if not tos_result.get("ok") or not tos_result.get("hex_reply"):
|
| 372 |
+
self._record_error(f"accept_tos failed: {tos_result.get('error') or 'unknown'}")
|
| 373 |
+
break
|
| 374 |
+
|
| 375 |
+
birth_result = birth_date_service.set_birth_date(
|
| 376 |
+
sso=sso,
|
| 377 |
+
sso_rw=sso_rw or "",
|
| 378 |
+
impersonate=impersonate_fingerprint,
|
| 379 |
+
user_agent=account_user_agent,
|
| 380 |
+
)
|
| 381 |
+
if not birth_result.get("ok"):
|
| 382 |
+
self._record_error(
|
| 383 |
+
f"set_birth_date failed: {birth_result.get('error') or 'unknown'}"
|
| 384 |
+
)
|
| 385 |
+
break
|
| 386 |
+
|
| 387 |
+
nsfw_result = nsfw_service.enable_nsfw(
|
| 388 |
+
sso=sso,
|
| 389 |
+
sso_rw=sso_rw or "",
|
| 390 |
+
impersonate=impersonate_fingerprint,
|
| 391 |
+
user_agent=account_user_agent,
|
| 392 |
+
)
|
| 393 |
+
if not nsfw_result.get("ok") or not nsfw_result.get("hex_reply"):
|
| 394 |
+
self._record_error(f"enable_nsfw failed: {nsfw_result.get('error') or 'unknown'}")
|
| 395 |
+
break
|
| 396 |
+
|
| 397 |
+
self._record_success(email, password, sso)
|
| 398 |
+
break
|
| 399 |
+
|
| 400 |
+
except Exception as exc:
|
| 401 |
+
self._record_error(f"thread error: {str(exc)[:80]}")
|
| 402 |
+
time.sleep(3)
|
| 403 |
+
|
| 404 |
+
def run(self) -> List[str]:
|
| 405 |
+
"""Run the registration process and return collected tokens."""
|
| 406 |
+
self._init_config()
|
| 407 |
+
self._start_time = time.time()
|
| 408 |
+
|
| 409 |
+
logger.info("Register: starting {} threads, target {}", self.thread_count, self.target_count)
|
| 410 |
+
|
| 411 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=self.thread_count) as executor:
|
| 412 |
+
futures = [executor.submit(self._register_single_thread) for _ in range(self.thread_count)]
|
| 413 |
+
concurrent.futures.wait(futures)
|
| 414 |
+
|
| 415 |
+
return list(self._tokens)
|
app/services/register/services/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Registration helper services."""
|
| 2 |
+
|
| 3 |
+
from app.services.register.services.email_service import EmailService
|
| 4 |
+
from app.services.register.services.turnstile_service import TurnstileService
|
| 5 |
+
from app.services.register.services.user_agreement_service import UserAgreementService
|
| 6 |
+
from app.services.register.services.birth_date_service import BirthDateService
|
| 7 |
+
from app.services.register.services.nsfw_service import NsfwSettingsService
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"EmailService",
|
| 11 |
+
"TurnstileService",
|
| 12 |
+
"UserAgreementService",
|
| 13 |
+
"BirthDateService",
|
| 14 |
+
"NsfwSettingsService",
|
| 15 |
+
]
|
app/services/register/services/birth_date_service.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import datetime
|
| 4 |
+
import random
|
| 5 |
+
from typing import Any, Dict, Optional
|
| 6 |
+
|
| 7 |
+
from curl_cffi import requests
|
| 8 |
+
|
| 9 |
+
DEFAULT_USER_AGENT = (
|
| 10 |
+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
| 11 |
+
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
| 12 |
+
"Chrome/120.0.0.0 Safari/537.36"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def generate_random_birthdate() -> str:
|
| 17 |
+
"""Generate a random birth date between 20 and 40 years old."""
|
| 18 |
+
today = datetime.date.today()
|
| 19 |
+
age = random.randint(20, 40)
|
| 20 |
+
birth_year = today.year - age
|
| 21 |
+
birth_month = random.randint(1, 12)
|
| 22 |
+
birth_day = random.randint(1, 28)
|
| 23 |
+
return f"{birth_year}-{birth_month:02d}-{birth_day:02d}T16:00:00.000Z"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class BirthDateService:
|
| 27 |
+
"""Set account birth date via Grok REST API."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, cf_clearance: str = ""):
|
| 30 |
+
self.cf_clearance = (cf_clearance or "").strip()
|
| 31 |
+
|
| 32 |
+
def set_birth_date(
|
| 33 |
+
self,
|
| 34 |
+
sso: str,
|
| 35 |
+
sso_rw: str,
|
| 36 |
+
impersonate: str,
|
| 37 |
+
user_agent: Optional[str] = None,
|
| 38 |
+
cf_clearance: Optional[str] = None,
|
| 39 |
+
timeout: int = 15,
|
| 40 |
+
) -> Dict[str, Any]:
|
| 41 |
+
if not sso:
|
| 42 |
+
return {
|
| 43 |
+
"ok": False,
|
| 44 |
+
"status_code": None,
|
| 45 |
+
"response_text": "",
|
| 46 |
+
"error": "missing sso",
|
| 47 |
+
}
|
| 48 |
+
if not sso_rw:
|
| 49 |
+
return {
|
| 50 |
+
"ok": False,
|
| 51 |
+
"status_code": None,
|
| 52 |
+
"response_text": "",
|
| 53 |
+
"error": "missing sso-rw",
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
url = "https://grok.com/rest/auth/set-birth-date"
|
| 57 |
+
cookies = {
|
| 58 |
+
"sso": sso,
|
| 59 |
+
"sso-rw": sso_rw,
|
| 60 |
+
}
|
| 61 |
+
clearance = (cf_clearance if cf_clearance is not None else self.cf_clearance).strip()
|
| 62 |
+
if clearance:
|
| 63 |
+
cookies["cf_clearance"] = clearance
|
| 64 |
+
|
| 65 |
+
headers = {
|
| 66 |
+
"content-type": "application/json",
|
| 67 |
+
"origin": "https://grok.com",
|
| 68 |
+
"referer": "https://grok.com/",
|
| 69 |
+
"user-agent": user_agent or DEFAULT_USER_AGENT,
|
| 70 |
+
}
|
| 71 |
+
payload = {"birthDate": generate_random_birthdate()}
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
response = requests.post(
|
| 75 |
+
url,
|
| 76 |
+
headers=headers,
|
| 77 |
+
cookies=cookies,
|
| 78 |
+
json=payload,
|
| 79 |
+
impersonate=impersonate or "chrome120",
|
| 80 |
+
timeout=timeout,
|
| 81 |
+
)
|
| 82 |
+
status_code = response.status_code
|
| 83 |
+
response_text = response.text or ""
|
| 84 |
+
ok = status_code == 200
|
| 85 |
+
return {
|
| 86 |
+
"ok": ok,
|
| 87 |
+
"status_code": status_code,
|
| 88 |
+
"response_text": response_text,
|
| 89 |
+
"error": None if ok else f"HTTP {status_code}",
|
| 90 |
+
}
|
| 91 |
+
except Exception as e:
|
| 92 |
+
return {
|
| 93 |
+
"ok": False,
|
| 94 |
+
"status_code": None,
|
| 95 |
+
"response_text": "",
|
| 96 |
+
"error": str(e),
|
| 97 |
+
}
|
app/services/register/services/email_service.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Email service for temporary inbox creation."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
import string
|
| 7 |
+
from typing import Tuple, Optional
|
| 8 |
+
|
| 9 |
+
import requests
|
| 10 |
+
|
| 11 |
+
from app.core.config import get_config
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class EmailService:
|
| 15 |
+
"""Email service wrapper."""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
worker_domain: Optional[str] = None,
|
| 20 |
+
email_domain: Optional[str] = None,
|
| 21 |
+
admin_password: Optional[str] = None,
|
| 22 |
+
) -> None:
|
| 23 |
+
self.worker_domain = (
|
| 24 |
+
(worker_domain or get_config("register.worker_domain", "") or os.getenv("WORKER_DOMAIN", "")).strip()
|
| 25 |
+
)
|
| 26 |
+
self.email_domain = (
|
| 27 |
+
(email_domain or get_config("register.email_domain", "") or os.getenv("EMAIL_DOMAIN", "")).strip()
|
| 28 |
+
)
|
| 29 |
+
self.admin_password = (
|
| 30 |
+
(admin_password or get_config("register.admin_password", "") or os.getenv("ADMIN_PASSWORD", "")).strip()
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
if not all([self.worker_domain, self.email_domain, self.admin_password]):
|
| 34 |
+
raise ValueError(
|
| 35 |
+
"Missing required email settings: register.worker_domain, register.email_domain, "
|
| 36 |
+
"register.admin_password"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def _generate_random_name(self) -> str:
|
| 40 |
+
letters1 = "".join(random.choices(string.ascii_lowercase, k=random.randint(4, 6)))
|
| 41 |
+
numbers = "".join(random.choices(string.digits, k=random.randint(1, 3)))
|
| 42 |
+
letters2 = "".join(random.choices(string.ascii_lowercase, k=random.randint(0, 5)))
|
| 43 |
+
return letters1 + numbers + letters2
|
| 44 |
+
|
| 45 |
+
def create_email(self) -> Tuple[Optional[str], Optional[str]]:
|
| 46 |
+
"""Create a temporary mailbox. Returns (jwt, address)."""
|
| 47 |
+
url = f"https://{self.worker_domain}/admin/new_address"
|
| 48 |
+
try:
|
| 49 |
+
random_name = self._generate_random_name()
|
| 50 |
+
res = requests.post(
|
| 51 |
+
url,
|
| 52 |
+
json={
|
| 53 |
+
"enablePrefix": True,
|
| 54 |
+
"name": random_name,
|
| 55 |
+
"domain": self.email_domain,
|
| 56 |
+
},
|
| 57 |
+
headers={
|
| 58 |
+
"x-admin-auth": self.admin_password,
|
| 59 |
+
"Content-Type": "application/json",
|
| 60 |
+
},
|
| 61 |
+
timeout=10,
|
| 62 |
+
)
|
| 63 |
+
if res.status_code == 200:
|
| 64 |
+
data = res.json()
|
| 65 |
+
return data.get("jwt"), data.get("address")
|
| 66 |
+
print(f"[-] Email create failed: {res.status_code} - {res.text}")
|
| 67 |
+
except Exception as exc: # pragma: no cover - network/remote errors
|
| 68 |
+
print(f"[-] Email create error ({url}): {exc}")
|
| 69 |
+
return None, None
|
| 70 |
+
|
| 71 |
+
def fetch_first_email(self, jwt: str) -> Optional[str]:
|
| 72 |
+
"""Fetch the first email content for the mailbox."""
|
| 73 |
+
try:
|
| 74 |
+
res = requests.get(
|
| 75 |
+
f"https://{self.worker_domain}/api/mails",
|
| 76 |
+
params={"limit": 10, "offset": 0},
|
| 77 |
+
headers={
|
| 78 |
+
"Authorization": f"Bearer {jwt}",
|
| 79 |
+
"Content-Type": "application/json",
|
| 80 |
+
},
|
| 81 |
+
timeout=10,
|
| 82 |
+
)
|
| 83 |
+
if res.status_code == 200:
|
| 84 |
+
data = res.json()
|
| 85 |
+
if data.get("results"):
|
| 86 |
+
return data["results"][0].get("raw")
|
| 87 |
+
return None
|
| 88 |
+
except Exception as exc: # pragma: no cover - network/remote errors
|
| 89 |
+
print(f"Email fetch failed: {exc}")
|
| 90 |
+
return None
|
app/services/register/services/nsfw_service.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Dict, Any
|
| 4 |
+
|
| 5 |
+
from curl_cffi import requests
|
| 6 |
+
|
| 7 |
+
DEFAULT_USER_AGENT = (
|
| 8 |
+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
| 9 |
+
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
| 10 |
+
"Chrome/120.0.0.0 Safari/537.36"
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class NsfwSettingsService:
|
| 15 |
+
"""开启 NSFW 相关设置(线程安全,无全局状态)。"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, cf_clearance: str = ""):
|
| 18 |
+
self.cf_clearance = (cf_clearance or "").strip()
|
| 19 |
+
|
| 20 |
+
def enable_nsfw(
|
| 21 |
+
self,
|
| 22 |
+
sso: str,
|
| 23 |
+
sso_rw: str,
|
| 24 |
+
impersonate: str,
|
| 25 |
+
user_agent: Optional[str] = None,
|
| 26 |
+
cf_clearance: Optional[str] = None,
|
| 27 |
+
timeout: int = 15,
|
| 28 |
+
) -> Dict[str, Any]:
|
| 29 |
+
"""
|
| 30 |
+
启用 always_show_nsfw_content。
|
| 31 |
+
返回: {
|
| 32 |
+
ok: bool,
|
| 33 |
+
hex_reply: str,
|
| 34 |
+
status_code: int | None,
|
| 35 |
+
grpc_status: str | None,
|
| 36 |
+
error: str | None
|
| 37 |
+
}
|
| 38 |
+
"""
|
| 39 |
+
if not sso:
|
| 40 |
+
return {
|
| 41 |
+
"ok": False,
|
| 42 |
+
"hex_reply": "",
|
| 43 |
+
"status_code": None,
|
| 44 |
+
"grpc_status": None,
|
| 45 |
+
"error": "缺少 sso",
|
| 46 |
+
}
|
| 47 |
+
if not sso_rw:
|
| 48 |
+
return {
|
| 49 |
+
"ok": False,
|
| 50 |
+
"hex_reply": "",
|
| 51 |
+
"status_code": None,
|
| 52 |
+
"grpc_status": None,
|
| 53 |
+
"error": "缺少 sso-rw",
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
url = "https://grok.com/auth_mgmt.AuthManagement/UpdateUserFeatureControls"
|
| 57 |
+
|
| 58 |
+
cookies = {
|
| 59 |
+
"sso": sso,
|
| 60 |
+
"sso-rw": sso_rw,
|
| 61 |
+
}
|
| 62 |
+
clearance = (cf_clearance if cf_clearance is not None else self.cf_clearance).strip()
|
| 63 |
+
if clearance:
|
| 64 |
+
cookies["cf_clearance"] = clearance
|
| 65 |
+
|
| 66 |
+
headers = {
|
| 67 |
+
"content-type": "application/grpc-web+proto",
|
| 68 |
+
"origin": "https://grok.com",
|
| 69 |
+
"referer": "https://grok.com/?_s=data",
|
| 70 |
+
"x-grpc-web": "1",
|
| 71 |
+
"user-agent": user_agent or DEFAULT_USER_AGENT,
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
data = (
|
| 75 |
+
b"\x00\x00\x00\x00"
|
| 76 |
+
b"\x20"
|
| 77 |
+
b"\x0a\x02\x10\x01"
|
| 78 |
+
b"\x12\x1a"
|
| 79 |
+
b"\x0a\x18"
|
| 80 |
+
b"always_show_nsfw_content"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
response = requests.post(
|
| 85 |
+
url,
|
| 86 |
+
headers=headers,
|
| 87 |
+
cookies=cookies,
|
| 88 |
+
data=data,
|
| 89 |
+
impersonate=impersonate or "chrome120",
|
| 90 |
+
timeout=timeout,
|
| 91 |
+
)
|
| 92 |
+
hex_reply = response.content.hex()
|
| 93 |
+
grpc_status = response.headers.get("grpc-status")
|
| 94 |
+
|
| 95 |
+
error = None
|
| 96 |
+
ok = response.status_code == 200 and (grpc_status in (None, "0"))
|
| 97 |
+
if response.status_code == 403:
|
| 98 |
+
error = "403 Forbidden"
|
| 99 |
+
elif response.status_code != 200:
|
| 100 |
+
error = f"HTTP {response.status_code}"
|
| 101 |
+
elif grpc_status not in (None, "0"):
|
| 102 |
+
error = f"gRPC {grpc_status}"
|
| 103 |
+
|
| 104 |
+
return {
|
| 105 |
+
"ok": ok,
|
| 106 |
+
"hex_reply": hex_reply,
|
| 107 |
+
"status_code": response.status_code,
|
| 108 |
+
"grpc_status": grpc_status,
|
| 109 |
+
"error": error,
|
| 110 |
+
}
|
| 111 |
+
except Exception as e:
|
| 112 |
+
return {
|
| 113 |
+
"ok": False,
|
| 114 |
+
"hex_reply": "",
|
| 115 |
+
"status_code": None,
|
| 116 |
+
"grpc_status": None,
|
| 117 |
+
"error": str(e),
|
| 118 |
+
}
|
app/services/register/services/turnstile_service.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Turnstile solving service."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
from app.core.logger import logger
|
| 9 |
+
|
| 10 |
+
import requests
|
| 11 |
+
|
| 12 |
+
from app.core.config import get_config
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TurnstileService:
|
| 16 |
+
"""Turnstile solver wrapper (local solver or YesCaptcha)."""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
solver_url: Optional[str] = None,
|
| 21 |
+
yescaptcha_key: Optional[str] = None,
|
| 22 |
+
) -> None:
|
| 23 |
+
self.yescaptcha_key = (
|
| 24 |
+
(yescaptcha_key or get_config("register.yescaptcha_key", "") or os.getenv("YESCAPTCHA_KEY", "")).strip()
|
| 25 |
+
)
|
| 26 |
+
self.solver_url = (
|
| 27 |
+
solver_url
|
| 28 |
+
or get_config("register.solver_url", "")
|
| 29 |
+
or os.getenv("TURNSTILE_SOLVER_URL", "")
|
| 30 |
+
or "http://127.0.0.1:5072"
|
| 31 |
+
).strip()
|
| 32 |
+
self.yescaptcha_api = "https://api.yescaptcha.com"
|
| 33 |
+
self.last_error: Optional[str] = None
|
| 34 |
+
|
| 35 |
+
def create_task(self, siteurl: str, sitekey: str) -> str:
|
| 36 |
+
"""Create a Turnstile task and return task ID."""
|
| 37 |
+
self.last_error = None
|
| 38 |
+
if self.yescaptcha_key:
|
| 39 |
+
url = f"{self.yescaptcha_api}/createTask"
|
| 40 |
+
payload = {
|
| 41 |
+
"clientKey": self.yescaptcha_key,
|
| 42 |
+
"task": {
|
| 43 |
+
"type": "TurnstileTaskProxyless",
|
| 44 |
+
"websiteURL": siteurl,
|
| 45 |
+
"websiteKey": sitekey,
|
| 46 |
+
},
|
| 47 |
+
}
|
| 48 |
+
response = requests.post(url, json=payload, timeout=20)
|
| 49 |
+
response.raise_for_status()
|
| 50 |
+
data = response.json()
|
| 51 |
+
if data.get("errorId") != 0:
|
| 52 |
+
desc = data.get("errorDescription") or "unknown"
|
| 53 |
+
self.last_error = f"YesCaptcha createTask failed: {desc}"
|
| 54 |
+
raise RuntimeError(self.last_error)
|
| 55 |
+
return data["taskId"]
|
| 56 |
+
|
| 57 |
+
response = requests.get(
|
| 58 |
+
f"{self.solver_url}/turnstile",
|
| 59 |
+
params={"url": siteurl, "sitekey": sitekey},
|
| 60 |
+
timeout=20,
|
| 61 |
+
)
|
| 62 |
+
response.raise_for_status()
|
| 63 |
+
data = response.json()
|
| 64 |
+
task_id = data.get("taskId")
|
| 65 |
+
if not task_id:
|
| 66 |
+
self.last_error = data.get("errorDescription") or data.get("errorCode") or "missing taskId"
|
| 67 |
+
raise RuntimeError(f"Solver create task failed: {self.last_error}")
|
| 68 |
+
return task_id
|
| 69 |
+
|
| 70 |
+
def get_response(
|
| 71 |
+
self,
|
| 72 |
+
task_id: str,
|
| 73 |
+
max_retries: int = 30,
|
| 74 |
+
initial_delay: int = 5,
|
| 75 |
+
retry_delay: int = 2,
|
| 76 |
+
stop_event: object | None = None,
|
| 77 |
+
) -> Optional[str]:
|
| 78 |
+
"""Fetch a Turnstile solution token."""
|
| 79 |
+
self.last_error = None
|
| 80 |
+
# Make shutdown/cancel responsive.
|
| 81 |
+
if initial_delay > 0:
|
| 82 |
+
for _ in range(int(initial_delay * 10)):
|
| 83 |
+
if stop_event is not None and getattr(stop_event, "is_set", lambda: False)():
|
| 84 |
+
return None
|
| 85 |
+
time.sleep(0.1)
|
| 86 |
+
|
| 87 |
+
for _ in range(max_retries):
|
| 88 |
+
if stop_event is not None and getattr(stop_event, "is_set", lambda: False)():
|
| 89 |
+
return None
|
| 90 |
+
try:
|
| 91 |
+
if self.yescaptcha_key:
|
| 92 |
+
url = f"{self.yescaptcha_api}/getTaskResult"
|
| 93 |
+
payload = {"clientKey": self.yescaptcha_key, "taskId": task_id}
|
| 94 |
+
response = requests.post(url, json=payload, timeout=20)
|
| 95 |
+
response.raise_for_status()
|
| 96 |
+
data = response.json()
|
| 97 |
+
|
| 98 |
+
if data.get("errorId") != 0:
|
| 99 |
+
self.last_error = str(data.get("errorDescription") or "unknown")
|
| 100 |
+
logger.warning("YesCaptcha getTaskResult failed: {}", self.last_error)
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
status = data.get("status")
|
| 104 |
+
if status == "ready":
|
| 105 |
+
token = data.get("solution", {}).get("token")
|
| 106 |
+
if token:
|
| 107 |
+
return token
|
| 108 |
+
self.last_error = "YesCaptcha returned empty token"
|
| 109 |
+
logger.warning(self.last_error)
|
| 110 |
+
return None
|
| 111 |
+
if status == "processing":
|
| 112 |
+
if retry_delay > 0:
|
| 113 |
+
for _ in range(int(retry_delay * 10)):
|
| 114 |
+
if stop_event is not None and getattr(stop_event, "is_set", lambda: False)():
|
| 115 |
+
return None
|
| 116 |
+
time.sleep(0.1)
|
| 117 |
+
continue
|
| 118 |
+
self.last_error = f"YesCaptcha unexpected status: {status}"
|
| 119 |
+
logger.warning(self.last_error)
|
| 120 |
+
if retry_delay > 0:
|
| 121 |
+
for _ in range(int(retry_delay * 10)):
|
| 122 |
+
if stop_event is not None and getattr(stop_event, "is_set", lambda: False)():
|
| 123 |
+
return None
|
| 124 |
+
time.sleep(0.1)
|
| 125 |
+
continue
|
| 126 |
+
|
| 127 |
+
response = requests.get(
|
| 128 |
+
f"{self.solver_url}/result",
|
| 129 |
+
params={"id": task_id},
|
| 130 |
+
timeout=20,
|
| 131 |
+
)
|
| 132 |
+
response.raise_for_status()
|
| 133 |
+
data = response.json()
|
| 134 |
+
|
| 135 |
+
# Solver error -> stop early (avoid polling forever on unsolvable tasks).
|
| 136 |
+
error_id = data.get("errorId")
|
| 137 |
+
if error_id is not None and error_id != 0:
|
| 138 |
+
self.last_error = str(data.get("errorDescription") or data.get("errorCode") or "solver error")
|
| 139 |
+
return None
|
| 140 |
+
|
| 141 |
+
token = data.get("solution", {}).get("token")
|
| 142 |
+
if token:
|
| 143 |
+
if token != "CAPTCHA_FAIL":
|
| 144 |
+
return token
|
| 145 |
+
self.last_error = "CAPTCHA_FAIL"
|
| 146 |
+
return None
|
| 147 |
+
if retry_delay > 0:
|
| 148 |
+
for _ in range(int(retry_delay * 10)):
|
| 149 |
+
if stop_event is not None and getattr(stop_event, "is_set", lambda: False)():
|
| 150 |
+
return None
|
| 151 |
+
time.sleep(0.1)
|
| 152 |
+
except Exception as exc: # pragma: no cover - network/remote errors
|
| 153 |
+
self.last_error = str(exc)
|
| 154 |
+
logger.debug("Turnstile response error: {}", exc)
|
| 155 |
+
if retry_delay > 0:
|
| 156 |
+
for _ in range(int(retry_delay * 10)):
|
| 157 |
+
if stop_event is not None and getattr(stop_event, "is_set", lambda: False)():
|
| 158 |
+
return None
|
| 159 |
+
time.sleep(0.1)
|
| 160 |
+
|
| 161 |
+
return None
|
app/services/register/services/user_agreement_service.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Dict, Any
|
| 4 |
+
|
| 5 |
+
from curl_cffi import requests
|
| 6 |
+
|
| 7 |
+
DEFAULT_USER_AGENT = (
|
| 8 |
+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
| 9 |
+
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
| 10 |
+
"Chrome/120.0.0.0 Safari/537.36"
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class UserAgreementService:
|
| 15 |
+
"""处理账号协议同意流程(线程安全,无全局状态)。"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, cf_clearance: str = ""):
|
| 18 |
+
self.cf_clearance = (cf_clearance or "").strip()
|
| 19 |
+
|
| 20 |
+
def accept_tos_version(
|
| 21 |
+
self,
|
| 22 |
+
sso: str,
|
| 23 |
+
sso_rw: str,
|
| 24 |
+
impersonate: str,
|
| 25 |
+
user_agent: Optional[str] = None,
|
| 26 |
+
cf_clearance: Optional[str] = None,
|
| 27 |
+
timeout: int = 15,
|
| 28 |
+
) -> Dict[str, Any]:
|
| 29 |
+
"""
|
| 30 |
+
同意 TOS 版本。
|
| 31 |
+
返回: {
|
| 32 |
+
ok: bool,
|
| 33 |
+
hex_reply: str,
|
| 34 |
+
status_code: int | None,
|
| 35 |
+
grpc_status: str | None,
|
| 36 |
+
error: str | None
|
| 37 |
+
}
|
| 38 |
+
"""
|
| 39 |
+
if not sso:
|
| 40 |
+
return {
|
| 41 |
+
"ok": False,
|
| 42 |
+
"hex_reply": "",
|
| 43 |
+
"status_code": None,
|
| 44 |
+
"grpc_status": None,
|
| 45 |
+
"error": "缺少 sso",
|
| 46 |
+
}
|
| 47 |
+
if not sso_rw:
|
| 48 |
+
return {
|
| 49 |
+
"ok": False,
|
| 50 |
+
"hex_reply": "",
|
| 51 |
+
"status_code": None,
|
| 52 |
+
"grpc_status": None,
|
| 53 |
+
"error": "缺少 sso-rw",
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
url = "https://accounts.x.ai/auth_mgmt.AuthManagement/SetTosAcceptedVersion"
|
| 57 |
+
|
| 58 |
+
cookies = {
|
| 59 |
+
"sso": sso,
|
| 60 |
+
"sso-rw": sso_rw,
|
| 61 |
+
}
|
| 62 |
+
clearance = (cf_clearance if cf_clearance is not None else self.cf_clearance).strip()
|
| 63 |
+
if clearance:
|
| 64 |
+
cookies["cf_clearance"] = clearance
|
| 65 |
+
|
| 66 |
+
headers = {
|
| 67 |
+
"content-type": "application/grpc-web+proto",
|
| 68 |
+
"origin": "https://accounts.x.ai",
|
| 69 |
+
"referer": "https://accounts.x.ai/accept-tos",
|
| 70 |
+
"x-grpc-web": "1",
|
| 71 |
+
"user-agent": user_agent or DEFAULT_USER_AGENT,
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
data = (
|
| 75 |
+
b"\x00\x00\x00\x00" # 头部
|
| 76 |
+
b"\x02" # 长度
|
| 77 |
+
b"\x10\x01" # Field 2 = 1
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
response = requests.post(
|
| 82 |
+
url,
|
| 83 |
+
headers=headers,
|
| 84 |
+
cookies=cookies,
|
| 85 |
+
data=data,
|
| 86 |
+
impersonate=impersonate or "chrome120",
|
| 87 |
+
timeout=timeout,
|
| 88 |
+
)
|
| 89 |
+
hex_reply = response.content.hex()
|
| 90 |
+
grpc_status = response.headers.get("grpc-status")
|
| 91 |
+
|
| 92 |
+
error = None
|
| 93 |
+
ok = response.status_code == 200 and (grpc_status in (None, "0"))
|
| 94 |
+
if response.status_code == 403:
|
| 95 |
+
error = "403 Forbidden"
|
| 96 |
+
elif response.status_code != 200:
|
| 97 |
+
error = f"HTTP {response.status_code}"
|
| 98 |
+
elif grpc_status not in (None, "0"):
|
| 99 |
+
error = f"gRPC {grpc_status}"
|
| 100 |
+
|
| 101 |
+
return {
|
| 102 |
+
"ok": ok,
|
| 103 |
+
"hex_reply": hex_reply,
|
| 104 |
+
"status_code": response.status_code,
|
| 105 |
+
"grpc_status": grpc_status,
|
| 106 |
+
"error": error,
|
| 107 |
+
}
|
| 108 |
+
except Exception as e:
|
| 109 |
+
return {
|
| 110 |
+
"ok": False,
|
| 111 |
+
"hex_reply": "",
|
| 112 |
+
"status_code": None,
|
| 113 |
+
"grpc_status": None,
|
| 114 |
+
"error": str(e),
|
| 115 |
+
}
|
app/services/register/solver.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Local Turnstile solver process manager."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import socket
|
| 5 |
+
import subprocess
|
| 6 |
+
import sys
|
| 7 |
+
import time
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Optional
|
| 11 |
+
from urllib.parse import urlparse
|
| 12 |
+
|
| 13 |
+
from app.core.logger import logger
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _wait_for_port(host: str, port: int, timeout: float = 20.0) -> bool:
|
| 17 |
+
deadline = time.time() + timeout
|
| 18 |
+
while time.time() < deadline:
|
| 19 |
+
try:
|
| 20 |
+
with socket.create_connection((host, port), timeout=1):
|
| 21 |
+
return True
|
| 22 |
+
except Exception:
|
| 23 |
+
time.sleep(0.5)
|
| 24 |
+
return False
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class SolverConfig:
|
| 29 |
+
url: str
|
| 30 |
+
threads: int = 5
|
| 31 |
+
browser_type: str = "chromium"
|
| 32 |
+
debug: bool = False
|
| 33 |
+
auto_start: bool = True
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class TurnstileSolverProcess:
|
| 37 |
+
"""Start/stop a local Turnstile solver."""
|
| 38 |
+
|
| 39 |
+
def __init__(self, config: SolverConfig) -> None:
|
| 40 |
+
self.config = config
|
| 41 |
+
self._process: Optional[subprocess.Popen] = None
|
| 42 |
+
self._started_by_us = False
|
| 43 |
+
self._repo_root = Path(__file__).resolve().parents[3]
|
| 44 |
+
self._python_exe: str = sys.executable
|
| 45 |
+
self._actual_browser_type: str = config.browser_type
|
| 46 |
+
|
| 47 |
+
def _script_path(self) -> Path:
|
| 48 |
+
return self._repo_root / "scripts" / "turnstile_solver" / "api_solver.py"
|
| 49 |
+
|
| 50 |
+
def _can_import(self, python_exe: str, modules: list[str]) -> bool:
|
| 51 |
+
"""Check whether a python executable can import given modules."""
|
| 52 |
+
code = "; ".join([f"import {m}" for m in modules])
|
| 53 |
+
try:
|
| 54 |
+
subprocess.check_call(
|
| 55 |
+
[python_exe, "-c", code],
|
| 56 |
+
stdout=subprocess.DEVNULL,
|
| 57 |
+
stderr=subprocess.DEVNULL,
|
| 58 |
+
)
|
| 59 |
+
return True
|
| 60 |
+
except Exception:
|
| 61 |
+
return False
|
| 62 |
+
|
| 63 |
+
def _windows_where_python(self) -> list[str]:
|
| 64 |
+
"""List python.exe candidates on Windows using `where python` (best-effort)."""
|
| 65 |
+
if not sys.platform.startswith("win"):
|
| 66 |
+
return []
|
| 67 |
+
try:
|
| 68 |
+
out = subprocess.check_output(
|
| 69 |
+
["where", "python"],
|
| 70 |
+
stderr=subprocess.DEVNULL,
|
| 71 |
+
text=True,
|
| 72 |
+
encoding="utf-8",
|
| 73 |
+
errors="ignore",
|
| 74 |
+
)
|
| 75 |
+
except Exception:
|
| 76 |
+
return []
|
| 77 |
+
|
| 78 |
+
paths: list[str] = []
|
| 79 |
+
seen: set[str] = set()
|
| 80 |
+
for line in out.splitlines():
|
| 81 |
+
p = (line or "").strip().strip('"')
|
| 82 |
+
if not p:
|
| 83 |
+
continue
|
| 84 |
+
key = p.lower()
|
| 85 |
+
if key in seen:
|
| 86 |
+
continue
|
| 87 |
+
seen.add(key)
|
| 88 |
+
paths.append(p)
|
| 89 |
+
return paths
|
| 90 |
+
|
| 91 |
+
def _select_runtime(self) -> None:
|
| 92 |
+
"""Pick python executable + browser type to run solver with.
|
| 93 |
+
|
| 94 |
+
Practical notes (Windows):
|
| 95 |
+
- The API server may run in a venv (e.g. Python 3.13).
|
| 96 |
+
- Many users install the solver dependencies (camoufox/patchright) into their
|
| 97 |
+
system python (e.g. Python 3.12) and start the solver via a `.bat`.
|
| 98 |
+
|
| 99 |
+
To match that workflow, we prefer an interpreter that has `patchright` when
|
| 100 |
+
available (it tends to have better anti-bot compatibility). For camoufox,
|
| 101 |
+
we also require `camoufox` import to succeed.
|
| 102 |
+
"""
|
| 103 |
+
desired = (self.config.browser_type or "chromium").strip().lower()
|
| 104 |
+
if desired not in {"chromium", "chrome", "msedge", "camoufox"}:
|
| 105 |
+
desired = "chromium"
|
| 106 |
+
|
| 107 |
+
# Collect python candidates.
|
| 108 |
+
#
|
| 109 |
+
# NOTE: When the API server runs under `uv run`, `python` on PATH usually points to
|
| 110 |
+
# the venv python, not the system python. On Windows, use `where python` to discover
|
| 111 |
+
# other interpreters (e.g. Python312) where users installed camoufox/patchright.
|
| 112 |
+
candidates: list[str] = [sys.executable]
|
| 113 |
+
for p in self._windows_where_python():
|
| 114 |
+
if p.lower() != sys.executable.lower():
|
| 115 |
+
candidates.append(p)
|
| 116 |
+
# As a last resort, try PATH resolution.
|
| 117 |
+
candidates.append("python")
|
| 118 |
+
|
| 119 |
+
# De-duplicate while preserving order.
|
| 120 |
+
dedup: list[str] = []
|
| 121 |
+
seen: set[str] = set()
|
| 122 |
+
for p in candidates:
|
| 123 |
+
k = p.lower()
|
| 124 |
+
if k in seen:
|
| 125 |
+
continue
|
| 126 |
+
seen.add(k)
|
| 127 |
+
dedup.append(p)
|
| 128 |
+
candidates = dedup
|
| 129 |
+
|
| 130 |
+
def _pick_with(modules: list[str]) -> str | None:
|
| 131 |
+
for exe in candidates:
|
| 132 |
+
if self._can_import(exe, modules):
|
| 133 |
+
return exe
|
| 134 |
+
return None
|
| 135 |
+
|
| 136 |
+
self._actual_browser_type = desired
|
| 137 |
+
|
| 138 |
+
if desired == "camoufox":
|
| 139 |
+
# Prefer patchright if possible.
|
| 140 |
+
exe = _pick_with(["quart", "camoufox", "patchright"])
|
| 141 |
+
if exe:
|
| 142 |
+
self._python_exe = exe
|
| 143 |
+
return
|
| 144 |
+
|
| 145 |
+
exe = _pick_with(["quart", "camoufox", "playwright"])
|
| 146 |
+
if exe:
|
| 147 |
+
self._python_exe = exe
|
| 148 |
+
return
|
| 149 |
+
|
| 150 |
+
# No camoufox in any known interpreter; fallback to chromium.
|
| 151 |
+
logger.warning("Camoufox not available. Falling back solver browser to chromium.")
|
| 152 |
+
self._actual_browser_type = "chromium"
|
| 153 |
+
|
| 154 |
+
# For chromium/chrome/msedge, prefer patchright if available.
|
| 155 |
+
exe = _pick_with(["quart", "patchright"])
|
| 156 |
+
if exe:
|
| 157 |
+
self._python_exe = exe
|
| 158 |
+
return
|
| 159 |
+
|
| 160 |
+
exe = _pick_with(["quart", "playwright"])
|
| 161 |
+
if exe:
|
| 162 |
+
self._python_exe = exe
|
| 163 |
+
return
|
| 164 |
+
|
| 165 |
+
# Last resort: current interpreter (may fail fast with a clear error from the solver process).
|
| 166 |
+
self._python_exe = sys.executable
|
| 167 |
+
|
| 168 |
+
def _ensure_playwright_browsers(self, python_exe: str) -> None:
|
| 169 |
+
"""Ensure Playwright browser binaries exist (best-effort).
|
| 170 |
+
|
| 171 |
+
We only auto-install for bundled Chromium. Branded channels (chrome/msedge)
|
| 172 |
+
rely on system-installed browsers.
|
| 173 |
+
"""
|
| 174 |
+
if self._actual_browser_type != "chromium":
|
| 175 |
+
return
|
| 176 |
+
|
| 177 |
+
lock_dir = self._repo_root / "data" / ".locks"
|
| 178 |
+
lock_dir.mkdir(parents=True, exist_ok=True)
|
| 179 |
+
lock_path = lock_dir / "playwright_chromium_v1.lock"
|
| 180 |
+
if lock_path.exists():
|
| 181 |
+
return
|
| 182 |
+
|
| 183 |
+
try:
|
| 184 |
+
logger.info("Installing Playwright Chromium (first run)...")
|
| 185 |
+
args = [python_exe, "-m", "playwright", "install"]
|
| 186 |
+
# On Linux (Docker), install system deps as well.
|
| 187 |
+
if sys.platform.startswith("linux"):
|
| 188 |
+
args.append("--with-deps")
|
| 189 |
+
args.append("chromium")
|
| 190 |
+
subprocess.check_call(args, cwd=str(self._repo_root))
|
| 191 |
+
lock_path.write_text(str(time.time()), encoding="utf-8")
|
| 192 |
+
except Exception as exc:
|
| 193 |
+
# Don't create lock file; let next run retry.
|
| 194 |
+
raise RuntimeError(f"Playwright browser install failed: {exc}") from exc
|
| 195 |
+
|
| 196 |
+
def _parse_host_port(self) -> tuple[str, int]:
|
| 197 |
+
parsed = urlparse(self.config.url)
|
| 198 |
+
host = parsed.hostname or "127.0.0.1"
|
| 199 |
+
port = parsed.port or 5072
|
| 200 |
+
return host, int(port)
|
| 201 |
+
|
| 202 |
+
def start(self) -> None:
|
| 203 |
+
if not self.config.auto_start:
|
| 204 |
+
return
|
| 205 |
+
|
| 206 |
+
host, port = self._parse_host_port()
|
| 207 |
+
|
| 208 |
+
def _spawn() -> None:
|
| 209 |
+
script = self._script_path()
|
| 210 |
+
if not script.exists():
|
| 211 |
+
raise RuntimeError(f"Solver script not found: {script}")
|
| 212 |
+
|
| 213 |
+
# Ensure Playwright browsers are present before starting the solver process.
|
| 214 |
+
self._ensure_playwright_browsers(self._python_exe)
|
| 215 |
+
|
| 216 |
+
cmd = [
|
| 217 |
+
self._python_exe,
|
| 218 |
+
str(script),
|
| 219 |
+
"--browser_type",
|
| 220 |
+
self._actual_browser_type,
|
| 221 |
+
"--thread",
|
| 222 |
+
str(self.config.threads),
|
| 223 |
+
]
|
| 224 |
+
if self.config.debug:
|
| 225 |
+
cmd.append("--debug")
|
| 226 |
+
cmd += ["--host", host, "--port", str(port)]
|
| 227 |
+
|
| 228 |
+
logger.info("Starting Turnstile solver: {}", " ".join(cmd))
|
| 229 |
+
self._process = subprocess.Popen(
|
| 230 |
+
cmd,
|
| 231 |
+
cwd=str(script.parent),
|
| 232 |
+
)
|
| 233 |
+
self._started_by_us = True
|
| 234 |
+
|
| 235 |
+
if not _wait_for_port(host, port, timeout=60.0):
|
| 236 |
+
exit_code = self._process.poll() if self._process else None
|
| 237 |
+
self.stop()
|
| 238 |
+
if exit_code is not None:
|
| 239 |
+
raise RuntimeError(
|
| 240 |
+
f"Turnstile solver exited early (code {exit_code}). "
|
| 241 |
+
"Please check solver dependencies."
|
| 242 |
+
)
|
| 243 |
+
raise RuntimeError("Turnstile solver did not become ready in time")
|
| 244 |
+
|
| 245 |
+
# Decide runtime + browser strategy before checking readiness.
|
| 246 |
+
self._select_runtime()
|
| 247 |
+
logger.info(
|
| 248 |
+
"Turnstile solver runtime selected: python={} browser_type={}",
|
| 249 |
+
self._python_exe,
|
| 250 |
+
self._actual_browser_type,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
if _wait_for_port(host, port, timeout=1.0):
|
| 254 |
+
logger.info("Turnstile solver already running at {}:{}", host, port)
|
| 255 |
+
self._started_by_us = False
|
| 256 |
+
return
|
| 257 |
+
|
| 258 |
+
try:
|
| 259 |
+
_spawn()
|
| 260 |
+
return
|
| 261 |
+
except Exception as exc:
|
| 262 |
+
# camoufox is not always stable/available across environments (notably Docker).
|
| 263 |
+
# Fall back to chromium instead of failing the whole auto-register workflow.
|
| 264 |
+
if self._actual_browser_type != "camoufox":
|
| 265 |
+
raise
|
| 266 |
+
logger.warning("Camoufox solver failed to start; falling back to chromium: {}", exc)
|
| 267 |
+
try:
|
| 268 |
+
self.stop()
|
| 269 |
+
except Exception:
|
| 270 |
+
pass
|
| 271 |
+
self.config.browser_type = "chromium"
|
| 272 |
+
self._actual_browser_type = "chromium"
|
| 273 |
+
self._select_runtime()
|
| 274 |
+
logger.info(
|
| 275 |
+
"Turnstile solver runtime selected: python={} browser_type={}",
|
| 276 |
+
self._python_exe,
|
| 277 |
+
self._actual_browser_type,
|
| 278 |
+
)
|
| 279 |
+
_spawn()
|
| 280 |
+
|
| 281 |
+
def stop(self) -> None:
|
| 282 |
+
if not self._process or not self._started_by_us:
|
| 283 |
+
return
|
| 284 |
+
|
| 285 |
+
try:
|
| 286 |
+
logger.info("Stopping Turnstile solver...")
|
| 287 |
+
self._process.terminate()
|
| 288 |
+
self._process.wait(timeout=10)
|
| 289 |
+
except Exception:
|
| 290 |
+
try:
|
| 291 |
+
self._process.kill()
|
| 292 |
+
except Exception:
|
| 293 |
+
pass
|
| 294 |
+
finally:
|
| 295 |
+
self._process = None
|
| 296 |
+
self._started_by_us = False
|
app/services/request_logger.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""请求日志审计 - 记录近期请求"""
|
| 2 |
+
|
| 3 |
+
import time
|
| 4 |
+
import asyncio
|
| 5 |
+
import orjson
|
| 6 |
+
from typing import List, Dict, Deque
|
| 7 |
+
from collections import deque
|
| 8 |
+
from dataclasses import dataclass, asdict
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
from app.core.logger import logger
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class RequestLog:
|
| 15 |
+
id: str
|
| 16 |
+
time: str
|
| 17 |
+
timestamp: float
|
| 18 |
+
ip: str
|
| 19 |
+
model: str
|
| 20 |
+
duration: float
|
| 21 |
+
status: int
|
| 22 |
+
key_name: str
|
| 23 |
+
token_suffix: str
|
| 24 |
+
error: str = ""
|
| 25 |
+
|
| 26 |
+
class RequestLogger:
|
| 27 |
+
"""请求日志记录器"""
|
| 28 |
+
|
| 29 |
+
_instance = None
|
| 30 |
+
|
| 31 |
+
def __new__(cls):
|
| 32 |
+
if cls._instance is None:
|
| 33 |
+
cls._instance = super().__new__(cls)
|
| 34 |
+
return cls._instance
|
| 35 |
+
|
| 36 |
+
def __init__(self, max_len: int = 1000):
|
| 37 |
+
if hasattr(self, '_initialized'):
|
| 38 |
+
return
|
| 39 |
+
|
| 40 |
+
self.file_path = Path(__file__).parents[2] / "data" / "logs.json"
|
| 41 |
+
self._logs: Deque[Dict] = deque(maxlen=max_len)
|
| 42 |
+
self._lock = asyncio.Lock()
|
| 43 |
+
self._loaded = False
|
| 44 |
+
|
| 45 |
+
self._initialized = True
|
| 46 |
+
|
| 47 |
+
async def init(self):
|
| 48 |
+
"""初始化加载数据"""
|
| 49 |
+
if not self._loaded:
|
| 50 |
+
await self._load_data()
|
| 51 |
+
|
| 52 |
+
async def _load_data(self):
|
| 53 |
+
"""从磁盘加载日志数据"""
|
| 54 |
+
if self._loaded:
|
| 55 |
+
return
|
| 56 |
+
|
| 57 |
+
if not self.file_path.exists():
|
| 58 |
+
self._loaded = True
|
| 59 |
+
return
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
async with self._lock:
|
| 63 |
+
content = await asyncio.to_thread(self.file_path.read_bytes)
|
| 64 |
+
if content:
|
| 65 |
+
data = orjson.loads(content)
|
| 66 |
+
if isinstance(data, list):
|
| 67 |
+
self._logs.clear()
|
| 68 |
+
self._logs.extend(data)
|
| 69 |
+
self._loaded = True
|
| 70 |
+
logger.debug(f"[Logger] 加载日志成功: {len(self._logs)} 条")
|
| 71 |
+
except Exception as e:
|
| 72 |
+
logger.error(f"[Logger] 加载日志失败: {e}")
|
| 73 |
+
self._loaded = True
|
| 74 |
+
|
| 75 |
+
async def _save_data(self):
|
| 76 |
+
"""保存日志数据到磁盘"""
|
| 77 |
+
if not self._loaded:
|
| 78 |
+
return
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
# 确保目录存在
|
| 82 |
+
self.file_path.parent.mkdir(parents=True, exist_ok=True)
|
| 83 |
+
|
| 84 |
+
async with self._lock:
|
| 85 |
+
# 转换为列表保存
|
| 86 |
+
content = orjson.dumps(list(self._logs))
|
| 87 |
+
await asyncio.to_thread(self.file_path.write_bytes, content)
|
| 88 |
+
except Exception as e:
|
| 89 |
+
logger.error(f"[Logger] 保存日志失败: {e}")
|
| 90 |
+
|
| 91 |
+
async def add_log(self,
|
| 92 |
+
ip: str,
|
| 93 |
+
model: str,
|
| 94 |
+
duration: float,
|
| 95 |
+
status: int,
|
| 96 |
+
key_name: str,
|
| 97 |
+
token_suffix: str = "",
|
| 98 |
+
error: str = ""):
|
| 99 |
+
"""添加日志"""
|
| 100 |
+
if not self._loaded:
|
| 101 |
+
await self.init()
|
| 102 |
+
|
| 103 |
+
try:
|
| 104 |
+
now = time.time()
|
| 105 |
+
# 格式化时间
|
| 106 |
+
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(now))
|
| 107 |
+
|
| 108 |
+
log = {
|
| 109 |
+
"id": str(int(now * 1000)),
|
| 110 |
+
"time": time_str,
|
| 111 |
+
"timestamp": now,
|
| 112 |
+
"ip": ip,
|
| 113 |
+
"model": model,
|
| 114 |
+
"duration": round(duration, 2),
|
| 115 |
+
"status": status,
|
| 116 |
+
"key_name": key_name,
|
| 117 |
+
"token_suffix": token_suffix,
|
| 118 |
+
"error": error
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
async with self._lock:
|
| 122 |
+
self._logs.appendleft(log) # 最新的在前
|
| 123 |
+
|
| 124 |
+
# 异步保存
|
| 125 |
+
asyncio.create_task(self._save_data())
|
| 126 |
+
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logger.error(f"[Logger] 记录日志失败: {e}")
|
| 129 |
+
|
| 130 |
+
async def get_logs(self, limit: int = 1000) -> List[Dict]:
|
| 131 |
+
"""获取日志"""
|
| 132 |
+
async with self._lock:
|
| 133 |
+
return list(self._logs)[:limit]
|
| 134 |
+
|
| 135 |
+
async def clear_logs(self):
|
| 136 |
+
"""清空日志"""
|
| 137 |
+
async with self._lock:
|
| 138 |
+
self._logs.clear()
|
| 139 |
+
await self._save_data()
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# 全局实例
|
| 143 |
+
request_logger = RequestLogger()
|
app/services/request_stats.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""请求统计模块 - 按小时/天统计请求数据"""
|
| 2 |
+
|
| 3 |
+
import time
|
| 4 |
+
import asyncio
|
| 5 |
+
import orjson
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import Dict, Any
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
|
| 11 |
+
from app.core.logger import logger
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class RequestStats:
|
| 15 |
+
"""请求统计管理器(单例)"""
|
| 16 |
+
|
| 17 |
+
_instance = None
|
| 18 |
+
|
| 19 |
+
def __new__(cls):
|
| 20 |
+
if cls._instance is None:
|
| 21 |
+
cls._instance = super().__new__(cls)
|
| 22 |
+
return cls._instance
|
| 23 |
+
|
| 24 |
+
def __init__(self):
|
| 25 |
+
if hasattr(self, '_initialized'):
|
| 26 |
+
return
|
| 27 |
+
|
| 28 |
+
self.file_path = Path(__file__).parents[2] / "data" / "stats.json"
|
| 29 |
+
|
| 30 |
+
# 统计数据
|
| 31 |
+
self._hourly: Dict[str, Dict[str, int]] = defaultdict(lambda: {"total": 0, "success": 0, "failed": 0})
|
| 32 |
+
self._daily: Dict[str, Dict[str, int]] = defaultdict(lambda: {"total": 0, "success": 0, "failed": 0})
|
| 33 |
+
self._models: Dict[str, int] = defaultdict(int)
|
| 34 |
+
|
| 35 |
+
# 保留策略
|
| 36 |
+
self._hourly_keep = 48 # 保留48小时
|
| 37 |
+
self._daily_keep = 30 # 保留30天
|
| 38 |
+
|
| 39 |
+
self._lock = asyncio.Lock()
|
| 40 |
+
self._loaded = False
|
| 41 |
+
self._initialized = True
|
| 42 |
+
|
| 43 |
+
async def init(self):
|
| 44 |
+
"""初始化加载数据"""
|
| 45 |
+
if not self._loaded:
|
| 46 |
+
await self._load_data()
|
| 47 |
+
|
| 48 |
+
async def _load_data(self):
|
| 49 |
+
"""从磁盘加载统计数据"""
|
| 50 |
+
if self._loaded:
|
| 51 |
+
return
|
| 52 |
+
|
| 53 |
+
if not self.file_path.exists():
|
| 54 |
+
self._loaded = True
|
| 55 |
+
return
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
async with self._lock:
|
| 59 |
+
content = await asyncio.to_thread(self.file_path.read_bytes)
|
| 60 |
+
if content:
|
| 61 |
+
data = orjson.loads(content)
|
| 62 |
+
|
| 63 |
+
# 恢复 defaultdict 结构
|
| 64 |
+
self._hourly = defaultdict(lambda: {"total": 0, "success": 0, "failed": 0})
|
| 65 |
+
self._hourly.update(data.get("hourly", {}))
|
| 66 |
+
|
| 67 |
+
self._daily = defaultdict(lambda: {"total": 0, "success": 0, "failed": 0})
|
| 68 |
+
self._daily.update(data.get("daily", {}))
|
| 69 |
+
|
| 70 |
+
self._models = defaultdict(int)
|
| 71 |
+
self._models.update(data.get("models", {}))
|
| 72 |
+
|
| 73 |
+
self._loaded = True
|
| 74 |
+
logger.debug(f"[Stats] 加载统计数据成功")
|
| 75 |
+
except Exception as e:
|
| 76 |
+
logger.error(f"[Stats] 加载数据失败: {e}")
|
| 77 |
+
self._loaded = True # 防止覆盖
|
| 78 |
+
|
| 79 |
+
async def _save_data(self):
|
| 80 |
+
"""保存统计数据到磁盘"""
|
| 81 |
+
if not self._loaded:
|
| 82 |
+
return
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
# 确保目录存在
|
| 86 |
+
self.file_path.parent.mkdir(parents=True, exist_ok=True)
|
| 87 |
+
|
| 88 |
+
async with self._lock:
|
| 89 |
+
data = {
|
| 90 |
+
"hourly": dict(self._hourly),
|
| 91 |
+
"daily": dict(self._daily),
|
| 92 |
+
"models": dict(self._models)
|
| 93 |
+
}
|
| 94 |
+
content = orjson.dumps(data)
|
| 95 |
+
await asyncio.to_thread(self.file_path.write_bytes, content)
|
| 96 |
+
except Exception as e:
|
| 97 |
+
logger.error(f"[Stats] 保存数据失败: {e}")
|
| 98 |
+
|
| 99 |
+
async def record_request(self, model: str, success: bool) -> None:
|
| 100 |
+
"""记录一次请求"""
|
| 101 |
+
if not self._loaded:
|
| 102 |
+
await self.init()
|
| 103 |
+
|
| 104 |
+
now = datetime.now()
|
| 105 |
+
hour_key = now.strftime("%Y-%m-%dT%H")
|
| 106 |
+
day_key = now.strftime("%Y-%m-%d")
|
| 107 |
+
|
| 108 |
+
# 小时统计
|
| 109 |
+
self._hourly[hour_key]["total"] += 1
|
| 110 |
+
if success:
|
| 111 |
+
self._hourly[hour_key]["success"] += 1
|
| 112 |
+
else:
|
| 113 |
+
self._hourly[hour_key]["failed"] += 1
|
| 114 |
+
|
| 115 |
+
# 天统计
|
| 116 |
+
self._daily[day_key]["total"] += 1
|
| 117 |
+
if success:
|
| 118 |
+
self._daily[day_key]["success"] += 1
|
| 119 |
+
else:
|
| 120 |
+
self._daily[day_key]["failed"] += 1
|
| 121 |
+
|
| 122 |
+
# 模型统计
|
| 123 |
+
self._models[model] += 1
|
| 124 |
+
|
| 125 |
+
# 定期清理旧数据
|
| 126 |
+
self._cleanup()
|
| 127 |
+
|
| 128 |
+
# 异步保存
|
| 129 |
+
asyncio.create_task(self._save_data())
|
| 130 |
+
|
| 131 |
+
def _cleanup(self) -> None:
|
| 132 |
+
"""清理过期数据"""
|
| 133 |
+
now = datetime.now()
|
| 134 |
+
|
| 135 |
+
# 清理小时数据
|
| 136 |
+
hour_keys = list(self._hourly.keys())
|
| 137 |
+
if len(hour_keys) > self._hourly_keep:
|
| 138 |
+
for key in sorted(hour_keys)[:-self._hourly_keep]:
|
| 139 |
+
del self._hourly[key]
|
| 140 |
+
|
| 141 |
+
# 清理天数据
|
| 142 |
+
day_keys = list(self._daily.keys())
|
| 143 |
+
if len(day_keys) > self._daily_keep:
|
| 144 |
+
for key in sorted(day_keys)[:-self._daily_keep]:
|
| 145 |
+
del self._daily[key]
|
| 146 |
+
|
| 147 |
+
def get_stats(self, hours: int = 24, days: int = 7) -> Dict[str, Any]:
|
| 148 |
+
"""获取统计数据"""
|
| 149 |
+
now = datetime.now()
|
| 150 |
+
|
| 151 |
+
# 获取最近N小时数据
|
| 152 |
+
hourly_data = []
|
| 153 |
+
for i in range(hours - 1, -1, -1):
|
| 154 |
+
from datetime import timedelta
|
| 155 |
+
dt = now - timedelta(hours=i)
|
| 156 |
+
key = dt.strftime("%Y-%m-%dT%H")
|
| 157 |
+
data = self._hourly.get(key, {"total": 0, "success": 0, "failed": 0})
|
| 158 |
+
hourly_data.append({
|
| 159 |
+
"hour": dt.strftime("%H:00"),
|
| 160 |
+
"date": dt.strftime("%m-%d"),
|
| 161 |
+
**data
|
| 162 |
+
})
|
| 163 |
+
|
| 164 |
+
# 获取最近N天数据
|
| 165 |
+
daily_data = []
|
| 166 |
+
for i in range(days - 1, -1, -1):
|
| 167 |
+
from datetime import timedelta
|
| 168 |
+
dt = now - timedelta(days=i)
|
| 169 |
+
key = dt.strftime("%Y-%m-%d")
|
| 170 |
+
data = self._daily.get(key, {"total": 0, "success": 0, "failed": 0})
|
| 171 |
+
daily_data.append({
|
| 172 |
+
"date": dt.strftime("%m-%d"),
|
| 173 |
+
**data
|
| 174 |
+
})
|
| 175 |
+
|
| 176 |
+
# 模型统计(取 Top 10)
|
| 177 |
+
model_data = sorted(self._models.items(), key=lambda x: x[1], reverse=True)[:10]
|
| 178 |
+
|
| 179 |
+
# 总计
|
| 180 |
+
total_requests = sum(d["total"] for d in self._hourly.values())
|
| 181 |
+
total_success = sum(d["success"] for d in self._hourly.values())
|
| 182 |
+
total_failed = sum(d["failed"] for d in self._hourly.values())
|
| 183 |
+
|
| 184 |
+
return {
|
| 185 |
+
"hourly": hourly_data,
|
| 186 |
+
"daily": daily_data,
|
| 187 |
+
"models": [{"model": m, "count": c} for m, c in model_data],
|
| 188 |
+
"summary": {
|
| 189 |
+
"total": total_requests,
|
| 190 |
+
"success": total_success,
|
| 191 |
+
"failed": total_failed,
|
| 192 |
+
"success_rate": round(total_success / total_requests * 100, 1) if total_requests > 0 else 0
|
| 193 |
+
}
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
async def reset(self) -> None:
|
| 197 |
+
"""重置所有统计"""
|
| 198 |
+
self._hourly.clear()
|
| 199 |
+
self._daily.clear()
|
| 200 |
+
self._models.clear()
|
| 201 |
+
await self._save_data()
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# 全局实例
|
| 205 |
+
request_stats = RequestStats()
|
app/services/token/__init__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Token 服务模块"""
|
| 2 |
+
|
| 3 |
+
from app.services.token.models import (
|
| 4 |
+
TokenInfo,
|
| 5 |
+
TokenStatus,
|
| 6 |
+
TokenPoolStats,
|
| 7 |
+
EffortType,
|
| 8 |
+
DEFAULT_QUOTA,
|
| 9 |
+
EFFORT_COST
|
| 10 |
+
)
|
| 11 |
+
from app.services.token.pool import TokenPool
|
| 12 |
+
from app.services.token.manager import TokenManager, get_token_manager
|
| 13 |
+
from app.services.token.service import TokenService
|
| 14 |
+
from app.services.token.scheduler import TokenRefreshScheduler, get_scheduler
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
# Models
|
| 18 |
+
"TokenInfo",
|
| 19 |
+
"TokenStatus",
|
| 20 |
+
"TokenPoolStats",
|
| 21 |
+
"EffortType",
|
| 22 |
+
"DEFAULT_QUOTA",
|
| 23 |
+
"EFFORT_COST",
|
| 24 |
+
|
| 25 |
+
# Core
|
| 26 |
+
"TokenPool",
|
| 27 |
+
"TokenManager",
|
| 28 |
+
|
| 29 |
+
# API
|
| 30 |
+
"TokenService",
|
| 31 |
+
"get_token_manager",
|
| 32 |
+
|
| 33 |
+
# Scheduler
|
| 34 |
+
"TokenRefreshScheduler",
|
| 35 |
+
"get_scheduler",
|
| 36 |
+
]
|
app/services/token/manager.py
ADDED
|
@@ -0,0 +1,654 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Token 管理服务"""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import time
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from typing import Dict, List, Optional
|
| 7 |
+
|
| 8 |
+
from app.core.logger import logger
|
| 9 |
+
from app.services.token.models import TokenInfo, EffortType, TokenPoolStats, FAIL_THRESHOLD, TokenStatus
|
| 10 |
+
from app.core.storage import get_storage
|
| 11 |
+
from app.core.config import get_config
|
| 12 |
+
from app.services.token.pool import TokenPool
|
| 13 |
+
|
| 14 |
+
# 批量刷新配置
|
| 15 |
+
REFRESH_INTERVAL_HOURS = 8
|
| 16 |
+
REFRESH_BATCH_SIZE = 10
|
| 17 |
+
REFRESH_CONCURRENCY = 5
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TokenManager:
|
| 21 |
+
"""管理 Token 的增删改查和配额同步"""
|
| 22 |
+
|
| 23 |
+
_instance: Optional["TokenManager"] = None
|
| 24 |
+
_lock = asyncio.Lock()
|
| 25 |
+
|
| 26 |
+
def __init__(self):
|
| 27 |
+
self.pools: Dict[str, TokenPool] = {}
|
| 28 |
+
self.initialized = False
|
| 29 |
+
self._save_lock = asyncio.Lock()
|
| 30 |
+
self._dirty = False
|
| 31 |
+
self._save_task: Optional[asyncio.Task] = None
|
| 32 |
+
self._save_delay = 0.5
|
| 33 |
+
self._last_reload_at = 0.0
|
| 34 |
+
|
| 35 |
+
@classmethod
|
| 36 |
+
async def get_instance(cls) -> "TokenManager":
|
| 37 |
+
"""获取单例实例"""
|
| 38 |
+
if cls._instance is None:
|
| 39 |
+
async with cls._lock:
|
| 40 |
+
if cls._instance is None:
|
| 41 |
+
cls._instance = cls()
|
| 42 |
+
await cls._instance._load()
|
| 43 |
+
return cls._instance
|
| 44 |
+
|
| 45 |
+
async def _load(self):
|
| 46 |
+
"""初始化加载"""
|
| 47 |
+
if not self.initialized:
|
| 48 |
+
try:
|
| 49 |
+
storage = get_storage()
|
| 50 |
+
data = await storage.load_tokens()
|
| 51 |
+
|
| 52 |
+
# 如果后端返回 None 或空数据,尝试从本地 data/token.json 初始化后端
|
| 53 |
+
if not data:
|
| 54 |
+
from app.core.storage import LocalStorage
|
| 55 |
+
local_storage = LocalStorage()
|
| 56 |
+
local_data = await local_storage.load_tokens()
|
| 57 |
+
if local_data:
|
| 58 |
+
data = local_data
|
| 59 |
+
await storage.save_tokens(local_data)
|
| 60 |
+
logger.info(f"Initialized remote token storage ({storage.__class__.__name__}) with local tokens.")
|
| 61 |
+
else:
|
| 62 |
+
data = {}
|
| 63 |
+
|
| 64 |
+
self.pools = {}
|
| 65 |
+
for pool_name, tokens in data.items():
|
| 66 |
+
pool = TokenPool(pool_name)
|
| 67 |
+
for token_data in tokens:
|
| 68 |
+
try:
|
| 69 |
+
# 统一存储裸 token
|
| 70 |
+
if isinstance(token_data, dict):
|
| 71 |
+
raw_token = token_data.get("token")
|
| 72 |
+
if isinstance(raw_token, str) and raw_token.startswith("sso="):
|
| 73 |
+
token_data["token"] = raw_token[4:]
|
| 74 |
+
token_info = TokenInfo(**token_data)
|
| 75 |
+
pool.add(token_info)
|
| 76 |
+
except Exception as e:
|
| 77 |
+
logger.warning(f"Failed to load token in pool '{pool_name}': {e}")
|
| 78 |
+
continue
|
| 79 |
+
pool._rebuild_index()
|
| 80 |
+
self.pools[pool_name] = pool
|
| 81 |
+
|
| 82 |
+
self.initialized = True
|
| 83 |
+
self._last_reload_at = time.monotonic()
|
| 84 |
+
total = sum(p.count() for p in self.pools.values())
|
| 85 |
+
logger.info(f"TokenManager initialized: {len(self.pools)} pools with {total} tokens")
|
| 86 |
+
except Exception as e:
|
| 87 |
+
logger.error(f"Failed to initialize TokenManager: {e}")
|
| 88 |
+
self.pools = {}
|
| 89 |
+
self.initialized = True
|
| 90 |
+
|
| 91 |
+
async def reload(self):
|
| 92 |
+
"""重新加载 Token 池数据"""
|
| 93 |
+
async with self.__class__._lock:
|
| 94 |
+
self.initialized = False
|
| 95 |
+
await self._load()
|
| 96 |
+
|
| 97 |
+
async def reload_if_stale(self):
|
| 98 |
+
"""在多 worker 场景下保持短周期一致性"""
|
| 99 |
+
interval = get_config("token.reload_interval_sec", 30)
|
| 100 |
+
try:
|
| 101 |
+
interval = float(interval)
|
| 102 |
+
except Exception:
|
| 103 |
+
interval = 30.0
|
| 104 |
+
if interval <= 0:
|
| 105 |
+
return
|
| 106 |
+
if time.monotonic() - self._last_reload_at < interval:
|
| 107 |
+
return
|
| 108 |
+
await self.reload()
|
| 109 |
+
|
| 110 |
+
async def _save(self):
|
| 111 |
+
"""保存变更"""
|
| 112 |
+
async with self._save_lock:
|
| 113 |
+
try:
|
| 114 |
+
data = {}
|
| 115 |
+
for pool_name, pool in self.pools.items():
|
| 116 |
+
data[pool_name] = [
|
| 117 |
+
info.model_dump() for info in pool.list()
|
| 118 |
+
]
|
| 119 |
+
|
| 120 |
+
storage = get_storage()
|
| 121 |
+
async with storage.acquire_lock("tokens_save", timeout=10):
|
| 122 |
+
await storage.save_tokens(data)
|
| 123 |
+
except Exception as e:
|
| 124 |
+
logger.error(f"Failed to save tokens: {e}")
|
| 125 |
+
|
| 126 |
+
def _schedule_save(self):
|
| 127 |
+
"""合并高频保存请求,减少写入开销"""
|
| 128 |
+
delay_ms = get_config("token.save_delay_ms", 500)
|
| 129 |
+
try:
|
| 130 |
+
delay_ms = float(delay_ms)
|
| 131 |
+
except Exception:
|
| 132 |
+
delay_ms = 500
|
| 133 |
+
self._save_delay = max(0.0, delay_ms / 1000.0)
|
| 134 |
+
self._dirty = True
|
| 135 |
+
if self._save_delay == 0:
|
| 136 |
+
if self._save_task and not self._save_task.done():
|
| 137 |
+
return
|
| 138 |
+
self._save_task = asyncio.create_task(self._save())
|
| 139 |
+
return
|
| 140 |
+
if self._save_task and not self._save_task.done():
|
| 141 |
+
return
|
| 142 |
+
self._save_task = asyncio.create_task(self._flush_loop())
|
| 143 |
+
|
| 144 |
+
async def _flush_loop(self):
|
| 145 |
+
try:
|
| 146 |
+
while True:
|
| 147 |
+
await asyncio.sleep(self._save_delay)
|
| 148 |
+
if not self._dirty:
|
| 149 |
+
break
|
| 150 |
+
self._dirty = False
|
| 151 |
+
await self._save()
|
| 152 |
+
finally:
|
| 153 |
+
self._save_task = None
|
| 154 |
+
if self._dirty:
|
| 155 |
+
self._schedule_save()
|
| 156 |
+
|
| 157 |
+
@staticmethod
|
| 158 |
+
def _extract_cookie_value(cookie_str: str, name: str) -> str | None:
|
| 159 |
+
needle = f"{name}="
|
| 160 |
+
if needle not in cookie_str:
|
| 161 |
+
return None
|
| 162 |
+
for part in cookie_str.split(";"):
|
| 163 |
+
part = part.strip()
|
| 164 |
+
if part.startswith(needle):
|
| 165 |
+
value = part[len(needle):].strip()
|
| 166 |
+
return value or None
|
| 167 |
+
return None
|
| 168 |
+
|
| 169 |
+
@classmethod
|
| 170 |
+
def _normalize_input_token(cls, token_str: str) -> str:
|
| 171 |
+
raw = str(token_str or "").strip()
|
| 172 |
+
if not raw:
|
| 173 |
+
return ""
|
| 174 |
+
if ";" in raw:
|
| 175 |
+
return (cls._extract_cookie_value(raw, "sso") or "").strip()
|
| 176 |
+
if raw.startswith("sso="):
|
| 177 |
+
return raw[4:].strip()
|
| 178 |
+
return raw
|
| 179 |
+
|
| 180 |
+
def _find_token_info(self, token_str: str) -> tuple[Optional[TokenInfo], str]:
|
| 181 |
+
raw_token = self._normalize_input_token(token_str)
|
| 182 |
+
if not raw_token:
|
| 183 |
+
return None, ""
|
| 184 |
+
for pool in self.pools.values():
|
| 185 |
+
token = pool.get(raw_token)
|
| 186 |
+
if token:
|
| 187 |
+
return token, raw_token
|
| 188 |
+
return None, raw_token
|
| 189 |
+
|
| 190 |
+
def get_token(self, pool_name: str = "ssoBasic") -> Optional[str]:
|
| 191 |
+
"""
|
| 192 |
+
获取可用 Token
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
pool_name: Token 池名称
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
Token 字符串或 None
|
| 199 |
+
"""
|
| 200 |
+
pool = self.pools.get(pool_name)
|
| 201 |
+
if not pool:
|
| 202 |
+
logger.warning(f"Pool '{pool_name}' not found")
|
| 203 |
+
return None
|
| 204 |
+
|
| 205 |
+
token_info = pool.select()
|
| 206 |
+
if not token_info:
|
| 207 |
+
logger.warning(f"No available token in pool '{pool_name}'")
|
| 208 |
+
return None
|
| 209 |
+
|
| 210 |
+
token = token_info.token
|
| 211 |
+
if token.startswith("sso="):
|
| 212 |
+
return token[4:]
|
| 213 |
+
return token
|
| 214 |
+
|
| 215 |
+
def get_token_for_model(self, model_id: str) -> Optional[str]:
|
| 216 |
+
"""按模型选择可用 Token(包含 basic->super 回退与 heavy 配额桶选择)。"""
|
| 217 |
+
from app.services.grok.model import ModelService
|
| 218 |
+
|
| 219 |
+
bucket = "heavy" if ModelService.is_heavy_bucket_model(model_id) else "normal"
|
| 220 |
+
for pool_name in ModelService.pool_candidates_for_model(model_id):
|
| 221 |
+
pool = self.pools.get(pool_name)
|
| 222 |
+
if not pool:
|
| 223 |
+
continue
|
| 224 |
+
token_info = pool.select(bucket=bucket)
|
| 225 |
+
if not token_info:
|
| 226 |
+
continue
|
| 227 |
+
token = token_info.token
|
| 228 |
+
return token[4:] if token.startswith("sso=") else token
|
| 229 |
+
|
| 230 |
+
logger.warning(f"No available token for model '{model_id}'")
|
| 231 |
+
return None
|
| 232 |
+
|
| 233 |
+
async def consume(self, token_str: str, effort: EffortType = EffortType.LOW, bucket: str = "normal") -> bool:
|
| 234 |
+
"""
|
| 235 |
+
消耗配额(本地预估)
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
token_str: Token 字符串
|
| 239 |
+
effort: 消耗力度
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
是否成功
|
| 243 |
+
"""
|
| 244 |
+
raw_token = token_str.replace("sso=", "")
|
| 245 |
+
|
| 246 |
+
for pool in self.pools.values():
|
| 247 |
+
token = pool.get(raw_token)
|
| 248 |
+
if token:
|
| 249 |
+
consumed = token.consume_heavy(effort) if bucket == "heavy" else token.consume(effort)
|
| 250 |
+
logger.debug(
|
| 251 |
+
f"Token {raw_token[:10]}...: consumed {consumed} quota (bucket={bucket}), use_count={token.use_count}"
|
| 252 |
+
)
|
| 253 |
+
self._schedule_save()
|
| 254 |
+
return True
|
| 255 |
+
|
| 256 |
+
logger.warning(f"Token {raw_token[:10]}...: not found for consumption")
|
| 257 |
+
return False
|
| 258 |
+
|
| 259 |
+
async def sync_usage(
|
| 260 |
+
self,
|
| 261 |
+
token_str: str,
|
| 262 |
+
model_id: str,
|
| 263 |
+
fallback_effort: EffortType = EffortType.LOW,
|
| 264 |
+
consume_on_fail: bool = True,
|
| 265 |
+
is_usage: bool = True
|
| 266 |
+
) -> bool:
|
| 267 |
+
"""
|
| 268 |
+
同步 Token 用量
|
| 269 |
+
|
| 270 |
+
优先从 API 获取最新配额,失败则降级到本地预估
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
token_str: Token 字符串(可带 sso= 前缀)
|
| 274 |
+
model_name: 模型名称(用于 API 查询)
|
| 275 |
+
fallback_effort: 降级时的消耗力度
|
| 276 |
+
consume_on_fail: 失败时是否降��扣费
|
| 277 |
+
is_usage: 是否记录为一次使用(影响 use_count)
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
是否成功
|
| 281 |
+
"""
|
| 282 |
+
raw_token = token_str.replace("sso=", "")
|
| 283 |
+
|
| 284 |
+
# 查找 Token 对象
|
| 285 |
+
target_token: Optional[TokenInfo] = None
|
| 286 |
+
for pool in self.pools.values():
|
| 287 |
+
target_token = pool.get(raw_token)
|
| 288 |
+
if target_token:
|
| 289 |
+
break
|
| 290 |
+
|
| 291 |
+
if not target_token:
|
| 292 |
+
logger.warning(f"Token {raw_token[:10]}...: not found for sync")
|
| 293 |
+
return False
|
| 294 |
+
|
| 295 |
+
from app.services.grok.model import ModelService
|
| 296 |
+
|
| 297 |
+
bucket = "heavy" if ModelService.is_heavy_bucket_model(model_id) else "normal"
|
| 298 |
+
rate_limit_model = ModelService.rate_limit_model_for(model_id)
|
| 299 |
+
|
| 300 |
+
# 尝试 API 同步
|
| 301 |
+
try:
|
| 302 |
+
from app.services.grok.usage import UsageService
|
| 303 |
+
|
| 304 |
+
usage_service = UsageService()
|
| 305 |
+
result = await usage_service.get(token_str, model_name=rate_limit_model)
|
| 306 |
+
|
| 307 |
+
if result and "remainingTokens" in result:
|
| 308 |
+
try:
|
| 309 |
+
new_quota = int(result["remainingTokens"])
|
| 310 |
+
except Exception:
|
| 311 |
+
new_quota = 0
|
| 312 |
+
|
| 313 |
+
if bucket == "heavy":
|
| 314 |
+
old_quota = target_token.heavy_quota
|
| 315 |
+
target_token.update_heavy_quota(new_quota)
|
| 316 |
+
else:
|
| 317 |
+
old_quota = target_token.quota
|
| 318 |
+
target_token.update_quota(new_quota)
|
| 319 |
+
|
| 320 |
+
target_token.record_success(is_usage=is_usage)
|
| 321 |
+
|
| 322 |
+
consumed = max(0, old_quota - new_quota) if old_quota >= 0 else 0
|
| 323 |
+
logger.info(
|
| 324 |
+
f"Token {raw_token[:10]}...: synced quota (bucket={bucket}, model={rate_limit_model}) "
|
| 325 |
+
f"{old_quota} -> {new_quota} (consumed: {consumed}, use_count: {target_token.use_count})"
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
self._schedule_save()
|
| 329 |
+
return True
|
| 330 |
+
|
| 331 |
+
except Exception as e:
|
| 332 |
+
logger.warning(f"Token {raw_token[:10]}...: API sync failed, fallback to local ({e})")
|
| 333 |
+
|
| 334 |
+
# 降级:本地预估扣费
|
| 335 |
+
if consume_on_fail:
|
| 336 |
+
logger.debug(f"Token {raw_token[:10]}...: using local consumption")
|
| 337 |
+
return await self.consume(token_str, fallback_effort, bucket=bucket)
|
| 338 |
+
else:
|
| 339 |
+
logger.debug(f"Token {raw_token[:10]}...: sync failed, skipping local consumption")
|
| 340 |
+
return False
|
| 341 |
+
|
| 342 |
+
async def record_fail(self, token_str: str, status_code: int = 401, reason: str = "") -> bool:
|
| 343 |
+
"""
|
| 344 |
+
记录 Token 失败
|
| 345 |
+
|
| 346 |
+
Args:
|
| 347 |
+
token_str: Token 字符串
|
| 348 |
+
status_code: HTTP 状态码
|
| 349 |
+
reason: 失败原因
|
| 350 |
+
|
| 351 |
+
Returns:
|
| 352 |
+
是否成功
|
| 353 |
+
"""
|
| 354 |
+
raw_token = token_str.replace("sso=", "")
|
| 355 |
+
|
| 356 |
+
for pool in self.pools.values():
|
| 357 |
+
token = pool.get(raw_token)
|
| 358 |
+
if token:
|
| 359 |
+
if status_code == 401:
|
| 360 |
+
token.record_fail(status_code, reason)
|
| 361 |
+
logger.warning(
|
| 362 |
+
f"Token {raw_token[:10]}...: recorded 401 failure "
|
| 363 |
+
f"({token.fail_count}/{FAIL_THRESHOLD}) - {reason}"
|
| 364 |
+
)
|
| 365 |
+
else:
|
| 366 |
+
logger.info(
|
| 367 |
+
f"Token {raw_token[:10]}...: non-401 error ({status_code}) - {reason} (not counted)"
|
| 368 |
+
)
|
| 369 |
+
self._schedule_save()
|
| 370 |
+
return True
|
| 371 |
+
|
| 372 |
+
logger.warning(f"Token {raw_token[:10]}...: not found for failure record")
|
| 373 |
+
return False
|
| 374 |
+
|
| 375 |
+
# ========== 管理功能 ==========
|
| 376 |
+
|
| 377 |
+
async def add(self, token: str, pool_name: str = "ssoBasic") -> bool:
|
| 378 |
+
"""
|
| 379 |
+
添加 Token
|
| 380 |
+
|
| 381 |
+
Args:
|
| 382 |
+
token: Token 字符串(不含 sso= 前缀)
|
| 383 |
+
pool_name: 池名称
|
| 384 |
+
|
| 385 |
+
Returns:
|
| 386 |
+
是否成功
|
| 387 |
+
"""
|
| 388 |
+
if pool_name not in self.pools:
|
| 389 |
+
self.pools[pool_name] = TokenPool(pool_name)
|
| 390 |
+
logger.info(f"Pool '{pool_name}': created")
|
| 391 |
+
|
| 392 |
+
pool = self.pools[pool_name]
|
| 393 |
+
|
| 394 |
+
token = token[4:] if token.startswith("sso=") else token
|
| 395 |
+
if pool.get(token):
|
| 396 |
+
logger.warning(f"Pool '{pool_name}': token already exists")
|
| 397 |
+
return False
|
| 398 |
+
|
| 399 |
+
pool.add(TokenInfo(token=token))
|
| 400 |
+
await self._save()
|
| 401 |
+
logger.info(f"Pool '{pool_name}': token added")
|
| 402 |
+
return True
|
| 403 |
+
|
| 404 |
+
async def mark_asset_clear(self, token: str) -> bool:
|
| 405 |
+
"""Record online asset cleanup timestamp."""
|
| 406 |
+
info, _ = self._find_token_info(token)
|
| 407 |
+
if info:
|
| 408 |
+
info.last_asset_clear_at = int(datetime.now().timestamp() * 1000)
|
| 409 |
+
self._schedule_save()
|
| 410 |
+
return True
|
| 411 |
+
return False
|
| 412 |
+
|
| 413 |
+
async def set_token_invalid(self, token_str: str, reason: str = "", save: bool = True) -> bool:
|
| 414 |
+
"""Mark a token as expired/invalid."""
|
| 415 |
+
token, raw_token = self._find_token_info(token_str)
|
| 416 |
+
if not token:
|
| 417 |
+
logger.warning(f"Token {raw_token[:10]}...: not found for invalidation")
|
| 418 |
+
return False
|
| 419 |
+
|
| 420 |
+
token.status = TokenStatus.EXPIRED
|
| 421 |
+
token.fail_count = max(token.fail_count, FAIL_THRESHOLD)
|
| 422 |
+
token.last_fail_at = int(datetime.now().timestamp() * 1000)
|
| 423 |
+
if reason:
|
| 424 |
+
token.last_fail_reason = str(reason)[:500]
|
| 425 |
+
|
| 426 |
+
if save:
|
| 427 |
+
await self._save()
|
| 428 |
+
return True
|
| 429 |
+
|
| 430 |
+
async def mark_token_account_settings_success(self, token_str: str, save: bool = True) -> bool:
|
| 431 |
+
"""Reset failure state after account-settings flow succeeded."""
|
| 432 |
+
token, raw_token = self._find_token_info(token_str)
|
| 433 |
+
if not token:
|
| 434 |
+
logger.warning(f"Token {raw_token[:10]}...: not found for account-settings success")
|
| 435 |
+
return False
|
| 436 |
+
|
| 437 |
+
token.fail_count = 0
|
| 438 |
+
token.last_fail_at = None
|
| 439 |
+
token.last_fail_reason = None
|
| 440 |
+
token.last_sync_at = int(datetime.now().timestamp() * 1000)
|
| 441 |
+
token.status = TokenStatus.COOLING if token.quota == 0 else TokenStatus.ACTIVE
|
| 442 |
+
|
| 443 |
+
if save:
|
| 444 |
+
await self._save()
|
| 445 |
+
return True
|
| 446 |
+
|
| 447 |
+
async def commit(self):
|
| 448 |
+
"""Persist current in-memory token state."""
|
| 449 |
+
await self._save()
|
| 450 |
+
|
| 451 |
+
async def remove(self, token: str) -> bool:
|
| 452 |
+
"""
|
| 453 |
+
删除 Token
|
| 454 |
+
|
| 455 |
+
Args:
|
| 456 |
+
token: Token 字符串
|
| 457 |
+
|
| 458 |
+
Returns:
|
| 459 |
+
是否成功
|
| 460 |
+
"""
|
| 461 |
+
for pool_name, pool in self.pools.items():
|
| 462 |
+
if pool.remove(token):
|
| 463 |
+
await self._save()
|
| 464 |
+
logger.info(f"Pool '{pool_name}': token removed")
|
| 465 |
+
return True
|
| 466 |
+
|
| 467 |
+
logger.warning(f"Token not found for removal")
|
| 468 |
+
return False
|
| 469 |
+
|
| 470 |
+
async def reset_all(self):
|
| 471 |
+
"""重置所有 Token 配额"""
|
| 472 |
+
count = 0
|
| 473 |
+
for pool in self.pools.values():
|
| 474 |
+
for token in pool:
|
| 475 |
+
token.reset()
|
| 476 |
+
count += 1
|
| 477 |
+
|
| 478 |
+
await self._save()
|
| 479 |
+
logger.info(f"Reset all: {count} tokens updated")
|
| 480 |
+
|
| 481 |
+
async def reset_token(self, token_str: str) -> bool:
|
| 482 |
+
"""
|
| 483 |
+
重置单个 Token
|
| 484 |
+
|
| 485 |
+
Args:
|
| 486 |
+
token_str: Token 字符串
|
| 487 |
+
|
| 488 |
+
Returns:
|
| 489 |
+
是否成功
|
| 490 |
+
"""
|
| 491 |
+
raw_token = token_str.replace("sso=", "")
|
| 492 |
+
|
| 493 |
+
for pool in self.pools.values():
|
| 494 |
+
token = pool.get(raw_token)
|
| 495 |
+
if token:
|
| 496 |
+
token.reset()
|
| 497 |
+
await self._save()
|
| 498 |
+
logger.info(f"Token {raw_token[:10]}...: reset completed")
|
| 499 |
+
return True
|
| 500 |
+
|
| 501 |
+
logger.warning(f"Token {raw_token[:10]}...: not found for reset")
|
| 502 |
+
return False
|
| 503 |
+
|
| 504 |
+
def get_stats(self) -> Dict[str, dict]:
|
| 505 |
+
"""获取统计信息"""
|
| 506 |
+
stats = {}
|
| 507 |
+
for name, pool in self.pools.items():
|
| 508 |
+
pool_stats = pool.get_stats()
|
| 509 |
+
stats[name] = pool_stats.model_dump()
|
| 510 |
+
return stats
|
| 511 |
+
|
| 512 |
+
def get_pool_tokens(self, pool_name: str = "ssoBasic") -> List[TokenInfo]:
|
| 513 |
+
"""
|
| 514 |
+
获取指定池的所有 Token
|
| 515 |
+
|
| 516 |
+
Args:
|
| 517 |
+
pool_name: 池名称
|
| 518 |
+
|
| 519 |
+
Returns:
|
| 520 |
+
Token 列表
|
| 521 |
+
"""
|
| 522 |
+
pool = self.pools.get(pool_name)
|
| 523 |
+
if not pool:
|
| 524 |
+
return []
|
| 525 |
+
return pool.list()
|
| 526 |
+
|
| 527 |
+
async def refresh_cooling_tokens(self) -> Dict[str, int]:
|
| 528 |
+
"""
|
| 529 |
+
批量刷新 cooling 状态的 Token 配额
|
| 530 |
+
|
| 531 |
+
Returns:
|
| 532 |
+
{"checked": int, "refreshed": int, "recovered": int, "expired": int}
|
| 533 |
+
"""
|
| 534 |
+
from app.services.grok.usage import UsageService
|
| 535 |
+
|
| 536 |
+
# 收集需要刷新的 token
|
| 537 |
+
to_refresh: List[TokenInfo] = []
|
| 538 |
+
for pool in self.pools.values():
|
| 539 |
+
for token in pool:
|
| 540 |
+
if token.need_refresh(REFRESH_INTERVAL_HOURS):
|
| 541 |
+
to_refresh.append(token)
|
| 542 |
+
|
| 543 |
+
if not to_refresh:
|
| 544 |
+
logger.debug("Refresh check: no tokens need refresh")
|
| 545 |
+
return {"checked": 0, "refreshed": 0, "recovered": 0, "expired": 0}
|
| 546 |
+
|
| 547 |
+
logger.info(f"Refresh check: found {len(to_refresh)} cooling tokens to refresh")
|
| 548 |
+
|
| 549 |
+
# 批量并发刷新
|
| 550 |
+
semaphore = asyncio.Semaphore(REFRESH_CONCURRENCY)
|
| 551 |
+
usage_service = UsageService()
|
| 552 |
+
refreshed = 0
|
| 553 |
+
recovered = 0
|
| 554 |
+
expired = 0
|
| 555 |
+
|
| 556 |
+
async def _refresh_one(token_info: TokenInfo) -> dict:
|
| 557 |
+
"""刷新单个 token"""
|
| 558 |
+
async with semaphore:
|
| 559 |
+
token_str = token_info.token
|
| 560 |
+
if token_str.startswith("sso="):
|
| 561 |
+
token_str = token_str[4:]
|
| 562 |
+
|
| 563 |
+
# 重试逻辑:最多 2 次重试
|
| 564 |
+
for retry in range(3): # 0, 1, 2
|
| 565 |
+
try:
|
| 566 |
+
result = await usage_service.get(token_str)
|
| 567 |
+
|
| 568 |
+
if result and "remainingTokens" in result:
|
| 569 |
+
new_quota = result["remainingTokens"]
|
| 570 |
+
old_quota = token_info.quota
|
| 571 |
+
old_status = token_info.status
|
| 572 |
+
|
| 573 |
+
token_info.update_quota(new_quota)
|
| 574 |
+
token_info.mark_synced()
|
| 575 |
+
|
| 576 |
+
logger.info(
|
| 577 |
+
f"Token {token_info.token[:10]}...: refreshed "
|
| 578 |
+
f"{old_quota} -> {new_quota}, status: {old_status} -> {token_info.status}"
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
return {
|
| 582 |
+
"recovered": new_quota > 0 and old_quota == 0,
|
| 583 |
+
"expired": False
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
token_info.mark_synced()
|
| 587 |
+
return {"recovered": False, "expired": False}
|
| 588 |
+
|
| 589 |
+
except Exception as e:
|
| 590 |
+
error_str = str(e)
|
| 591 |
+
|
| 592 |
+
# 检查是否为 401 错误
|
| 593 |
+
if "401" in error_str or "Unauthorized" in error_str:
|
| 594 |
+
if retry < 2:
|
| 595 |
+
logger.warning(
|
| 596 |
+
f"Token {token_info.token[:10]}...: 401 error, "
|
| 597 |
+
f"retry {retry + 1}/2..."
|
| 598 |
+
)
|
| 599 |
+
await asyncio.sleep(0.5)
|
| 600 |
+
continue
|
| 601 |
+
else:
|
| 602 |
+
# 重试 2 次后仍然 401,标记为 expired
|
| 603 |
+
logger.error(
|
| 604 |
+
f"Token {token_info.token[:10]}...: 401 after 2 retries, "
|
| 605 |
+
f"marking as expired"
|
| 606 |
+
)
|
| 607 |
+
token_info.status = TokenStatus.EXPIRED
|
| 608 |
+
token_info.mark_synced()
|
| 609 |
+
return {"recovered": False, "expired": True}
|
| 610 |
+
else:
|
| 611 |
+
logger.warning(
|
| 612 |
+
f"Token {token_info.token[:10]}...: refresh failed ({e})"
|
| 613 |
+
)
|
| 614 |
+
token_info.mark_synced()
|
| 615 |
+
return {"recovered": False, "expired": False}
|
| 616 |
+
|
| 617 |
+
token_info.mark_synced()
|
| 618 |
+
return {"recovered": False, "expired": False}
|
| 619 |
+
|
| 620 |
+
# 批量处理
|
| 621 |
+
for i in range(0, len(to_refresh), REFRESH_BATCH_SIZE):
|
| 622 |
+
batch = to_refresh[i:i + REFRESH_BATCH_SIZE]
|
| 623 |
+
results = await asyncio.gather(*[_refresh_one(t) for t in batch])
|
| 624 |
+
refreshed += len(batch)
|
| 625 |
+
recovered += sum(r["recovered"] for r in results)
|
| 626 |
+
expired += sum(r["expired"] for r in results)
|
| 627 |
+
|
| 628 |
+
# 批次间延迟
|
| 629 |
+
if i + REFRESH_BATCH_SIZE < len(to_refresh):
|
| 630 |
+
await asyncio.sleep(1)
|
| 631 |
+
|
| 632 |
+
await self._save()
|
| 633 |
+
|
| 634 |
+
logger.info(
|
| 635 |
+
f"Refresh completed: "
|
| 636 |
+
f"checked={len(to_refresh)}, refreshed={refreshed}, "
|
| 637 |
+
f"recovered={recovered}, expired={expired}"
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
return {
|
| 641 |
+
"checked": len(to_refresh),
|
| 642 |
+
"refreshed": refreshed,
|
| 643 |
+
"recovered": recovered,
|
| 644 |
+
"expired": expired
|
| 645 |
+
}
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
# 便捷函数
|
| 649 |
+
async def get_token_manager() -> TokenManager:
|
| 650 |
+
"""获取 TokenManager 单例"""
|
| 651 |
+
return await TokenManager.get_instance()
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
__all__ = ["TokenManager", "get_token_manager"]
|
app/services/token/models.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Token 数据模型
|
| 3 |
+
|
| 4 |
+
额度规则:
|
| 5 |
+
- 新号默认 80 配额
|
| 6 |
+
- 重置后恢复 80
|
| 7 |
+
- lowEffort 扣 1,highEffort 扣 4
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from enum import Enum
|
| 11 |
+
from typing import Optional, List
|
| 12 |
+
from pydantic import BaseModel, Field
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# 默认配额
|
| 17 |
+
DEFAULT_QUOTA = 80
|
| 18 |
+
|
| 19 |
+
# 失败阈值
|
| 20 |
+
FAIL_THRESHOLD = 5
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TokenStatus(str, Enum):
|
| 24 |
+
"""Token 状态"""
|
| 25 |
+
ACTIVE = "active"
|
| 26 |
+
DISABLED = "disabled"
|
| 27 |
+
EXPIRED = "expired"
|
| 28 |
+
COOLING = "cooling"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class EffortType(str, Enum):
|
| 32 |
+
"""请求消耗类型"""
|
| 33 |
+
LOW = "low" # 扣 1
|
| 34 |
+
HIGH = "high" # 扣 4
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
EFFORT_COST = {
|
| 38 |
+
EffortType.LOW: 1,
|
| 39 |
+
EffortType.HIGH: 4,
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class TokenInfo(BaseModel):
|
| 44 |
+
"""Token 信息"""
|
| 45 |
+
|
| 46 |
+
token: str
|
| 47 |
+
status: TokenStatus = TokenStatus.ACTIVE
|
| 48 |
+
quota: int = DEFAULT_QUOTA
|
| 49 |
+
heavy_quota: int = -1
|
| 50 |
+
|
| 51 |
+
# 统计
|
| 52 |
+
created_at: int = Field(default_factory=lambda: int(datetime.now().timestamp() * 1000))
|
| 53 |
+
last_used_at: Optional[int] = None
|
| 54 |
+
use_count: int = 0
|
| 55 |
+
|
| 56 |
+
# 失败追踪
|
| 57 |
+
fail_count: int = 0
|
| 58 |
+
last_fail_at: Optional[int] = None
|
| 59 |
+
last_fail_reason: Optional[str] = None
|
| 60 |
+
|
| 61 |
+
# 冷却管理
|
| 62 |
+
last_sync_at: Optional[int] = None # 上次同步时间
|
| 63 |
+
|
| 64 |
+
# 扩展
|
| 65 |
+
tags: List[str] = Field(default_factory=list)
|
| 66 |
+
note: str = ""
|
| 67 |
+
last_asset_clear_at: Optional[int] = None
|
| 68 |
+
|
| 69 |
+
def is_available(self) -> bool:
|
| 70 |
+
"""检查是否可用(状态正常且配额 > 0)"""
|
| 71 |
+
return self.status == TokenStatus.ACTIVE and self.quota > 0
|
| 72 |
+
|
| 73 |
+
def consume(self, effort: EffortType = EffortType.LOW) -> int:
|
| 74 |
+
"""
|
| 75 |
+
消耗配额
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
effort: LOW 扣 1,HIGH 扣 4
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
实际扣除的配额
|
| 82 |
+
"""
|
| 83 |
+
cost = EFFORT_COST[effort]
|
| 84 |
+
actual_cost = min(cost, self.quota)
|
| 85 |
+
|
| 86 |
+
self.last_used_at = int(datetime.now().timestamp() * 1000)
|
| 87 |
+
self.use_count += 1
|
| 88 |
+
self.quota = max(0, self.quota - cost)
|
| 89 |
+
|
| 90 |
+
# 成功消耗后清空失败计数
|
| 91 |
+
self.fail_count = 0
|
| 92 |
+
self.last_fail_reason = None
|
| 93 |
+
|
| 94 |
+
if self.quota == 0:
|
| 95 |
+
self.status = TokenStatus.COOLING
|
| 96 |
+
elif self.status in [TokenStatus.COOLING, TokenStatus.EXPIRED]:
|
| 97 |
+
self.status = TokenStatus.ACTIVE
|
| 98 |
+
|
| 99 |
+
return actual_cost
|
| 100 |
+
|
| 101 |
+
def update_quota(self, new_quota: int):
|
| 102 |
+
"""
|
| 103 |
+
更新配额(用于 API 同步)
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
new_quota: 新的配额值
|
| 107 |
+
"""
|
| 108 |
+
self.quota = max(0, new_quota)
|
| 109 |
+
|
| 110 |
+
if self.quota == 0:
|
| 111 |
+
self.status = TokenStatus.COOLING
|
| 112 |
+
elif self.quota > 0 and self.status in [TokenStatus.COOLING, TokenStatus.EXPIRED]:
|
| 113 |
+
self.status = TokenStatus.ACTIVE
|
| 114 |
+
|
| 115 |
+
def update_heavy_quota(self, new_quota: int):
|
| 116 |
+
"""
|
| 117 |
+
更新 heavy 配额(用于 grok-4-heavy 的 rate-limits 同步)。
|
| 118 |
+
|
| 119 |
+
注意:heavy 配额不参与 status 计算,避免误伤普通模型可用性。
|
| 120 |
+
"""
|
| 121 |
+
try:
|
| 122 |
+
v = int(new_quota)
|
| 123 |
+
except Exception:
|
| 124 |
+
v = 0
|
| 125 |
+
self.heavy_quota = max(0, v)
|
| 126 |
+
|
| 127 |
+
def consume_heavy(self, effort: EffortType = EffortType.LOW) -> int:
|
| 128 |
+
"""
|
| 129 |
+
消耗 heavy 配额(本地预估)。
|
| 130 |
+
|
| 131 |
+
当 heavy_quota 为 -1(未知)时,不扣减配额,仅记录一次使用。
|
| 132 |
+
"""
|
| 133 |
+
cost = EFFORT_COST[effort]
|
| 134 |
+
|
| 135 |
+
self.last_used_at = int(datetime.now().timestamp() * 1000)
|
| 136 |
+
self.use_count += 1
|
| 137 |
+
|
| 138 |
+
# 成功消耗后清空失败计数
|
| 139 |
+
self.fail_count = 0
|
| 140 |
+
self.last_fail_reason = None
|
| 141 |
+
|
| 142 |
+
if self.heavy_quota < 0:
|
| 143 |
+
return 0
|
| 144 |
+
|
| 145 |
+
actual_cost = min(cost, self.heavy_quota)
|
| 146 |
+
self.heavy_quota = max(0, self.heavy_quota - actual_cost)
|
| 147 |
+
return actual_cost
|
| 148 |
+
|
| 149 |
+
def reset(self):
|
| 150 |
+
"""重置配额到默认值"""
|
| 151 |
+
self.quota = DEFAULT_QUOTA
|
| 152 |
+
self.heavy_quota = -1
|
| 153 |
+
self.status = TokenStatus.ACTIVE
|
| 154 |
+
self.fail_count = 0
|
| 155 |
+
self.last_fail_reason = None
|
| 156 |
+
|
| 157 |
+
def record_fail(self, status_code: int = 401, reason: str = ""):
|
| 158 |
+
"""记录失败,达到阈值后自动标记为 expired"""
|
| 159 |
+
# 仅 401 错误才计入失败
|
| 160 |
+
if status_code != 401:
|
| 161 |
+
return
|
| 162 |
+
|
| 163 |
+
self.fail_count += 1
|
| 164 |
+
self.last_fail_at = int(datetime.now().timestamp() * 1000)
|
| 165 |
+
self.last_fail_reason = reason
|
| 166 |
+
|
| 167 |
+
if self.fail_count >= FAIL_THRESHOLD:
|
| 168 |
+
self.status = TokenStatus.EXPIRED
|
| 169 |
+
|
| 170 |
+
def record_success(self, is_usage: bool = True):
|
| 171 |
+
"""记录成功,清空失败计数并根据配额更新状态"""
|
| 172 |
+
self.fail_count = 0
|
| 173 |
+
self.last_fail_at = None
|
| 174 |
+
self.last_fail_reason = None
|
| 175 |
+
|
| 176 |
+
if is_usage:
|
| 177 |
+
self.use_count += 1
|
| 178 |
+
self.last_used_at = int(datetime.now().timestamp() * 1000)
|
| 179 |
+
|
| 180 |
+
if self.quota == 0:
|
| 181 |
+
self.status = TokenStatus.COOLING
|
| 182 |
+
else:
|
| 183 |
+
self.status = TokenStatus.ACTIVE
|
| 184 |
+
|
| 185 |
+
def need_refresh(self, interval_hours: int = 8) -> bool:
|
| 186 |
+
"""检查是否需要刷新配额"""
|
| 187 |
+
if self.status != TokenStatus.COOLING:
|
| 188 |
+
return False
|
| 189 |
+
|
| 190 |
+
if self.last_sync_at is None:
|
| 191 |
+
return True
|
| 192 |
+
|
| 193 |
+
now = int(datetime.now().timestamp() * 1000)
|
| 194 |
+
interval_ms = interval_hours * 3600 * 1000
|
| 195 |
+
return (now - self.last_sync_at) >= interval_ms
|
| 196 |
+
|
| 197 |
+
def mark_synced(self):
|
| 198 |
+
"""标记已同步"""
|
| 199 |
+
self.last_sync_at = int(datetime.now().timestamp() * 1000)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class TokenPoolStats(BaseModel):
|
| 203 |
+
"""Token 池统计"""
|
| 204 |
+
total: int = 0
|
| 205 |
+
active: int = 0
|
| 206 |
+
disabled: int = 0
|
| 207 |
+
expired: int = 0
|
| 208 |
+
cooling: int = 0
|
| 209 |
+
total_quota: int = 0
|
| 210 |
+
avg_quota: float = 0.0
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
__all__ = [
|
| 214 |
+
"TokenStatus",
|
| 215 |
+
"TokenInfo",
|
| 216 |
+
"TokenPoolStats",
|
| 217 |
+
"EffortType",
|
| 218 |
+
"EFFORT_COST",
|
| 219 |
+
"DEFAULT_QUOTA",
|
| 220 |
+
"FAIL_THRESHOLD",
|
| 221 |
+
]
|
app/services/token/pool.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Token 池管理"""
|
| 2 |
+
|
| 3 |
+
import random
|
| 4 |
+
from typing import Dict, List, Optional, Iterator
|
| 5 |
+
|
| 6 |
+
from app.services.token.models import TokenInfo, TokenStatus, TokenPoolStats
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TokenPool:
|
| 10 |
+
"""Token 池(管理一组 Token)"""
|
| 11 |
+
|
| 12 |
+
def __init__(self, name: str):
|
| 13 |
+
self.name = name
|
| 14 |
+
self._tokens: Dict[str, TokenInfo] = {}
|
| 15 |
+
|
| 16 |
+
def add(self, token: TokenInfo):
|
| 17 |
+
"""添加 Token"""
|
| 18 |
+
self._tokens[token.token] = token
|
| 19 |
+
|
| 20 |
+
def remove(self, token_str: str) -> bool:
|
| 21 |
+
"""删除 Token"""
|
| 22 |
+
if token_str in self._tokens:
|
| 23 |
+
del self._tokens[token_str]
|
| 24 |
+
return True
|
| 25 |
+
return False
|
| 26 |
+
|
| 27 |
+
def get(self, token_str: str) -> Optional[TokenInfo]:
|
| 28 |
+
"""获取 Token"""
|
| 29 |
+
return self._tokens.get(token_str)
|
| 30 |
+
|
| 31 |
+
def select(self, bucket: str = "normal") -> Optional[TokenInfo]:
|
| 32 |
+
"""
|
| 33 |
+
选择一个可用 Token
|
| 34 |
+
策略:
|
| 35 |
+
1. 选择 active 状态且有配额的 token
|
| 36 |
+
2. 优先选择剩余额度最多的
|
| 37 |
+
3. 如果额度相同,随机选择(避免并发冲突)
|
| 38 |
+
"""
|
| 39 |
+
# 选择 token
|
| 40 |
+
if bucket == "heavy":
|
| 41 |
+
available = [
|
| 42 |
+
t
|
| 43 |
+
for t in self._tokens.values()
|
| 44 |
+
if t.status in (TokenStatus.ACTIVE, TokenStatus.COOLING) and t.heavy_quota != 0
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
if not available:
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
+
unknown = [t for t in available if t.heavy_quota < 0]
|
| 51 |
+
if unknown:
|
| 52 |
+
return random.choice(unknown)
|
| 53 |
+
|
| 54 |
+
max_quota = max(t.heavy_quota for t in available)
|
| 55 |
+
candidates = [t for t in available if t.heavy_quota == max_quota]
|
| 56 |
+
return random.choice(candidates)
|
| 57 |
+
|
| 58 |
+
available = [
|
| 59 |
+
t for t in self._tokens.values()
|
| 60 |
+
if t.status == TokenStatus.ACTIVE and t.quota > 0
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
if not available:
|
| 64 |
+
return None
|
| 65 |
+
|
| 66 |
+
# 找到最大额度
|
| 67 |
+
max_quota = max(t.quota for t in available)
|
| 68 |
+
|
| 69 |
+
# 筛选最大额度
|
| 70 |
+
candidates = [t for t in available if t.quota == max_quota]
|
| 71 |
+
|
| 72 |
+
# 随机选择
|
| 73 |
+
return random.choice(candidates)
|
| 74 |
+
|
| 75 |
+
def count(self) -> int:
|
| 76 |
+
"""Token 数量"""
|
| 77 |
+
return len(self._tokens)
|
| 78 |
+
|
| 79 |
+
def list(self) -> List[TokenInfo]:
|
| 80 |
+
"""获取所有 Token"""
|
| 81 |
+
return list(self._tokens.values())
|
| 82 |
+
|
| 83 |
+
def get_stats(self) -> TokenPoolStats:
|
| 84 |
+
"""获取池统计信息"""
|
| 85 |
+
stats = TokenPoolStats(total=len(self._tokens))
|
| 86 |
+
|
| 87 |
+
for token in self._tokens.values():
|
| 88 |
+
stats.total_quota += token.quota
|
| 89 |
+
|
| 90 |
+
if token.status == TokenStatus.ACTIVE:
|
| 91 |
+
stats.active += 1
|
| 92 |
+
elif token.status == TokenStatus.DISABLED:
|
| 93 |
+
stats.disabled += 1
|
| 94 |
+
elif token.status == TokenStatus.EXPIRED:
|
| 95 |
+
stats.expired += 1
|
| 96 |
+
elif token.status == TokenStatus.COOLING:
|
| 97 |
+
stats.cooling += 1
|
| 98 |
+
|
| 99 |
+
if stats.total > 0:
|
| 100 |
+
stats.avg_quota = stats.total_quota / stats.total
|
| 101 |
+
|
| 102 |
+
return stats
|
| 103 |
+
|
| 104 |
+
def _rebuild_index(self):
|
| 105 |
+
"""重建索引(预留接口,用于加载时调用)"""
|
| 106 |
+
pass
|
| 107 |
+
|
| 108 |
+
def __iter__(self) -> Iterator[TokenInfo]:
|
| 109 |
+
return iter(self._tokens.values())
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
__all__ = ["TokenPool"]
|
app/services/token/scheduler.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Token 刷新调度器"""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
from app.core.logger import logger
|
| 7 |
+
from app.core.storage import get_storage, StorageError, RedisStorage
|
| 8 |
+
from app.services.token.manager import get_token_manager
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TokenRefreshScheduler:
|
| 12 |
+
"""Token 自动刷新调度器"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, interval_hours: int = 8):
|
| 15 |
+
self.interval_hours = interval_hours
|
| 16 |
+
self.interval_seconds = interval_hours * 3600
|
| 17 |
+
self._task: Optional[asyncio.Task] = None
|
| 18 |
+
self._running = False
|
| 19 |
+
|
| 20 |
+
async def _refresh_loop(self):
|
| 21 |
+
"""刷新循环"""
|
| 22 |
+
logger.info(f"Scheduler: started (interval: {self.interval_hours}h)")
|
| 23 |
+
|
| 24 |
+
while self._running:
|
| 25 |
+
try:
|
| 26 |
+
await asyncio.sleep(self.interval_seconds)
|
| 27 |
+
storage = get_storage()
|
| 28 |
+
lock_acquired = False
|
| 29 |
+
lock = None
|
| 30 |
+
|
| 31 |
+
if isinstance(storage, RedisStorage):
|
| 32 |
+
# Redis: non-blocking lock to avoid multi-worker duplication
|
| 33 |
+
lock_key = "grok2api:lock:token_refresh"
|
| 34 |
+
lock = storage.redis.lock(lock_key, timeout=self.interval_seconds + 60, blocking_timeout=0)
|
| 35 |
+
lock_acquired = await lock.acquire(blocking=False)
|
| 36 |
+
else:
|
| 37 |
+
try:
|
| 38 |
+
async with storage.acquire_lock("token_refresh", timeout=0):
|
| 39 |
+
lock_acquired = True
|
| 40 |
+
except StorageError:
|
| 41 |
+
lock_acquired = False
|
| 42 |
+
|
| 43 |
+
if not lock_acquired:
|
| 44 |
+
logger.info("Scheduler: skipped (lock not acquired)")
|
| 45 |
+
continue
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
logger.info("Scheduler: starting token refresh...")
|
| 49 |
+
manager = await get_token_manager()
|
| 50 |
+
result = await manager.refresh_cooling_tokens()
|
| 51 |
+
|
| 52 |
+
logger.info(
|
| 53 |
+
f"Scheduler: refresh completed - "
|
| 54 |
+
f"checked={result['checked']}, "
|
| 55 |
+
f"refreshed={result['refreshed']}, "
|
| 56 |
+
f"recovered={result['recovered']}, "
|
| 57 |
+
f"expired={result['expired']}"
|
| 58 |
+
)
|
| 59 |
+
finally:
|
| 60 |
+
if lock is not None and lock_acquired:
|
| 61 |
+
try:
|
| 62 |
+
await lock.release()
|
| 63 |
+
except Exception:
|
| 64 |
+
pass
|
| 65 |
+
|
| 66 |
+
except asyncio.CancelledError:
|
| 67 |
+
break
|
| 68 |
+
except Exception as e:
|
| 69 |
+
logger.error(f"Scheduler: refresh error - {e}")
|
| 70 |
+
|
| 71 |
+
def start(self):
|
| 72 |
+
"""启动调度器"""
|
| 73 |
+
if self._running:
|
| 74 |
+
logger.warning("Scheduler: already running")
|
| 75 |
+
return
|
| 76 |
+
|
| 77 |
+
self._running = True
|
| 78 |
+
self._task = asyncio.create_task(self._refresh_loop())
|
| 79 |
+
logger.info("Scheduler: enabled")
|
| 80 |
+
|
| 81 |
+
def stop(self):
|
| 82 |
+
"""停止调度器"""
|
| 83 |
+
if not self._running:
|
| 84 |
+
return
|
| 85 |
+
|
| 86 |
+
self._running = False
|
| 87 |
+
if self._task:
|
| 88 |
+
self._task.cancel()
|
| 89 |
+
logger.info("Scheduler: stopped")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# 全局单例
|
| 93 |
+
_scheduler: Optional[TokenRefreshScheduler] = None
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_scheduler(interval_hours: int = 8) -> TokenRefreshScheduler:
|
| 97 |
+
"""获取调度器单例"""
|
| 98 |
+
global _scheduler
|
| 99 |
+
if _scheduler is None:
|
| 100 |
+
_scheduler = TokenRefreshScheduler(interval_hours)
|
| 101 |
+
return _scheduler
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
__all__ = ["TokenRefreshScheduler", "get_scheduler"]
|
app/services/token/service.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Token 服务外观(Facade)"""
|
| 2 |
+
|
| 3 |
+
from typing import List, Optional, Dict
|
| 4 |
+
|
| 5 |
+
from app.services.token.manager import get_token_manager
|
| 6 |
+
from app.services.token.models import TokenInfo, EffortType
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TokenService:
|
| 10 |
+
"""
|
| 11 |
+
Token 服务外观
|
| 12 |
+
|
| 13 |
+
提供简化的 API,隐藏内部实现细节
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
@staticmethod
|
| 17 |
+
async def get_token(pool_name: str = "ssoBasic") -> Optional[str]:
|
| 18 |
+
"""
|
| 19 |
+
获取可用 Token
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
pool_name: Token 池名称
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
Token 字符串(不含 sso= 前缀)或 None
|
| 26 |
+
"""
|
| 27 |
+
manager = await get_token_manager()
|
| 28 |
+
return manager.get_token(pool_name)
|
| 29 |
+
|
| 30 |
+
@staticmethod
|
| 31 |
+
async def consume(token: str, effort: EffortType = EffortType.LOW) -> bool:
|
| 32 |
+
"""
|
| 33 |
+
消耗 Token 配额(本地预估)
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
token: Token 字符串
|
| 37 |
+
effort: 消耗力度
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
是否成功
|
| 41 |
+
"""
|
| 42 |
+
manager = await get_token_manager()
|
| 43 |
+
return await manager.consume(token, effort)
|
| 44 |
+
|
| 45 |
+
@staticmethod
|
| 46 |
+
async def sync_usage(
|
| 47 |
+
token: str,
|
| 48 |
+
model: str,
|
| 49 |
+
effort: EffortType = EffortType.LOW
|
| 50 |
+
) -> bool:
|
| 51 |
+
"""
|
| 52 |
+
同步 Token 使用量(优先 API,降级本地)
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
token: Token 字符串
|
| 56 |
+
model: 模型名称
|
| 57 |
+
effort: 降级时的消耗力度
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
是否成功
|
| 61 |
+
"""
|
| 62 |
+
manager = await get_token_manager()
|
| 63 |
+
return await manager.sync_usage(token, model, effort)
|
| 64 |
+
|
| 65 |
+
@staticmethod
|
| 66 |
+
async def record_fail(token: str, status_code: int = 401, reason: str = "") -> bool:
|
| 67 |
+
"""
|
| 68 |
+
记录 Token 失败
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
token: Token 字符串
|
| 72 |
+
status_code: HTTP 状态码
|
| 73 |
+
reason: 失败原因
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
是否成功
|
| 77 |
+
"""
|
| 78 |
+
manager = await get_token_manager()
|
| 79 |
+
return await manager.record_fail(token, status_code, reason)
|
| 80 |
+
|
| 81 |
+
@staticmethod
|
| 82 |
+
async def add_token(token: str, pool_name: str = "ssoBasic") -> bool:
|
| 83 |
+
"""
|
| 84 |
+
添加 Token
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
token: Token 字符串
|
| 88 |
+
pool: Token 池名称
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
是否成功
|
| 92 |
+
"""
|
| 93 |
+
manager = await get_token_manager()
|
| 94 |
+
return await manager.add(token, pool_name)
|
| 95 |
+
|
| 96 |
+
@staticmethod
|
| 97 |
+
async def remove_token(token: str) -> bool:
|
| 98 |
+
"""
|
| 99 |
+
删除 Token
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
token: Token 字符串
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
是否成功
|
| 106 |
+
"""
|
| 107 |
+
manager = await get_token_manager()
|
| 108 |
+
return await manager.remove(token)
|
| 109 |
+
|
| 110 |
+
@staticmethod
|
| 111 |
+
async def reset_token(token: str) -> bool:
|
| 112 |
+
"""
|
| 113 |
+
重置单个 Token
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
token: Token 字符串
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
是否成功
|
| 120 |
+
"""
|
| 121 |
+
manager = await get_token_manager()
|
| 122 |
+
return await manager.reset_token(token)
|
| 123 |
+
|
| 124 |
+
@staticmethod
|
| 125 |
+
async def reset_all():
|
| 126 |
+
"""重置所有 Token"""
|
| 127 |
+
manager = await get_token_manager()
|
| 128 |
+
await manager.reset_all()
|
| 129 |
+
|
| 130 |
+
@staticmethod
|
| 131 |
+
async def get_stats() -> Dict[str, dict]:
|
| 132 |
+
"""
|
| 133 |
+
获取统计信息
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
各池的统计信息
|
| 137 |
+
"""
|
| 138 |
+
manager = await get_token_manager()
|
| 139 |
+
return manager.get_stats()
|
| 140 |
+
|
| 141 |
+
@staticmethod
|
| 142 |
+
async def list_tokens(pool_name: str = "ssoBasic") -> List[TokenInfo]:
|
| 143 |
+
"""
|
| 144 |
+
获取指定池的所有 Token
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
pool_name: Token 池名称
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
Token 列表
|
| 151 |
+
"""
|
| 152 |
+
manager = await get_token_manager()
|
| 153 |
+
return manager.get_pool_tokens(pool_name)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
__all__ = ["TokenService"]
|
app/static/.assetsignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_worker.js
|
| 2 |
+
|
app/static/_worker.js
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import worker from "../../src/index.ts";
|
| 2 |
+
|
| 3 |
+
export default worker;
|
| 4 |
+
|