Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .env.example +103 -0
- .gitattributes +8 -0
- .github/workflows/docker-publish.yml +70 -0
- .github/workflows/release.yml +55 -0
- .gitignore +261 -0
- .vscode/launch.json +24 -0
- .vscode/settings.json +3 -0
- Dockerfile +19 -0
- LICENSE +17 -0
- README.md +11 -6
- README_ZH.md +290 -0
- VERSION +1 -0
- app/config/config.py +533 -0
- app/core/application.py +156 -0
- app/core/constants.py +112 -0
- app/core/security.py +90 -0
- app/database/__init__.py +3 -0
- app/database/connection.py +71 -0
- app/database/initialization.py +77 -0
- app/database/models.py +129 -0
- app/database/services.py +805 -0
- app/domain/file_models.py +69 -0
- app/domain/gemini_models.py +115 -0
- app/domain/image_models.py +20 -0
- app/domain/openai_models.py +43 -0
- app/exception/exceptions.py +135 -0
- app/handler/error_handler.py +32 -0
- app/handler/message_converter.py +363 -0
- app/handler/response_handler.py +449 -0
- app/handler/retry_handler.py +51 -0
- app/handler/stream_optimizer.py +143 -0
- app/log/logger.py +349 -0
- app/main.py +15 -0
- app/middleware/middleware.py +81 -0
- app/middleware/request_logging_middleware.py +40 -0
- app/middleware/smart_routing_middleware.py +210 -0
- app/router/config_routes.py +225 -0
- app/router/error_log_routes.py +271 -0
- app/router/files_routes.py +296 -0
- app/router/gemini_routes.py +618 -0
- app/router/key_routes.py +83 -0
- app/router/openai_compatiable_routes.py +146 -0
- app/router/openai_routes.py +206 -0
- app/router/routes.py +290 -0
- app/router/scheduler_routes.py +57 -0
- app/router/stats_routes.py +56 -0
- app/router/version_routes.py +37 -0
- app/router/vertex_express_routes.py +191 -0
- app/scheduler/scheduled_tasks.py +196 -0
- app/service/chat/gemini_chat_service.py +545 -0
.env.example
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 数据库配置
|
| 2 |
+
DATABASE_TYPE=mysql
|
| 3 |
+
#SQLITE_DATABASE=default_db
|
| 4 |
+
MYSQL_HOST=gemini-balance-mysql
|
| 5 |
+
#MYSQL_SOCKET=/run/mysqld/mysqld.sock
|
| 6 |
+
MYSQL_PORT=3306
|
| 7 |
+
MYSQL_USER=gemini
|
| 8 |
+
MYSQL_PASSWORD=change_me
|
| 9 |
+
MYSQL_DATABASE=default_db
|
| 10 |
+
API_KEYS=["AIzaSyxxxxxxxxxxxxxxxxxxx","AIzaSyxxxxxxxxxxxxxxxxxxx"]
|
| 11 |
+
ALLOWED_TOKENS=["sk-123456"]
|
| 12 |
+
AUTH_TOKEN=sk-123456
|
| 13 |
+
# For Vertex AI Platform API Keys
|
| 14 |
+
VERTEX_API_KEYS=["AQ.Abxxxxxxxxxxxxxxxxxxx"]
|
| 15 |
+
# For Vertex AI Platform Express API Base URL
|
| 16 |
+
VERTEX_EXPRESS_BASE_URL=https://aiplatform.googleapis.com/v1beta1/publishers/google
|
| 17 |
+
TEST_MODEL=gemini-2.5-flash-lite
|
| 18 |
+
THINKING_MODELS=["gemini-2.5-flash","gemini-2.5-pro"]
|
| 19 |
+
THINKING_BUDGET_MAP={"gemini-2.5-flash": -1}
|
| 20 |
+
IMAGE_MODELS=["gemini-2.0-flash-exp", "gemini-2.5-flash-image-preview"]
|
| 21 |
+
SEARCH_MODELS=["gemini-2.5-flash","gemini-2.5-pro"]
|
| 22 |
+
FILTERED_MODELS=["gemini-1.0-pro-vision-latest", "gemini-pro-vision", "chat-bison-001", "text-bison-001", "embedding-gecko-001"]
|
| 23 |
+
# 是否启用网址上下文,默认启用
|
| 24 |
+
URL_CONTEXT_ENABLED=false
|
| 25 |
+
URL_CONTEXT_MODELS=["gemini-2.5-pro","gemini-2.5-flash","gemini-2.5-flash-lite","gemini-2.0-flash","gemini-2.0-flash-live-001"]
|
| 26 |
+
TOOLS_CODE_EXECUTION_ENABLED=false
|
| 27 |
+
SHOW_SEARCH_LINK=true
|
| 28 |
+
SHOW_THINKING_PROCESS=true
|
| 29 |
+
BASE_URL=https://generativelanguage.googleapis.com/v1beta
|
| 30 |
+
MAX_FAILURES=10
|
| 31 |
+
MAX_RETRIES=3
|
| 32 |
+
CHECK_INTERVAL_HOURS=1
|
| 33 |
+
TIMEZONE=Asia/Shanghai
|
| 34 |
+
# 请求超时时间(秒)
|
| 35 |
+
TIME_OUT=300
|
| 36 |
+
# 代理服务器配置 (支持 http 和 socks5)
|
| 37 |
+
# 示例: PROXIES=["http://user:pass@host:port", "socks5://host:port"]
|
| 38 |
+
PROXIES=[]
|
| 39 |
+
# 对同一个API_KEY使用代理列表中固定的IP策略
|
| 40 |
+
PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY=true
|
| 41 |
+
#########################image_generate 相关配置###########################
|
| 42 |
+
PAID_KEY=AIzaSyxxxxxxxxxxxxxxxxxxx
|
| 43 |
+
CREATE_IMAGE_MODEL=imagen-3.0-generate-002
|
| 44 |
+
UPLOAD_PROVIDER=smms
|
| 45 |
+
SMMS_SECRET_TOKEN=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
| 46 |
+
PICGO_API_KEY=xxxx
|
| 47 |
+
PICGO_API_URL=https://www.picgo.net/api/1/upload
|
| 48 |
+
CLOUDFLARE_IMGBED_URL=https://xxxxxxx.pages.dev/upload
|
| 49 |
+
CLOUDFLARE_IMGBED_AUTH_CODE=xxxxxxxxx
|
| 50 |
+
CLOUDFLARE_IMGBED_UPLOAD_FOLDER=
|
| 51 |
+
# 阿里云OSS配置
|
| 52 |
+
OSS_ENDPOINT=oss-cn-shanghai.aliyuncs.com
|
| 53 |
+
OSS_ENDPOINT_INNER=oss-cn-shanghai-internal.aliyuncs.com
|
| 54 |
+
OSS_ACCESS_KEY=LTAI5txxxxxxxxxxxxxxxx
|
| 55 |
+
OSS_ACCESS_KEY_SECRET=yXxxxxxxxxxxxxxxxxxxxxx
|
| 56 |
+
OSS_BUCKET_NAME=your-bucket-name
|
| 57 |
+
OSS_REGION=cn-shanghai
|
| 58 |
+
##########################################################################
|
| 59 |
+
#########################stream_optimizer 相关配置########################
|
| 60 |
+
STREAM_OPTIMIZER_ENABLED=false
|
| 61 |
+
STREAM_MIN_DELAY=0.016
|
| 62 |
+
STREAM_MAX_DELAY=0.024
|
| 63 |
+
STREAM_SHORT_TEXT_THRESHOLD=10
|
| 64 |
+
STREAM_LONG_TEXT_THRESHOLD=50
|
| 65 |
+
STREAM_CHUNK_SIZE=5
|
| 66 |
+
##########################################################################
|
| 67 |
+
######################### 日志配置 #######################################
|
| 68 |
+
# 日志级别 (debug, info, warning, error, critical),默认为 info
|
| 69 |
+
LOG_LEVEL=info
|
| 70 |
+
# 是否记录错误日志的请求体(可能包含敏感信息),默认 false
|
| 71 |
+
ERROR_LOG_RECORD_REQUEST_BODY=false
|
| 72 |
+
# 是否开启自动删除错误日志
|
| 73 |
+
AUTO_DELETE_ERROR_LOGS_ENABLED=true
|
| 74 |
+
# 自动删除多少天前的错误日志 (1, 7, 30)
|
| 75 |
+
AUTO_DELETE_ERROR_LOGS_DAYS=7
|
| 76 |
+
# 是否开启自动删除请求日志
|
| 77 |
+
AUTO_DELETE_REQUEST_LOGS_ENABLED=false
|
| 78 |
+
# 自动删除多少天前的请求日志 (1, 7, 30)
|
| 79 |
+
AUTO_DELETE_REQUEST_LOGS_DAYS=30
|
| 80 |
+
##########################################################################
|
| 81 |
+
|
| 82 |
+
# 假流式配置 (Fake Streaming Configuration)
|
| 83 |
+
# 是否启用假流式输出
|
| 84 |
+
FAKE_STREAM_ENABLED=True
|
| 85 |
+
# 假流式发送空数据的间隔时间(秒)
|
| 86 |
+
FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS=5
|
| 87 |
+
|
| 88 |
+
# 安全设置 (JSON 字符串格式)
|
| 89 |
+
# 注意:这里的示例值可能需要根据实际模型支持情况调整
|
| 90 |
+
SAFETY_SETTINGS=[{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}]
|
| 91 |
+
URL_NORMALIZATION_ENABLED=false
|
| 92 |
+
# tts配置
|
| 93 |
+
TTS_MODEL=gemini-2.5-flash-preview-tts
|
| 94 |
+
TTS_VOICE_NAME=Zephyr
|
| 95 |
+
TTS_SPEED=normal
|
| 96 |
+
#########################Files API 相关配置########################
|
| 97 |
+
# 是否启用文件过期自动清理
|
| 98 |
+
FILES_CLEANUP_ENABLED=true
|
| 99 |
+
# 文件过期清理间隔(小时)
|
| 100 |
+
FILES_CLEANUP_INTERVAL_HOURS=1
|
| 101 |
+
# 是否启用用户文件隔离(每个用户只能看到自己上传的文件)
|
| 102 |
+
FILES_USER_ISOLATION_ENABLED=true
|
| 103 |
+
##########################################################################
|
.gitattributes
CHANGED
|
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
files/image.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
files/image1.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
files/image2.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
files/image3.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
files/image4.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
files/image5.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
files/image6.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
files/image7.png filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/docker-publish.yml
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Docker Image CI
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
pull_request:
|
| 6 |
+
branches: [ "main" ]
|
| 7 |
+
|
| 8 |
+
env:
|
| 9 |
+
REGISTRY: ghcr.io
|
| 10 |
+
# github.repository as <account>/<repo>
|
| 11 |
+
IMAGE_NAME: ${{ github.repository }}
|
| 12 |
+
|
| 13 |
+
jobs:
|
| 14 |
+
build:
|
| 15 |
+
runs-on: ubuntu-latest
|
| 16 |
+
permissions:
|
| 17 |
+
contents: read
|
| 18 |
+
packages: write
|
| 19 |
+
# 这个权限用于标记容器镜像
|
| 20 |
+
id-token: write
|
| 21 |
+
|
| 22 |
+
steps:
|
| 23 |
+
- name: Checkout repository
|
| 24 |
+
uses: actions/checkout@v4
|
| 25 |
+
|
| 26 |
+
- name: Set up Docker Buildx
|
| 27 |
+
uses: docker/setup-buildx-action@v3
|
| 28 |
+
|
| 29 |
+
# 登录到 GitHub Container Registry
|
| 30 |
+
- name: Log into registry ${{ env.REGISTRY }}
|
| 31 |
+
if: github.event_name != 'pull_request'
|
| 32 |
+
uses: docker/login-action@v3
|
| 33 |
+
with:
|
| 34 |
+
registry: ${{ env.REGISTRY }}
|
| 35 |
+
username: ${{ github.actor }}
|
| 36 |
+
password: ${{ secrets.GITHUB_TOKEN }}
|
| 37 |
+
|
| 38 |
+
- name: Extract Docker metadata
|
| 39 |
+
id: meta
|
| 40 |
+
uses: docker/metadata-action@v5
|
| 41 |
+
with:
|
| 42 |
+
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
| 43 |
+
tags: |
|
| 44 |
+
# https://github.com/docker/metadata-action/tree/v5/?tab=readme-ov-file#semver
|
| 45 |
+
# Event: push, Ref: refs/head/main, Tags: main
|
| 46 |
+
# Event: push tag, Ref: refs/tags/v1.2.3, Tags: 1.2.3, 1.2, 1, latest
|
| 47 |
+
# Event: push tag, Ref: refs/tags/v2.0.8-rc1, Tags: 2.0.8-rc1
|
| 48 |
+
type=ref,event=branch
|
| 49 |
+
type=semver,pattern={{version}}
|
| 50 |
+
type=semver,pattern={{major}}.{{minor}}
|
| 51 |
+
type=semver,pattern={{major}}
|
| 52 |
+
labels: |
|
| 53 |
+
org.opencontainers.image.description=OpenAI API Compatible Server
|
| 54 |
+
org.opencontainers.image.source=${{ github.event.repository.html_url }}
|
| 55 |
+
|
| 56 |
+
- name: Set up QEMU
|
| 57 |
+
uses: docker/setup-qemu-action@v3
|
| 58 |
+
|
| 59 |
+
- name: Build and push
|
| 60 |
+
uses: docker/build-push-action@v6
|
| 61 |
+
with:
|
| 62 |
+
file: Dockerfile
|
| 63 |
+
context: .
|
| 64 |
+
platforms: linux/amd64,linux/arm64
|
| 65 |
+
push: ${{ github.event_name != 'pull_request' }}
|
| 66 |
+
load: false
|
| 67 |
+
tags: ${{ steps.meta.outputs.tags }}
|
| 68 |
+
labels: ${{ steps.meta.outputs.labels }}
|
| 69 |
+
cache-from: type=gha,scope=${{ github.workflow }}
|
| 70 |
+
cache-to: type=gha,scope=${{ github.workflow }}
|
.github/workflows/release.yml
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Publish Release
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
tags:
|
| 6 |
+
- "v*" # 当推送以 "v" 开头的标签时触发(如 v1.0.0, v2.1.0)
|
| 7 |
+
|
| 8 |
+
jobs:
|
| 9 |
+
update-release-draft:
|
| 10 |
+
permissions:
|
| 11 |
+
contents: write
|
| 12 |
+
pull-requests: write
|
| 13 |
+
runs-on: ubuntu-latest
|
| 14 |
+
steps:
|
| 15 |
+
# Step 1: 检出代码库
|
| 16 |
+
- name: Checkout code
|
| 17 |
+
uses: actions/checkout@v3
|
| 18 |
+
with:
|
| 19 |
+
fetch-depth: 0
|
| 20 |
+
|
| 21 |
+
# Step 2: 自动生成 Release Notes
|
| 22 |
+
- name: Generate release notes
|
| 23 |
+
id: changelog
|
| 24 |
+
uses: mikepenz/release-changelog-builder-action@v4
|
| 25 |
+
env:
|
| 26 |
+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
| 27 |
+
|
| 28 |
+
# Step 3: 自动生成 Release
|
| 29 |
+
- name: Create Release
|
| 30 |
+
id: create_release
|
| 31 |
+
uses: actions/create-release@v1
|
| 32 |
+
env:
|
| 33 |
+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
| 34 |
+
with:
|
| 35 |
+
tag_name: ${{ github.ref_name }}
|
| 36 |
+
release_name: ${{ github.ref_name }}
|
| 37 |
+
body: ${{ steps.changelog.outputs.changelog }}
|
| 38 |
+
draft: false
|
| 39 |
+
prerelease: false
|
| 40 |
+
|
| 41 |
+
# Step 4: 可选,构建zip文件
|
| 42 |
+
- name: Create ZIP file
|
| 43 |
+
run: |
|
| 44 |
+
zip -r gemini-balance.zip . -x "*.git*" "*.github*" "*.env*" "logs/*" "tests/*"
|
| 45 |
+
|
| 46 |
+
# Step 5: 可选,上传构建文件
|
| 47 |
+
- name: Upload Release Asset
|
| 48 |
+
uses: actions/upload-release-asset@v1
|
| 49 |
+
env:
|
| 50 |
+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
| 51 |
+
with:
|
| 52 |
+
upload_url: ${{ steps.create_release.outputs.upload_url }}
|
| 53 |
+
asset_path: ./gemini-balance.zip # 替换为你的构建文件路径
|
| 54 |
+
asset_name: gemini-balance.zip # 替换为你的文件名
|
| 55 |
+
asset_content_type: application/zip
|
.gitignore
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# File created using '.gitignore Generator' for Visual Studio Code: https://bit.ly/vscode-gig
|
| 2 |
+
# Created by https://www.toptal.com/developers/gitignore/api/windows,visualstudiocode,circuitpython,python,pythonvanilla
|
| 3 |
+
# Edit at https://www.toptal.com/developers/gitignore?templates=windows,visualstudiocode,circuitpython,python,pythonvanilla
|
| 4 |
+
|
| 5 |
+
### CircuitPython ###
|
| 6 |
+
.Trashes
|
| 7 |
+
.metadata_never_index
|
| 8 |
+
.fseventsd/
|
| 9 |
+
boot_out.txt
|
| 10 |
+
|
| 11 |
+
### Python ###
|
| 12 |
+
# Byte-compiled / optimized / DLL files
|
| 13 |
+
__pycache__/
|
| 14 |
+
*.py[cod]
|
| 15 |
+
*$py.class
|
| 16 |
+
|
| 17 |
+
# C extensions
|
| 18 |
+
*.so
|
| 19 |
+
|
| 20 |
+
# Distribution / packaging
|
| 21 |
+
.Python
|
| 22 |
+
build/
|
| 23 |
+
develop-eggs/
|
| 24 |
+
dist/
|
| 25 |
+
downloads/
|
| 26 |
+
eggs/
|
| 27 |
+
.eggs/
|
| 28 |
+
lib/
|
| 29 |
+
lib64/
|
| 30 |
+
parts/
|
| 31 |
+
sdist/
|
| 32 |
+
var/
|
| 33 |
+
wheels/
|
| 34 |
+
share/python-wheels/
|
| 35 |
+
*.egg-info/
|
| 36 |
+
.installed.cfg
|
| 37 |
+
*.egg
|
| 38 |
+
MANIFEST
|
| 39 |
+
.idea/
|
| 40 |
+
|
| 41 |
+
# PyInstaller
|
| 42 |
+
# Usually these files are written by a python script from a template
|
| 43 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 44 |
+
*.manifest
|
| 45 |
+
*.spec
|
| 46 |
+
|
| 47 |
+
# Installer logs
|
| 48 |
+
pip-log.txt
|
| 49 |
+
pip-delete-this-directory.txt
|
| 50 |
+
|
| 51 |
+
# Unit test / coverage reports
|
| 52 |
+
htmlcov/
|
| 53 |
+
.tox/
|
| 54 |
+
.nox/
|
| 55 |
+
.coverage
|
| 56 |
+
.coverage.*
|
| 57 |
+
.cache
|
| 58 |
+
nosetests.xml
|
| 59 |
+
coverage.xml
|
| 60 |
+
*.cover
|
| 61 |
+
*.py,cover
|
| 62 |
+
.hypothesis/
|
| 63 |
+
.pytest_cache/
|
| 64 |
+
cover/
|
| 65 |
+
|
| 66 |
+
# Translations
|
| 67 |
+
*.mo
|
| 68 |
+
*.pot
|
| 69 |
+
|
| 70 |
+
# Django stuff:
|
| 71 |
+
*.log
|
| 72 |
+
local_settings.py
|
| 73 |
+
db.sqlite3
|
| 74 |
+
db.sqlite3-journal
|
| 75 |
+
|
| 76 |
+
# Flask stuff:
|
| 77 |
+
instance/
|
| 78 |
+
.webassets-cache
|
| 79 |
+
|
| 80 |
+
# Scrapy stuff:
|
| 81 |
+
.scrapy
|
| 82 |
+
|
| 83 |
+
# Sphinx documentation
|
| 84 |
+
docs/_build/
|
| 85 |
+
|
| 86 |
+
# PyBuilder
|
| 87 |
+
.pybuilder/
|
| 88 |
+
target/
|
| 89 |
+
|
| 90 |
+
# Jupyter Notebook
|
| 91 |
+
.ipynb_checkpoints
|
| 92 |
+
|
| 93 |
+
# IPython
|
| 94 |
+
profile_default/
|
| 95 |
+
ipython_config.py
|
| 96 |
+
|
| 97 |
+
# pyenv
|
| 98 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 99 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 100 |
+
# .python-version
|
| 101 |
+
|
| 102 |
+
# pipenv
|
| 103 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 104 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 105 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 106 |
+
# install all needed dependencies.
|
| 107 |
+
#Pipfile.lock
|
| 108 |
+
|
| 109 |
+
# poetry
|
| 110 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 111 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 112 |
+
# commonly ignored for libraries.
|
| 113 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 114 |
+
#poetry.lock
|
| 115 |
+
|
| 116 |
+
# pdm
|
| 117 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 118 |
+
#pdm.lock
|
| 119 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 120 |
+
# in version control.
|
| 121 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 122 |
+
.pdm.toml
|
| 123 |
+
|
| 124 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 125 |
+
__pypackages__/
|
| 126 |
+
|
| 127 |
+
# Celery stuff
|
| 128 |
+
celerybeat-schedule
|
| 129 |
+
celerybeat.pid
|
| 130 |
+
|
| 131 |
+
# SageMath parsed files
|
| 132 |
+
*.sage.py
|
| 133 |
+
|
| 134 |
+
# Environments
|
| 135 |
+
.env
|
| 136 |
+
.venv
|
| 137 |
+
env/
|
| 138 |
+
venv/
|
| 139 |
+
ENV/
|
| 140 |
+
env.bak/
|
| 141 |
+
venv.bak/
|
| 142 |
+
|
| 143 |
+
# Spyder project settings
|
| 144 |
+
.spyderproject
|
| 145 |
+
.spyproject
|
| 146 |
+
|
| 147 |
+
# Rope project settings
|
| 148 |
+
.ropeproject
|
| 149 |
+
|
| 150 |
+
# mkdocs documentation
|
| 151 |
+
/site
|
| 152 |
+
|
| 153 |
+
# mypy
|
| 154 |
+
.mypy_cache/
|
| 155 |
+
.dmypy.json
|
| 156 |
+
dmypy.json
|
| 157 |
+
|
| 158 |
+
# Pyre type checker
|
| 159 |
+
.pyre/
|
| 160 |
+
|
| 161 |
+
# pytype static type analyzer
|
| 162 |
+
.pytype/
|
| 163 |
+
|
| 164 |
+
# Cython debug symbols
|
| 165 |
+
cython_debug/
|
| 166 |
+
|
| 167 |
+
# PyCharm
|
| 168 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 169 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 170 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 171 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 172 |
+
#.idea/
|
| 173 |
+
|
| 174 |
+
### Python Patch ###
|
| 175 |
+
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
|
| 176 |
+
poetry.toml
|
| 177 |
+
|
| 178 |
+
# ruff
|
| 179 |
+
.ruff_cache/
|
| 180 |
+
|
| 181 |
+
# LSP config files
|
| 182 |
+
pyrightconfig.json
|
| 183 |
+
|
| 184 |
+
### PythonVanilla ###
|
| 185 |
+
# Byte-compiled / optimized / DLL files
|
| 186 |
+
|
| 187 |
+
# C extensions
|
| 188 |
+
|
| 189 |
+
# Distribution / packaging
|
| 190 |
+
|
| 191 |
+
# Installer logs
|
| 192 |
+
|
| 193 |
+
# Unit test / coverage reports
|
| 194 |
+
|
| 195 |
+
# Translations
|
| 196 |
+
|
| 197 |
+
# pyenv
|
| 198 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 199 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 200 |
+
# .python-version
|
| 201 |
+
|
| 202 |
+
# pipenv
|
| 203 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 204 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 205 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 206 |
+
# install all needed dependencies.
|
| 207 |
+
|
| 208 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
### VisualStudioCode ###
|
| 212 |
+
.vscode/*
|
| 213 |
+
!.vscode/settings.json
|
| 214 |
+
!.vscode/tasks.json
|
| 215 |
+
!.vscode/launch.json
|
| 216 |
+
!.vscode/extensions.json
|
| 217 |
+
!.vscode/*.code-snippets
|
| 218 |
+
|
| 219 |
+
# Local History for Visual Studio Code
|
| 220 |
+
.history/
|
| 221 |
+
|
| 222 |
+
# Built Visual Studio Code Extensions
|
| 223 |
+
*.vsix
|
| 224 |
+
|
| 225 |
+
### VisualStudioCode Patch ###
|
| 226 |
+
# Ignore all local history of files
|
| 227 |
+
.history
|
| 228 |
+
.ionide
|
| 229 |
+
|
| 230 |
+
### Windows ###
|
| 231 |
+
# Windows thumbnail cache files
|
| 232 |
+
Thumbs.db
|
| 233 |
+
Thumbs.db:encryptable
|
| 234 |
+
ehthumbs.db
|
| 235 |
+
ehthumbs_vista.db
|
| 236 |
+
|
| 237 |
+
# Dump file
|
| 238 |
+
*.stackdump
|
| 239 |
+
|
| 240 |
+
# Folder config file
|
| 241 |
+
[Dd]esktop.ini
|
| 242 |
+
|
| 243 |
+
# Recycle Bin used on file shares
|
| 244 |
+
$RECYCLE.BIN/
|
| 245 |
+
|
| 246 |
+
# Windows Installer files
|
| 247 |
+
*.cab
|
| 248 |
+
*.msi
|
| 249 |
+
*.msix
|
| 250 |
+
*.msm
|
| 251 |
+
*.msp
|
| 252 |
+
|
| 253 |
+
# Windows shortcuts
|
| 254 |
+
*.lnk
|
| 255 |
+
|
| 256 |
+
# End of https://www.toptal.com/developers/gitignore/api/windows,visualstudiocode,circuitpython,python,pythonvanilla
|
| 257 |
+
|
| 258 |
+
# Custom rules (everything added below won't be overriden by 'Generate .gitignore File' if you use 'Update' option)
|
| 259 |
+
|
| 260 |
+
tests/
|
| 261 |
+
default_db
|
.vscode/launch.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
// 使用 IntelliSense 了解相关属性。
|
| 3 |
+
// 悬停以查看现有属性的描述。
|
| 4 |
+
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
|
| 5 |
+
"version": "0.2.0",
|
| 6 |
+
"configurations": [
|
| 7 |
+
{
|
| 8 |
+
"name": "Python 调试程序: FastAPI",
|
| 9 |
+
"type": "debugpy",
|
| 10 |
+
"request": "launch",
|
| 11 |
+
"module": "uvicorn",
|
| 12 |
+
"args": [
|
| 13 |
+
"app.main:app",
|
| 14 |
+
"--reload",
|
| 15 |
+
"--host",
|
| 16 |
+
"0.0.0.0",
|
| 17 |
+
"--port",
|
| 18 |
+
"8000",
|
| 19 |
+
// "--no-access-log"
|
| 20 |
+
],
|
| 21 |
+
"jinja": true
|
| 22 |
+
}
|
| 23 |
+
]
|
| 24 |
+
}
|
.vscode/settings.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"commentTranslate.source": "upupnoah.chatgpt-comment-translateX-chatgpt"
|
| 3 |
+
}
|
Dockerfile
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# 复制所需文件到容器中
|
| 6 |
+
COPY ./requirements.txt /app
|
| 7 |
+
COPY ./VERSION /app
|
| 8 |
+
|
| 9 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 10 |
+
COPY ./app /app/app
|
| 11 |
+
ENV API_KEYS='["your_api_key_1"]'
|
| 12 |
+
ENV ALLOWED_TOKENS='["your_token_1"]'
|
| 13 |
+
ENV TZ='Asia/Shanghai'
|
| 14 |
+
|
| 15 |
+
# Expose port
|
| 16 |
+
EXPOSE 8000
|
| 17 |
+
|
| 18 |
+
# Run the application
|
| 19 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--no-access-log"]
|
LICENSE
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
知识共享署名-非商业性使用 4.0 国际 (CC BY-NC 4.0) 协议
|
| 2 |
+
|
| 3 |
+
您可以自由地:
|
| 4 |
+
- 共享 — 在任何媒介以任何形式复制、发行本作品
|
| 5 |
+
- 演绎 — 修改、转换或以本作品为基础进行创作
|
| 6 |
+
|
| 7 |
+
惟须遵守下列条件:
|
| 8 |
+
- 署名 — 您必须给出适当的署名,提供指向本协议的链接,并指明是否(对原作)作了修改。您可以以任何合理方式进行,但不得以任何方式暗示许可方认可您或您的使用。
|
| 9 |
+
- 非商业性使用 — 您不得将本作品用于商业目的,包括但不限于任何形式的商业倒卖、SaaS、API 付费接口、二次销售、打包出售、收费分发或其他直接或间接盈利行为。
|
| 10 |
+
|
| 11 |
+
如需商业授权,请联系原作者获得书面许可。违者将承担相应法律责任。
|
| 12 |
+
|
| 13 |
+
Creative Commons Attribution-NonCommercial 4.0 International Public License
|
| 14 |
+
|
| 15 |
+
By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
|
| 16 |
+
|
| 17 |
+
Full license text: https://creativecommons.org/licenses/by-nc/4.0/legalcode
|
README.md
CHANGED
|
@@ -1,10 +1,15 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
-
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: "gemini"
|
| 3 |
+
emoji: "🚀"
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
---
|
| 9 |
|
| 10 |
+
### 🚀 一键部署
|
| 11 |
+
[](https://github.com/kfcx/HFSpaceDeploy)
|
| 12 |
+
|
| 13 |
+
本项目由[HFSpaceDeploy](https://github.com/kfcx/HFSpaceDeploy)一键部署
|
| 14 |
+
|
| 15 |
+
|
README_ZH.md
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Gemini Balance - Gemini API 代理和负载均衡器
|
| 2 |
+
|
| 3 |
+
<p align="center">
|
| 4 |
+
<a href="https://trendshift.io/repositories/13692" target="_blank">
|
| 5 |
+
<img src="https://trendshift.io/api/badge/repositories/13692" alt="snailyp%2Fgemini-balance | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
|
| 6 |
+
</a>
|
| 7 |
+
</p>
|
| 8 |
+
|
| 9 |
+
<p align="center">
|
| 10 |
+
<a href="https://www.python.org/"><img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python"></a>
|
| 11 |
+
<a href="https://fastapi.tiangolo.com/"><img src="https://img.shields.io/badge/FastAPI-0.100%2B-green.svg" alt="FastAPI"></a>
|
| 12 |
+
<a href="https://www.uvicorn.org/"><img src="https://img.shields.io/badge/Uvicorn-running-purple.svg" alt="Uvicorn"></a>
|
| 13 |
+
<a href="https://t.me/+soaHax5lyI0wZDVl"><img src="https://img.shields.io/badge/Telegram-Group-blue.svg?logo=telegram" alt="Telegram Group"></a>
|
| 14 |
+
</p>
|
| 15 |
+
|
| 16 |
+
> ⚠️ **重要声明**: 本项目采用 [CC BY-NC 4.0](LICENSE) 协议,**禁止任何形式的商业倒卖服务**。
|
| 17 |
+
> 本人从未在任何平台售卖服务,如遇售卖,均为倒卖行为,请勿上当受骗。
|
| 18 |
+
|
| 19 |
+
---
|
| 20 |
+
|
| 21 |
+
## 📖 项目简介
|
| 22 |
+
|
| 23 |
+
**Gemini Balance** 是一个基于 Python FastAPI 构建的应用程序,旨在提供 Google Gemini API 的代理和负载均衡功能。它允许您管理多个 Gemini API Key,并通过简单的配置实现 Key 的轮询、认证、模型过滤和状态监控。此外,项目还集成了图像生成和多种图床上传功能,并支持 OpenAI API 格式的代理。
|
| 24 |
+
|
| 25 |
+
<details>
|
| 26 |
+
<summary>📂 查看项目结构</summary>
|
| 27 |
+
|
| 28 |
+
```plaintext
|
| 29 |
+
app/
|
| 30 |
+
├── config/ # 配置管理
|
| 31 |
+
├── core/ # 核心应用逻辑 (FastAPI 实例创建, 中间件等)
|
| 32 |
+
├── database/ # 数据库模型和连接
|
| 33 |
+
├── domain/ # 业务领域对象
|
| 34 |
+
├── exception/ # 自定义异常
|
| 35 |
+
├── handler/ # 请求处理器
|
| 36 |
+
├── log/ # 日志配置
|
| 37 |
+
├── main.py # 应用入口
|
| 38 |
+
├── middleware/ # FastAPI 中间件
|
| 39 |
+
├── router/ # API 路由 (Gemini, OpenAI, 状态页等)
|
| 40 |
+
├── scheduler/ # 定时任务 (如 Key 状态检查)
|
| 41 |
+
├── service/ # 业务逻辑服务 (聊天, Key 管理, 统计等)
|
| 42 |
+
├── static/ # 静态文件 (CSS, JS)
|
| 43 |
+
├── templates/ # HTML 模板 (如 Key 状态页)
|
| 44 |
+
└── utils/ # 工具函数
|
| 45 |
+
```
|
| 46 |
+
</details>
|
| 47 |
+
|
| 48 |
+
---
|
| 49 |
+
|
| 50 |
+
## ✨ 功能亮点
|
| 51 |
+
|
| 52 |
+
* **多 Key 负载均衡**: 支持配置多个 Gemini API Key (`API_KEYS`),自动按顺序轮询使用,提高可用性和并发能力。
|
| 53 |
+
* **可视化配置即时生效**: 通过管理后台修改配置后,无需重启服务即可生效。
|
| 54 |
+

|
| 55 |
+
* **双协议 API 兼容**: 同时支持 Gemini 和 OpenAI 格式的 CHAT API 请求转发。
|
| 56 |
+
* OpenAI Base URL: `http://localhost:8000(/hf)/v1`
|
| 57 |
+
* Gemini Base URL: `http://localhost:8000(/gemini)/v1beta`
|
| 58 |
+
* **图文对话与修图**: 通过 `IMAGE_MODELS` 配置支持图文对话和修图功能的模型,调用时使用 `配置模型-image` 模型名。
|
| 59 |
+

|
| 60 |
+

|
| 61 |
+
* **联网搜索**: 通过 `SEARCH_MODELS` 配置支持联网搜索的模型,调用时使用 `配置模型-search` 模型名。
|
| 62 |
+

|
| 63 |
+
* **Key 状态监控**: 提供 `/keys_status` 页面(需要认证),实时查看各 Key 的状态和使用情况。
|
| 64 |
+

|
| 65 |
+
* **详细日志记录**: 提供详细的错误日志,方便排查问题。
|
| 66 |
+

|
| 67 |
+

|
| 68 |
+

|
| 69 |
+
* **灵活的密钥添加**: 支持通过正则表达式 `gemini_key` 批量添加密钥,并自动去重。
|
| 70 |
+

|
| 71 |
+
* **失败重试与自动禁用**: 自动处理 API 请求失败,进行重试 (`MAX_RETRIES`),并在 Key 失效次数过多时自动禁用 (`MAX_FAILURES`),定时检查恢复 (`CHECK_INTERVAL_HOURS`)。
|
| 72 |
+
* **全面的 API 兼容**:
|
| 73 |
+
* **Embeddings 接口**: 完美适配 OpenAI 格式的 `embeddings` 接口。
|
| 74 |
+
* **画图接口**: 将 `imagen-3.0-generate-002` 模型接口改造为 OpenAI 画图接口格式。
|
| 75 |
+
* **模型列表自动维护**: 自动获取并同步 Gemini 和 OpenAI 的最新模型列表,兼容 New API。
|
| 76 |
+
* **代理支持**: 支持配置 HTTP/SOCKS5 代理 (`PROXIES`),方便在特殊网络环境下使用。
|
| 77 |
+
* **Docker 支持**: 提供 AMD 和 ARM 架构的 Docker 镜像,方便快速部署。
|
| 78 |
+
* 镜像地址: `ghcr.io/snailyp/gemini-balance:latest`
|
| 79 |
+
|
| 80 |
+
---
|
| 81 |
+
|
| 82 |
+
## 🚀 快速开始
|
| 83 |
+
|
| 84 |
+
### 方式一:使用 Docker Compose (推荐)
|
| 85 |
+
|
| 86 |
+
这是最推荐的部署方式,可以一键启动应用和数据库。
|
| 87 |
+
|
| 88 |
+
1. **下载 `docker-compose.yml`**:
|
| 89 |
+
从项目仓库获取 `docker-compose.yml` 文件。
|
| 90 |
+
2. **准备 `.env` 文件**:
|
| 91 |
+
从 `.env.example` 复制一份并重命名为 `.env`,然后根据需求修改配置。特别注意,`DATABASE_TYPE` 应设置为 `mysql`,并填写 `MYSQL_*` 相关配置。
|
| 92 |
+
3. **启动服务**:
|
| 93 |
+
在 `docker-compose.yml` 和 `.env` 文件所在的目录下,运行以下命令:
|
| 94 |
+
```bash
|
| 95 |
+
docker-compose up -d
|
| 96 |
+
```
|
| 97 |
+
该命令会以后台模式启动 `gemini-balance` 应用和 `mysql` 数据库。
|
| 98 |
+
|
| 99 |
+
### 方式二:使用 Docker 命令
|
| 100 |
+
|
| 101 |
+
1. **拉取镜像**:
|
| 102 |
+
```bash
|
| 103 |
+
docker pull ghcr.io/snailyp/gemini-balance:latest
|
| 104 |
+
```
|
| 105 |
+
2. **准备 `.env` 文件**:
|
| 106 |
+
从 `.env.example` 复制一份并重命名为 `.env`,然后根据需求修改配置。
|
| 107 |
+
3. **运行容器**:
|
| 108 |
+
```bash
|
| 109 |
+
docker run -d -p 8000:8000 --name gemini-balance \
|
| 110 |
+
-v ./data:/app/data \
|
| 111 |
+
--env-file .env \
|
| 112 |
+
ghcr.io/snailyp/gemini-balance:latest
|
| 113 |
+
```
|
| 114 |
+
* `-d`: 后台运行。
|
| 115 |
+
* `-p 8000:8000`: 将容器的 8000 端口映射到主机。
|
| 116 |
+
* `-v ./data:/app/data`: 挂载数据卷以持久化 SQLite 数据和日志。
|
| 117 |
+
* `--env-file .env`: 加载环境变量配置文件。
|
| 118 |
+
|
| 119 |
+
### 方式三:本地运行 (适用于开发)
|
| 120 |
+
|
| 121 |
+
1. **克隆仓库并安装依赖**:
|
| 122 |
+
```bash
|
| 123 |
+
git clone https://github.com/snailyp/gemini-balance.git
|
| 124 |
+
cd gemini-balance
|
| 125 |
+
pip install -r requirements.txt
|
| 126 |
+
```
|
| 127 |
+
2. **配置环境变量**:
|
| 128 |
+
从 `.env.example` 复制一份并重命名为 `.env`,然后根据需求修改配置。
|
| 129 |
+
3. **启动应用**:
|
| 130 |
+
```bash
|
| 131 |
+
uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
| 132 |
+
```
|
| 133 |
+
应用启动后,访问 `http://localhost:8000`。
|
| 134 |
+
|
| 135 |
+
---
|
| 136 |
+
|
| 137 |
+
## ⚙️ API 端点
|
| 138 |
+
|
| 139 |
+
### Gemini API 格式 (`/gemini/v1beta`)
|
| 140 |
+
|
| 141 |
+
此端点将请求直接转发到官方 Gemini API 格式的端点,不包含高级功能。
|
| 142 |
+
|
| 143 |
+
* `GET /models`: 列出可用的 Gemini 模型。
|
| 144 |
+
* `POST /models/{model_name}:generateContent`: 生成内容。
|
| 145 |
+
* `POST /models/{model_name}:streamGenerateContent`: 流式生成内容。
|
| 146 |
+
|
| 147 |
+
### OpenAI API 格式
|
| 148 |
+
|
| 149 |
+
#### 兼容 huggingface (HF) 格式
|
| 150 |
+
|
| 151 |
+
如果您需要使用高级功能(例如假流式输出),请使用此端点。
|
| 152 |
+
|
| 153 |
+
* `GET /hf/v1/models`: 列出模型。
|
| 154 |
+
* `POST /hf/v1/chat/completions`: 聊天补全。
|
| 155 |
+
* `POST /hf/v1/embeddings`: 创建文本嵌入。
|
| 156 |
+
* `POST /hf/v1/images/generations`: 生成图像。
|
| 157 |
+
|
| 158 |
+
#### 标准 OpenAI 格式
|
| 159 |
+
|
| 160 |
+
此端点直接转发至官方的 OpenAI 兼容 API 格式端点,不包含高级功能。
|
| 161 |
+
|
| 162 |
+
* `GET /openai/v1/models`: 列出模型。
|
| 163 |
+
* `POST /openai/v1/chat/completions`: 聊天补全 (推荐,速度更快,防截断)。
|
| 164 |
+
* `POST /openai/v1/embeddings`: 创建文本嵌入。
|
| 165 |
+
* `POST /openai/v1/images/generations`: 生成图像。
|
| 166 |
+
|
| 167 |
+
---
|
| 168 |
+
|
| 169 |
+
<details>
|
| 170 |
+
<summary>📋 查看完整配置项列表</summary>
|
| 171 |
+
|
| 172 |
+
| 配置项 | 说明 | 默认值 |
|
| 173 |
+
| :--- | :--- | :--- |
|
| 174 |
+
| **数据库配置** | | |
|
| 175 |
+
| `DATABASE_TYPE` | 数据库类型: `mysql` 或 `sqlite` | `mysql` |
|
| 176 |
+
| `SQLITE_DATABASE` | 当使用 `sqlite` 时必填,SQLite 数据库文件路径 | `default_db` |
|
| 177 |
+
| `MYSQL_HOST` | 当使用 `mysql` 时必填,MySQL 数据库主机地址 | `localhost` |
|
| 178 |
+
| `MYSQL_SOCKET` | 可选,MySQL 数据库 socket 地址 | `/var/run/mysqld/mysqld.sock` |
|
| 179 |
+
| `MYSQL_PORT` | 当使用 `mysql` 时必填,MySQL 数据库端口 | `3306` |
|
| 180 |
+
| `MYSQL_USER` | 当使用 `mysql` 时必填,MySQL 数据库用户名 | `your_db_user` |
|
| 181 |
+
| `MYSQL_PASSWORD` | 当使用 `mysql` 时必填,MySQL 数据库密码 | `your_db_password` |
|
| 182 |
+
| `MYSQL_DATABASE` | 当使用 `mysql` 时必填,MySQL 数据库名称 | `defaultdb` |
|
| 183 |
+
| **API 相关配置** | | |
|
| 184 |
+
| `API_KEYS` | **必填**, Gemini API 密钥列表,用于负载均衡 | `[]` |
|
| 185 |
+
| `ALLOWED_TOKENS` | **必填**, 允许访问的 Token 列表 | `[]` |
|
| 186 |
+
| `AUTH_TOKEN` | 超级管理员 Token,不填则使用 `ALLOWED_TOKENS` 的第一个 | `sk-123456` |
|
| 187 |
+
| `TEST_MODEL` | 用于测试密钥可用性的模型 | `gemini-2.5-flash-lite` |
|
| 188 |
+
| `IMAGE_MODELS` | 支持绘图功能的模型列表 | `["gemini-2.0-flash-exp", "gemini-2.5-flash-image-preview"]` |
|
| 189 |
+
| `SEARCH_MODELS` | 支持搜索功能的模型列表 | `["gemini-2.5-flash","gemini-2.5-pro"]` |
|
| 190 |
+
| `FILTERED_MODELS` | 被禁用的模型列表 | `[]` |
|
| 191 |
+
| `TOOLS_CODE_EXECUTION_ENABLED` | 是否启用代码执行工具 | `false` |
|
| 192 |
+
| `SHOW_SEARCH_LINK` | 是否在响应中显示搜索结果链接 | `true` |
|
| 193 |
+
| `SHOW_THINKING_PROCESS` | 是否显示模型思考过程 | `true` |
|
| 194 |
+
| `THINKING_MODELS` | 支持思考功能的模型列表 | `[]` |
|
| 195 |
+
| `THINKING_BUDGET_MAP` | 思考功能预算映射 (模型名:预算值) | `{}` |
|
| 196 |
+
| `URL_NORMALIZATION_ENABLED` | 是否启用智能路由映射功能 | `false` |
|
| 197 |
+
| `URL_CONTEXT_ENABLED` | 是否启用URL上下文理解功能 | `false` |
|
| 198 |
+
| `URL_CONTEXT_MODELS` | 支持URL上下文理解功能的模型列表 | `[]` |
|
| 199 |
+
| `BASE_URL` | Gemini API 基础 URL | `https://generativelanguage.googleapis.com/v1beta` |
|
| 200 |
+
| `MAX_FAILURES` | 单个 Key 允许的最大失败次数 | `3` |
|
| 201 |
+
| `MAX_RETRIES` | API 请求失败时的最大重试次数 | `3` |
|
| 202 |
+
| `CHECK_INTERVAL_HOURS` | 禁用 Key 恢复检查间隔 (小时) | `1` |
|
| 203 |
+
| `TIMEZONE` | 应用程序使用���时区 | `Asia/Shanghai` |
|
| 204 |
+
| `TIME_OUT` | 请求超时时间 (秒) | `300` |
|
| 205 |
+
| `PROXIES` | 代理服务器列表 (例如 `http://user:pass@host:port`) | `[]` |
|
| 206 |
+
| **日志与安全** | | |
|
| 207 |
+
| `LOG_LEVEL` | 日志级别: `DEBUG`, `INFO`, `WARNING`, `ERROR` | `INFO` |
|
| 208 |
+
| `ERROR_LOG_RECORD_REQUEST_BODY` | 是否记录错误日志的请求体(可能包含敏感信息) | `false` |
|
| 209 |
+
| `AUTO_DELETE_ERROR_LOGS_ENABLED` | 是否自动删除错误日志 | `true` |
|
| 210 |
+
| `AUTO_DELETE_ERROR_LOGS_DAYS` | 错误日志保留天数 | `7` |
|
| 211 |
+
| `AUTO_DELETE_REQUEST_LOGS_ENABLED`| 是否自动删除请求日志 | `false` |
|
| 212 |
+
| `AUTO_DELETE_REQUEST_LOGS_DAYS` | 请求日志保留天数 | `30` |
|
| 213 |
+
| `SAFETY_SETTINGS` | 内容安全阈值 (JSON 字符串) | `[{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, ...]` |
|
| 214 |
+
| **TTS 相关** | | |
|
| 215 |
+
| `TTS_MODEL` | TTS 模型名称 | `gemini-2.5-flash-preview-tts` |
|
| 216 |
+
| `TTS_VOICE_NAME` | TTS 语音名称 | `Zephyr` |
|
| 217 |
+
| `TTS_SPEED` | TTS 语速 | `normal` |
|
| 218 |
+
| **图像生成相关** | | |
|
| 219 |
+
| `PAID_KEY` | 付费版API Key,用于图片生成等高级功能 | `your-paid-api-key` |
|
| 220 |
+
| `CREATE_IMAGE_MODEL` | 图片生成模型 | `imagen-3.0-generate-002` |
|
| 221 |
+
| `UPLOAD_PROVIDER` | 图片上传提供商: `smms`, `picgo`, `cloudflare_imgbed`, `aliyun_oss` | `smms` |
|
| 222 |
+
| `OSS_ENDPOINT` | 阿里云 OSS 公网 Endpoint | `oss-cn-shanghai.aliyuncs.com` |
|
| 223 |
+
| `OSS_ENDPOINT_INNER` | 阿里云 OSS 内网 Endpoint(同 VPC 内网访问) | `oss-cn-shanghai-internal.aliyuncs.com` |
|
| 224 |
+
| `OSS_ACCESS_KEY` | 阿里云 AccessKey ID | `LTAI5txxxxxxxxxxxxxxxx` |
|
| 225 |
+
| `OSS_ACCESS_KEY_SECRET` | 阿里云 AccessKey Secret | `yXxxxxxxxxxxxxxxxxxxxxx` |
|
| 226 |
+
| `OSS_BUCKET_NAME` | 阿里云 OSS Bucket 名称 | `your-bucket-name` |
|
| 227 |
+
| `OSS_REGION` | 阿里云 OSS 区域 Region | `cn-shanghai` |
|
| 228 |
+
| `SMMS_SECRET_TOKEN` | SM.MS图床的API Token | `your-smms-token` |
|
| 229 |
+
| `PICGO_API_KEY` | [PicoGo](https://www.picgo.net/)图床的API Key | `your-picogo-apikey` |
|
| 230 |
+
| `PICGO_API_URL` | [PicoGo](https://www.picgo.net/)图床的API服务器地址 | `https://www.picgo.net/api/1/upload` |
|
| 231 |
+
| `CLOUDFLARE_IMGBED_URL` | [CloudFlare](https://github.com/MarSeventh/CloudFlare-ImgBed) 图床上传地址 | `https://xxxxxxx.pages.dev/upload` |
|
| 232 |
+
| `CLOUDFLARE_IMGBED_AUTH_CODE`| CloudFlare图床的鉴权key | `your-cloudflare-imgber-auth-code` |
|
| 233 |
+
| `CLOUDFLARE_IMGBED_UPLOAD_FOLDER`| CloudFlare图床的上传文件夹路径 | `""` |
|
| 234 |
+
| **流式优化器相关** | | |
|
| 235 |
+
| `STREAM_OPTIMIZER_ENABLED` | 是否启用流式输出优化 | `false` |
|
| 236 |
+
| `STREAM_MIN_DELAY` | 流式输出最小延迟 | `0.016` |
|
| 237 |
+
| `STREAM_MAX_DELAY` | 流式输出最大延迟 | `0.024` |
|
| 238 |
+
| `STREAM_SHORT_TEXT_THRESHOLD`| 短文本阈值 | `10` |
|
| 239 |
+
| `STREAM_LONG_TEXT_THRESHOLD` | 长文本阈值 | `50` |
|
| 240 |
+
| `STREAM_CHUNK_SIZE` | 流式输出块大小 | `5` |
|
| 241 |
+
| **伪流式 (Fake Stream) 相关** | | |
|
| 242 |
+
| `FAKE_STREAM_ENABLED` | 是否启用伪流式传输 | `false` |
|
| 243 |
+
| `FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS` | 伪流式传输时发送心跳空数据的间隔秒数 | `5` |
|
| 244 |
+
|
| 245 |
+
</details>
|
| 246 |
+
|
| 247 |
+
---
|
| 248 |
+
|
| 249 |
+
## 🤝 贡献
|
| 250 |
+
|
| 251 |
+
欢迎通过提交 Pull Request 或 Issue 来为项目做出贡献。
|
| 252 |
+
|
| 253 |
+
[](https://github.com/snailyp/gemini-balance/graphs/contributors)
|
| 254 |
+
|
| 255 |
+
## ⭐ Star History
|
| 256 |
+
|
| 257 |
+
[](https://star-history.com/#snailyp/gemini-balance&Date)
|
| 258 |
+
|
| 259 |
+
## 🎉 特别鸣谢
|
| 260 |
+
|
| 261 |
+
* [PicGo](https://www.picgo.net/)
|
| 262 |
+
* [SM.MS](https://smms.app/)
|
| 263 |
+
* [CloudFlare-ImgBed](https://github.com/MarSeventh/CloudFlare-ImgBed)
|
| 264 |
+
|
| 265 |
+
## 💖 友情项目
|
| 266 |
+
|
| 267 |
+
* **[OneLine](https://github.com/chengtx809/OneLine)** by [chengtx809](https://github.com/chengtx809) - AI 驱动的热点事件时间轴生成工具。
|
| 268 |
+
|
| 269 |
+
## 🎁 项目支持
|
| 270 |
+
|
| 271 |
+
如果你觉得这个项目对你有帮助,可以考虑通过 [爱发电](https://afdian.com/a/snaily) 支持我。
|
| 272 |
+
|
| 273 |
+
## 许可证
|
| 274 |
+
|
| 275 |
+
本项目采用 [CC BY-NC 4.0](LICENSE)(署名-非商业性使用)协议。
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
## 赞助商
|
| 279 |
+
|
| 280 |
+
特别感谢 [DigitalOcean](https://m.do.co/c/b249dd7f3b4c) 为本项目提供稳定可靠的云基础设施支持。
|
| 281 |
+
|
| 282 |
+
<a href="https://m.do.co/c/b249dd7f3b4c">
|
| 283 |
+
<img src="files/dataocean.svg" alt="DigitalOcean Logo" width="200"/>
|
| 284 |
+
</a>
|
| 285 |
+
|
| 286 |
+
本项目的 CDN 加速和安全防护由 [Tencent EdgeOne](https://edgeone.ai/?from=github) 赞助。
|
| 287 |
+
|
| 288 |
+
<a href="https://edgeone.ai/?from=github">
|
| 289 |
+
<img src="https://edgeone.ai/media/34fe3a45-492d-4ea4-ae5d-ea1087ca7b4b.png" alt="EdgeOne Logo" width="200"/>
|
| 290 |
+
</a>
|
VERSION
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
2.2.8
|
app/config/config.py
ADDED
|
@@ -0,0 +1,533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
应用程序配置模块
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import datetime
|
| 6 |
+
import json
|
| 7 |
+
from typing import Any, Dict, List, Type, get_args, get_origin
|
| 8 |
+
|
| 9 |
+
from pydantic import Field, ValidationError, ValidationInfo, field_validator
|
| 10 |
+
from pydantic_settings import BaseSettings
|
| 11 |
+
from sqlalchemy import insert, select, update
|
| 12 |
+
|
| 13 |
+
from app.core.constants import (
|
| 14 |
+
API_VERSION,
|
| 15 |
+
DEFAULT_CREATE_IMAGE_MODEL,
|
| 16 |
+
DEFAULT_FILTER_MODELS,
|
| 17 |
+
DEFAULT_MODEL,
|
| 18 |
+
DEFAULT_SAFETY_SETTINGS,
|
| 19 |
+
DEFAULT_STREAM_CHUNK_SIZE,
|
| 20 |
+
DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
|
| 21 |
+
DEFAULT_STREAM_MAX_DELAY,
|
| 22 |
+
DEFAULT_STREAM_MIN_DELAY,
|
| 23 |
+
DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
|
| 24 |
+
DEFAULT_TIMEOUT,
|
| 25 |
+
MAX_RETRIES,
|
| 26 |
+
)
|
| 27 |
+
from app.log.logger import Logger
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Settings(BaseSettings):
|
| 31 |
+
# 数据库配置
|
| 32 |
+
DATABASE_TYPE: str = "mysql" # sqlite 或 mysql
|
| 33 |
+
SQLITE_DATABASE: str = "default_db"
|
| 34 |
+
MYSQL_HOST: str = ""
|
| 35 |
+
MYSQL_PORT: int = 3306
|
| 36 |
+
MYSQL_USER: str = ""
|
| 37 |
+
MYSQL_PASSWORD: str = ""
|
| 38 |
+
MYSQL_DATABASE: str = ""
|
| 39 |
+
MYSQL_SOCKET: str = ""
|
| 40 |
+
|
| 41 |
+
# 验证 MySQL 配置
|
| 42 |
+
@field_validator(
|
| 43 |
+
"MYSQL_HOST", "MYSQL_PORT", "MYSQL_USER", "MYSQL_PASSWORD", "MYSQL_DATABASE"
|
| 44 |
+
)
|
| 45 |
+
def validate_mysql_config(cls, v: Any, info: ValidationInfo) -> Any:
|
| 46 |
+
if info.data.get("DATABASE_TYPE") == "mysql":
|
| 47 |
+
if v is None or v == "":
|
| 48 |
+
raise ValueError(
|
| 49 |
+
"MySQL configuration is required when DATABASE_TYPE is 'mysql'"
|
| 50 |
+
)
|
| 51 |
+
return v
|
| 52 |
+
|
| 53 |
+
# API相关配置
|
| 54 |
+
API_KEYS: List[str] = []
|
| 55 |
+
ALLOWED_TOKENS: List[str] = []
|
| 56 |
+
BASE_URL: str = f"https://generativelanguage.googleapis.com/{API_VERSION}"
|
| 57 |
+
AUTH_TOKEN: str = ""
|
| 58 |
+
MAX_FAILURES: int = 3
|
| 59 |
+
TEST_MODEL: str = DEFAULT_MODEL
|
| 60 |
+
TIME_OUT: int = DEFAULT_TIMEOUT
|
| 61 |
+
MAX_RETRIES: int = MAX_RETRIES
|
| 62 |
+
PROXIES: List[str] = []
|
| 63 |
+
PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: bool = True # 是否使用一致性哈希来选择代理
|
| 64 |
+
VERTEX_API_KEYS: List[str] = []
|
| 65 |
+
VERTEX_EXPRESS_BASE_URL: str = (
|
| 66 |
+
"https://aiplatform.googleapis.com/v1beta1/publishers/google"
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# 智能路由配置
|
| 70 |
+
URL_NORMALIZATION_ENABLED: bool = False # 是否启用智能路由映射功能
|
| 71 |
+
|
| 72 |
+
# 自定义 Headers
|
| 73 |
+
CUSTOM_HEADERS: Dict[str, str] = {}
|
| 74 |
+
|
| 75 |
+
# 模型相关配置
|
| 76 |
+
SEARCH_MODELS: List[str] = ["gemini-2.5-flash", "gemini-2.5-pro"]
|
| 77 |
+
IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp", "gemini-2.5-flash-image-preview"]
|
| 78 |
+
FILTERED_MODELS: List[str] = DEFAULT_FILTER_MODELS
|
| 79 |
+
TOOLS_CODE_EXECUTION_ENABLED: bool = False
|
| 80 |
+
# 是否启用网址上下文
|
| 81 |
+
URL_CONTEXT_ENABLED: bool = False
|
| 82 |
+
URL_CONTEXT_MODELS: List[str] = [
|
| 83 |
+
"gemini-2.5-pro",
|
| 84 |
+
"gemini-2.5-flash",
|
| 85 |
+
"gemini-2.5-flash-lite",
|
| 86 |
+
"gemini-2.0-flash",
|
| 87 |
+
"gemini-2.0-flash-live-001",
|
| 88 |
+
]
|
| 89 |
+
SHOW_SEARCH_LINK: bool = True
|
| 90 |
+
SHOW_THINKING_PROCESS: bool = True
|
| 91 |
+
THINKING_MODELS: List[str] = []
|
| 92 |
+
THINKING_BUDGET_MAP: Dict[str, float] = {}
|
| 93 |
+
|
| 94 |
+
# TTS相关配置
|
| 95 |
+
TTS_MODEL: str = "gemini-2.5-flash-preview-tts"
|
| 96 |
+
TTS_VOICE_NAME: str = "Zephyr"
|
| 97 |
+
TTS_SPEED: str = "normal"
|
| 98 |
+
|
| 99 |
+
# 图像生成相关配置
|
| 100 |
+
PAID_KEY: str = ""
|
| 101 |
+
CREATE_IMAGE_MODEL: str = DEFAULT_CREATE_IMAGE_MODEL
|
| 102 |
+
UPLOAD_PROVIDER: str = "smms"
|
| 103 |
+
SMMS_SECRET_TOKEN: str = ""
|
| 104 |
+
PICGO_API_KEY: str = ""
|
| 105 |
+
PICGO_API_URL: str = "https://www.picgo.net/api/1/upload"
|
| 106 |
+
CLOUDFLARE_IMGBED_URL: str = ""
|
| 107 |
+
CLOUDFLARE_IMGBED_AUTH_CODE: str = ""
|
| 108 |
+
CLOUDFLARE_IMGBED_UPLOAD_FOLDER: str = ""
|
| 109 |
+
# 阿里云OSS配置
|
| 110 |
+
OSS_ENDPOINT: str = ""
|
| 111 |
+
OSS_ENDPOINT_INNER: str = ""
|
| 112 |
+
OSS_ACCESS_KEY: str = ""
|
| 113 |
+
OSS_ACCESS_KEY_SECRET: str = ""
|
| 114 |
+
OSS_BUCKET_NAME: str = ""
|
| 115 |
+
OSS_REGION: str = ""
|
| 116 |
+
|
| 117 |
+
# 流式输出优化器配置
|
| 118 |
+
STREAM_OPTIMIZER_ENABLED: bool = False
|
| 119 |
+
STREAM_MIN_DELAY: float = DEFAULT_STREAM_MIN_DELAY
|
| 120 |
+
STREAM_MAX_DELAY: float = DEFAULT_STREAM_MAX_DELAY
|
| 121 |
+
STREAM_SHORT_TEXT_THRESHOLD: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD
|
| 122 |
+
STREAM_LONG_TEXT_THRESHOLD: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD
|
| 123 |
+
STREAM_CHUNK_SIZE: int = DEFAULT_STREAM_CHUNK_SIZE
|
| 124 |
+
|
| 125 |
+
# 假流式配置 (Fake Streaming Configuration)
|
| 126 |
+
FAKE_STREAM_ENABLED: bool = False # 是否启用假流式输出
|
| 127 |
+
FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS: int = 5 # 假流式发送空数据的间隔时间(秒)
|
| 128 |
+
|
| 129 |
+
# 调度器配置
|
| 130 |
+
CHECK_INTERVAL_HOURS: int = 1 # 默认检查间隔为1小时
|
| 131 |
+
TIMEZONE: str = "Asia/Shanghai" # 默认时区
|
| 132 |
+
|
| 133 |
+
# github
|
| 134 |
+
GITHUB_REPO_OWNER: str = "snailyp"
|
| 135 |
+
GITHUB_REPO_NAME: str = "gemini-balance"
|
| 136 |
+
|
| 137 |
+
# 日志配置
|
| 138 |
+
LOG_LEVEL: str = "INFO"
|
| 139 |
+
ERROR_LOG_RECORD_REQUEST_BODY: bool = False
|
| 140 |
+
AUTO_DELETE_ERROR_LOGS_ENABLED: bool = True
|
| 141 |
+
AUTO_DELETE_ERROR_LOGS_DAYS: int = 7
|
| 142 |
+
AUTO_DELETE_REQUEST_LOGS_ENABLED: bool = False
|
| 143 |
+
AUTO_DELETE_REQUEST_LOGS_DAYS: int = 30
|
| 144 |
+
SAFETY_SETTINGS: List[Dict[str, str]] = DEFAULT_SAFETY_SETTINGS
|
| 145 |
+
|
| 146 |
+
# Files API
|
| 147 |
+
FILES_CLEANUP_ENABLED: bool = True
|
| 148 |
+
FILES_CLEANUP_INTERVAL_HOURS: int = 1
|
| 149 |
+
FILES_USER_ISOLATION_ENABLED: bool = True
|
| 150 |
+
|
| 151 |
+
# Admin Session Configuration
|
| 152 |
+
ADMIN_SESSION_EXPIRE: int = Field(
|
| 153 |
+
default=3600,
|
| 154 |
+
ge=300,
|
| 155 |
+
le=86400,
|
| 156 |
+
description="Admin session expiration time in seconds (5 minutes to 24 hours)",
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def __init__(self, **kwargs):
|
| 160 |
+
super().__init__(**kwargs)
|
| 161 |
+
# 设置默认AUTH_TOKEN(如果未提供)
|
| 162 |
+
if not self.AUTH_TOKEN and self.ALLOWED_TOKENS:
|
| 163 |
+
self.AUTH_TOKEN = self.ALLOWED_TOKENS[0]
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# 创建全局配置实例
|
| 167 |
+
settings = Settings()
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def _parse_db_value(key: str, db_value: str, target_type: Type) -> Any:
|
| 171 |
+
"""尝试将数据库字符串值解析为目标 Python 类型"""
|
| 172 |
+
from app.log.logger import get_config_logger
|
| 173 |
+
|
| 174 |
+
logger = get_config_logger()
|
| 175 |
+
try:
|
| 176 |
+
origin_type = get_origin(target_type)
|
| 177 |
+
args = get_args(target_type)
|
| 178 |
+
|
| 179 |
+
# 处理 List 类型
|
| 180 |
+
if origin_type is list:
|
| 181 |
+
# 处理 List[str]
|
| 182 |
+
if args and args[0] == str:
|
| 183 |
+
try:
|
| 184 |
+
parsed = json.loads(db_value)
|
| 185 |
+
if isinstance(parsed, list):
|
| 186 |
+
return [str(item) for item in parsed]
|
| 187 |
+
except json.JSONDecodeError:
|
| 188 |
+
return [
|
| 189 |
+
item.strip() for item in db_value.split(",") if item.strip()
|
| 190 |
+
]
|
| 191 |
+
logger.warning(
|
| 192 |
+
f"Could not parse '{db_value}' as List[str] for key '{key}', falling back to comma split or empty list."
|
| 193 |
+
)
|
| 194 |
+
return [item.strip() for item in db_value.split(",") if item.strip()]
|
| 195 |
+
# 处理 List[Dict[str, str]]
|
| 196 |
+
elif args and get_origin(args[0]) is dict:
|
| 197 |
+
try:
|
| 198 |
+
parsed = json.loads(db_value)
|
| 199 |
+
if isinstance(parsed, list):
|
| 200 |
+
valid = all(
|
| 201 |
+
isinstance(item, dict)
|
| 202 |
+
and all(isinstance(k, str) for k in item.keys())
|
| 203 |
+
and all(isinstance(v, str) for v in item.values())
|
| 204 |
+
for item in parsed
|
| 205 |
+
)
|
| 206 |
+
if valid:
|
| 207 |
+
return parsed
|
| 208 |
+
else:
|
| 209 |
+
logger.warning(
|
| 210 |
+
f"Invalid structure in List[Dict[str, str]] for key '{key}'. Value: {db_value}"
|
| 211 |
+
)
|
| 212 |
+
return []
|
| 213 |
+
else:
|
| 214 |
+
logger.warning(
|
| 215 |
+
f"Parsed DB value for key '{key}' is not a list type. Value: {db_value}"
|
| 216 |
+
)
|
| 217 |
+
return []
|
| 218 |
+
except json.JSONDecodeError:
|
| 219 |
+
logger.error(
|
| 220 |
+
f"Could not parse '{db_value}' as JSON for List[Dict[str, str]] for key '{key}'. Returning empty list."
|
| 221 |
+
)
|
| 222 |
+
return []
|
| 223 |
+
except Exception as e:
|
| 224 |
+
logger.error(
|
| 225 |
+
f"Error parsing List[Dict[str, str]] for key '{key}': {e}. Value: {db_value}. Returning empty list."
|
| 226 |
+
)
|
| 227 |
+
return []
|
| 228 |
+
# 处理 Dict 类型
|
| 229 |
+
elif origin_type is dict:
|
| 230 |
+
# 处理 Dict[str, str]
|
| 231 |
+
if args and args == (str, str):
|
| 232 |
+
parsed_dict = {}
|
| 233 |
+
try:
|
| 234 |
+
parsed = json.loads(db_value)
|
| 235 |
+
if isinstance(parsed, dict):
|
| 236 |
+
parsed_dict = {str(k): str(v) for k, v in parsed.items()}
|
| 237 |
+
else:
|
| 238 |
+
logger.warning(
|
| 239 |
+
f"Parsed DB value for key '{key}' is not a dictionary type. Value: {db_value}"
|
| 240 |
+
)
|
| 241 |
+
except json.JSONDecodeError:
|
| 242 |
+
logger.error(
|
| 243 |
+
f"Could not parse '{db_value}' as Dict[str, str] for key '{key}'. Returning empty dict."
|
| 244 |
+
)
|
| 245 |
+
return parsed_dict
|
| 246 |
+
# 处理 Dict[str, float]
|
| 247 |
+
elif args and args == (str, float):
|
| 248 |
+
parsed_dict = {}
|
| 249 |
+
try:
|
| 250 |
+
parsed = json.loads(db_value)
|
| 251 |
+
if isinstance(parsed, dict):
|
| 252 |
+
parsed_dict = {str(k): float(v) for k, v in parsed.items()}
|
| 253 |
+
else:
|
| 254 |
+
logger.warning(
|
| 255 |
+
f"Parsed DB value for key '{key}' is not a dictionary type. Value: {db_value}"
|
| 256 |
+
)
|
| 257 |
+
except (json.JSONDecodeError, ValueError, TypeError) as e1:
|
| 258 |
+
if isinstance(e1, json.JSONDecodeError) and "'" in db_value:
|
| 259 |
+
logger.warning(
|
| 260 |
+
f"Failed initial JSON parse for key '{key}'. Attempting to replace single quotes. Error: {e1}"
|
| 261 |
+
)
|
| 262 |
+
try:
|
| 263 |
+
corrected_db_value = db_value.replace("'", '"')
|
| 264 |
+
parsed = json.loads(corrected_db_value)
|
| 265 |
+
if isinstance(parsed, dict):
|
| 266 |
+
parsed_dict = {
|
| 267 |
+
str(k): float(v) for k, v in parsed.items()
|
| 268 |
+
}
|
| 269 |
+
else:
|
| 270 |
+
logger.warning(
|
| 271 |
+
f"Parsed DB value (after quote replacement) for key '{key}' is not a dictionary type. Value: {corrected_db_value}"
|
| 272 |
+
)
|
| 273 |
+
except (json.JSONDecodeError, ValueError, TypeError) as e2:
|
| 274 |
+
logger.error(
|
| 275 |
+
f"Could not parse '{db_value}' as Dict[str, float] for key '{key}' even after replacing quotes: {e2}. Returning empty dict."
|
| 276 |
+
)
|
| 277 |
+
else:
|
| 278 |
+
logger.error(
|
| 279 |
+
f"Could not parse '{db_value}' as Dict[str, float] for key '{key}': {e1}. Returning empty dict."
|
| 280 |
+
)
|
| 281 |
+
return parsed_dict
|
| 282 |
+
# 处理 bool
|
| 283 |
+
elif target_type == bool:
|
| 284 |
+
return db_value.lower() in ("true", "1", "yes", "on")
|
| 285 |
+
# 处理 int
|
| 286 |
+
elif target_type == int:
|
| 287 |
+
return int(db_value)
|
| 288 |
+
# 处理 float
|
| 289 |
+
elif target_type == float:
|
| 290 |
+
return float(db_value)
|
| 291 |
+
# 默认为 str 或其他 pydantic 能直接处理的类型
|
| 292 |
+
else:
|
| 293 |
+
return db_value
|
| 294 |
+
except (ValueError, TypeError, json.JSONDecodeError) as e:
|
| 295 |
+
logger.warning(
|
| 296 |
+
f"Failed to parse db_value '{db_value}' for key '{key}' as type {target_type}: {e}. Using original string value."
|
| 297 |
+
)
|
| 298 |
+
return db_value # 解析失败则返回原始字符串
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
async def sync_initial_settings():
|
| 302 |
+
"""
|
| 303 |
+
应用启动时同步配置:
|
| 304 |
+
1. 从数据库加载设置。
|
| 305 |
+
2. 将数据库设置合并到内存 settings (数据库优先)。
|
| 306 |
+
3. 将最终的内存 settings 同步回数据库。
|
| 307 |
+
"""
|
| 308 |
+
from app.log.logger import get_config_logger
|
| 309 |
+
|
| 310 |
+
logger = get_config_logger()
|
| 311 |
+
# 延迟导入以避免循环依赖和确保数据库连接已初始化
|
| 312 |
+
from app.database.connection import database
|
| 313 |
+
from app.database.models import Settings as SettingsModel
|
| 314 |
+
|
| 315 |
+
global settings
|
| 316 |
+
logger.info("Starting initial settings synchronization...")
|
| 317 |
+
|
| 318 |
+
if not database.is_connected:
|
| 319 |
+
try:
|
| 320 |
+
await database.connect()
|
| 321 |
+
logger.info("Database connection established for initial sync.")
|
| 322 |
+
except Exception as e:
|
| 323 |
+
logger.error(
|
| 324 |
+
f"Failed to connect to database for initial settings sync: {e}. Skipping sync."
|
| 325 |
+
)
|
| 326 |
+
return
|
| 327 |
+
|
| 328 |
+
try:
|
| 329 |
+
# 1. 从数据库加载设置
|
| 330 |
+
db_settings_raw: List[Dict[str, Any]] = []
|
| 331 |
+
try:
|
| 332 |
+
query = select(SettingsModel.key, SettingsModel.value)
|
| 333 |
+
results = await database.fetch_all(query)
|
| 334 |
+
db_settings_raw = [
|
| 335 |
+
{"key": row["key"], "value": row["value"]} for row in results
|
| 336 |
+
]
|
| 337 |
+
logger.info(f"Fetched {len(db_settings_raw)} settings from database.")
|
| 338 |
+
except Exception as e:
|
| 339 |
+
logger.error(
|
| 340 |
+
f"Failed to fetch settings from database: {e}. Proceeding with environment/dotenv settings."
|
| 341 |
+
)
|
| 342 |
+
# 即使数据库读取失败,也要继续执行,确保基于 env/dotenv 的配置能同步到数据库
|
| 343 |
+
|
| 344 |
+
db_settings_map: Dict[str, str] = {
|
| 345 |
+
s["key"]: s["value"] for s in db_settings_raw
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
# 2. 将数据库设置合并到内存 settings (数据库优先)
|
| 349 |
+
updated_in_memory = False
|
| 350 |
+
|
| 351 |
+
for key, db_value in db_settings_map.items():
|
| 352 |
+
if key == "DATABASE_TYPE":
|
| 353 |
+
logger.debug(
|
| 354 |
+
f"Skipping update of '{key}' in memory from database. "
|
| 355 |
+
"This setting is controlled by environment/dotenv."
|
| 356 |
+
)
|
| 357 |
+
continue
|
| 358 |
+
if hasattr(settings, key):
|
| 359 |
+
target_type = Settings.__annotations__.get(key)
|
| 360 |
+
if target_type:
|
| 361 |
+
try:
|
| 362 |
+
parsed_db_value = _parse_db_value(key, db_value, target_type)
|
| 363 |
+
memory_value = getattr(settings, key)
|
| 364 |
+
|
| 365 |
+
# 比较解析后的值和内存中的值
|
| 366 |
+
# 注意:对于列表等复杂类型,直接比较可能不够健壮,但这里简化处理
|
| 367 |
+
if parsed_db_value != memory_value:
|
| 368 |
+
# 检查类型是否匹配,以防解析函数返回了不兼容的类型
|
| 369 |
+
type_match = False
|
| 370 |
+
origin_type = get_origin(target_type)
|
| 371 |
+
if origin_type: # It's a generic type
|
| 372 |
+
if isinstance(parsed_db_value, origin_type):
|
| 373 |
+
type_match = True
|
| 374 |
+
# It's a non-generic type, or a specific generic we want to handle
|
| 375 |
+
elif isinstance(parsed_db_value, target_type):
|
| 376 |
+
type_match = True
|
| 377 |
+
|
| 378 |
+
if type_match:
|
| 379 |
+
setattr(settings, key, parsed_db_value)
|
| 380 |
+
logger.debug(
|
| 381 |
+
f"Updated setting '{key}' in memory from database value ({target_type})."
|
| 382 |
+
)
|
| 383 |
+
updated_in_memory = True
|
| 384 |
+
else:
|
| 385 |
+
logger.warning(
|
| 386 |
+
f"Parsed DB value type mismatch for key '{key}'. Expected {target_type}, got {type(parsed_db_value)}. Skipping update."
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
except Exception as e:
|
| 390 |
+
logger.error(
|
| 391 |
+
f"Error processing database setting for key '{key}': {e}"
|
| 392 |
+
)
|
| 393 |
+
else:
|
| 394 |
+
logger.warning(
|
| 395 |
+
f"Database setting '{key}' not found in Settings model definition. Ignoring."
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
# 如果内存中有更新,重新验证 Pydantic 模型(可选但推荐)
|
| 399 |
+
if updated_in_memory:
|
| 400 |
+
try:
|
| 401 |
+
# 重新加载以确保类型转换和验证
|
| 402 |
+
settings = Settings(**settings.model_dump())
|
| 403 |
+
logger.info(
|
| 404 |
+
"Settings object re-validated after merging database values."
|
| 405 |
+
)
|
| 406 |
+
except ValidationError as e:
|
| 407 |
+
logger.error(
|
| 408 |
+
f"Validation error after merging database settings: {e}. Settings might be inconsistent."
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
# 3. 将最终的内存 settings 同步回数据库
|
| 412 |
+
final_memory_settings = settings.model_dump()
|
| 413 |
+
settings_to_update: List[Dict[str, Any]] = []
|
| 414 |
+
settings_to_insert: List[Dict[str, Any]] = []
|
| 415 |
+
now = datetime.datetime.now(datetime.timezone.utc)
|
| 416 |
+
|
| 417 |
+
existing_db_keys = set(db_settings_map.keys())
|
| 418 |
+
|
| 419 |
+
for key, value in final_memory_settings.items():
|
| 420 |
+
if key == "DATABASE_TYPE":
|
| 421 |
+
logger.debug(
|
| 422 |
+
f"Skipping synchronization of '{key}' to database. "
|
| 423 |
+
"This setting is controlled by environment/dotenv."
|
| 424 |
+
)
|
| 425 |
+
continue
|
| 426 |
+
|
| 427 |
+
# 序列化值为字符串或 JSON 字符串
|
| 428 |
+
if isinstance(value, (list, dict)):
|
| 429 |
+
db_value = json.dumps(value, ensure_ascii=False)
|
| 430 |
+
elif isinstance(value, bool):
|
| 431 |
+
db_value = str(value).lower()
|
| 432 |
+
elif value is None:
|
| 433 |
+
db_value = ""
|
| 434 |
+
else:
|
| 435 |
+
db_value = str(value)
|
| 436 |
+
|
| 437 |
+
data = {
|
| 438 |
+
"key": key,
|
| 439 |
+
"value": db_value,
|
| 440 |
+
"description": f"{key} configuration setting",
|
| 441 |
+
"updated_at": now,
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
if key in existing_db_keys:
|
| 445 |
+
# 仅当值与数据库中的不同时才更新
|
| 446 |
+
if db_settings_map[key] != db_value:
|
| 447 |
+
settings_to_update.append(data)
|
| 448 |
+
else:
|
| 449 |
+
# 如果键不在数据库中,则插入
|
| 450 |
+
data["created_at"] = now
|
| 451 |
+
settings_to_insert.append(data)
|
| 452 |
+
|
| 453 |
+
# 在事务中执行批量插入和更新
|
| 454 |
+
if settings_to_insert or settings_to_update:
|
| 455 |
+
try:
|
| 456 |
+
async with database.transaction():
|
| 457 |
+
if settings_to_insert:
|
| 458 |
+
# 获取现有描述以避免覆盖
|
| 459 |
+
query_existing = select(
|
| 460 |
+
SettingsModel.key, SettingsModel.description
|
| 461 |
+
).where(
|
| 462 |
+
SettingsModel.key.in_(
|
| 463 |
+
[s["key"] for s in settings_to_insert]
|
| 464 |
+
)
|
| 465 |
+
)
|
| 466 |
+
existing_desc = {
|
| 467 |
+
row["key"]: row["description"]
|
| 468 |
+
for row in await database.fetch_all(query_existing)
|
| 469 |
+
}
|
| 470 |
+
for item in settings_to_insert:
|
| 471 |
+
item["description"] = existing_desc.get(
|
| 472 |
+
item["key"], item["description"]
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
query_insert = insert(SettingsModel).values(settings_to_insert)
|
| 476 |
+
await database.execute(query=query_insert)
|
| 477 |
+
logger.info(
|
| 478 |
+
f"Synced (inserted) {len(settings_to_insert)} settings to database."
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
if settings_to_update:
|
| 482 |
+
# 获取现有描述以避免覆盖
|
| 483 |
+
query_existing = select(
|
| 484 |
+
SettingsModel.key, SettingsModel.description
|
| 485 |
+
).where(
|
| 486 |
+
SettingsModel.key.in_(
|
| 487 |
+
[s["key"] for s in settings_to_update]
|
| 488 |
+
)
|
| 489 |
+
)
|
| 490 |
+
existing_desc = {
|
| 491 |
+
row["key"]: row["description"]
|
| 492 |
+
for row in await database.fetch_all(query_existing)
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
for setting_data in settings_to_update:
|
| 496 |
+
setting_data["description"] = existing_desc.get(
|
| 497 |
+
setting_data["key"], setting_data["description"]
|
| 498 |
+
)
|
| 499 |
+
query_update = (
|
| 500 |
+
update(SettingsModel)
|
| 501 |
+
.where(SettingsModel.key == setting_data["key"])
|
| 502 |
+
.values(
|
| 503 |
+
value=setting_data["value"],
|
| 504 |
+
description=setting_data["description"],
|
| 505 |
+
updated_at=setting_data["updated_at"],
|
| 506 |
+
)
|
| 507 |
+
)
|
| 508 |
+
await database.execute(query=query_update)
|
| 509 |
+
logger.info(
|
| 510 |
+
f"Synced (updated) {len(settings_to_update)} settings to database."
|
| 511 |
+
)
|
| 512 |
+
except Exception as e:
|
| 513 |
+
logger.error(
|
| 514 |
+
f"Failed to sync settings to database during startup: {str(e)}"
|
| 515 |
+
)
|
| 516 |
+
else:
|
| 517 |
+
logger.info(
|
| 518 |
+
"No setting changes detected between memory and database during initial sync."
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
# 刷新日志等级
|
| 522 |
+
Logger.update_log_levels(final_memory_settings.get("LOG_LEVEL"))
|
| 523 |
+
|
| 524 |
+
except Exception as e:
|
| 525 |
+
logger.error(f"An unexpected error occurred during initial settings sync: {e}")
|
| 526 |
+
finally:
|
| 527 |
+
if database.is_connected:
|
| 528 |
+
try:
|
| 529 |
+
pass
|
| 530 |
+
except Exception as e:
|
| 531 |
+
logger.error(f"Error disconnecting database after initial sync: {e}")
|
| 532 |
+
|
| 533 |
+
logger.info("Initial settings synchronization finished.")
|
app/core/application.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import asynccontextmanager
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
from fastapi import FastAPI
|
| 5 |
+
from fastapi.staticfiles import StaticFiles
|
| 6 |
+
from fastapi.templating import Jinja2Templates
|
| 7 |
+
|
| 8 |
+
from app.config.config import settings, sync_initial_settings
|
| 9 |
+
from app.database.connection import connect_to_db, disconnect_from_db
|
| 10 |
+
from app.database.initialization import initialize_database
|
| 11 |
+
from app.exception.exceptions import setup_exception_handlers
|
| 12 |
+
from app.log.logger import get_application_logger, setup_access_logging
|
| 13 |
+
from app.middleware.middleware import setup_middlewares
|
| 14 |
+
from app.router.routes import setup_routers
|
| 15 |
+
from app.scheduler.scheduled_tasks import start_scheduler, stop_scheduler
|
| 16 |
+
from app.service.key.key_manager import get_key_manager_instance
|
| 17 |
+
from app.service.update.update_service import check_for_updates
|
| 18 |
+
from app.utils.helpers import get_current_version
|
| 19 |
+
|
| 20 |
+
logger = get_application_logger()
|
| 21 |
+
|
| 22 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
|
| 23 |
+
STATIC_DIR = PROJECT_ROOT / "app" / "static"
|
| 24 |
+
TEMPLATES_DIR = PROJECT_ROOT / "app" / "templates"
|
| 25 |
+
|
| 26 |
+
# 初始化模板引擎,并添加全局变量
|
| 27 |
+
templates = Jinja2Templates(directory="app/templates")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# 定义一个函数来更新模板全局变量
|
| 31 |
+
def update_template_globals(app: FastAPI, update_info: dict):
|
| 32 |
+
# Jinja2Templates 实例没有直接更新全局变量的方法
|
| 33 |
+
# 我们需要在请求上下文中传递这些变量,或者修改 Jinja 环境
|
| 34 |
+
# 更简单的方法是将其存储在 app.state 中,并在渲染时传递
|
| 35 |
+
app.state.update_info = update_info
|
| 36 |
+
logger.info(f"Update info stored in app.state: {update_info}")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# --- Helper functions for lifespan ---
|
| 40 |
+
async def _setup_database_and_config(app_settings):
|
| 41 |
+
"""Initializes database, syncs settings, and initializes KeyManager."""
|
| 42 |
+
initialize_database()
|
| 43 |
+
logger.info("Database initialized successfully")
|
| 44 |
+
await connect_to_db()
|
| 45 |
+
await sync_initial_settings()
|
| 46 |
+
await get_key_manager_instance(app_settings.API_KEYS, app_settings.VERTEX_API_KEYS)
|
| 47 |
+
logger.info("Database, config sync, and KeyManager initialized successfully")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
async def _shutdown_database():
|
| 51 |
+
"""Disconnects from the database."""
|
| 52 |
+
await disconnect_from_db()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _start_scheduler():
|
| 56 |
+
"""Starts the background scheduler."""
|
| 57 |
+
try:
|
| 58 |
+
start_scheduler()
|
| 59 |
+
logger.info("Scheduler started successfully.")
|
| 60 |
+
except Exception as e:
|
| 61 |
+
logger.error(f"Failed to start scheduler: {e}")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _stop_scheduler():
|
| 65 |
+
"""Stops the background scheduler."""
|
| 66 |
+
stop_scheduler()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
async def _perform_update_check(app: FastAPI):
|
| 70 |
+
"""Checks for updates and stores the info in app.state."""
|
| 71 |
+
update_available, latest_version, error_message = await check_for_updates()
|
| 72 |
+
current_version = get_current_version()
|
| 73 |
+
update_info = {
|
| 74 |
+
"update_available": update_available,
|
| 75 |
+
"latest_version": latest_version,
|
| 76 |
+
"error_message": error_message,
|
| 77 |
+
"current_version": current_version,
|
| 78 |
+
}
|
| 79 |
+
if not hasattr(app, "state"):
|
| 80 |
+
from starlette.datastructures import State
|
| 81 |
+
|
| 82 |
+
app.state = State()
|
| 83 |
+
app.state.update_info = update_info
|
| 84 |
+
logger.info(f"Update check completed. Info: {update_info}")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@asynccontextmanager
|
| 88 |
+
async def lifespan(app: FastAPI):
|
| 89 |
+
"""
|
| 90 |
+
Manages the application startup and shutdown events.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
app: FastAPI应用实例
|
| 94 |
+
"""
|
| 95 |
+
logger.info("Application starting up...")
|
| 96 |
+
try:
|
| 97 |
+
await _setup_database_and_config(settings)
|
| 98 |
+
await _perform_update_check(app)
|
| 99 |
+
_start_scheduler()
|
| 100 |
+
|
| 101 |
+
except Exception as e:
|
| 102 |
+
logger.critical(
|
| 103 |
+
f"Critical error during application startup: {str(e)}", exc_info=True
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
yield
|
| 107 |
+
|
| 108 |
+
logger.info("Application shutting down...")
|
| 109 |
+
_stop_scheduler()
|
| 110 |
+
await _shutdown_database()
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def create_app() -> FastAPI:
|
| 114 |
+
"""
|
| 115 |
+
创建并配置FastAPI应用程序实例
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
FastAPI: 配置好的FastAPI应用程序实例
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
# 创建FastAPI应用
|
| 122 |
+
current_version = get_current_version()
|
| 123 |
+
app = FastAPI(
|
| 124 |
+
title="Gemini Balance API",
|
| 125 |
+
description="Gemini API代理服务,支持负载均衡和密钥管理",
|
| 126 |
+
version=current_version,
|
| 127 |
+
lifespan=lifespan,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
if not hasattr(app, "state"):
|
| 131 |
+
from starlette.datastructures import State
|
| 132 |
+
|
| 133 |
+
app.state = State()
|
| 134 |
+
app.state.update_info = {
|
| 135 |
+
"update_available": False,
|
| 136 |
+
"latest_version": None,
|
| 137 |
+
"error_message": "Initializing...",
|
| 138 |
+
"current_version": current_version,
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
# 配置静态文件
|
| 142 |
+
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
|
| 143 |
+
|
| 144 |
+
# 配置中间件
|
| 145 |
+
setup_middlewares(app)
|
| 146 |
+
|
| 147 |
+
# 配置异常处理器
|
| 148 |
+
setup_exception_handlers(app)
|
| 149 |
+
|
| 150 |
+
# 配置路由
|
| 151 |
+
setup_routers(app)
|
| 152 |
+
|
| 153 |
+
# 配置访问日志API密钥隐藏
|
| 154 |
+
setup_access_logging()
|
| 155 |
+
|
| 156 |
+
return app
|
app/core/constants.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
常量定义模块
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
# API相关常量
|
| 6 |
+
API_VERSION = "v1beta"
|
| 7 |
+
DEFAULT_TIMEOUT = 300 # 秒
|
| 8 |
+
MAX_RETRIES = 3 # 最大重试次数
|
| 9 |
+
|
| 10 |
+
# 模型相关常量
|
| 11 |
+
SUPPORTED_ROLES = ["user", "model", "system"]
|
| 12 |
+
DEFAULT_MODEL = "gemini-2.5-flash-lite"
|
| 13 |
+
DEFAULT_TEMPERATURE = 0.7
|
| 14 |
+
DEFAULT_MAX_TOKENS = 8192
|
| 15 |
+
DEFAULT_TOP_P = 0.9
|
| 16 |
+
DEFAULT_TOP_K = 40
|
| 17 |
+
DEFAULT_FILTER_MODELS = [
|
| 18 |
+
"gemini-1.0-pro-vision-latest",
|
| 19 |
+
"gemini-pro-vision",
|
| 20 |
+
"chat-bison-001",
|
| 21 |
+
"text-bison-001",
|
| 22 |
+
"embedding-gecko-001",
|
| 23 |
+
]
|
| 24 |
+
DEFAULT_CREATE_IMAGE_MODEL = "imagen-3.0-generate-002"
|
| 25 |
+
|
| 26 |
+
# 图像生成相关常量
|
| 27 |
+
VALID_IMAGE_RATIOS = ["1:1", "3:4", "4:3", "9:16", "16:9"]
|
| 28 |
+
|
| 29 |
+
# 上传提供商
|
| 30 |
+
UPLOAD_PROVIDERS = ["smms", "picgo", "cloudflare_imgbed", "aliyun_oss"]
|
| 31 |
+
DEFAULT_UPLOAD_PROVIDER = "smms"
|
| 32 |
+
|
| 33 |
+
# 流式输出相关常量
|
| 34 |
+
DEFAULT_STREAM_MIN_DELAY = 0.016
|
| 35 |
+
DEFAULT_STREAM_MAX_DELAY = 0.024
|
| 36 |
+
DEFAULT_STREAM_SHORT_TEXT_THRESHOLD = 10
|
| 37 |
+
DEFAULT_STREAM_LONG_TEXT_THRESHOLD = 50
|
| 38 |
+
DEFAULT_STREAM_CHUNK_SIZE = 5
|
| 39 |
+
|
| 40 |
+
# 正则表达式模式
|
| 41 |
+
IMAGE_URL_PATTERN = r"!\[(.*?)\]\((.*?)\)"
|
| 42 |
+
DATA_URL_PATTERN = r"data:([^;]+);base64,(.+)"
|
| 43 |
+
|
| 44 |
+
# Audio/Video Settings
|
| 45 |
+
SUPPORTED_AUDIO_FORMATS = ["wav", "mp3", "flac", "ogg"]
|
| 46 |
+
SUPPORTED_VIDEO_FORMATS = ["mp4", "mov", "avi", "webm"]
|
| 47 |
+
MAX_AUDIO_SIZE_BYTES = 50 * 1024 * 1024 # Example: 50MB limit for Base64 payload
|
| 48 |
+
MAX_VIDEO_SIZE_BYTES = 200 * 1024 * 1024 # Example: 200MB limit
|
| 49 |
+
|
| 50 |
+
# Optional: Define MIME type mappings if needed, or handle directly in converter
|
| 51 |
+
AUDIO_FORMAT_TO_MIMETYPE = {
|
| 52 |
+
"wav": "audio/wav",
|
| 53 |
+
"mp3": "audio/mpeg",
|
| 54 |
+
"flac": "audio/flac",
|
| 55 |
+
"ogg": "audio/ogg",
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
VIDEO_FORMAT_TO_MIMETYPE = {
|
| 59 |
+
"mp4": "video/mp4",
|
| 60 |
+
"mov": "video/quicktime",
|
| 61 |
+
"avi": "video/x-msvideo",
|
| 62 |
+
"webm": "video/webm",
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
GEMINI_2_FLASH_EXP_SAFETY_SETTINGS = [
|
| 66 |
+
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
| 67 |
+
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
| 68 |
+
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
| 69 |
+
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
| 70 |
+
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
DEFAULT_SAFETY_SETTINGS = [
|
| 74 |
+
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
| 75 |
+
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
| 76 |
+
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
| 77 |
+
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
| 78 |
+
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
TTS_VOICE_NAMES = [
|
| 82 |
+
"Zephyr",
|
| 83 |
+
"Puck",
|
| 84 |
+
"Charon",
|
| 85 |
+
"Kore",
|
| 86 |
+
"Fenrir",
|
| 87 |
+
"Leda",
|
| 88 |
+
"Orus",
|
| 89 |
+
"Aoede",
|
| 90 |
+
"Callirrhoe",
|
| 91 |
+
"Autonoe",
|
| 92 |
+
"Enceladus",
|
| 93 |
+
"Iapetus",
|
| 94 |
+
"Umbriel",
|
| 95 |
+
"Algieba",
|
| 96 |
+
"Despina",
|
| 97 |
+
"Erinome",
|
| 98 |
+
"Algenib",
|
| 99 |
+
"Rasalgethi",
|
| 100 |
+
"Laomedeia",
|
| 101 |
+
"Achernar",
|
| 102 |
+
"Alnilam",
|
| 103 |
+
"Schedar",
|
| 104 |
+
"Gacrux",
|
| 105 |
+
"Pulcherrima",
|
| 106 |
+
"Achird",
|
| 107 |
+
"Zubenelgenubi",
|
| 108 |
+
"Vindemiatrix",
|
| 109 |
+
"Sadachbia",
|
| 110 |
+
"Sadaltager",
|
| 111 |
+
"Sulafat",
|
| 112 |
+
]
|
app/core/security.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
from fastapi import Header, HTTPException
|
| 4 |
+
|
| 5 |
+
from app.config.config import settings
|
| 6 |
+
from app.log.logger import get_security_logger
|
| 7 |
+
|
| 8 |
+
logger = get_security_logger()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def verify_auth_token(token: str) -> bool:
|
| 12 |
+
return token == settings.AUTH_TOKEN
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SecurityService:
|
| 16 |
+
|
| 17 |
+
async def verify_key(self, key: str):
|
| 18 |
+
if key not in settings.ALLOWED_TOKENS and key != settings.AUTH_TOKEN:
|
| 19 |
+
logger.error("Invalid key")
|
| 20 |
+
raise HTTPException(status_code=401, detail="Invalid key")
|
| 21 |
+
return key
|
| 22 |
+
|
| 23 |
+
async def verify_authorization(
|
| 24 |
+
self, authorization: Optional[str] = Header(None)
|
| 25 |
+
) -> str:
|
| 26 |
+
if not authorization:
|
| 27 |
+
logger.error("Missing Authorization header")
|
| 28 |
+
raise HTTPException(status_code=401, detail="Missing Authorization header")
|
| 29 |
+
|
| 30 |
+
if not authorization.startswith("Bearer "):
|
| 31 |
+
logger.error("Invalid Authorization header format")
|
| 32 |
+
raise HTTPException(
|
| 33 |
+
status_code=401, detail="Invalid Authorization header format"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
token = authorization.replace("Bearer ", "")
|
| 37 |
+
if token not in settings.ALLOWED_TOKENS and token != settings.AUTH_TOKEN:
|
| 38 |
+
logger.error("Invalid token")
|
| 39 |
+
raise HTTPException(status_code=401, detail="Invalid token")
|
| 40 |
+
|
| 41 |
+
return token
|
| 42 |
+
|
| 43 |
+
async def verify_goog_api_key(
|
| 44 |
+
self, x_goog_api_key: Optional[str] = Header(None)
|
| 45 |
+
) -> str:
|
| 46 |
+
"""验证Google API Key"""
|
| 47 |
+
if not x_goog_api_key:
|
| 48 |
+
logger.error("Missing x-goog-api-key header")
|
| 49 |
+
raise HTTPException(status_code=401, detail="Missing x-goog-api-key header")
|
| 50 |
+
|
| 51 |
+
if (
|
| 52 |
+
x_goog_api_key not in settings.ALLOWED_TOKENS
|
| 53 |
+
and x_goog_api_key != settings.AUTH_TOKEN
|
| 54 |
+
):
|
| 55 |
+
logger.error("Invalid x-goog-api-key")
|
| 56 |
+
raise HTTPException(status_code=401, detail="Invalid x-goog-api-key")
|
| 57 |
+
|
| 58 |
+
return x_goog_api_key
|
| 59 |
+
|
| 60 |
+
async def verify_auth_token(
|
| 61 |
+
self, authorization: Optional[str] = Header(None)
|
| 62 |
+
) -> str:
|
| 63 |
+
if not authorization:
|
| 64 |
+
logger.error("Missing auth_token header")
|
| 65 |
+
raise HTTPException(status_code=401, detail="Missing auth_token header")
|
| 66 |
+
token = authorization.replace("Bearer ", "")
|
| 67 |
+
if token != settings.AUTH_TOKEN:
|
| 68 |
+
logger.error("Invalid auth_token")
|
| 69 |
+
raise HTTPException(status_code=401, detail="Invalid auth_token")
|
| 70 |
+
|
| 71 |
+
return token
|
| 72 |
+
|
| 73 |
+
async def verify_key_or_goog_api_key(
|
| 74 |
+
self, key: Optional[str] = None , x_goog_api_key: Optional[str] = Header(None)
|
| 75 |
+
) -> str:
|
| 76 |
+
"""验证URL中的key或请求头中的x-goog-api-key"""
|
| 77 |
+
# 如果URL中的key有效,直接返回
|
| 78 |
+
if key in settings.ALLOWED_TOKENS or key == settings.AUTH_TOKEN:
|
| 79 |
+
return key
|
| 80 |
+
|
| 81 |
+
# 否则检查请求头中的x-goog-api-key
|
| 82 |
+
if not x_goog_api_key:
|
| 83 |
+
logger.error("Invalid key and missing x-goog-api-key header")
|
| 84 |
+
raise HTTPException(status_code=401, detail="Invalid key and missing x-goog-api-key header")
|
| 85 |
+
|
| 86 |
+
if x_goog_api_key not in settings.ALLOWED_TOKENS and x_goog_api_key != settings.AUTH_TOKEN:
|
| 87 |
+
logger.error("Invalid key and invalid x-goog-api-key")
|
| 88 |
+
raise HTTPException(status_code=401, detail="Invalid key and invalid x-goog-api-key")
|
| 89 |
+
|
| 90 |
+
return x_goog_api_key
|
app/database/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
数据库模块
|
| 3 |
+
"""
|
app/database/connection.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
数据库连接池模块
|
| 3 |
+
"""
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from urllib.parse import quote_plus
|
| 6 |
+
from databases import Database
|
| 7 |
+
from sqlalchemy import create_engine, MetaData
|
| 8 |
+
from sqlalchemy.ext.declarative import declarative_base
|
| 9 |
+
|
| 10 |
+
from app.config.config import settings
|
| 11 |
+
from app.log.logger import get_database_logger
|
| 12 |
+
|
| 13 |
+
logger = get_database_logger()
|
| 14 |
+
|
| 15 |
+
# 数据库URL
|
| 16 |
+
if settings.DATABASE_TYPE == "sqlite":
|
| 17 |
+
# 确保 data 目录存在
|
| 18 |
+
data_dir = Path("data")
|
| 19 |
+
data_dir.mkdir(exist_ok=True)
|
| 20 |
+
db_path = data_dir / settings.SQLITE_DATABASE
|
| 21 |
+
DATABASE_URL = f"sqlite:///{db_path}"
|
| 22 |
+
elif settings.DATABASE_TYPE == "mysql":
|
| 23 |
+
if settings.MYSQL_SOCKET:
|
| 24 |
+
DATABASE_URL = f"mysql+pymysql://{settings.MYSQL_USER}:{quote_plus(settings.MYSQL_PASSWORD)}@/{settings.MYSQL_DATABASE}?unix_socket={settings.MYSQL_SOCKET}"
|
| 25 |
+
else:
|
| 26 |
+
DATABASE_URL = f"mysql+pymysql://{settings.MYSQL_USER}:{quote_plus(settings.MYSQL_PASSWORD)}@{settings.MYSQL_HOST}:{settings.MYSQL_PORT}/{settings.MYSQL_DATABASE}"
|
| 27 |
+
else:
|
| 28 |
+
raise ValueError("Unsupported database type. Please set DATABASE_TYPE to 'sqlite' or 'mysql'.")
|
| 29 |
+
|
| 30 |
+
# 创建数据库引擎
|
| 31 |
+
# pool_pre_ping=True: 在从连接池获取连接前执行简单的 "ping" 测试,确保连接有效
|
| 32 |
+
engine = create_engine(DATABASE_URL, pool_pre_ping=True)
|
| 33 |
+
|
| 34 |
+
# 创建元数据对象
|
| 35 |
+
metadata = MetaData()
|
| 36 |
+
|
| 37 |
+
# 创建基类
|
| 38 |
+
Base = declarative_base(metadata=metadata)
|
| 39 |
+
|
| 40 |
+
# 创建数据库连接池,并配置连接池参数,在sqlite中不使用连接池
|
| 41 |
+
# min_size/max_size: 连接池的最小/最大连接数
|
| 42 |
+
# pool_recycle=3600: 连接在池中允许存在的最大秒数(生命周期)。
|
| 43 |
+
# 设置为 3600 秒(1小时),确保在 MySQL 默认的 wait_timeout (通常8小时) 或其他网络超时之前回收连接。
|
| 44 |
+
# 如果遇到连接失效问题,可以尝试调低此值,使其小于实际的 wait_timeout 或网络超时时间。
|
| 45 |
+
# databases 库会自动处理连接失效后的重连尝试。
|
| 46 |
+
if settings.DATABASE_TYPE == "sqlite":
|
| 47 |
+
database = Database(DATABASE_URL)
|
| 48 |
+
else:
|
| 49 |
+
database = Database(DATABASE_URL, min_size=5, max_size=20, pool_recycle=1800)
|
| 50 |
+
|
| 51 |
+
async def connect_to_db():
|
| 52 |
+
"""
|
| 53 |
+
连接到数据库
|
| 54 |
+
"""
|
| 55 |
+
try:
|
| 56 |
+
await database.connect()
|
| 57 |
+
logger.info(f"Connected to {settings.DATABASE_TYPE}")
|
| 58 |
+
except Exception as e:
|
| 59 |
+
logger.error(f"Failed to connect to database: {str(e)}")
|
| 60 |
+
raise
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
async def disconnect_from_db():
|
| 64 |
+
"""
|
| 65 |
+
断开数据库连接
|
| 66 |
+
"""
|
| 67 |
+
try:
|
| 68 |
+
await database.disconnect()
|
| 69 |
+
logger.info(f"Disconnected from {settings.DATABASE_TYPE}")
|
| 70 |
+
except Exception as e:
|
| 71 |
+
logger.error(f"Failed to disconnect from database: {str(e)}")
|
app/database/initialization.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
数据库初始化模块
|
| 3 |
+
"""
|
| 4 |
+
from dotenv import dotenv_values
|
| 5 |
+
|
| 6 |
+
from sqlalchemy import inspect
|
| 7 |
+
from sqlalchemy.orm import Session
|
| 8 |
+
|
| 9 |
+
from app.database.connection import engine, Base
|
| 10 |
+
from app.database.models import Settings
|
| 11 |
+
from app.log.logger import get_database_logger
|
| 12 |
+
|
| 13 |
+
logger = get_database_logger()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def create_tables():
|
| 17 |
+
"""
|
| 18 |
+
创建数据库表
|
| 19 |
+
"""
|
| 20 |
+
try:
|
| 21 |
+
# 创建所有表
|
| 22 |
+
Base.metadata.create_all(engine)
|
| 23 |
+
logger.info("Database tables created successfully")
|
| 24 |
+
except Exception as e:
|
| 25 |
+
logger.error(f"Failed to create database tables: {str(e)}")
|
| 26 |
+
raise
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def import_env_to_settings():
|
| 30 |
+
"""
|
| 31 |
+
将.env文件中的配置项导入到t_settings表中
|
| 32 |
+
"""
|
| 33 |
+
try:
|
| 34 |
+
# 获取.env文件中的所有配置项
|
| 35 |
+
env_values = dotenv_values(".env")
|
| 36 |
+
|
| 37 |
+
# 获取检查器
|
| 38 |
+
inspector = inspect(engine)
|
| 39 |
+
|
| 40 |
+
# 检查t_settings表是否存在
|
| 41 |
+
if "t_settings" in inspector.get_table_names():
|
| 42 |
+
# 使用Session进行数据库操作
|
| 43 |
+
with Session(engine) as session:
|
| 44 |
+
# 获取所有现有的配置项
|
| 45 |
+
current_settings = {setting.key: setting for setting in session.query(Settings).all()}
|
| 46 |
+
|
| 47 |
+
# 遍历所有配置项
|
| 48 |
+
for key, value in env_values.items():
|
| 49 |
+
# 检查配置项是否已存在
|
| 50 |
+
if key not in current_settings:
|
| 51 |
+
# 插入配置项
|
| 52 |
+
new_setting = Settings(key=key, value=value)
|
| 53 |
+
session.add(new_setting)
|
| 54 |
+
logger.info(f"Inserted setting: {key}")
|
| 55 |
+
|
| 56 |
+
# 提交事务
|
| 57 |
+
session.commit()
|
| 58 |
+
|
| 59 |
+
logger.info("Environment variables imported to settings table successfully")
|
| 60 |
+
except Exception as e:
|
| 61 |
+
logger.error(f"Failed to import environment variables to settings table: {str(e)}")
|
| 62 |
+
raise
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def initialize_database():
|
| 66 |
+
"""
|
| 67 |
+
初始化数据库
|
| 68 |
+
"""
|
| 69 |
+
try:
|
| 70 |
+
# 创建表
|
| 71 |
+
create_tables()
|
| 72 |
+
|
| 73 |
+
# 导入环境变量
|
| 74 |
+
import_env_to_settings()
|
| 75 |
+
except Exception as e:
|
| 76 |
+
logger.error(f"Failed to initialize database: {str(e)}")
|
| 77 |
+
raise
|
app/database/models.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
数据库模型模块
|
| 3 |
+
"""
|
| 4 |
+
import datetime
|
| 5 |
+
from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, Boolean, BigInteger, Enum
|
| 6 |
+
import enum
|
| 7 |
+
|
| 8 |
+
from app.database.connection import Base
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Settings(Base):
|
| 12 |
+
"""
|
| 13 |
+
设置表,对应.env中的配置项
|
| 14 |
+
"""
|
| 15 |
+
__tablename__ = "t_settings"
|
| 16 |
+
|
| 17 |
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
| 18 |
+
key = Column(String(100), nullable=False, unique=True, comment="配置项键名")
|
| 19 |
+
value = Column(Text, nullable=True, comment="配置项值")
|
| 20 |
+
description = Column(String(255), nullable=True, comment="配置项描述")
|
| 21 |
+
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
|
| 22 |
+
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
|
| 23 |
+
|
| 24 |
+
def __repr__(self):
|
| 25 |
+
return f"<Settings(key='{self.key}', value='{self.value}')>"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ErrorLog(Base):
|
| 29 |
+
"""
|
| 30 |
+
错误日志表
|
| 31 |
+
"""
|
| 32 |
+
__tablename__ = "t_error_logs"
|
| 33 |
+
|
| 34 |
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
| 35 |
+
gemini_key = Column(String(100), nullable=True, comment="Gemini API密钥")
|
| 36 |
+
model_name = Column(String(100), nullable=True, comment="模型名称")
|
| 37 |
+
error_type = Column(String(50), nullable=True, comment="错误类型")
|
| 38 |
+
error_log = Column(Text, nullable=True, comment="错误日志")
|
| 39 |
+
error_code = Column(Integer, nullable=True, comment="错误代码")
|
| 40 |
+
request_msg = Column(JSON, nullable=True, comment="请求消息")
|
| 41 |
+
request_time = Column(DateTime, default=datetime.datetime.now, comment="请求时间")
|
| 42 |
+
|
| 43 |
+
def __repr__(self):
|
| 44 |
+
return f"<ErrorLog(id='{self.id}', gemini_key='{self.gemini_key}')>"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class RequestLog(Base):
|
| 48 |
+
"""
|
| 49 |
+
API 请求日志表
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
__tablename__ = "t_request_log"
|
| 53 |
+
|
| 54 |
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
| 55 |
+
request_time = Column(DateTime, default=datetime.datetime.now, comment="请求时间")
|
| 56 |
+
model_name = Column(String(100), nullable=True, comment="模型名称")
|
| 57 |
+
api_key = Column(String(100), nullable=True, comment="使用的API密钥")
|
| 58 |
+
is_success = Column(Boolean, nullable=False, comment="请求是否成功")
|
| 59 |
+
status_code = Column(Integer, nullable=True, comment="API响应状态码")
|
| 60 |
+
latency_ms = Column(Integer, nullable=True, comment="请求耗时(毫秒)")
|
| 61 |
+
|
| 62 |
+
def __repr__(self):
|
| 63 |
+
return f"<RequestLog(id='{self.id}', key='{self.api_key[:4]}...', success='{self.is_success}')>"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class FileState(enum.Enum):
|
| 67 |
+
"""文件状态枚举"""
|
| 68 |
+
PROCESSING = "PROCESSING"
|
| 69 |
+
ACTIVE = "ACTIVE"
|
| 70 |
+
FAILED = "FAILED"
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class FileRecord(Base):
|
| 74 |
+
"""
|
| 75 |
+
文件记录表,用于存储上传到 Gemini 的文件信息
|
| 76 |
+
"""
|
| 77 |
+
__tablename__ = "t_file_records"
|
| 78 |
+
|
| 79 |
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
| 80 |
+
|
| 81 |
+
# 文件基本信息
|
| 82 |
+
name = Column(String(255), unique=True, nullable=False, comment="文件名称,格式: files/{file_id}")
|
| 83 |
+
display_name = Column(String(255), nullable=True, comment="用户上传时的原始文件名")
|
| 84 |
+
mime_type = Column(String(100), nullable=False, comment="MIME 类型")
|
| 85 |
+
size_bytes = Column(BigInteger, nullable=False, comment="文件大小(字节)")
|
| 86 |
+
sha256_hash = Column(String(255), nullable=True, comment="文件的 SHA256 哈希值")
|
| 87 |
+
|
| 88 |
+
# 状态信息
|
| 89 |
+
state = Column(Enum(FileState), nullable=False, default=FileState.PROCESSING, comment="文件状态")
|
| 90 |
+
|
| 91 |
+
# 时间戳
|
| 92 |
+
create_time = Column(DateTime, nullable=False, comment="创建时间")
|
| 93 |
+
update_time = Column(DateTime, nullable=False, comment="更新时间")
|
| 94 |
+
expiration_time = Column(DateTime, nullable=False, comment="过期时间")
|
| 95 |
+
|
| 96 |
+
# API 相关
|
| 97 |
+
uri = Column(String(500), nullable=False, comment="文件访问 URI")
|
| 98 |
+
api_key = Column(String(100), nullable=False, comment="上传时使用的 API Key")
|
| 99 |
+
upload_url = Column(Text, nullable=True, comment="临时上传 URL(用于分块上传)")
|
| 100 |
+
|
| 101 |
+
# 额外信息
|
| 102 |
+
user_token = Column(String(100), nullable=True, comment="上传用户的 token")
|
| 103 |
+
upload_completed = Column(DateTime, nullable=True, comment="上传完成时间")
|
| 104 |
+
|
| 105 |
+
def __repr__(self):
|
| 106 |
+
return f"<FileRecord(name='{self.name}', state='{self.state.value if self.state else 'None'}', api_key='{self.api_key[:8]}...')>"
|
| 107 |
+
|
| 108 |
+
def to_dict(self):
|
| 109 |
+
"""转换为字典格式,用于 API 响应"""
|
| 110 |
+
return {
|
| 111 |
+
"name": self.name,
|
| 112 |
+
"displayName": self.display_name,
|
| 113 |
+
"mimeType": self.mime_type,
|
| 114 |
+
"sizeBytes": str(self.size_bytes),
|
| 115 |
+
"createTime": self.create_time.isoformat() + "Z",
|
| 116 |
+
"updateTime": self.update_time.isoformat() + "Z",
|
| 117 |
+
"expirationTime": self.expiration_time.isoformat() + "Z",
|
| 118 |
+
"sha256Hash": self.sha256_hash,
|
| 119 |
+
"uri": self.uri,
|
| 120 |
+
"state": self.state.value if self.state else "PROCESSING"
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
def is_expired(self):
|
| 124 |
+
"""检查文件是否已过期"""
|
| 125 |
+
# 确保比较时都是 timezone-aware
|
| 126 |
+
expiration_time = self.expiration_time
|
| 127 |
+
if expiration_time.tzinfo is None:
|
| 128 |
+
expiration_time = expiration_time.replace(tzinfo=datetime.timezone.utc)
|
| 129 |
+
return datetime.datetime.now(datetime.timezone.utc) > expiration_time
|
app/database/services.py
ADDED
|
@@ -0,0 +1,805 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
数据库服务模块
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import json
|
| 7 |
+
from datetime import datetime, timedelta, timezone
|
| 8 |
+
from typing import Any, Dict, List, Optional, Union
|
| 9 |
+
|
| 10 |
+
from sqlalchemy import asc, delete, desc, func, insert, select, update
|
| 11 |
+
|
| 12 |
+
from app.database.connection import database
|
| 13 |
+
from app.database.models import ErrorLog, FileRecord, FileState, RequestLog, Settings
|
| 14 |
+
from app.log.logger import get_database_logger
|
| 15 |
+
from app.utils.helpers import redact_key_for_logging
|
| 16 |
+
|
| 17 |
+
logger = get_database_logger()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
async def get_all_settings() -> List[Dict[str, Any]]:
|
| 21 |
+
"""
|
| 22 |
+
获取所有设置
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
List[Dict[str, Any]]: 设置列表
|
| 26 |
+
"""
|
| 27 |
+
try:
|
| 28 |
+
query = select(Settings)
|
| 29 |
+
result = await database.fetch_all(query)
|
| 30 |
+
return [dict(row) for row in result]
|
| 31 |
+
except Exception as e:
|
| 32 |
+
logger.error(f"Failed to get all settings: {str(e)}")
|
| 33 |
+
raise
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
async def get_setting(key: str) -> Optional[Dict[str, Any]]:
|
| 37 |
+
"""
|
| 38 |
+
获取指定键的设置
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
key: 设置键名
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
Optional[Dict[str, Any]]: 设置信息,如果不存在则返回None
|
| 45 |
+
"""
|
| 46 |
+
try:
|
| 47 |
+
query = select(Settings).where(Settings.key == key)
|
| 48 |
+
result = await database.fetch_one(query)
|
| 49 |
+
return dict(result) if result else None
|
| 50 |
+
except Exception as e:
|
| 51 |
+
logger.error(f"Failed to get setting {key}: {str(e)}")
|
| 52 |
+
raise
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
async def update_setting(
|
| 56 |
+
key: str, value: str, description: Optional[str] = None
|
| 57 |
+
) -> bool:
|
| 58 |
+
"""
|
| 59 |
+
更新设置
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
key: 设置键名
|
| 63 |
+
value: 设置值
|
| 64 |
+
description: 设置描述
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
bool: 是否更新成功
|
| 68 |
+
"""
|
| 69 |
+
try:
|
| 70 |
+
# 检查设置是否存在
|
| 71 |
+
setting = await get_setting(key)
|
| 72 |
+
|
| 73 |
+
if setting:
|
| 74 |
+
# 更新设置
|
| 75 |
+
query = (
|
| 76 |
+
update(Settings)
|
| 77 |
+
.where(Settings.key == key)
|
| 78 |
+
.values(
|
| 79 |
+
value=value,
|
| 80 |
+
description=description if description else setting["description"],
|
| 81 |
+
updated_at=datetime.now(),
|
| 82 |
+
)
|
| 83 |
+
)
|
| 84 |
+
await database.execute(query)
|
| 85 |
+
logger.info(f"Updated setting: {key}")
|
| 86 |
+
return True
|
| 87 |
+
else:
|
| 88 |
+
# 插入设置
|
| 89 |
+
query = insert(Settings).values(
|
| 90 |
+
key=key,
|
| 91 |
+
value=value,
|
| 92 |
+
description=description,
|
| 93 |
+
created_at=datetime.now(),
|
| 94 |
+
updated_at=datetime.now(),
|
| 95 |
+
)
|
| 96 |
+
await database.execute(query)
|
| 97 |
+
logger.info(f"Inserted setting: {key}")
|
| 98 |
+
return True
|
| 99 |
+
except Exception as e:
|
| 100 |
+
logger.error(f"Failed to update setting {key}: {str(e)}")
|
| 101 |
+
return False
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
async def add_error_log(
|
| 105 |
+
gemini_key: Optional[str] = None,
|
| 106 |
+
model_name: Optional[str] = None,
|
| 107 |
+
error_type: Optional[str] = None,
|
| 108 |
+
error_log: Optional[str] = None,
|
| 109 |
+
error_code: Optional[int] = None,
|
| 110 |
+
request_msg: Optional[Union[Dict[str, Any], str]] = None,
|
| 111 |
+
request_datetime: Optional[datetime] = None,
|
| 112 |
+
) -> bool:
|
| 113 |
+
"""
|
| 114 |
+
添加错误日志
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
gemini_key: Gemini API密钥
|
| 118 |
+
error_log: 错误日志
|
| 119 |
+
error_code: 错误代码 (例如 HTTP 状态码)
|
| 120 |
+
request_msg: 请求消息
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
bool: 是否添加成功
|
| 124 |
+
"""
|
| 125 |
+
try:
|
| 126 |
+
if request_msg is None:
|
| 127 |
+
request_msg_json = None
|
| 128 |
+
else:
|
| 129 |
+
# 如果request_msg是字典,则转换为JSON字符串
|
| 130 |
+
if isinstance(request_msg, dict):
|
| 131 |
+
request_msg_json = request_msg
|
| 132 |
+
elif isinstance(request_msg, str):
|
| 133 |
+
try:
|
| 134 |
+
request_msg_json = json.loads(request_msg)
|
| 135 |
+
except json.JSONDecodeError:
|
| 136 |
+
request_msg_json = {"message": request_msg}
|
| 137 |
+
else:
|
| 138 |
+
request_msg_json = None
|
| 139 |
+
|
| 140 |
+
# 插入错误日志
|
| 141 |
+
query = insert(ErrorLog).values(
|
| 142 |
+
gemini_key=gemini_key,
|
| 143 |
+
error_type=error_type,
|
| 144 |
+
error_log=error_log,
|
| 145 |
+
model_name=model_name,
|
| 146 |
+
error_code=error_code,
|
| 147 |
+
request_msg=request_msg_json,
|
| 148 |
+
request_time=(request_datetime if request_datetime else datetime.now()),
|
| 149 |
+
)
|
| 150 |
+
await database.execute(query)
|
| 151 |
+
logger.info(f"Added error log for key: {redact_key_for_logging(gemini_key)}")
|
| 152 |
+
return True
|
| 153 |
+
except Exception as e:
|
| 154 |
+
logger.error(f"Failed to add error log: {str(e)}")
|
| 155 |
+
return False
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
async def get_error_logs(
|
| 159 |
+
limit: int = 20,
|
| 160 |
+
offset: int = 0,
|
| 161 |
+
key_search: Optional[str] = None,
|
| 162 |
+
error_search: Optional[str] = None,
|
| 163 |
+
error_code_search: Optional[str] = None,
|
| 164 |
+
start_date: Optional[datetime] = None,
|
| 165 |
+
end_date: Optional[datetime] = None,
|
| 166 |
+
sort_by: str = "id",
|
| 167 |
+
sort_order: str = "desc",
|
| 168 |
+
) -> List[Dict[str, Any]]:
|
| 169 |
+
"""
|
| 170 |
+
获取错误日志,支��搜索、日期过滤和排序
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
limit (int): 限制数量
|
| 174 |
+
offset (int): 偏移量
|
| 175 |
+
key_search (Optional[str]): Gemini密钥搜索词 (模糊匹配)
|
| 176 |
+
error_search (Optional[str]): 错误类型或日志内容搜索词 (模糊匹配)
|
| 177 |
+
error_code_search (Optional[str]): 错误码搜索词 (精确匹配)
|
| 178 |
+
start_date (Optional[datetime]): 开始日期时间
|
| 179 |
+
end_date (Optional[datetime]): 结束日期时间
|
| 180 |
+
sort_by (str): 排序字段 (例如 'id', 'request_time')
|
| 181 |
+
sort_order (str): 排序顺序 ('asc' or 'desc')
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
List[Dict[str, Any]]: 错误日志列表
|
| 185 |
+
"""
|
| 186 |
+
try:
|
| 187 |
+
query = select(
|
| 188 |
+
ErrorLog.id,
|
| 189 |
+
ErrorLog.gemini_key,
|
| 190 |
+
ErrorLog.model_name,
|
| 191 |
+
ErrorLog.error_type,
|
| 192 |
+
ErrorLog.error_log,
|
| 193 |
+
ErrorLog.error_code,
|
| 194 |
+
ErrorLog.request_time,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
if key_search:
|
| 198 |
+
query = query.where(ErrorLog.gemini_key.ilike(f"%{key_search}%"))
|
| 199 |
+
if error_search:
|
| 200 |
+
query = query.where(
|
| 201 |
+
(ErrorLog.error_type.ilike(f"%{error_search}%"))
|
| 202 |
+
| (ErrorLog.error_log.ilike(f"%{error_search}%"))
|
| 203 |
+
)
|
| 204 |
+
if start_date:
|
| 205 |
+
query = query.where(ErrorLog.request_time >= start_date)
|
| 206 |
+
if end_date:
|
| 207 |
+
query = query.where(ErrorLog.request_time < end_date)
|
| 208 |
+
if error_code_search:
|
| 209 |
+
try:
|
| 210 |
+
error_code_int = int(error_code_search)
|
| 211 |
+
query = query.where(ErrorLog.error_code == error_code_int)
|
| 212 |
+
except ValueError:
|
| 213 |
+
logger.warning(
|
| 214 |
+
f"Invalid format for error_code_search: '{error_code_search}'. Expected an integer. Skipping error code filter."
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
sort_column = getattr(ErrorLog, sort_by, ErrorLog.id)
|
| 218 |
+
if sort_order.lower() == "asc":
|
| 219 |
+
query = query.order_by(asc(sort_column))
|
| 220 |
+
else:
|
| 221 |
+
query = query.order_by(desc(sort_column))
|
| 222 |
+
|
| 223 |
+
query = query.limit(limit).offset(offset)
|
| 224 |
+
|
| 225 |
+
result = await database.fetch_all(query)
|
| 226 |
+
return [dict(row) for row in result]
|
| 227 |
+
except Exception as e:
|
| 228 |
+
logger.exception(f"Failed to get error logs with filters: {str(e)}")
|
| 229 |
+
raise
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
async def get_error_logs_count(
|
| 233 |
+
key_search: Optional[str] = None,
|
| 234 |
+
error_search: Optional[str] = None,
|
| 235 |
+
error_code_search: Optional[str] = None,
|
| 236 |
+
start_date: Optional[datetime] = None,
|
| 237 |
+
end_date: Optional[datetime] = None,
|
| 238 |
+
) -> int:
|
| 239 |
+
"""
|
| 240 |
+
获取符合条件的错误日志总数
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
key_search (Optional[str]): Gemini密钥搜索词 (模糊匹配)
|
| 244 |
+
error_search (Optional[str]): 错误类型或日志内容搜索词 (模糊匹配)
|
| 245 |
+
error_code_search (Optional[str]): 错误码搜索词 (精确匹配)
|
| 246 |
+
start_date (Optional[datetime]): 开始日期时间
|
| 247 |
+
end_date (Optional[datetime]): 结束日期时间
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
int: 日志总数
|
| 251 |
+
"""
|
| 252 |
+
try:
|
| 253 |
+
query = select(func.count()).select_from(ErrorLog)
|
| 254 |
+
|
| 255 |
+
if key_search:
|
| 256 |
+
query = query.where(ErrorLog.gemini_key.ilike(f"%{key_search}%"))
|
| 257 |
+
if error_search:
|
| 258 |
+
query = query.where(
|
| 259 |
+
(ErrorLog.error_type.ilike(f"%{error_search}%"))
|
| 260 |
+
| (ErrorLog.error_log.ilike(f"%{error_search}%"))
|
| 261 |
+
)
|
| 262 |
+
if start_date:
|
| 263 |
+
query = query.where(ErrorLog.request_time >= start_date)
|
| 264 |
+
if end_date:
|
| 265 |
+
query = query.where(ErrorLog.request_time < end_date)
|
| 266 |
+
if error_code_search:
|
| 267 |
+
try:
|
| 268 |
+
error_code_int = int(error_code_search)
|
| 269 |
+
query = query.where(ErrorLog.error_code == error_code_int)
|
| 270 |
+
except ValueError:
|
| 271 |
+
logger.warning(
|
| 272 |
+
f"Invalid format for error_code_search in count: '{error_code_search}'. Expected an integer. Skipping error code filter."
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
count_result = await database.fetch_one(query)
|
| 276 |
+
return count_result[0] if count_result else 0
|
| 277 |
+
except Exception as e:
|
| 278 |
+
logger.exception(f"Failed to count error logs with filters: {str(e)}")
|
| 279 |
+
raise
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
# 新增函数:获取单条错误日志详情
|
| 283 |
+
async def get_error_log_details(log_id: int) -> Optional[Dict[str, Any]]:
|
| 284 |
+
"""
|
| 285 |
+
根据 ID 获取单个错误日志的详细信息
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
log_id (int): 错误日志的 ID
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
Optional[Dict[str, Any]]: 包含日志详细信息的字典,如果未找到则返回 None
|
| 292 |
+
"""
|
| 293 |
+
try:
|
| 294 |
+
query = select(ErrorLog).where(ErrorLog.id == log_id)
|
| 295 |
+
result = await database.fetch_one(query)
|
| 296 |
+
if result:
|
| 297 |
+
# 将 request_msg (JSONB) 转换为字符串以便在 API 中返回
|
| 298 |
+
log_dict = dict(result)
|
| 299 |
+
if "request_msg" in log_dict and log_dict["request_msg"] is not None:
|
| 300 |
+
# 确保即使是 None 或非 JSON 数据也能处理
|
| 301 |
+
try:
|
| 302 |
+
log_dict["request_msg"] = json.dumps(
|
| 303 |
+
log_dict["request_msg"], ensure_ascii=False, indent=2
|
| 304 |
+
)
|
| 305 |
+
except TypeError:
|
| 306 |
+
log_dict["request_msg"] = str(log_dict["request_msg"])
|
| 307 |
+
return log_dict
|
| 308 |
+
else:
|
| 309 |
+
return None
|
| 310 |
+
except Exception as e:
|
| 311 |
+
logger.exception(f"Failed to get error log details for ID {log_id}: {str(e)}")
|
| 312 |
+
raise
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
# 新增函数:通过 gemini_key / error_code / 时间窗口 查找最接近的错误日志
|
| 316 |
+
async def find_error_log_by_info(
|
| 317 |
+
gemini_key: str,
|
| 318 |
+
timestamp: datetime,
|
| 319 |
+
status_code: Optional[int] = None,
|
| 320 |
+
window_seconds: int = 1,
|
| 321 |
+
) -> Optional[Dict[str, Any]]:
|
| 322 |
+
"""
|
| 323 |
+
在给定时间窗口内,根据 gemini_key(精确匹配)及可选的 status_code 查找最接近 timestamp 的错误日志。
|
| 324 |
+
|
| 325 |
+
假设错误日志的 error_code 存储的是 HTTP 状态码或等价错误码。
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
gemini_key: 完整的 Gemini key 字符串。
|
| 329 |
+
timestamp: 目标时间(UTC 或本地,与存储一致)。
|
| 330 |
+
status_code: 可选的错误码,若提供则优先匹配该错误码。
|
| 331 |
+
window_seconds: 允许的时间偏差窗口,单位秒,默认为 1 秒。
|
| 332 |
+
|
| 333 |
+
Returns:
|
| 334 |
+
Optional[Dict[str, Any]]: 最匹配的一条错误日志的完整详情(字段与 get_error_log_details 一致),若未找到则返回 None。
|
| 335 |
+
"""
|
| 336 |
+
try:
|
| 337 |
+
start_time = timestamp - timedelta(seconds=window_seconds)
|
| 338 |
+
end_time = timestamp + timedelta(seconds=window_seconds)
|
| 339 |
+
|
| 340 |
+
base_query = select(ErrorLog).where(
|
| 341 |
+
ErrorLog.gemini_key == gemini_key,
|
| 342 |
+
ErrorLog.request_time >= start_time,
|
| 343 |
+
ErrorLog.request_time <= end_time,
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
# 若提供了状态码,先尝试按状态码过滤
|
| 347 |
+
if status_code is not None:
|
| 348 |
+
query = base_query.where(ErrorLog.error_code == status_code).order_by(
|
| 349 |
+
ErrorLog.request_time.desc()
|
| 350 |
+
)
|
| 351 |
+
candidates = await database.fetch_all(query)
|
| 352 |
+
if not candidates:
|
| 353 |
+
# 回退:不按状态码,仅按时间窗口
|
| 354 |
+
query2 = base_query.order_by(ErrorLog.request_time.desc())
|
| 355 |
+
candidates = await database.fetch_all(query2)
|
| 356 |
+
else:
|
| 357 |
+
query = base_query.order_by(ErrorLog.request_time.desc())
|
| 358 |
+
candidates = await database.fetch_all(query)
|
| 359 |
+
|
| 360 |
+
if not candidates:
|
| 361 |
+
return None
|
| 362 |
+
|
| 363 |
+
# 在 Python 中选择与 timestamp 最接近的一条
|
| 364 |
+
def _to_dict(row: Any) -> Dict[str, Any]:
|
| 365 |
+
d = dict(row)
|
| 366 |
+
if "request_msg" in d and d["request_msg"] is not None:
|
| 367 |
+
try:
|
| 368 |
+
d["request_msg"] = json.dumps(
|
| 369 |
+
d["request_msg"], ensure_ascii=False, indent=2
|
| 370 |
+
)
|
| 371 |
+
except TypeError:
|
| 372 |
+
d["request_msg"] = str(d["request_msg"])
|
| 373 |
+
return d
|
| 374 |
+
|
| 375 |
+
best = min(
|
| 376 |
+
candidates,
|
| 377 |
+
key=lambda r: abs((r["request_time"] - timestamp).total_seconds()),
|
| 378 |
+
)
|
| 379 |
+
return _to_dict(best)
|
| 380 |
+
except Exception as e:
|
| 381 |
+
logger.exception(
|
| 382 |
+
f"Failed to find error log by info (key=***{gemini_key[-4:] if gemini_key else ''}, code={status_code}, ts={timestamp}, window={window_seconds}s): {str(e)}"
|
| 383 |
+
)
|
| 384 |
+
raise
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
async def delete_error_logs_by_ids(log_ids: List[int]) -> int:
|
| 388 |
+
"""
|
| 389 |
+
根据提供的 ID 列表批量删除错误日志 (异步)。
|
| 390 |
+
|
| 391 |
+
Args:
|
| 392 |
+
log_ids: 要删除的错误日志 ID 列表。
|
| 393 |
+
|
| 394 |
+
Returns:
|
| 395 |
+
int: 实际删除的日志数量。
|
| 396 |
+
"""
|
| 397 |
+
if not log_ids:
|
| 398 |
+
return 0
|
| 399 |
+
try:
|
| 400 |
+
# 使用 databases 执行删除
|
| 401 |
+
query = delete(ErrorLog).where(ErrorLog.id.in_(log_ids))
|
| 402 |
+
# execute 返回受影响的行数,但 databases 库的 execute 不直接返回 rowcount
|
| 403 |
+
# 我们需要先查询是否存在,或者依赖数据库约束/触发器(如果适用)
|
| 404 |
+
# 或者,我们可以执行删除并假设成功,除非抛出异常
|
| 405 |
+
# 为了简单起见,我们执行删除并记录日志,不精确返回删除数量
|
| 406 |
+
# 如果需要精确数量,需要先执行 SELECT COUNT(*)
|
| 407 |
+
await database.execute(query)
|
| 408 |
+
# 注意:databases 的 execute 不返回 rowcount,所以我们不能直接返回删除的数量
|
| 409 |
+
# 返回 log_ids 的长度作为尝试删除的数量,或者返回 0/1 表示操作尝试
|
| 410 |
+
logger.info(f"Attempted bulk deletion for error logs with IDs: {log_ids}")
|
| 411 |
+
return len(log_ids) # 返回尝试删除的数量
|
| 412 |
+
except Exception as e:
|
| 413 |
+
# 数据库连接或执行错误
|
| 414 |
+
logger.error(
|
| 415 |
+
f"Error during bulk deletion of error logs {log_ids}: {e}", exc_info=True
|
| 416 |
+
)
|
| 417 |
+
raise
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
async def delete_error_log_by_id(log_id: int) -> bool:
|
| 421 |
+
"""
|
| 422 |
+
根据 ID 删除单个错误日志 (异步)。
|
| 423 |
+
|
| 424 |
+
Args:
|
| 425 |
+
log_id: 要删除的错误日志 ID。
|
| 426 |
+
|
| 427 |
+
Returns:
|
| 428 |
+
bool: 如果成功删除返回 True,否则返回 False。
|
| 429 |
+
"""
|
| 430 |
+
try:
|
| 431 |
+
# 先检查是否存在 (可选,但更明确)
|
| 432 |
+
check_query = select(ErrorLog.id).where(ErrorLog.id == log_id)
|
| 433 |
+
exists = await database.fetch_one(check_query)
|
| 434 |
+
|
| 435 |
+
if not exists:
|
| 436 |
+
logger.warning(
|
| 437 |
+
f"Attempted to delete non-existent error log with ID: {log_id}"
|
| 438 |
+
)
|
| 439 |
+
return False
|
| 440 |
+
|
| 441 |
+
# 执行删除
|
| 442 |
+
delete_query = delete(ErrorLog).where(ErrorLog.id == log_id)
|
| 443 |
+
await database.execute(delete_query)
|
| 444 |
+
logger.info(f"Successfully deleted error log with ID: {log_id}")
|
| 445 |
+
return True
|
| 446 |
+
except Exception as e:
|
| 447 |
+
logger.error(f"Error deleting error log with ID {log_id}: {e}", exc_info=True)
|
| 448 |
+
raise
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
async def delete_all_error_logs() -> int:
|
| 452 |
+
"""
|
| 453 |
+
分批删除所有错误日志,以避免大数据量下的超时和性能问题。
|
| 454 |
+
|
| 455 |
+
Returns:
|
| 456 |
+
int: 被删除的错误日志总数。
|
| 457 |
+
"""
|
| 458 |
+
total_deleted_count = 0
|
| 459 |
+
# SQLite 对 SQL 参数数量有上限(常见为 999),IN 子句中过多参数会报错
|
| 460 |
+
# 统一使用 500,兼容 SQLite/MySQL,必要时可在配置中暴露该值
|
| 461 |
+
batch_size = 200
|
| 462 |
+
|
| 463 |
+
try:
|
| 464 |
+
while True:
|
| 465 |
+
# 1) 读取一批待删除的ID,仅选择ID列以提升效率
|
| 466 |
+
id_query = select(ErrorLog.id).order_by(ErrorLog.id).limit(batch_size)
|
| 467 |
+
rows = await database.fetch_all(id_query)
|
| 468 |
+
if not rows:
|
| 469 |
+
break
|
| 470 |
+
|
| 471 |
+
ids = [row["id"] for row in rows]
|
| 472 |
+
|
| 473 |
+
# 2) 按ID批量删除
|
| 474 |
+
delete_query = delete(ErrorLog).where(ErrorLog.id.in_(ids))
|
| 475 |
+
await database.execute(delete_query)
|
| 476 |
+
|
| 477 |
+
deleted_in_batch = len(ids)
|
| 478 |
+
total_deleted_count += deleted_in_batch
|
| 479 |
+
|
| 480 |
+
logger.debug(f"Deleted a batch of {deleted_in_batch} error logs.")
|
| 481 |
+
|
| 482 |
+
# 若不足一个批次,说明已删除完成
|
| 483 |
+
if deleted_in_batch < batch_size:
|
| 484 |
+
break
|
| 485 |
+
|
| 486 |
+
# 3) 将控制权交还事件循环,缓解长时间占用
|
| 487 |
+
await asyncio.sleep(0)
|
| 488 |
+
|
| 489 |
+
logger.info(
|
| 490 |
+
f"Successfully deleted all error logs in batches. Total deleted: {total_deleted_count}"
|
| 491 |
+
)
|
| 492 |
+
return total_deleted_count
|
| 493 |
+
except Exception as e:
|
| 494 |
+
logger.error(
|
| 495 |
+
f"Failed to delete all error logs in batches: {str(e)}", exc_info=True
|
| 496 |
+
)
|
| 497 |
+
raise
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
# 新增函数:添加请求日志
|
| 501 |
+
async def add_request_log(
|
| 502 |
+
model_name: Optional[str],
|
| 503 |
+
api_key: Optional[str],
|
| 504 |
+
is_success: bool,
|
| 505 |
+
status_code: Optional[int] = None,
|
| 506 |
+
latency_ms: Optional[int] = None,
|
| 507 |
+
request_time: Optional[datetime] = None,
|
| 508 |
+
) -> bool:
|
| 509 |
+
"""
|
| 510 |
+
添加 API 请求日志
|
| 511 |
+
|
| 512 |
+
Args:
|
| 513 |
+
model_name: 模型名称
|
| 514 |
+
api_key: 使用的 API 密钥
|
| 515 |
+
is_success: 请求是否成功
|
| 516 |
+
status_code: API 响应状态码
|
| 517 |
+
latency_ms: 请求耗时(毫秒)
|
| 518 |
+
request_time: 请求发生时间 (如果为 None, 则使用当前时间)
|
| 519 |
+
|
| 520 |
+
Returns:
|
| 521 |
+
bool: 是否添加成功
|
| 522 |
+
"""
|
| 523 |
+
try:
|
| 524 |
+
log_time = request_time if request_time else datetime.now()
|
| 525 |
+
|
| 526 |
+
query = insert(RequestLog).values(
|
| 527 |
+
request_time=log_time,
|
| 528 |
+
model_name=model_name,
|
| 529 |
+
api_key=api_key,
|
| 530 |
+
is_success=is_success,
|
| 531 |
+
status_code=status_code,
|
| 532 |
+
latency_ms=latency_ms,
|
| 533 |
+
)
|
| 534 |
+
await database.execute(query)
|
| 535 |
+
return True
|
| 536 |
+
except Exception as e:
|
| 537 |
+
logger.error(f"Failed to add request log: {str(e)}")
|
| 538 |
+
return False
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
# ==================== 文件记录相关函数 ====================
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
async def create_file_record(
|
| 545 |
+
name: str,
|
| 546 |
+
mime_type: str,
|
| 547 |
+
size_bytes: int,
|
| 548 |
+
api_key: str,
|
| 549 |
+
uri: str,
|
| 550 |
+
create_time: datetime,
|
| 551 |
+
update_time: datetime,
|
| 552 |
+
expiration_time: datetime,
|
| 553 |
+
state: FileState = FileState.PROCESSING,
|
| 554 |
+
display_name: Optional[str] = None,
|
| 555 |
+
sha256_hash: Optional[str] = None,
|
| 556 |
+
upload_url: Optional[str] = None,
|
| 557 |
+
user_token: Optional[str] = None,
|
| 558 |
+
) -> Dict[str, Any]:
|
| 559 |
+
"""
|
| 560 |
+
创建文件记录
|
| 561 |
+
|
| 562 |
+
Args:
|
| 563 |
+
name: 文件名称(格式: files/{file_id})
|
| 564 |
+
mime_type: MIME 类型
|
| 565 |
+
size_bytes: 文件大小(字节)
|
| 566 |
+
api_key: 上传时使用的 API Key
|
| 567 |
+
uri: 文件访问 URI
|
| 568 |
+
create_time: 创建时间
|
| 569 |
+
update_time: 更新时间
|
| 570 |
+
expiration_time: 过期时间
|
| 571 |
+
display_name: 显示名称
|
| 572 |
+
sha256_hash: SHA256 哈希值
|
| 573 |
+
upload_url: 临时上传 URL
|
| 574 |
+
user_token: 上传用户的 token
|
| 575 |
+
|
| 576 |
+
Returns:
|
| 577 |
+
Dict[str, Any]: 创建的文件记录
|
| 578 |
+
"""
|
| 579 |
+
try:
|
| 580 |
+
query = insert(FileRecord).values(
|
| 581 |
+
name=name,
|
| 582 |
+
display_name=display_name,
|
| 583 |
+
mime_type=mime_type,
|
| 584 |
+
size_bytes=size_bytes,
|
| 585 |
+
sha256_hash=sha256_hash,
|
| 586 |
+
state=state,
|
| 587 |
+
create_time=create_time,
|
| 588 |
+
update_time=update_time,
|
| 589 |
+
expiration_time=expiration_time,
|
| 590 |
+
uri=uri,
|
| 591 |
+
api_key=api_key,
|
| 592 |
+
upload_url=upload_url,
|
| 593 |
+
user_token=user_token,
|
| 594 |
+
)
|
| 595 |
+
await database.execute(query)
|
| 596 |
+
|
| 597 |
+
# 返回创建的记录
|
| 598 |
+
return await get_file_record_by_name(name)
|
| 599 |
+
except Exception as e:
|
| 600 |
+
logger.error(f"Failed to create file record: {str(e)}")
|
| 601 |
+
raise
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
async def get_file_record_by_name(name: str) -> Optional[Dict[str, Any]]:
|
| 605 |
+
"""
|
| 606 |
+
根据文件名获取文件记录
|
| 607 |
+
|
| 608 |
+
Args:
|
| 609 |
+
name: 文件名称(格式: files/{file_id})
|
| 610 |
+
|
| 611 |
+
Returns:
|
| 612 |
+
Optional[Dict[str, Any]]: 文件记录,如果不存在则返回 None
|
| 613 |
+
"""
|
| 614 |
+
try:
|
| 615 |
+
query = select(FileRecord).where(FileRecord.name == name)
|
| 616 |
+
result = await database.fetch_one(query)
|
| 617 |
+
return dict(result) if result else None
|
| 618 |
+
except Exception as e:
|
| 619 |
+
logger.error(f"Failed to get file record by name {name}: {str(e)}")
|
| 620 |
+
raise
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
async def update_file_record_state(
|
| 624 |
+
file_name: str,
|
| 625 |
+
state: FileState,
|
| 626 |
+
update_time: Optional[datetime] = None,
|
| 627 |
+
upload_completed: Optional[datetime] = None,
|
| 628 |
+
sha256_hash: Optional[str] = None,
|
| 629 |
+
) -> bool:
|
| 630 |
+
"""
|
| 631 |
+
更新文件记录状态
|
| 632 |
+
|
| 633 |
+
Args:
|
| 634 |
+
file_name: 文件名
|
| 635 |
+
state: 新状态
|
| 636 |
+
update_time: 更新时间
|
| 637 |
+
upload_completed: 上传完成时间
|
| 638 |
+
sha256_hash: SHA256 哈希值
|
| 639 |
+
|
| 640 |
+
Returns:
|
| 641 |
+
bool: 是否更新成功
|
| 642 |
+
"""
|
| 643 |
+
try:
|
| 644 |
+
values = {"state": state}
|
| 645 |
+
if update_time:
|
| 646 |
+
values["update_time"] = update_time
|
| 647 |
+
if upload_completed:
|
| 648 |
+
values["upload_completed"] = upload_completed
|
| 649 |
+
if sha256_hash:
|
| 650 |
+
values["sha256_hash"] = sha256_hash
|
| 651 |
+
|
| 652 |
+
query = update(FileRecord).where(FileRecord.name == file_name).values(**values)
|
| 653 |
+
result = await database.execute(query)
|
| 654 |
+
|
| 655 |
+
if result:
|
| 656 |
+
logger.info(f"Updated file record state for {file_name} to {state}")
|
| 657 |
+
return True
|
| 658 |
+
|
| 659 |
+
logger.warning(f"File record not found for update: {file_name}")
|
| 660 |
+
return False
|
| 661 |
+
except Exception as e:
|
| 662 |
+
logger.error(f"Failed to update file record state: {str(e)}")
|
| 663 |
+
return False
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
async def list_file_records(
|
| 667 |
+
user_token: Optional[str] = None,
|
| 668 |
+
api_key: Optional[str] = None,
|
| 669 |
+
page_size: int = 10,
|
| 670 |
+
page_token: Optional[str] = None,
|
| 671 |
+
) -> tuple[List[Dict[str, Any]], Optional[str]]:
|
| 672 |
+
"""
|
| 673 |
+
列出文件记录
|
| 674 |
+
|
| 675 |
+
Args:
|
| 676 |
+
user_token: 用户 token(如果提供,只返回该用户的文件)
|
| 677 |
+
api_key: API Key(如果提供,只返回使用该 key 的文件)
|
| 678 |
+
page_size: 每页大小
|
| 679 |
+
page_token: 分页标记(偏移量)
|
| 680 |
+
|
| 681 |
+
Returns:
|
| 682 |
+
tuple[List[Dict[str, Any]], Optional[str]]: (文件列表, 下一页标记)
|
| 683 |
+
"""
|
| 684 |
+
try:
|
| 685 |
+
logger.debug(
|
| 686 |
+
f"list_file_records called with page_size={page_size}, page_token={page_token}"
|
| 687 |
+
)
|
| 688 |
+
query = select(FileRecord).where(
|
| 689 |
+
FileRecord.expiration_time > datetime.now(timezone.utc)
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
if user_token:
|
| 693 |
+
query = query.where(FileRecord.user_token == user_token)
|
| 694 |
+
if api_key:
|
| 695 |
+
query = query.where(FileRecord.api_key == api_key)
|
| 696 |
+
|
| 697 |
+
# 使用偏移量进行分页
|
| 698 |
+
offset = 0
|
| 699 |
+
if page_token:
|
| 700 |
+
try:
|
| 701 |
+
offset = int(page_token)
|
| 702 |
+
except ValueError:
|
| 703 |
+
logger.warning(f"Invalid page token: {page_token}")
|
| 704 |
+
offset = 0
|
| 705 |
+
|
| 706 |
+
# 按ID升序排列,使用 OFFSET 和 LIMIT
|
| 707 |
+
query = query.order_by(FileRecord.id).offset(offset).limit(page_size + 1)
|
| 708 |
+
|
| 709 |
+
results = await database.fetch_all(query)
|
| 710 |
+
|
| 711 |
+
logger.debug(f"Query returned {len(results)} records")
|
| 712 |
+
if results:
|
| 713 |
+
logger.debug(
|
| 714 |
+
f"First record ID: {results[0]['id']}, Last record ID: {results[-1]['id']}"
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
# 处理分页
|
| 718 |
+
has_next = len(results) > page_size
|
| 719 |
+
if has_next:
|
| 720 |
+
results = results[:page_size]
|
| 721 |
+
# 下一页的偏移量是当前偏移量加上本页返回的记录数
|
| 722 |
+
next_offset = offset + page_size
|
| 723 |
+
next_page_token = str(next_offset)
|
| 724 |
+
logger.debug(
|
| 725 |
+
f"Has next page, offset={offset}, page_size={page_size}, next_page_token={next_page_token}"
|
| 726 |
+
)
|
| 727 |
+
else:
|
| 728 |
+
next_page_token = None
|
| 729 |
+
logger.debug(f"No next page, returning {len(results)} results")
|
| 730 |
+
|
| 731 |
+
return [dict(row) for row in results], next_page_token
|
| 732 |
+
except Exception as e:
|
| 733 |
+
logger.error(f"Failed to list file records: {str(e)}")
|
| 734 |
+
raise
|
| 735 |
+
|
| 736 |
+
|
| 737 |
+
async def delete_file_record(name: str) -> bool:
|
| 738 |
+
"""
|
| 739 |
+
删除文件记录
|
| 740 |
+
|
| 741 |
+
Args:
|
| 742 |
+
name: 文件名称
|
| 743 |
+
|
| 744 |
+
Returns:
|
| 745 |
+
bool: 是否删除成功
|
| 746 |
+
"""
|
| 747 |
+
try:
|
| 748 |
+
query = delete(FileRecord).where(FileRecord.name == name)
|
| 749 |
+
await database.execute(query)
|
| 750 |
+
return True
|
| 751 |
+
except Exception as e:
|
| 752 |
+
logger.error(f"Failed to delete file record: {str(e)}")
|
| 753 |
+
return False
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
async def delete_expired_file_records() -> List[Dict[str, Any]]:
|
| 757 |
+
"""
|
| 758 |
+
删除已过期的文件记录
|
| 759 |
+
|
| 760 |
+
Returns:
|
| 761 |
+
List[Dict[str, Any]]: 删除的记录列表
|
| 762 |
+
"""
|
| 763 |
+
try:
|
| 764 |
+
# 先获取要删除的记录
|
| 765 |
+
query = select(FileRecord).where(
|
| 766 |
+
FileRecord.expiration_time <= datetime.now(timezone.utc)
|
| 767 |
+
)
|
| 768 |
+
expired_records = await database.fetch_all(query)
|
| 769 |
+
|
| 770 |
+
if not expired_records:
|
| 771 |
+
return []
|
| 772 |
+
|
| 773 |
+
# 执行删除
|
| 774 |
+
delete_query = delete(FileRecord).where(
|
| 775 |
+
FileRecord.expiration_time <= datetime.now(timezone.utc)
|
| 776 |
+
)
|
| 777 |
+
await database.execute(delete_query)
|
| 778 |
+
|
| 779 |
+
logger.info(f"Deleted {len(expired_records)} expired file records")
|
| 780 |
+
return [dict(record) for record in expired_records]
|
| 781 |
+
except Exception as e:
|
| 782 |
+
logger.error(f"Failed to delete expired file records: {str(e)}")
|
| 783 |
+
raise
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
async def get_file_api_key(name: str) -> Optional[str]:
|
| 787 |
+
"""
|
| 788 |
+
获取文件对应的 API Key
|
| 789 |
+
|
| 790 |
+
Args:
|
| 791 |
+
name: 文件名称
|
| 792 |
+
|
| 793 |
+
Returns:
|
| 794 |
+
Optional[str]: API Key,如果文件不存在或已过期则返回 None
|
| 795 |
+
"""
|
| 796 |
+
try:
|
| 797 |
+
query = select(FileRecord.api_key).where(
|
| 798 |
+
(FileRecord.name == name)
|
| 799 |
+
& (FileRecord.expiration_time > datetime.now(timezone.utc))
|
| 800 |
+
)
|
| 801 |
+
result = await database.fetch_one(query)
|
| 802 |
+
return result["api_key"] if result else None
|
| 803 |
+
except Exception as e:
|
| 804 |
+
logger.error(f"Failed to get file API key: {str(e)}")
|
| 805 |
+
raise
|
app/domain/file_models.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Files API 相关的领域模型
|
| 3 |
+
"""
|
| 4 |
+
from typing import Optional, Dict, Any, List
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from pydantic import BaseModel, Field
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class FileUploadConfig(BaseModel):
|
| 10 |
+
"""文件上传配置"""
|
| 11 |
+
mime_type: Optional[str] = Field(None, description="MIME 类型")
|
| 12 |
+
display_name: Optional[str] = Field(None, description="显示名称,最多40个字符")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class CreateFileRequest(BaseModel):
|
| 16 |
+
"""创建文件请求(用于初始化上传)"""
|
| 17 |
+
file: Optional[Dict[str, Any]] = Field(None, description="文件元数据")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class FileMetadata(BaseModel):
|
| 21 |
+
"""文件元数据响应"""
|
| 22 |
+
name: str = Field(..., description="文件名称,格式: files/{file_id}")
|
| 23 |
+
displayName: Optional[str] = Field(None, description="显示名称")
|
| 24 |
+
mimeType: str = Field(..., description="MIME 类型")
|
| 25 |
+
sizeBytes: str = Field(..., description="文件大小(字节)")
|
| 26 |
+
createTime: str = Field(..., description="创建时间 (RFC3339)")
|
| 27 |
+
updateTime: str = Field(..., description="更新时间 (RFC3339)")
|
| 28 |
+
expirationTime: str = Field(..., description="过期时间 (RFC3339)")
|
| 29 |
+
sha256Hash: Optional[str] = Field(None, description="SHA256 哈希值")
|
| 30 |
+
uri: str = Field(..., description="文件访问 URI")
|
| 31 |
+
state: str = Field(..., description="文件状态")
|
| 32 |
+
|
| 33 |
+
class Config:
|
| 34 |
+
json_encoders = {
|
| 35 |
+
datetime: lambda v: v.isoformat() + "Z"
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ListFilesRequest(BaseModel):
|
| 40 |
+
"""列出文件请求参数"""
|
| 41 |
+
pageSize: Optional[int] = Field(10, ge=1, le=100, description="每页大小")
|
| 42 |
+
pageToken: Optional[str] = Field(None, description="分页标记")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class ListFilesResponse(BaseModel):
|
| 46 |
+
"""列出文件响应"""
|
| 47 |
+
files: List[FileMetadata] = Field(default_factory=list, description="文件列表")
|
| 48 |
+
nextPageToken: Optional[str] = Field(None, description="下一页标记")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class UploadInitResponse(BaseModel):
|
| 52 |
+
"""上传初始化响应(内部使用)"""
|
| 53 |
+
file_metadata: FileMetadata
|
| 54 |
+
upload_url: str
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class FileKeyMapping(BaseModel):
|
| 58 |
+
"""文件与 API Key 的映射关系(内部使用)"""
|
| 59 |
+
file_name: str
|
| 60 |
+
api_key: str
|
| 61 |
+
user_token: str
|
| 62 |
+
created_at: datetime
|
| 63 |
+
expires_at: datetime
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class DeleteFileResponse(BaseModel):
|
| 67 |
+
"""删除文件响应"""
|
| 68 |
+
success: bool = Field(..., description="是否删除成功")
|
| 69 |
+
message: Optional[str] = Field(None, description="消息")
|
app/domain/gemini_models.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Literal, Optional, Union
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel, Field
|
| 4 |
+
|
| 5 |
+
from app.core.constants import DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SafetySetting(BaseModel):
|
| 9 |
+
category: Optional[
|
| 10 |
+
Literal[
|
| 11 |
+
"HARM_CATEGORY_HATE_SPEECH",
|
| 12 |
+
"HARM_CATEGORY_DANGEROUS_CONTENT",
|
| 13 |
+
"HARM_CATEGORY_HARASSMENT",
|
| 14 |
+
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
| 15 |
+
"HARM_CATEGORY_CIVIC_INTEGRITY",
|
| 16 |
+
]
|
| 17 |
+
] = None
|
| 18 |
+
threshold: Optional[
|
| 19 |
+
Literal[
|
| 20 |
+
"HARM_BLOCK_THRESHOLD_UNSPECIFIED",
|
| 21 |
+
"BLOCK_LOW_AND_ABOVE",
|
| 22 |
+
"BLOCK_MEDIUM_AND_ABOVE",
|
| 23 |
+
"BLOCK_ONLY_HIGH",
|
| 24 |
+
"BLOCK_NONE",
|
| 25 |
+
"OFF",
|
| 26 |
+
]
|
| 27 |
+
] = None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class GenerationConfig(BaseModel):
|
| 31 |
+
stopSequences: Optional[List[str]] = None
|
| 32 |
+
responseMimeType: Optional[str] = None
|
| 33 |
+
responseSchema: Optional[Dict[str, Any]] = None
|
| 34 |
+
candidateCount: Optional[int] = 1
|
| 35 |
+
maxOutputTokens: Optional[int] = None
|
| 36 |
+
temperature: Optional[float] = DEFAULT_TEMPERATURE
|
| 37 |
+
topP: Optional[float] = DEFAULT_TOP_P
|
| 38 |
+
topK: Optional[int] = DEFAULT_TOP_K
|
| 39 |
+
presencePenalty: Optional[float] = None
|
| 40 |
+
frequencyPenalty: Optional[float] = None
|
| 41 |
+
responseLogprobs: Optional[bool] = None
|
| 42 |
+
logprobs: Optional[int] = None
|
| 43 |
+
thinkingConfig: Optional[Dict[str, Any]] = None
|
| 44 |
+
# TTS相关字段
|
| 45 |
+
responseModalities: Optional[List[str]] = None
|
| 46 |
+
speechConfig: Optional[Dict[str, Any]] = None
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class SystemInstruction(BaseModel):
|
| 50 |
+
role: Optional[str] = "system"
|
| 51 |
+
parts: Union[List[Dict[str, Any]], Dict[str, Any]]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class GeminiContent(BaseModel):
|
| 55 |
+
role: Optional[str] = None
|
| 56 |
+
parts: List[Dict[str, Any]]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class GeminiRequest(BaseModel):
|
| 60 |
+
contents: List[GeminiContent] = []
|
| 61 |
+
tools: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = []
|
| 62 |
+
safetySettings: Optional[List[SafetySetting]] = Field(
|
| 63 |
+
default=None, alias="safety_settings"
|
| 64 |
+
)
|
| 65 |
+
generationConfig: Optional[GenerationConfig] = Field(
|
| 66 |
+
default=None, alias="generation_config"
|
| 67 |
+
)
|
| 68 |
+
systemInstruction: Optional[SystemInstruction] = Field(
|
| 69 |
+
default=None, alias="system_instruction"
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
class Config:
|
| 73 |
+
populate_by_name = True
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class ResetSelectedKeysRequest(BaseModel):
|
| 77 |
+
keys: List[str]
|
| 78 |
+
key_type: str
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class VerifySelectedKeysRequest(BaseModel):
|
| 82 |
+
keys: List[str]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class GeminiEmbedContent(BaseModel):
|
| 86 |
+
"""嵌入内容模型"""
|
| 87 |
+
|
| 88 |
+
parts: List[Dict[str, str]]
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class GeminiEmbedRequest(BaseModel):
|
| 92 |
+
"""单一嵌入请求模型"""
|
| 93 |
+
|
| 94 |
+
content: GeminiEmbedContent
|
| 95 |
+
taskType: Optional[
|
| 96 |
+
Literal[
|
| 97 |
+
"TASK_TYPE_UNSPECIFIED",
|
| 98 |
+
"RETRIEVAL_QUERY",
|
| 99 |
+
"RETRIEVAL_DOCUMENT",
|
| 100 |
+
"SEMANTIC_SIMILARITY",
|
| 101 |
+
"CLASSIFICATION",
|
| 102 |
+
"CLUSTERING",
|
| 103 |
+
"QUESTION_ANSWERING",
|
| 104 |
+
"FACT_VERIFICATION",
|
| 105 |
+
"CODE_RETRIEVAL_QUERY",
|
| 106 |
+
]
|
| 107 |
+
] = None
|
| 108 |
+
title: Optional[str] = None
|
| 109 |
+
outputDimensionality: Optional[int] = None
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class GeminiBatchEmbedRequest(BaseModel):
|
| 113 |
+
"""批量嵌入请求模型"""
|
| 114 |
+
|
| 115 |
+
requests: List[GeminiEmbedRequest]
|
app/domain/image_models.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ImageMetadata:
|
| 5 |
+
def __init__(self, width: int, height: int, filename: str, size: int, url: str, delete_url: Union[str, None] = None):
|
| 6 |
+
self.width = width
|
| 7 |
+
self.height = height
|
| 8 |
+
self.filename = filename
|
| 9 |
+
self.size = size
|
| 10 |
+
self.url = url
|
| 11 |
+
self.delete_url = delete_url
|
| 12 |
+
class UploadResponse:
|
| 13 |
+
def __init__(self, success: bool, code: str, message: str, data: ImageMetadata):
|
| 14 |
+
self.success = success
|
| 15 |
+
self.code = code
|
| 16 |
+
self.message = message
|
| 17 |
+
self.data = data
|
| 18 |
+
class ImageUploader:
|
| 19 |
+
def upload(self, file: bytes, filename: str) -> UploadResponse:
|
| 20 |
+
raise NotImplementedError
|
app/domain/openai_models.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import Any, Dict, List, Optional, Union
|
| 3 |
+
|
| 4 |
+
from app.core.constants import DEFAULT_MODEL, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ChatRequest(BaseModel):
|
| 8 |
+
messages: List[dict]
|
| 9 |
+
model: str = DEFAULT_MODEL
|
| 10 |
+
temperature: Optional[float] = DEFAULT_TEMPERATURE
|
| 11 |
+
stream: Optional[bool] = False
|
| 12 |
+
max_tokens: Optional[int] = None
|
| 13 |
+
top_p: Optional[float] = DEFAULT_TOP_P
|
| 14 |
+
top_k: Optional[int] = DEFAULT_TOP_K
|
| 15 |
+
n: Optional[int] = 1
|
| 16 |
+
stop: Optional[Union[List[str],str]] = None
|
| 17 |
+
reasoning_effort: Optional[str] = None
|
| 18 |
+
tools: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = []
|
| 19 |
+
tool_choice: Optional[str] = None
|
| 20 |
+
response_format: Optional[dict] = None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class EmbeddingRequest(BaseModel):
|
| 24 |
+
input: Union[str, List[str]]
|
| 25 |
+
model: str = "text-embedding-004"
|
| 26 |
+
encoding_format: Optional[str] = "float"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ImageGenerationRequest(BaseModel):
|
| 30 |
+
model: str = "imagen-3.0-generate-002"
|
| 31 |
+
prompt: str = ""
|
| 32 |
+
n: int = 1
|
| 33 |
+
size: Optional[str] = "1024x1024"
|
| 34 |
+
quality: Optional[str] = None
|
| 35 |
+
style: Optional[str] = None
|
| 36 |
+
response_format: Optional[str] = "url"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class TTSRequest(BaseModel):
|
| 40 |
+
model: str = "gemini-2.5-flash-preview-tts"
|
| 41 |
+
input: str
|
| 42 |
+
voice: str = "Kore"
|
| 43 |
+
response_format: Optional[str] = "wav"
|
app/exception/exceptions.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
异常处理模块,定义应用程序中使用的自定义异常和异常处理器
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from fastapi import FastAPI, Request
|
| 6 |
+
from fastapi.exceptions import RequestValidationError
|
| 7 |
+
from fastapi.responses import JSONResponse
|
| 8 |
+
from starlette.exceptions import HTTPException as StarletteHTTPException
|
| 9 |
+
|
| 10 |
+
from app.log.logger import get_exceptions_logger
|
| 11 |
+
|
| 12 |
+
logger = get_exceptions_logger()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class APIError(Exception):
|
| 16 |
+
"""API错误基类"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, status_code: int, detail: str, error_code: str = None):
|
| 19 |
+
self.status_code = status_code
|
| 20 |
+
self.detail = detail
|
| 21 |
+
self.error_code = error_code or "api_error"
|
| 22 |
+
super().__init__(self.detail)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class AuthenticationError(APIError):
|
| 26 |
+
"""认证错误"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, detail: str = "Authentication failed"):
|
| 29 |
+
super().__init__(
|
| 30 |
+
status_code=401, detail=detail, error_code="authentication_error"
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class AuthorizationError(APIError):
|
| 35 |
+
"""授权错误"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, detail: str = "Not authorized to access this resource"):
|
| 38 |
+
super().__init__(
|
| 39 |
+
status_code=403, detail=detail, error_code="authorization_error"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ResourceNotFoundError(APIError):
|
| 44 |
+
"""资源未找到错误"""
|
| 45 |
+
|
| 46 |
+
def __init__(self, detail: str = "Resource not found"):
|
| 47 |
+
super().__init__(
|
| 48 |
+
status_code=404, detail=detail, error_code="resource_not_found"
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class ModelNotSupportedError(APIError):
|
| 53 |
+
"""模型不支持错误"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, model: str):
|
| 56 |
+
super().__init__(
|
| 57 |
+
status_code=400,
|
| 58 |
+
detail=f"Model {model} is not supported",
|
| 59 |
+
error_code="model_not_supported",
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class APIKeyError(APIError):
|
| 64 |
+
"""API密钥错误"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, detail: str = "Invalid or expired API key"):
|
| 67 |
+
super().__init__(status_code=401, detail=detail, error_code="api_key_error")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class ServiceUnavailableError(APIError):
|
| 71 |
+
"""服务不可用错误"""
|
| 72 |
+
|
| 73 |
+
def __init__(self, detail: str = "Service temporarily unavailable"):
|
| 74 |
+
super().__init__(
|
| 75 |
+
status_code=503, detail=detail, error_code="service_unavailable"
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def setup_exception_handlers(app: FastAPI) -> None:
|
| 80 |
+
"""
|
| 81 |
+
设置应用程序的异常处理器
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
app: FastAPI应用程序实例
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
@app.exception_handler(APIError)
|
| 88 |
+
async def api_error_handler(request: Request, exc: APIError):
|
| 89 |
+
"""处理API错误"""
|
| 90 |
+
logger.error(f"API Error: {exc.detail} (Code: {exc.error_code})")
|
| 91 |
+
return JSONResponse(
|
| 92 |
+
status_code=exc.status_code,
|
| 93 |
+
content={"error": {"code": exc.error_code, "message": exc.detail}},
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
@app.exception_handler(StarletteHTTPException)
|
| 97 |
+
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
|
| 98 |
+
"""处理HTTP异常"""
|
| 99 |
+
logger.error(f"HTTP Exception: {exc.detail} (Status: {exc.status_code})")
|
| 100 |
+
return JSONResponse(
|
| 101 |
+
status_code=exc.status_code,
|
| 102 |
+
content={"error": {"code": "http_error", "message": exc.detail}},
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
@app.exception_handler(RequestValidationError)
|
| 106 |
+
async def validation_exception_handler(
|
| 107 |
+
request: Request, exc: RequestValidationError
|
| 108 |
+
):
|
| 109 |
+
"""处理请求验证错误"""
|
| 110 |
+
error_details = []
|
| 111 |
+
for error in exc.errors():
|
| 112 |
+
error_details.append(
|
| 113 |
+
{"loc": error["loc"], "msg": error["msg"], "type": error["type"]}
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
logger.error(f"Validation Error: {error_details}")
|
| 117 |
+
return JSONResponse(
|
| 118 |
+
status_code=422,
|
| 119 |
+
content={
|
| 120 |
+
"error": {
|
| 121 |
+
"code": "validation_error",
|
| 122 |
+
"message": "Request validation failed",
|
| 123 |
+
"details": error_details,
|
| 124 |
+
}
|
| 125 |
+
},
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
@app.exception_handler(Exception)
|
| 129 |
+
async def general_exception_handler(request: Request, exc: Exception):
|
| 130 |
+
"""处理通用异常"""
|
| 131 |
+
logger.exception(f"Unhandled Exception: {str(exc)}")
|
| 132 |
+
return JSONResponse(
|
| 133 |
+
status_code=500,
|
| 134 |
+
content=str(exc),
|
| 135 |
+
)
|
app/handler/error_handler.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import asynccontextmanager
|
| 2 |
+
from fastapi import HTTPException
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
@asynccontextmanager
|
| 6 |
+
async def handle_route_errors(logger: logging.Logger, operation_name: str, success_message: str = None, failure_message: str = None):
|
| 7 |
+
"""
|
| 8 |
+
一个异步上下文管理器,用于统一处理 FastAPI 路由中的常见错误和日志记录。
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
logger: 用于记录日志的 Logger 实例。
|
| 12 |
+
operation_name: 操作的名称,用于日志记录和错误详情。
|
| 13 |
+
success_message: 操作成功时记录的自定义消息 (可选)。
|
| 14 |
+
failure_message: 操作失败时记录的自定义消息 (可选)。
|
| 15 |
+
"""
|
| 16 |
+
default_success_msg = f"{operation_name} request successful"
|
| 17 |
+
default_failure_msg = f"{operation_name} request failed"
|
| 18 |
+
|
| 19 |
+
logger.info("-" * 50 + operation_name + "-" * 50)
|
| 20 |
+
try:
|
| 21 |
+
yield
|
| 22 |
+
logger.info(success_message or default_success_msg)
|
| 23 |
+
except HTTPException as http_exc:
|
| 24 |
+
# 如果已经是 HTTPException,直接重新抛出,保留原始状态码和详情
|
| 25 |
+
logger.error(f"{failure_message or default_failure_msg}: {http_exc.detail} (Status: {http_exc.status_code})")
|
| 26 |
+
raise http_exc
|
| 27 |
+
except Exception as e:
|
| 28 |
+
# 对于其他所有异常,记录错误并抛出标准的 500 错误
|
| 29 |
+
logger.error(f"{failure_message or default_failure_msg}: {str(e)}")
|
| 30 |
+
raise HTTPException(
|
| 31 |
+
status_code=500, detail=f"Internal server error during {operation_name}"
|
| 32 |
+
) from e
|
app/handler/message_converter.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
from typing import Any, Dict, List, Optional
|
| 6 |
+
|
| 7 |
+
import requests
|
| 8 |
+
|
| 9 |
+
from app.core.constants import (
|
| 10 |
+
AUDIO_FORMAT_TO_MIMETYPE,
|
| 11 |
+
DATA_URL_PATTERN,
|
| 12 |
+
IMAGE_URL_PATTERN,
|
| 13 |
+
MAX_AUDIO_SIZE_BYTES,
|
| 14 |
+
MAX_VIDEO_SIZE_BYTES,
|
| 15 |
+
SUPPORTED_AUDIO_FORMATS,
|
| 16 |
+
SUPPORTED_ROLES,
|
| 17 |
+
SUPPORTED_VIDEO_FORMATS,
|
| 18 |
+
VIDEO_FORMAT_TO_MIMETYPE,
|
| 19 |
+
)
|
| 20 |
+
from app.log.logger import get_message_converter_logger
|
| 21 |
+
|
| 22 |
+
logger = get_message_converter_logger()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class MessageConverter(ABC):
|
| 26 |
+
"""消息转换器基类"""
|
| 27 |
+
|
| 28 |
+
@abstractmethod
|
| 29 |
+
def convert(
|
| 30 |
+
self, messages: List[Dict[str, Any]], model: str
|
| 31 |
+
) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
| 32 |
+
pass
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _get_mime_type_and_data(base64_string):
|
| 36 |
+
"""
|
| 37 |
+
从 base64 字符串中提取 MIME 类型和数据。
|
| 38 |
+
|
| 39 |
+
参数:
|
| 40 |
+
base64_string (str): 可能包含 MIME 类型信息的 base64 字符串
|
| 41 |
+
|
| 42 |
+
返回:
|
| 43 |
+
tuple: (mime_type, encoded_data)
|
| 44 |
+
"""
|
| 45 |
+
# 检查字符串是否以 "data:" 格式开始
|
| 46 |
+
if base64_string.startswith("data:"):
|
| 47 |
+
# 提取 MIME 类型和数据
|
| 48 |
+
pattern = DATA_URL_PATTERN
|
| 49 |
+
match = re.match(pattern, base64_string)
|
| 50 |
+
if match:
|
| 51 |
+
mime_type = (
|
| 52 |
+
"image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
|
| 53 |
+
)
|
| 54 |
+
encoded_data = match.group(2)
|
| 55 |
+
return mime_type, encoded_data
|
| 56 |
+
|
| 57 |
+
# 如果不是预期格式,假定它只是数据部分
|
| 58 |
+
return None, base64_string
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _convert_image(image_url: str) -> Dict[str, Any]:
|
| 62 |
+
if image_url.startswith("data:image"):
|
| 63 |
+
mime_type, encoded_data = _get_mime_type_and_data(image_url)
|
| 64 |
+
return {"inline_data": {"mime_type": mime_type, "data": encoded_data}}
|
| 65 |
+
else:
|
| 66 |
+
encoded_data = _convert_image_to_base64(image_url)
|
| 67 |
+
return {"inline_data": {"mime_type": "image/png", "data": encoded_data}}
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _convert_image_to_base64(url: str) -> str:
|
| 71 |
+
"""
|
| 72 |
+
将图片URL转换为base64编码
|
| 73 |
+
Args:
|
| 74 |
+
url: 图片URL
|
| 75 |
+
Returns:
|
| 76 |
+
str: base64编码的图片数据
|
| 77 |
+
"""
|
| 78 |
+
response = requests.get(url)
|
| 79 |
+
if response.status_code == 200:
|
| 80 |
+
# 将图片内容转换为base64
|
| 81 |
+
img_data = base64.b64encode(response.content).decode("utf-8")
|
| 82 |
+
return img_data
|
| 83 |
+
else:
|
| 84 |
+
raise Exception(f"Failed to fetch image: {response.status_code}")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _process_text_with_image(text: str, model: str) -> List[Dict[str, Any]]:
|
| 88 |
+
"""
|
| 89 |
+
处理可能包含图片URL的文本,提取图片并转换为base64
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
text: 可能包含图片URL的文本
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
List[Dict[str, Any]]: 包含文本和图片的部分列表
|
| 96 |
+
"""
|
| 97 |
+
# 如果模型名中没有包含image,当作普通文本处理
|
| 98 |
+
if "image" not in model:
|
| 99 |
+
return [{"text": text}]
|
| 100 |
+
parts = []
|
| 101 |
+
img_url_match = re.search(IMAGE_URL_PATTERN, text)
|
| 102 |
+
if img_url_match:
|
| 103 |
+
# 提取URL
|
| 104 |
+
img_url = img_url_match.group(2)
|
| 105 |
+
# 先判断是否是base64url如果是,直接用,不过不是,再将URL对应的图片转换为base64
|
| 106 |
+
try:
|
| 107 |
+
base64_url_match = re.search(DATA_URL_PATTERN, img_url)
|
| 108 |
+
if base64_url_match:
|
| 109 |
+
parts.append(
|
| 110 |
+
{
|
| 111 |
+
"inline_data": {
|
| 112 |
+
"mimeType": base64_url_match.group(1),
|
| 113 |
+
"data": base64_url_match.group(2),
|
| 114 |
+
}
|
| 115 |
+
}
|
| 116 |
+
)
|
| 117 |
+
else:
|
| 118 |
+
base64_data = _convert_image_to_base64(img_url)
|
| 119 |
+
parts.append(
|
| 120 |
+
{"inline_data": {"mimeType": "image/png", "data": base64_data}}
|
| 121 |
+
)
|
| 122 |
+
except Exception:
|
| 123 |
+
# 如果转换失败,回退到文本模式
|
| 124 |
+
parts.append({"text": text})
|
| 125 |
+
else:
|
| 126 |
+
# 没有图片URL,作为纯文本处理
|
| 127 |
+
parts.append({"text": text})
|
| 128 |
+
return parts
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class OpenAIMessageConverter(MessageConverter):
|
| 132 |
+
"""OpenAI消息格式转换器"""
|
| 133 |
+
|
| 134 |
+
def _validate_media_data(
|
| 135 |
+
self, format: str, data: str, supported_formats: List[str], max_size: int
|
| 136 |
+
) -> tuple[Optional[str], Optional[str]]:
|
| 137 |
+
"""Validates format and size of Base64 media data."""
|
| 138 |
+
if format.lower() not in supported_formats:
|
| 139 |
+
logger.error(
|
| 140 |
+
f"Unsupported media format: {format}. Supported: {supported_formats}"
|
| 141 |
+
)
|
| 142 |
+
raise ValueError(f"Unsupported media format: {format}")
|
| 143 |
+
|
| 144 |
+
try:
|
| 145 |
+
decoded_data = base64.b64decode(data, validate=True)
|
| 146 |
+
if len(decoded_data) > max_size:
|
| 147 |
+
logger.error(
|
| 148 |
+
f"Media data size ({len(decoded_data)} bytes) exceeds limit ({max_size} bytes)."
|
| 149 |
+
)
|
| 150 |
+
raise ValueError(
|
| 151 |
+
f"Media data size exceeds limit of {max_size // 1024 // 1024}MB"
|
| 152 |
+
)
|
| 153 |
+
return data
|
| 154 |
+
except base64.binascii.Error as e:
|
| 155 |
+
logger.error(f"Invalid Base64 data provided: {e}")
|
| 156 |
+
raise ValueError("Invalid Base64 data")
|
| 157 |
+
except Exception as e:
|
| 158 |
+
logger.error(f"Error validating media data: {e}")
|
| 159 |
+
raise
|
| 160 |
+
|
| 161 |
+
def convert(
|
| 162 |
+
self, messages: List[Dict[str, Any]], model: str
|
| 163 |
+
) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
| 164 |
+
converted_messages = []
|
| 165 |
+
system_instruction_parts = []
|
| 166 |
+
|
| 167 |
+
for idx, msg in enumerate(messages):
|
| 168 |
+
role = msg.get("role", "")
|
| 169 |
+
parts = []
|
| 170 |
+
|
| 171 |
+
if "content" in msg and isinstance(msg["content"], list):
|
| 172 |
+
for content_item in msg["content"]:
|
| 173 |
+
if not isinstance(content_item, dict):
|
| 174 |
+
logger.warning(
|
| 175 |
+
f"Skipping unexpected content item format: {type(content_item)}"
|
| 176 |
+
)
|
| 177 |
+
continue
|
| 178 |
+
|
| 179 |
+
content_type = content_item.get("type")
|
| 180 |
+
|
| 181 |
+
if content_type == "text" and content_item.get("text"):
|
| 182 |
+
parts.append({"text": content_item["text"]})
|
| 183 |
+
elif content_type == "image_url" and content_item.get(
|
| 184 |
+
"image_url", {}
|
| 185 |
+
).get("url"):
|
| 186 |
+
try:
|
| 187 |
+
parts.append(
|
| 188 |
+
_convert_image(content_item["image_url"]["url"])
|
| 189 |
+
)
|
| 190 |
+
except Exception as e:
|
| 191 |
+
logger.error(
|
| 192 |
+
f"Failed to convert image URL {content_item['image_url']['url']}: {e}"
|
| 193 |
+
)
|
| 194 |
+
parts.append(
|
| 195 |
+
{
|
| 196 |
+
"text": f"[Error processing image: {content_item['image_url']['url']}]"
|
| 197 |
+
}
|
| 198 |
+
)
|
| 199 |
+
elif content_type == "input_audio" and content_item.get(
|
| 200 |
+
"input_audio"
|
| 201 |
+
):
|
| 202 |
+
audio_info = content_item["input_audio"]
|
| 203 |
+
audio_data = audio_info.get("data")
|
| 204 |
+
audio_format = audio_info.get("format", "").lower()
|
| 205 |
+
|
| 206 |
+
if not audio_data or not audio_format:
|
| 207 |
+
logger.warning(
|
| 208 |
+
"Skipping audio part due to missing data or format."
|
| 209 |
+
)
|
| 210 |
+
continue
|
| 211 |
+
|
| 212 |
+
try:
|
| 213 |
+
validated_data = self._validate_media_data(
|
| 214 |
+
audio_format,
|
| 215 |
+
audio_data,
|
| 216 |
+
SUPPORTED_AUDIO_FORMATS,
|
| 217 |
+
MAX_AUDIO_SIZE_BYTES,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
# Get MIME type
|
| 221 |
+
mime_type = AUDIO_FORMAT_TO_MIMETYPE.get(audio_format)
|
| 222 |
+
if not mime_type:
|
| 223 |
+
# Should not happen if format validation passed, but double-check
|
| 224 |
+
logger.error(
|
| 225 |
+
f"Could not find MIME type for supported format: {audio_format}"
|
| 226 |
+
)
|
| 227 |
+
raise ValueError(
|
| 228 |
+
f"Internal error: MIME type mapping missing for {audio_format}"
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
parts.append(
|
| 232 |
+
{
|
| 233 |
+
"inline_data": {
|
| 234 |
+
"mimeType": mime_type,
|
| 235 |
+
"data": validated_data, # Use the validated Base64 data
|
| 236 |
+
}
|
| 237 |
+
}
|
| 238 |
+
)
|
| 239 |
+
logger.debug(
|
| 240 |
+
f"Successfully added audio part (format: {audio_format})"
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
except ValueError as e:
|
| 244 |
+
logger.error(
|
| 245 |
+
f"Skipping audio part due to validation error: {e}"
|
| 246 |
+
)
|
| 247 |
+
parts.append({"text": f"[Error processing audio: {e}]"})
|
| 248 |
+
except Exception:
|
| 249 |
+
logger.exception("Unexpected error processing audio part.")
|
| 250 |
+
parts.append(
|
| 251 |
+
{"text": "[Unexpected error processing audio]"}
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
elif content_type == "input_video" and content_item.get(
|
| 255 |
+
"input_video"
|
| 256 |
+
):
|
| 257 |
+
video_info = content_item["input_video"]
|
| 258 |
+
video_data = video_info.get("data")
|
| 259 |
+
video_format = video_info.get("format", "").lower()
|
| 260 |
+
|
| 261 |
+
if not video_data or not video_format:
|
| 262 |
+
logger.warning(
|
| 263 |
+
"Skipping video part due to missing data or format."
|
| 264 |
+
)
|
| 265 |
+
continue
|
| 266 |
+
|
| 267 |
+
try:
|
| 268 |
+
validated_data = self._validate_media_data(
|
| 269 |
+
video_format,
|
| 270 |
+
video_data,
|
| 271 |
+
SUPPORTED_VIDEO_FORMATS,
|
| 272 |
+
MAX_VIDEO_SIZE_BYTES,
|
| 273 |
+
)
|
| 274 |
+
mime_type = VIDEO_FORMAT_TO_MIMETYPE.get(video_format)
|
| 275 |
+
if not mime_type:
|
| 276 |
+
raise ValueError(
|
| 277 |
+
f"Internal error: MIME type mapping missing for {video_format}"
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
parts.append(
|
| 281 |
+
{
|
| 282 |
+
"inline_data": {
|
| 283 |
+
"mimeType": mime_type,
|
| 284 |
+
"data": validated_data,
|
| 285 |
+
}
|
| 286 |
+
}
|
| 287 |
+
)
|
| 288 |
+
logger.debug(
|
| 289 |
+
f"Successfully added video part (format: {video_format})"
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
except ValueError as e:
|
| 293 |
+
logger.error(
|
| 294 |
+
f"Skipping video part due to validation error: {e}"
|
| 295 |
+
)
|
| 296 |
+
parts.append({"text": f"[Error processing video: {e}]"})
|
| 297 |
+
except Exception:
|
| 298 |
+
logger.exception("Unexpected error processing video part.")
|
| 299 |
+
parts.append(
|
| 300 |
+
{"text": "[Unexpected error processing video]"}
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
else:
|
| 304 |
+
# Log unrecognized but present types
|
| 305 |
+
if content_type:
|
| 306 |
+
logger.warning(
|
| 307 |
+
f"Unsupported content type or missing data in structured content: {content_type}"
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
elif (
|
| 311 |
+
"content" in msg and isinstance(msg["content"], str) and msg["content"]
|
| 312 |
+
):
|
| 313 |
+
parts.extend(_process_text_with_image(msg["content"], model))
|
| 314 |
+
elif "tool_calls" in msg and isinstance(msg["tool_calls"], list):
|
| 315 |
+
# Keep existing tool call processing
|
| 316 |
+
for tool_call in msg["tool_calls"]:
|
| 317 |
+
function_call = tool_call.get("function", {})
|
| 318 |
+
# Sanitize arguments loading
|
| 319 |
+
arguments_str = function_call.get("arguments", "{}")
|
| 320 |
+
try:
|
| 321 |
+
function_call["args"] = json.loads(arguments_str)
|
| 322 |
+
except json.JSONDecodeError:
|
| 323 |
+
logger.warning(
|
| 324 |
+
f"Failed to decode tool call arguments: {arguments_str}"
|
| 325 |
+
)
|
| 326 |
+
function_call["args"] = {}
|
| 327 |
+
if "arguments" in function_call:
|
| 328 |
+
if "arguments" in function_call:
|
| 329 |
+
del function_call["arguments"]
|
| 330 |
+
|
| 331 |
+
parts.append({"functionCall": function_call})
|
| 332 |
+
|
| 333 |
+
if role not in SUPPORTED_ROLES:
|
| 334 |
+
if role == "tool":
|
| 335 |
+
role = "user"
|
| 336 |
+
else:
|
| 337 |
+
# 如果是最后一条消息,则认为是用户消息
|
| 338 |
+
if idx == len(messages) - 1:
|
| 339 |
+
role = "user"
|
| 340 |
+
else:
|
| 341 |
+
role = "model"
|
| 342 |
+
if parts:
|
| 343 |
+
if role == "system":
|
| 344 |
+
text_only_parts = [p for p in parts if "text" in p]
|
| 345 |
+
if len(text_only_parts) != len(parts):
|
| 346 |
+
logger.warning(
|
| 347 |
+
"Non-text parts found in system message; discarding them."
|
| 348 |
+
)
|
| 349 |
+
if text_only_parts:
|
| 350 |
+
system_instruction_parts.extend(text_only_parts)
|
| 351 |
+
|
| 352 |
+
else:
|
| 353 |
+
converted_messages.append({"role": role, "parts": parts})
|
| 354 |
+
|
| 355 |
+
system_instruction = (
|
| 356 |
+
None
|
| 357 |
+
if not system_instruction_parts
|
| 358 |
+
else {
|
| 359 |
+
"role": "system",
|
| 360 |
+
"parts": system_instruction_parts,
|
| 361 |
+
}
|
| 362 |
+
)
|
| 363 |
+
return converted_messages, system_instruction
|
app/handler/response_handler.py
ADDED
|
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import json
|
| 3 |
+
import random
|
| 4 |
+
import string
|
| 5 |
+
import time
|
| 6 |
+
import uuid
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
from typing import Any, Dict, List, Optional
|
| 9 |
+
|
| 10 |
+
from app.config.config import settings
|
| 11 |
+
from app.log.logger import get_openai_logger
|
| 12 |
+
from app.utils.helpers import is_image_upload_configured
|
| 13 |
+
from app.utils.uploader import ImageUploaderFactory
|
| 14 |
+
|
| 15 |
+
logger = get_openai_logger()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ResponseHandler(ABC):
|
| 19 |
+
"""响应处理器基类"""
|
| 20 |
+
|
| 21 |
+
@abstractmethod
|
| 22 |
+
def handle_response(
|
| 23 |
+
self, response: Dict[str, Any], model: str, stream: bool = False
|
| 24 |
+
) -> Dict[str, Any]:
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class GeminiResponseHandler(ResponseHandler):
|
| 29 |
+
"""Gemini响应处理器"""
|
| 30 |
+
|
| 31 |
+
def __init__(self):
|
| 32 |
+
self.thinking_first = True
|
| 33 |
+
self.thinking_status = False
|
| 34 |
+
|
| 35 |
+
def handle_response(
|
| 36 |
+
self,
|
| 37 |
+
response: Dict[str, Any],
|
| 38 |
+
model: str,
|
| 39 |
+
stream: bool = False,
|
| 40 |
+
usage_metadata: Optional[Dict[str, Any]] = None,
|
| 41 |
+
) -> Dict[str, Any]:
|
| 42 |
+
if stream:
|
| 43 |
+
return _handle_gemini_stream_response(response, model, stream)
|
| 44 |
+
return _handle_gemini_normal_response(response, model, stream)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _handle_openai_stream_response(
|
| 48 |
+
response: Dict[str, Any],
|
| 49 |
+
model: str,
|
| 50 |
+
finish_reason: str,
|
| 51 |
+
usage_metadata: Optional[Dict[str, Any]],
|
| 52 |
+
) -> Dict[str, Any]:
|
| 53 |
+
choices = []
|
| 54 |
+
candidates = response.get("candidates", [])
|
| 55 |
+
|
| 56 |
+
for candidate in candidates:
|
| 57 |
+
index = candidate.get("index", 0)
|
| 58 |
+
text, reasoning_content, tool_calls, _ = _extract_result(
|
| 59 |
+
{"candidates": [candidate]}, model, stream=True, gemini_format=False
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
if not text and not tool_calls and not reasoning_content:
|
| 63 |
+
delta = {}
|
| 64 |
+
else:
|
| 65 |
+
delta = {
|
| 66 |
+
"content": text,
|
| 67 |
+
"reasoning_content": reasoning_content,
|
| 68 |
+
"role": "assistant",
|
| 69 |
+
}
|
| 70 |
+
if tool_calls:
|
| 71 |
+
delta["tool_calls"] = tool_calls
|
| 72 |
+
|
| 73 |
+
choice = {"index": index, "delta": delta, "finish_reason": finish_reason}
|
| 74 |
+
choices.append(choice)
|
| 75 |
+
|
| 76 |
+
template_chunk = {
|
| 77 |
+
"id": f"chatcmpl-{uuid.uuid4()}",
|
| 78 |
+
"object": "chat.completion.chunk",
|
| 79 |
+
"created": int(time.time()),
|
| 80 |
+
"model": model,
|
| 81 |
+
"choices": choices,
|
| 82 |
+
}
|
| 83 |
+
if usage_metadata:
|
| 84 |
+
template_chunk["usage"] = {
|
| 85 |
+
"prompt_tokens": usage_metadata.get("promptTokenCount", 0),
|
| 86 |
+
"completion_tokens": usage_metadata.get("candidatesTokenCount", 0),
|
| 87 |
+
"total_tokens": usage_metadata.get("totalTokenCount", 0),
|
| 88 |
+
}
|
| 89 |
+
return template_chunk
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _handle_openai_normal_response(
|
| 93 |
+
response: Dict[str, Any],
|
| 94 |
+
model: str,
|
| 95 |
+
finish_reason: str,
|
| 96 |
+
usage_metadata: Optional[Dict[str, Any]],
|
| 97 |
+
) -> Dict[str, Any]:
|
| 98 |
+
choices = []
|
| 99 |
+
candidates = response.get("candidates", [])
|
| 100 |
+
|
| 101 |
+
for i, candidate in enumerate(candidates):
|
| 102 |
+
text, reasoning_content, tool_calls, _ = _extract_result(
|
| 103 |
+
{"candidates": [candidate]}, model, stream=False, gemini_format=False
|
| 104 |
+
)
|
| 105 |
+
choice = {
|
| 106 |
+
"index": i,
|
| 107 |
+
"message": {
|
| 108 |
+
"role": "assistant",
|
| 109 |
+
"content": text,
|
| 110 |
+
"reasoning_content": reasoning_content,
|
| 111 |
+
"tool_calls": tool_calls,
|
| 112 |
+
},
|
| 113 |
+
"finish_reason": finish_reason,
|
| 114 |
+
}
|
| 115 |
+
choices.append(choice)
|
| 116 |
+
|
| 117 |
+
return {
|
| 118 |
+
"id": f"chatcmpl-{uuid.uuid4()}",
|
| 119 |
+
"object": "chat.completion",
|
| 120 |
+
"created": int(time.time()),
|
| 121 |
+
"model": model,
|
| 122 |
+
"choices": choices,
|
| 123 |
+
"usage": {
|
| 124 |
+
"prompt_tokens": usage_metadata.get("promptTokenCount", 0),
|
| 125 |
+
"completion_tokens": usage_metadata.get("candidatesTokenCount", 0),
|
| 126 |
+
"total_tokens": usage_metadata.get("totalTokenCount", 0),
|
| 127 |
+
},
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class OpenAIResponseHandler(ResponseHandler):
|
| 132 |
+
"""OpenAI响应处理器"""
|
| 133 |
+
|
| 134 |
+
def __init__(self, config):
|
| 135 |
+
self.config = config
|
| 136 |
+
self.thinking_first = True
|
| 137 |
+
self.thinking_status = False
|
| 138 |
+
|
| 139 |
+
def handle_response(
|
| 140 |
+
self,
|
| 141 |
+
response: Dict[str, Any],
|
| 142 |
+
model: str,
|
| 143 |
+
stream: bool = False,
|
| 144 |
+
finish_reason: str = None,
|
| 145 |
+
usage_metadata: Optional[Dict[str, Any]] = None,
|
| 146 |
+
) -> Optional[Dict[str, Any]]:
|
| 147 |
+
if stream:
|
| 148 |
+
return _handle_openai_stream_response(
|
| 149 |
+
response, model, finish_reason, usage_metadata
|
| 150 |
+
)
|
| 151 |
+
return _handle_openai_normal_response(
|
| 152 |
+
response, model, finish_reason, usage_metadata
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
def handle_image_chat_response(
|
| 156 |
+
self, image_str: str, model: str, stream=False, finish_reason="stop"
|
| 157 |
+
):
|
| 158 |
+
if stream:
|
| 159 |
+
return _handle_openai_stream_image_response(image_str, model, finish_reason)
|
| 160 |
+
return _handle_openai_normal_image_response(image_str, model, finish_reason)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def _handle_openai_stream_image_response(
|
| 164 |
+
image_str: str, model: str, finish_reason: str
|
| 165 |
+
) -> Dict[str, Any]:
|
| 166 |
+
return {
|
| 167 |
+
"id": f"chatcmpl-{uuid.uuid4()}",
|
| 168 |
+
"object": "chat.completion.chunk",
|
| 169 |
+
"created": int(time.time()),
|
| 170 |
+
"model": model,
|
| 171 |
+
"choices": [
|
| 172 |
+
{
|
| 173 |
+
"index": 0,
|
| 174 |
+
"delta": {"content": image_str} if image_str else {},
|
| 175 |
+
"finish_reason": finish_reason,
|
| 176 |
+
}
|
| 177 |
+
],
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def _handle_openai_normal_image_response(
|
| 182 |
+
image_str: str, model: str, finish_reason: str
|
| 183 |
+
) -> Dict[str, Any]:
|
| 184 |
+
return {
|
| 185 |
+
"id": f"chatcmpl-{uuid.uuid4()}",
|
| 186 |
+
"object": "chat.completion",
|
| 187 |
+
"created": int(time.time()),
|
| 188 |
+
"model": model,
|
| 189 |
+
"choices": [
|
| 190 |
+
{
|
| 191 |
+
"index": 0,
|
| 192 |
+
"message": {"role": "assistant", "content": image_str},
|
| 193 |
+
"finish_reason": finish_reason,
|
| 194 |
+
}
|
| 195 |
+
],
|
| 196 |
+
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def _extract_result(
|
| 201 |
+
response: Dict[str, Any],
|
| 202 |
+
model: str,
|
| 203 |
+
stream: bool = False,
|
| 204 |
+
gemini_format: bool = False,
|
| 205 |
+
) -> tuple[str, Optional[str], List[Dict[str, Any]], Optional[bool]]:
|
| 206 |
+
text, reasoning_content, tool_calls, thought = "", "", [], None
|
| 207 |
+
|
| 208 |
+
if stream:
|
| 209 |
+
if response.get("candidates"):
|
| 210 |
+
candidate = response["candidates"][0]
|
| 211 |
+
content = candidate.get("content", {})
|
| 212 |
+
parts = content.get("parts", [])
|
| 213 |
+
if not parts:
|
| 214 |
+
logger.warning("No parts found in stream response")
|
| 215 |
+
return "", None, [], None
|
| 216 |
+
|
| 217 |
+
if "text" in parts[0]:
|
| 218 |
+
text = parts[0].get("text")
|
| 219 |
+
if "thought" in parts[0]:
|
| 220 |
+
if not gemini_format and settings.SHOW_THINKING_PROCESS:
|
| 221 |
+
reasoning_content = text
|
| 222 |
+
text = ""
|
| 223 |
+
thought = parts[0].get("thought")
|
| 224 |
+
elif "executableCode" in parts[0]:
|
| 225 |
+
text = _format_code_block(parts[0]["executableCode"])
|
| 226 |
+
elif "codeExecution" in parts[0]:
|
| 227 |
+
text = _format_code_block(parts[0]["codeExecution"])
|
| 228 |
+
elif "executableCodeResult" in parts[0]:
|
| 229 |
+
text = _format_execution_result(parts[0]["executableCodeResult"])
|
| 230 |
+
elif "codeExecutionResult" in parts[0]:
|
| 231 |
+
text = _format_execution_result(parts[0]["codeExecutionResult"])
|
| 232 |
+
elif "inlineData" in parts[0]:
|
| 233 |
+
text = _extract_image_data(parts[0])
|
| 234 |
+
else:
|
| 235 |
+
text = ""
|
| 236 |
+
text = _add_search_link_text(model, candidate, text)
|
| 237 |
+
tool_calls = _extract_tool_calls(parts, gemini_format)
|
| 238 |
+
else:
|
| 239 |
+
if response.get("candidates"):
|
| 240 |
+
candidate = response["candidates"][0]
|
| 241 |
+
text, reasoning_content = "", ""
|
| 242 |
+
|
| 243 |
+
# 使用安全的访问方式
|
| 244 |
+
content = candidate.get("content", {})
|
| 245 |
+
|
| 246 |
+
if content and isinstance(content, dict):
|
| 247 |
+
parts = content.get("parts", [])
|
| 248 |
+
|
| 249 |
+
if parts:
|
| 250 |
+
for part in parts:
|
| 251 |
+
if "text" in part:
|
| 252 |
+
if "thought" in part and settings.SHOW_THINKING_PROCESS:
|
| 253 |
+
reasoning_content += part["text"]
|
| 254 |
+
else:
|
| 255 |
+
text += part["text"]
|
| 256 |
+
if "thought" in part and thought is None:
|
| 257 |
+
thought = part.get("thought")
|
| 258 |
+
elif "inlineData" in part:
|
| 259 |
+
text += _extract_image_data(part)
|
| 260 |
+
else:
|
| 261 |
+
logger.warning(f"No parts found in content for model: {model}")
|
| 262 |
+
else:
|
| 263 |
+
logger.error(f"Invalid content structure for model: {model}")
|
| 264 |
+
|
| 265 |
+
text = _add_search_link_text(model, candidate, text)
|
| 266 |
+
|
| 267 |
+
# 安全地获取 parts 用于工具调用提取
|
| 268 |
+
parts = candidate.get("content", {}).get("parts", [])
|
| 269 |
+
tool_calls = _extract_tool_calls(parts, gemini_format)
|
| 270 |
+
else:
|
| 271 |
+
logger.warning(f"No candidates found in response for model: {model}")
|
| 272 |
+
text = "暂无返回"
|
| 273 |
+
|
| 274 |
+
return text, reasoning_content, tool_calls, thought
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def _has_inline_image_part(response: Dict[str, Any]) -> bool:
|
| 278 |
+
try:
|
| 279 |
+
for c in response.get("candidates", []):
|
| 280 |
+
for p in c.get("content", {}).get("parts", []):
|
| 281 |
+
if isinstance(p, dict) and ("inlineData" in p):
|
| 282 |
+
return True
|
| 283 |
+
except Exception:
|
| 284 |
+
return False
|
| 285 |
+
return False
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def _extract_image_data(part: dict) -> str:
|
| 289 |
+
image_uploader = None
|
| 290 |
+
if settings.UPLOAD_PROVIDER == "smms":
|
| 291 |
+
image_uploader = ImageUploaderFactory.create(
|
| 292 |
+
provider=settings.UPLOAD_PROVIDER, api_key=settings.SMMS_SECRET_TOKEN
|
| 293 |
+
)
|
| 294 |
+
elif settings.UPLOAD_PROVIDER == "picgo":
|
| 295 |
+
image_uploader = ImageUploaderFactory.create(
|
| 296 |
+
provider=settings.UPLOAD_PROVIDER,
|
| 297 |
+
api_key=settings.PICGO_API_KEY,
|
| 298 |
+
api_url=settings.PICGO_API_URL
|
| 299 |
+
)
|
| 300 |
+
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
|
| 301 |
+
image_uploader = ImageUploaderFactory.create(
|
| 302 |
+
provider=settings.UPLOAD_PROVIDER,
|
| 303 |
+
base_url=settings.CLOUDFLARE_IMGBED_URL,
|
| 304 |
+
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
|
| 305 |
+
upload_folder=settings.CLOUDFLARE_IMGBED_UPLOAD_FOLDER,
|
| 306 |
+
)
|
| 307 |
+
elif settings.UPLOAD_PROVIDER == "aliyun_oss":
|
| 308 |
+
image_uploader = ImageUploaderFactory.create(
|
| 309 |
+
provider=settings.UPLOAD_PROVIDER,
|
| 310 |
+
access_key=settings.OSS_ACCESS_KEY,
|
| 311 |
+
access_key_secret=settings.OSS_ACCESS_KEY_SECRET,
|
| 312 |
+
bucket_name=settings.OSS_BUCKET_NAME,
|
| 313 |
+
endpoint=settings.OSS_ENDPOINT,
|
| 314 |
+
region=settings.OSS_REGION,
|
| 315 |
+
use_internal=False
|
| 316 |
+
)
|
| 317 |
+
current_date = time.strftime("%Y/%m/%d")
|
| 318 |
+
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
|
| 319 |
+
base64_data = part["inlineData"]["data"]
|
| 320 |
+
mime_type = part["inlineData"]["mimeType"]
|
| 321 |
+
# 将base64_data转成bytes数组
|
| 322 |
+
# Return empty string if no uploader is configured
|
| 323 |
+
if not is_image_upload_configured(settings):
|
| 324 |
+
return f"\n\n\n\n"
|
| 325 |
+
bytes_data = base64.b64decode(base64_data)
|
| 326 |
+
upload_response = image_uploader.upload(bytes_data, filename)
|
| 327 |
+
if upload_response.success:
|
| 328 |
+
text = f"\n\n\n\n"
|
| 329 |
+
else:
|
| 330 |
+
text = f"\n\n\n\n"
|
| 331 |
+
return text
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def _extract_tool_calls(
|
| 335 |
+
parts: List[Dict[str, Any]], gemini_format: bool
|
| 336 |
+
) -> List[Dict[str, Any]]:
|
| 337 |
+
"""提取工具调用信息"""
|
| 338 |
+
if not parts or not isinstance(parts, list):
|
| 339 |
+
return []
|
| 340 |
+
|
| 341 |
+
letters = string.ascii_lowercase + string.digits
|
| 342 |
+
tool_calls = list()
|
| 343 |
+
|
| 344 |
+
for i in range(len(parts)):
|
| 345 |
+
part = parts[i]
|
| 346 |
+
if not part or not isinstance(part, dict):
|
| 347 |
+
continue
|
| 348 |
+
|
| 349 |
+
item = part.get("functionCall", {})
|
| 350 |
+
if not item or not isinstance(item, dict):
|
| 351 |
+
continue
|
| 352 |
+
|
| 353 |
+
if gemini_format:
|
| 354 |
+
tool_calls.append(part)
|
| 355 |
+
else:
|
| 356 |
+
id = f"call_{''.join(random.sample(letters, 32))}"
|
| 357 |
+
name = item.get("name", "")
|
| 358 |
+
arguments = json.dumps(item.get("args", None) or {})
|
| 359 |
+
|
| 360 |
+
tool_calls.append(
|
| 361 |
+
{
|
| 362 |
+
"index": i,
|
| 363 |
+
"id": id,
|
| 364 |
+
"type": "function",
|
| 365 |
+
"function": {"name": name, "arguments": arguments},
|
| 366 |
+
}
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
return tool_calls
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def _handle_gemini_stream_response(
|
| 373 |
+
response: Dict[str, Any], model: str, stream: bool
|
| 374 |
+
) -> Dict[str, Any]:
|
| 375 |
+
# Early return raw Gemini response if no uploader configured and contains inline images
|
| 376 |
+
if not is_image_upload_configured(settings) and _has_inline_image_part(response):
|
| 377 |
+
return response
|
| 378 |
+
|
| 379 |
+
text, reasoning_content, tool_calls, thought = _extract_result(
|
| 380 |
+
response, model, stream=stream, gemini_format=True
|
| 381 |
+
)
|
| 382 |
+
if tool_calls:
|
| 383 |
+
content = {"parts": tool_calls, "role": "model"}
|
| 384 |
+
else:
|
| 385 |
+
part = {"text": text}
|
| 386 |
+
if thought is not None:
|
| 387 |
+
part["thought"] = thought
|
| 388 |
+
content = {"parts": [part], "role": "model"}
|
| 389 |
+
response["candidates"][0]["content"] = content
|
| 390 |
+
return response
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def _handle_gemini_normal_response(
|
| 394 |
+
response: Dict[str, Any], model: str, stream: bool
|
| 395 |
+
) -> Dict[str, Any]:
|
| 396 |
+
# Early return raw Gemini response if no uploader configured and contains inline images
|
| 397 |
+
if not is_image_upload_configured(settings) and _has_inline_image_part(response):
|
| 398 |
+
return response
|
| 399 |
+
|
| 400 |
+
text, reasoning_content, tool_calls, thought = _extract_result(
|
| 401 |
+
response, model, stream=stream, gemini_format=True
|
| 402 |
+
)
|
| 403 |
+
parts = []
|
| 404 |
+
if tool_calls:
|
| 405 |
+
parts = tool_calls
|
| 406 |
+
else:
|
| 407 |
+
if thought is not None:
|
| 408 |
+
parts.append({"text": reasoning_content, "thought": thought})
|
| 409 |
+
part = {"text": text}
|
| 410 |
+
parts.append(part)
|
| 411 |
+
content = {"parts": parts, "role": "model"}
|
| 412 |
+
response["candidates"][0]["content"] = content
|
| 413 |
+
return response
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def _format_code_block(code_data: dict) -> str:
|
| 417 |
+
"""格式化代码块输出"""
|
| 418 |
+
language = code_data.get("language", "").lower()
|
| 419 |
+
code = code_data.get("code", "").strip()
|
| 420 |
+
return f"""\n\n---\n\n【代码执行】\n```{language}\n{code}\n```\n"""
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def _add_search_link_text(model: str, candidate: dict, text: str) -> str:
|
| 424 |
+
if (
|
| 425 |
+
settings.SHOW_SEARCH_LINK
|
| 426 |
+
and model.endswith("-search")
|
| 427 |
+
and "groundingMetadata" in candidate
|
| 428 |
+
and "groundingChunks" in candidate["groundingMetadata"]
|
| 429 |
+
):
|
| 430 |
+
grounding_chunks = candidate["groundingMetadata"]["groundingChunks"]
|
| 431 |
+
text += "\n\n---\n\n"
|
| 432 |
+
text += "**【引用来源】**\n\n"
|
| 433 |
+
for _, grounding_chunk in enumerate(grounding_chunks, 1):
|
| 434 |
+
if "web" in grounding_chunk:
|
| 435 |
+
text += _create_search_link(grounding_chunk["web"])
|
| 436 |
+
return text
|
| 437 |
+
else:
|
| 438 |
+
return text
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def _create_search_link(grounding_chunk: dict) -> str:
|
| 442 |
+
return f'\n- [{grounding_chunk["title"]}]({grounding_chunk["uri"]})'
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def _format_execution_result(result_data: dict) -> str:
|
| 446 |
+
"""格式化执行结果输出"""
|
| 447 |
+
outcome = result_data.get("outcome", "")
|
| 448 |
+
output = result_data.get("output", "").strip()
|
| 449 |
+
return f"""\n【执行结果】\n> outcome: {outcome}\n\n【输出结果】\n```plaintext\n{output}\n```\n\n---\n\n"""
|
app/handler/retry_handler.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from functools import wraps
|
| 3 |
+
from typing import Callable, TypeVar
|
| 4 |
+
|
| 5 |
+
from app.config.config import settings
|
| 6 |
+
from app.log.logger import get_retry_logger
|
| 7 |
+
from app.utils.helpers import redact_key_for_logging
|
| 8 |
+
|
| 9 |
+
T = TypeVar("T")
|
| 10 |
+
logger = get_retry_logger()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class RetryHandler:
|
| 14 |
+
"""重试处理装饰器"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, key_arg: str = "api_key"):
|
| 17 |
+
self.key_arg = key_arg
|
| 18 |
+
|
| 19 |
+
def __call__(self, func: Callable[..., T]) -> Callable[..., T]:
|
| 20 |
+
@wraps(func)
|
| 21 |
+
async def wrapper(*args, **kwargs) -> T:
|
| 22 |
+
last_exception = None
|
| 23 |
+
|
| 24 |
+
for attempt in range(settings.MAX_RETRIES):
|
| 25 |
+
retries = attempt + 1
|
| 26 |
+
try:
|
| 27 |
+
return await func(*args, **kwargs)
|
| 28 |
+
except Exception as e:
|
| 29 |
+
last_exception = e
|
| 30 |
+
logger.warning(
|
| 31 |
+
f"API call failed with error: {str(e)}. Attempt {retries} of {settings.MAX_RETRIES}"
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# 从函数参数中获取 key_manager
|
| 35 |
+
key_manager = kwargs.get("key_manager")
|
| 36 |
+
if key_manager:
|
| 37 |
+
old_key = kwargs.get(self.key_arg)
|
| 38 |
+
new_key = await key_manager.handle_api_failure(old_key, retries)
|
| 39 |
+
if new_key:
|
| 40 |
+
kwargs[self.key_arg] = new_key
|
| 41 |
+
logger.info(f"Switched to new API key: {redact_key_for_logging(new_key)}")
|
| 42 |
+
else:
|
| 43 |
+
logger.error(f"No valid API key available after {retries} retries.")
|
| 44 |
+
break
|
| 45 |
+
|
| 46 |
+
logger.error(
|
| 47 |
+
f"All retry attempts failed, raising final exception: {str(last_exception)}"
|
| 48 |
+
)
|
| 49 |
+
raise last_exception
|
| 50 |
+
|
| 51 |
+
return wrapper
|
app/handler/stream_optimizer.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import asyncio
|
| 3 |
+
import math
|
| 4 |
+
from typing import Any, AsyncGenerator, Callable, List
|
| 5 |
+
|
| 6 |
+
from app.config.config import settings
|
| 7 |
+
from app.core.constants import (
|
| 8 |
+
DEFAULT_STREAM_CHUNK_SIZE,
|
| 9 |
+
DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
|
| 10 |
+
DEFAULT_STREAM_MAX_DELAY,
|
| 11 |
+
DEFAULT_STREAM_MIN_DELAY,
|
| 12 |
+
DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
|
| 13 |
+
)
|
| 14 |
+
from app.log.logger import get_gemini_logger, get_openai_logger
|
| 15 |
+
|
| 16 |
+
logger_openai = get_openai_logger()
|
| 17 |
+
logger_gemini = get_gemini_logger()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class StreamOptimizer:
|
| 21 |
+
"""流式输出优化器
|
| 22 |
+
|
| 23 |
+
提供流式输出优化功能,包括智能延迟调整和长文本分块输出。
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
logger=None,
|
| 29 |
+
min_delay: float = DEFAULT_STREAM_MIN_DELAY,
|
| 30 |
+
max_delay: float = DEFAULT_STREAM_MAX_DELAY,
|
| 31 |
+
short_text_threshold: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
|
| 32 |
+
long_text_threshold: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
|
| 33 |
+
chunk_size: int = DEFAULT_STREAM_CHUNK_SIZE,
|
| 34 |
+
):
|
| 35 |
+
"""初始化流式输出优化器
|
| 36 |
+
|
| 37 |
+
参数:
|
| 38 |
+
logger: 日志记录器
|
| 39 |
+
min_delay: 最小延迟时间(秒)
|
| 40 |
+
max_delay: 最大延迟时间(秒)
|
| 41 |
+
short_text_threshold: 短文本阈值(字符数)
|
| 42 |
+
long_text_threshold: 长文本阈值(字符数)
|
| 43 |
+
chunk_size: 长文本分块大小(字符数)
|
| 44 |
+
"""
|
| 45 |
+
self.logger = logger
|
| 46 |
+
self.min_delay = min_delay
|
| 47 |
+
self.max_delay = max_delay
|
| 48 |
+
self.short_text_threshold = short_text_threshold
|
| 49 |
+
self.long_text_threshold = long_text_threshold
|
| 50 |
+
self.chunk_size = chunk_size
|
| 51 |
+
|
| 52 |
+
def calculate_delay(self, text_length: int) -> float:
|
| 53 |
+
"""根据文本长度计算延迟时间
|
| 54 |
+
|
| 55 |
+
参数:
|
| 56 |
+
text_length: 文本长度
|
| 57 |
+
|
| 58 |
+
返回:
|
| 59 |
+
延迟时间(秒)
|
| 60 |
+
"""
|
| 61 |
+
if text_length <= self.short_text_threshold:
|
| 62 |
+
# 短文本使用较大延迟
|
| 63 |
+
return self.max_delay
|
| 64 |
+
elif text_length >= self.long_text_threshold:
|
| 65 |
+
# 长文本使用较小延迟
|
| 66 |
+
return self.min_delay
|
| 67 |
+
else:
|
| 68 |
+
# 中等长度文本使用线性插值计算延迟
|
| 69 |
+
# 使用对数函数使延迟变化更平滑
|
| 70 |
+
ratio = math.log(text_length / self.short_text_threshold) / math.log(
|
| 71 |
+
self.long_text_threshold / self.short_text_threshold
|
| 72 |
+
)
|
| 73 |
+
return self.max_delay - ratio * (self.max_delay - self.min_delay)
|
| 74 |
+
|
| 75 |
+
def split_text_into_chunks(self, text: str) -> List[str]:
|
| 76 |
+
"""将文本分割成小块
|
| 77 |
+
|
| 78 |
+
参数:
|
| 79 |
+
text: 要分割的文本
|
| 80 |
+
|
| 81 |
+
返回:
|
| 82 |
+
文本块列表
|
| 83 |
+
"""
|
| 84 |
+
return [
|
| 85 |
+
text[i : i + self.chunk_size] for i in range(0, len(text), self.chunk_size)
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
async def optimize_stream_output(
|
| 89 |
+
self,
|
| 90 |
+
text: str,
|
| 91 |
+
create_response_chunk: Callable[[str], Any],
|
| 92 |
+
format_chunk: Callable[[Any], str],
|
| 93 |
+
) -> AsyncGenerator[str, None]:
|
| 94 |
+
"""优化流式输出
|
| 95 |
+
|
| 96 |
+
参数:
|
| 97 |
+
text: 要输出的文本
|
| 98 |
+
create_response_chunk: 创建响应块的函数,接收文本,返回响应块
|
| 99 |
+
format_chunk: 格式化响应块的函数,接收响应块,返回格式化后的字符串
|
| 100 |
+
|
| 101 |
+
返回:
|
| 102 |
+
异步生成器,生成格式化后的响应块
|
| 103 |
+
"""
|
| 104 |
+
if not text:
|
| 105 |
+
return
|
| 106 |
+
|
| 107 |
+
# 计算智能延迟时间
|
| 108 |
+
delay = self.calculate_delay(len(text))
|
| 109 |
+
|
| 110 |
+
# 根据文本长度决定输出方式
|
| 111 |
+
if len(text) >= self.long_text_threshold:
|
| 112 |
+
# 长文本:分块输出
|
| 113 |
+
chunks = self.split_text_into_chunks(text)
|
| 114 |
+
for chunk_text in chunks:
|
| 115 |
+
chunk_response = create_response_chunk(chunk_text)
|
| 116 |
+
yield format_chunk(chunk_response)
|
| 117 |
+
await asyncio.sleep(delay)
|
| 118 |
+
else:
|
| 119 |
+
# 短文本:逐字符输出
|
| 120 |
+
for char in text:
|
| 121 |
+
char_chunk = create_response_chunk(char)
|
| 122 |
+
yield format_chunk(char_chunk)
|
| 123 |
+
await asyncio.sleep(delay)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# 创建默认的优化器实例,可以直接导入使用
|
| 127 |
+
openai_optimizer = StreamOptimizer(
|
| 128 |
+
logger=logger_openai,
|
| 129 |
+
min_delay=settings.STREAM_MIN_DELAY,
|
| 130 |
+
max_delay=settings.STREAM_MAX_DELAY,
|
| 131 |
+
short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
|
| 132 |
+
long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
|
| 133 |
+
chunk_size=settings.STREAM_CHUNK_SIZE,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
gemini_optimizer = StreamOptimizer(
|
| 137 |
+
logger=logger_gemini,
|
| 138 |
+
min_delay=settings.STREAM_MIN_DELAY,
|
| 139 |
+
max_delay=settings.STREAM_MAX_DELAY,
|
| 140 |
+
short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
|
| 141 |
+
long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
|
| 142 |
+
chunk_size=settings.STREAM_CHUNK_SIZE,
|
| 143 |
+
)
|
app/log/logger.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import platform
|
| 3 |
+
import re
|
| 4 |
+
import sys
|
| 5 |
+
from typing import Dict, Optional
|
| 6 |
+
|
| 7 |
+
# ANSI转义序列颜色代码
|
| 8 |
+
COLORS = {
|
| 9 |
+
"DEBUG": "\033[34m", # 蓝色
|
| 10 |
+
"INFO": "\033[32m", # 绿色
|
| 11 |
+
"WARNING": "\033[33m", # 黄色
|
| 12 |
+
"ERROR": "\033[31m", # 红色
|
| 13 |
+
"CRITICAL": "\033[1;31m", # 红色加粗
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Windows系统启用ANSI支持
|
| 18 |
+
if platform.system() == "Windows":
|
| 19 |
+
import ctypes
|
| 20 |
+
|
| 21 |
+
kernel32 = ctypes.windll.kernel32
|
| 22 |
+
kernel32.SetConsoleMode(kernel32.GetStdHandle(-11), 7)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ColoredFormatter(logging.Formatter):
|
| 26 |
+
"""
|
| 27 |
+
自定义的日志格式化器,添加颜色支持
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def format(self, record):
|
| 31 |
+
# 获取对应级别的颜色代码
|
| 32 |
+
color = COLORS.get(record.levelname, "")
|
| 33 |
+
# 添加颜色代码和重置代码
|
| 34 |
+
record.levelname = f"{color}{record.levelname}\033[0m"
|
| 35 |
+
# 创建包含文件名和行号的固定宽度字符串
|
| 36 |
+
record.fileloc = f"[{record.filename}:{record.lineno}]"
|
| 37 |
+
return super().format(record)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class AccessLogFormatter(logging.Formatter):
|
| 41 |
+
"""
|
| 42 |
+
Custom access log formatter that redacts API keys in URLs
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
# API key patterns to match in URLs
|
| 46 |
+
API_KEY_PATTERNS = [
|
| 47 |
+
r"\bAIza[0-9A-Za-z_-]{35}", # Google API keys (like Gemini)
|
| 48 |
+
r"\bsk-[0-9A-Za-z_-]{20,}", # OpenAI and general sk- prefixed keys
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
def __init__(self, *args, **kwargs):
|
| 52 |
+
super().__init__(*args, **kwargs)
|
| 53 |
+
# Compile regex patterns for better performance
|
| 54 |
+
self.compiled_patterns = [
|
| 55 |
+
re.compile(pattern) for pattern in self.API_KEY_PATTERNS
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
def format(self, record):
|
| 59 |
+
# Format the record normally first
|
| 60 |
+
formatted_msg = super().format(record)
|
| 61 |
+
|
| 62 |
+
# Redact API keys in the formatted message
|
| 63 |
+
return self._redact_api_keys_in_message(formatted_msg)
|
| 64 |
+
|
| 65 |
+
def _redact_api_keys_in_message(self, message: str) -> str:
|
| 66 |
+
"""
|
| 67 |
+
Replace API keys in log message with redacted versions
|
| 68 |
+
"""
|
| 69 |
+
try:
|
| 70 |
+
for pattern in self.compiled_patterns:
|
| 71 |
+
|
| 72 |
+
def replace_key(match):
|
| 73 |
+
key = match.group(0)
|
| 74 |
+
return redact_key_for_logging(key)
|
| 75 |
+
|
| 76 |
+
message = pattern.sub(replace_key, message)
|
| 77 |
+
|
| 78 |
+
return message
|
| 79 |
+
except Exception as e:
|
| 80 |
+
# Log the error but don't expose the original message in case it contains keys
|
| 81 |
+
import logging
|
| 82 |
+
|
| 83 |
+
logger = logging.getLogger(__name__)
|
| 84 |
+
logger.error(f"Error redacting API keys in access log: {e}")
|
| 85 |
+
return "[LOG_REDACTION_ERROR]"
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def redact_key_for_logging(key: str) -> str:
|
| 89 |
+
"""
|
| 90 |
+
Redacts API key for secure logging by showing only first and last 6 characters.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
key: API key to redact
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
str: Redacted key in format "first6...last6" or descriptive placeholder for edge cases
|
| 97 |
+
"""
|
| 98 |
+
if not key:
|
| 99 |
+
return key
|
| 100 |
+
|
| 101 |
+
if len(key) <= 12:
|
| 102 |
+
return f"{key[:3]}...{key[-3:]}"
|
| 103 |
+
else:
|
| 104 |
+
return f"{key[:6]}...{key[-6:]}"
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# 日志格式 - 使用 fileloc 并设置固定宽度 (例如 30)
|
| 108 |
+
FORMATTER = ColoredFormatter(
|
| 109 |
+
"%(asctime)s | %(levelname)-17s | %(fileloc)-30s | %(message)s"
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# 日志级别映射
|
| 113 |
+
LOG_LEVELS = {
|
| 114 |
+
"debug": logging.DEBUG,
|
| 115 |
+
"info": logging.INFO,
|
| 116 |
+
"warning": logging.WARNING,
|
| 117 |
+
"error": logging.ERROR,
|
| 118 |
+
"critical": logging.CRITICAL,
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class Logger:
|
| 123 |
+
def __init__(self):
|
| 124 |
+
pass
|
| 125 |
+
|
| 126 |
+
_loggers: Dict[str, logging.Logger] = {}
|
| 127 |
+
|
| 128 |
+
@staticmethod
|
| 129 |
+
def setup_logger(name: str) -> logging.Logger:
|
| 130 |
+
"""
|
| 131 |
+
设置并获取logger
|
| 132 |
+
:param name: logger名称
|
| 133 |
+
:return: logger实例
|
| 134 |
+
"""
|
| 135 |
+
# 导入 settings 对象
|
| 136 |
+
from app.config.config import settings
|
| 137 |
+
|
| 138 |
+
# 从全局配置获取日志级别
|
| 139 |
+
log_level_str = settings.LOG_LEVEL.lower()
|
| 140 |
+
level = LOG_LEVELS.get(log_level_str, logging.INFO)
|
| 141 |
+
|
| 142 |
+
if name in Logger._loggers:
|
| 143 |
+
# 如果 logger 已存在,检查并更新其级别(如果需要)
|
| 144 |
+
existing_logger = Logger._loggers[name]
|
| 145 |
+
if existing_logger.level != level:
|
| 146 |
+
existing_logger.setLevel(level)
|
| 147 |
+
return existing_logger
|
| 148 |
+
|
| 149 |
+
logger = logging.getLogger(name)
|
| 150 |
+
logger.setLevel(level)
|
| 151 |
+
logger.propagate = False
|
| 152 |
+
|
| 153 |
+
# 添加控制台输出
|
| 154 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
| 155 |
+
console_handler.setFormatter(FORMATTER)
|
| 156 |
+
logger.addHandler(console_handler)
|
| 157 |
+
|
| 158 |
+
Logger._loggers[name] = logger
|
| 159 |
+
return logger
|
| 160 |
+
|
| 161 |
+
@staticmethod
|
| 162 |
+
def get_logger(name: str) -> Optional[logging.Logger]:
|
| 163 |
+
"""
|
| 164 |
+
获取已存在的logger
|
| 165 |
+
:param name: logger名称
|
| 166 |
+
:return: logger实例或None
|
| 167 |
+
"""
|
| 168 |
+
return Logger._loggers.get(name)
|
| 169 |
+
|
| 170 |
+
@staticmethod
|
| 171 |
+
def update_log_levels(log_level: str):
|
| 172 |
+
"""
|
| 173 |
+
根据当前��全局配置更新所有已创建 logger 的日志级别。
|
| 174 |
+
"""
|
| 175 |
+
log_level_str = log_level.lower()
|
| 176 |
+
new_level = LOG_LEVELS.get(log_level_str, logging.INFO)
|
| 177 |
+
|
| 178 |
+
updated_count = 0
|
| 179 |
+
for logger_name, logger_instance in Logger._loggers.items():
|
| 180 |
+
if logger_instance.level != new_level:
|
| 181 |
+
logger_instance.setLevel(new_level)
|
| 182 |
+
# 可选:记录级别变更日志,但注意避免在日志模块内部产生过多日志
|
| 183 |
+
# print(f"Updated log level for logger '{logger_name}' to {log_level_str.upper()}")
|
| 184 |
+
updated_count += 1
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# 预定义的loggers
|
| 188 |
+
def get_openai_logger():
|
| 189 |
+
return Logger.setup_logger("openai")
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def get_gemini_logger():
|
| 193 |
+
return Logger.setup_logger("gemini")
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def get_chat_logger():
|
| 197 |
+
return Logger.setup_logger("chat")
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def get_model_logger():
|
| 201 |
+
return Logger.setup_logger("model")
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def get_security_logger():
|
| 205 |
+
return Logger.setup_logger("security")
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def get_key_manager_logger():
|
| 209 |
+
return Logger.setup_logger("key_manager")
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def get_main_logger():
|
| 213 |
+
return Logger.setup_logger("main")
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def get_embeddings_logger():
|
| 217 |
+
return Logger.setup_logger("embeddings")
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def get_request_logger():
|
| 221 |
+
return Logger.setup_logger("request")
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def get_retry_logger():
|
| 225 |
+
return Logger.setup_logger("retry")
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def get_image_create_logger():
|
| 229 |
+
return Logger.setup_logger("image_create")
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def get_exceptions_logger():
|
| 233 |
+
return Logger.setup_logger("exceptions")
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def get_application_logger():
|
| 237 |
+
return Logger.setup_logger("application")
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def get_initialization_logger():
|
| 241 |
+
return Logger.setup_logger("initialization")
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def get_middleware_logger():
|
| 245 |
+
return Logger.setup_logger("middleware")
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def get_routes_logger():
|
| 249 |
+
return Logger.setup_logger("routes")
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def get_config_routes_logger():
|
| 253 |
+
return Logger.setup_logger("config_routes")
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def get_config_logger():
|
| 257 |
+
return Logger.setup_logger("config")
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def get_database_logger():
|
| 261 |
+
return Logger.setup_logger("database")
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def get_log_routes_logger():
|
| 265 |
+
return Logger.setup_logger("log_routes")
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def get_stats_logger():
|
| 269 |
+
return Logger.setup_logger("stats")
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def get_update_logger():
|
| 273 |
+
return Logger.setup_logger("update_service")
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def get_scheduler_routes():
|
| 277 |
+
return Logger.setup_logger("scheduler_routes")
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def get_message_converter_logger():
|
| 281 |
+
return Logger.setup_logger("message_converter")
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def get_api_client_logger():
|
| 285 |
+
return Logger.setup_logger("api_client")
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def get_openai_compatible_logger():
|
| 289 |
+
return Logger.setup_logger("openai_compatible")
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def get_error_log_logger():
|
| 293 |
+
return Logger.setup_logger("error_log")
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def get_request_log_logger():
|
| 297 |
+
return Logger.setup_logger("request_log")
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def get_files_logger():
|
| 301 |
+
return Logger.setup_logger("files")
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def get_vertex_express_logger():
|
| 305 |
+
return Logger.setup_logger("vertex_express")
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def get_gemini_embedding_logger():
|
| 309 |
+
return Logger.setup_logger("gemini_embedding")
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def setup_access_logging():
|
| 313 |
+
"""
|
| 314 |
+
Configure uvicorn access logging with API key redaction
|
| 315 |
+
|
| 316 |
+
This function sets up a custom access log formatter that automatically
|
| 317 |
+
redacts API keys in HTTP access logs. It works by:
|
| 318 |
+
|
| 319 |
+
1. Intercepting uvicorn's access log messages
|
| 320 |
+
2. Using regex patterns to find API keys in URLs
|
| 321 |
+
3. Replacing them with redacted versions (first6...last6)
|
| 322 |
+
|
| 323 |
+
Supported API key formats:
|
| 324 |
+
- Google/Gemini API keys: AIza[35 chars]
|
| 325 |
+
- OpenAI API keys: sk-[48 chars]
|
| 326 |
+
- General sk- prefixed keys: sk-[20+ chars]
|
| 327 |
+
|
| 328 |
+
Usage:
|
| 329 |
+
- Automatically called in main.py when running with uvicorn
|
| 330 |
+
- For production deployment with gunicorn, ensure this is called in startup
|
| 331 |
+
"""
|
| 332 |
+
# Get the uvicorn access logger
|
| 333 |
+
access_logger = logging.getLogger("uvicorn.access")
|
| 334 |
+
|
| 335 |
+
# Remove existing handlers to avoid duplicate logs
|
| 336 |
+
for handler in access_logger.handlers[:]:
|
| 337 |
+
access_logger.removeHandler(handler)
|
| 338 |
+
|
| 339 |
+
# Create new handler with our custom formatter that includes timestamp and log level
|
| 340 |
+
handler = logging.StreamHandler(sys.stdout)
|
| 341 |
+
access_formatter = AccessLogFormatter("%(asctime)s | %(levelname)-8s | %(message)s")
|
| 342 |
+
handler.setFormatter(access_formatter)
|
| 343 |
+
|
| 344 |
+
# Add the handler to uvicorn access logger
|
| 345 |
+
access_logger.addHandler(handler)
|
| 346 |
+
access_logger.setLevel(logging.INFO)
|
| 347 |
+
access_logger.propagate = False
|
| 348 |
+
|
| 349 |
+
return access_logger
|
app/main.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uvicorn
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
|
| 4 |
+
# 在导入应用程序配置之前加载 .env 文件到环境变量
|
| 5 |
+
load_dotenv()
|
| 6 |
+
|
| 7 |
+
from app.core.application import create_app
|
| 8 |
+
from app.log.logger import get_main_logger
|
| 9 |
+
|
| 10 |
+
app = create_app()
|
| 11 |
+
|
| 12 |
+
if __name__ == "__main__":
|
| 13 |
+
logger = get_main_logger()
|
| 14 |
+
logger.info("Starting application server...")
|
| 15 |
+
uvicorn.run(app, host="0.0.0.0", port=8001)
|
app/middleware/middleware.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
中间件配置模块,负责设置和配置应用程序的中间件
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from fastapi import FastAPI, Request
|
| 6 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 7 |
+
from fastapi.responses import RedirectResponse
|
| 8 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 9 |
+
|
| 10 |
+
# from app.middleware.request_logging_middleware import RequestLoggingMiddleware
|
| 11 |
+
from app.middleware.smart_routing_middleware import SmartRoutingMiddleware
|
| 12 |
+
from app.core.constants import API_VERSION
|
| 13 |
+
from app.core.security import verify_auth_token
|
| 14 |
+
from app.log.logger import get_middleware_logger
|
| 15 |
+
|
| 16 |
+
logger = get_middleware_logger()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class AuthMiddleware(BaseHTTPMiddleware):
|
| 20 |
+
"""
|
| 21 |
+
认证中间件,处理未经身份验证的请求
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
async def dispatch(self, request: Request, call_next):
|
| 25 |
+
# 允许特定路径绕过身份验证
|
| 26 |
+
if (
|
| 27 |
+
request.url.path not in ["/", "/auth"]
|
| 28 |
+
and not request.url.path.startswith("/static")
|
| 29 |
+
and not request.url.path.startswith("/gemini")
|
| 30 |
+
and not request.url.path.startswith("/v1")
|
| 31 |
+
and not request.url.path.startswith(f"/{API_VERSION}")
|
| 32 |
+
and not request.url.path.startswith("/health")
|
| 33 |
+
and not request.url.path.startswith("/hf")
|
| 34 |
+
and not request.url.path.startswith("/openai")
|
| 35 |
+
and not request.url.path.startswith("/api/version/check")
|
| 36 |
+
and not request.url.path.startswith("/vertex-express")
|
| 37 |
+
and not request.url.path.startswith("/upload")
|
| 38 |
+
):
|
| 39 |
+
|
| 40 |
+
auth_token = request.cookies.get("auth_token")
|
| 41 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 42 |
+
logger.warning(f"Unauthorized access attempt to {request.url.path}")
|
| 43 |
+
return RedirectResponse(url="/")
|
| 44 |
+
logger.debug("Request authenticated successfully")
|
| 45 |
+
|
| 46 |
+
response = await call_next(request)
|
| 47 |
+
return response
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def setup_middlewares(app: FastAPI) -> None:
|
| 51 |
+
"""
|
| 52 |
+
设置应用程序的中间件
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
app: FastAPI应用程序实例
|
| 56 |
+
"""
|
| 57 |
+
# 添加智能路由中间件(必须在认证中间件之前)
|
| 58 |
+
app.add_middleware(SmartRoutingMiddleware)
|
| 59 |
+
|
| 60 |
+
# 添加认证中间件
|
| 61 |
+
app.add_middleware(AuthMiddleware)
|
| 62 |
+
|
| 63 |
+
# 添加请求日志中间件(可选,默认注释掉)
|
| 64 |
+
# app.add_middleware(RequestLoggingMiddleware)
|
| 65 |
+
|
| 66 |
+
# 配置CORS中间件
|
| 67 |
+
app.add_middleware(
|
| 68 |
+
CORSMiddleware,
|
| 69 |
+
allow_origins=["*"],
|
| 70 |
+
allow_credentials=True,
|
| 71 |
+
allow_methods=[
|
| 72 |
+
"GET",
|
| 73 |
+
"POST",
|
| 74 |
+
"PUT",
|
| 75 |
+
"DELETE",
|
| 76 |
+
"OPTIONS",
|
| 77 |
+
],
|
| 78 |
+
allow_headers=["*"],
|
| 79 |
+
expose_headers=["*"],
|
| 80 |
+
max_age=600,
|
| 81 |
+
)
|
app/middleware/request_logging_middleware.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
from fastapi import Request
|
| 4 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 5 |
+
|
| 6 |
+
from app.log.logger import get_request_logger
|
| 7 |
+
|
| 8 |
+
logger = get_request_logger()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# 添加中间件类
|
| 12 |
+
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
| 13 |
+
async def dispatch(self, request: Request, call_next):
|
| 14 |
+
# 记录请求路径
|
| 15 |
+
logger.info(f"Request path: {request.url.path}")
|
| 16 |
+
|
| 17 |
+
# 获取并记录请求体
|
| 18 |
+
try:
|
| 19 |
+
body = await request.body()
|
| 20 |
+
if body:
|
| 21 |
+
body_str = body.decode()
|
| 22 |
+
# 尝试格式化JSON
|
| 23 |
+
try:
|
| 24 |
+
formatted_body = json.loads(body_str)
|
| 25 |
+
logger.info(
|
| 26 |
+
f"Formatted request body:\n{json.dumps(formatted_body, indent=2, ensure_ascii=False)}"
|
| 27 |
+
)
|
| 28 |
+
except json.JSONDecodeError:
|
| 29 |
+
logger.error("Request body is not valid JSON.")
|
| 30 |
+
except Exception as e:
|
| 31 |
+
logger.error(f"Error reading request body: {str(e)}")
|
| 32 |
+
|
| 33 |
+
# 重置请求的接收器,以便后续处理器可以继续读取请求体
|
| 34 |
+
async def receive():
|
| 35 |
+
return {"type": "http.request", "body": body, "more_body": False}
|
| 36 |
+
|
| 37 |
+
request._receive = receive
|
| 38 |
+
|
| 39 |
+
response = await call_next(request)
|
| 40 |
+
return response
|
app/middleware/smart_routing_middleware.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import Request
|
| 2 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 3 |
+
from app.config.config import settings
|
| 4 |
+
from app.log.logger import get_main_logger
|
| 5 |
+
import re
|
| 6 |
+
|
| 7 |
+
logger = get_main_logger()
|
| 8 |
+
|
| 9 |
+
class SmartRoutingMiddleware(BaseHTTPMiddleware):
|
| 10 |
+
def __init__(self, app):
|
| 11 |
+
super().__init__(app)
|
| 12 |
+
# 简化的路由规则 - 直接根据检测结果路由
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
async def dispatch(self, request: Request, call_next):
|
| 16 |
+
if not settings.URL_NORMALIZATION_ENABLED:
|
| 17 |
+
return await call_next(request)
|
| 18 |
+
logger.debug(f"request: {request}")
|
| 19 |
+
original_path = str(request.url.path)
|
| 20 |
+
method = request.method
|
| 21 |
+
|
| 22 |
+
# 尝试修复URL
|
| 23 |
+
fixed_path, fix_info = self.fix_request_url(original_path, method, request)
|
| 24 |
+
|
| 25 |
+
if fixed_path != original_path:
|
| 26 |
+
logger.info(f"URL fixed: {method} {original_path} → {fixed_path}")
|
| 27 |
+
if fix_info:
|
| 28 |
+
logger.debug(f"Fix details: {fix_info}")
|
| 29 |
+
|
| 30 |
+
# 重写请求路径
|
| 31 |
+
request.scope["path"] = fixed_path
|
| 32 |
+
request.scope["raw_path"] = fixed_path.encode()
|
| 33 |
+
|
| 34 |
+
return await call_next(request)
|
| 35 |
+
|
| 36 |
+
def fix_request_url(self, path: str, method: str, request: Request) -> tuple:
|
| 37 |
+
"""简化的URL修复逻辑"""
|
| 38 |
+
|
| 39 |
+
# 首先检查是否已经是正确的格式,如果是则不处理
|
| 40 |
+
if self.is_already_correct_format(path):
|
| 41 |
+
return path, None
|
| 42 |
+
|
| 43 |
+
# 1. 最高优先级:包含generateContent → Gemini格式
|
| 44 |
+
if "generatecontent" in path.lower() or "v1beta/models" in path.lower():
|
| 45 |
+
return self.fix_gemini_by_operation(path, method, request)
|
| 46 |
+
|
| 47 |
+
# 2. 第二优先级:包含/openai/ → OpenAI格式
|
| 48 |
+
if "/openai/" in path.lower():
|
| 49 |
+
return self.fix_openai_by_operation(path, method)
|
| 50 |
+
|
| 51 |
+
# 3. 第三优先级:包含/v1/ → v1格式
|
| 52 |
+
if "/v1/" in path.lower():
|
| 53 |
+
return self.fix_v1_by_operation(path, method)
|
| 54 |
+
|
| 55 |
+
# 4. 第四优先级:包含/chat/completions → chat功能
|
| 56 |
+
if "/chat/completions" in path.lower():
|
| 57 |
+
return "/v1/chat/completions", {"type": "v1_chat"}
|
| 58 |
+
|
| 59 |
+
# 5. 默认:原样传递
|
| 60 |
+
return path, None
|
| 61 |
+
|
| 62 |
+
def is_already_correct_format(self, path: str) -> bool:
|
| 63 |
+
"""检查是否已经是正确的API格式"""
|
| 64 |
+
# 检查是否已经是正确的端点格式
|
| 65 |
+
correct_patterns = [
|
| 66 |
+
r"^/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", # Gemini原生
|
| 67 |
+
r"^/gemini/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", # Gemini带前缀
|
| 68 |
+
r"^/v1beta/models$", # Gemini模型列表
|
| 69 |
+
r"^/gemini/v1beta/models$", # Gemini带前缀的模型列表
|
| 70 |
+
r"^/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", # v1格式
|
| 71 |
+
r"^/openai/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", # OpenAI格式
|
| 72 |
+
r"^/hf/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", # HF格式
|
| 73 |
+
r"^/vertex-express/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", # Vertex Express Gemini格式
|
| 74 |
+
r"^/vertex-express/v1beta/models$", # Vertex Express模型列表
|
| 75 |
+
r"^/vertex-express/v1/(chat/completions|models|embeddings|images/generations)$", # Vertex Express OpenAI格式
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
for pattern in correct_patterns:
|
| 79 |
+
if re.match(pattern, path):
|
| 80 |
+
return True
|
| 81 |
+
|
| 82 |
+
return False
|
| 83 |
+
|
| 84 |
+
def fix_gemini_by_operation(
|
| 85 |
+
self, path: str, method: str, request: Request
|
| 86 |
+
) -> tuple:
|
| 87 |
+
"""根据Gemini操作修复,考虑端点偏好"""
|
| 88 |
+
if method == "GET":
|
| 89 |
+
return "/v1beta/models", {
|
| 90 |
+
"role": "gemini_models",
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
# 提取模型名称
|
| 94 |
+
try:
|
| 95 |
+
model_name = self.extract_model_name(path, request)
|
| 96 |
+
except ValueError:
|
| 97 |
+
# 无法提取模型名称,返回原路径不做处理
|
| 98 |
+
return path, None
|
| 99 |
+
|
| 100 |
+
# 检测是否为流式请求
|
| 101 |
+
is_stream = self.detect_stream_request(path, request)
|
| 102 |
+
|
| 103 |
+
# 检查是否有vertex-express偏好
|
| 104 |
+
if "/vertex-express/" in path.lower():
|
| 105 |
+
if is_stream:
|
| 106 |
+
target_url = (
|
| 107 |
+
f"/vertex-express/v1beta/models/{model_name}:streamGenerateContent"
|
| 108 |
+
)
|
| 109 |
+
else:
|
| 110 |
+
target_url = (
|
| 111 |
+
f"/vertex-express/v1beta/models/{model_name}:generateContent"
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
fix_info = {
|
| 115 |
+
"rule": (
|
| 116 |
+
"vertex_express_generate"
|
| 117 |
+
if not is_stream
|
| 118 |
+
else "vertex_express_stream"
|
| 119 |
+
),
|
| 120 |
+
"preference": "vertex_express_format",
|
| 121 |
+
"is_stream": is_stream,
|
| 122 |
+
"model": model_name,
|
| 123 |
+
}
|
| 124 |
+
else:
|
| 125 |
+
# 标准Gemini端点
|
| 126 |
+
if is_stream:
|
| 127 |
+
target_url = f"/v1beta/models/{model_name}:streamGenerateContent"
|
| 128 |
+
else:
|
| 129 |
+
target_url = f"/v1beta/models/{model_name}:generateContent"
|
| 130 |
+
|
| 131 |
+
fix_info = {
|
| 132 |
+
"rule": "gemini_generate" if not is_stream else "gemini_stream",
|
| 133 |
+
"preference": "gemini_format",
|
| 134 |
+
"is_stream": is_stream,
|
| 135 |
+
"model": model_name,
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
return target_url, fix_info
|
| 139 |
+
|
| 140 |
+
def fix_openai_by_operation(self, path: str, method: str) -> tuple:
|
| 141 |
+
"""根据操作类型修复OpenAI格式"""
|
| 142 |
+
if method == "POST":
|
| 143 |
+
if "chat" in path.lower() or "completion" in path.lower():
|
| 144 |
+
return "/openai/v1/chat/completions", {"type": "openai_chat"}
|
| 145 |
+
elif "embedding" in path.lower():
|
| 146 |
+
return "/openai/v1/embeddings", {"type": "openai_embeddings"}
|
| 147 |
+
elif "image" in path.lower():
|
| 148 |
+
return "/openai/v1/images/generations", {"type": "openai_images"}
|
| 149 |
+
elif "audio" in path.lower():
|
| 150 |
+
return "/openai/v1/audio/speech", {"type": "openai_audio"}
|
| 151 |
+
elif method == "GET":
|
| 152 |
+
if "model" in path.lower():
|
| 153 |
+
return "/openai/v1/models", {"type": "openai_models"}
|
| 154 |
+
|
| 155 |
+
return path, None
|
| 156 |
+
|
| 157 |
+
def fix_v1_by_operation(self, path: str, method: str) -> tuple:
|
| 158 |
+
"""根据操作类型修复v1格式"""
|
| 159 |
+
if method == "POST":
|
| 160 |
+
if "chat" in path.lower() or "completion" in path.lower():
|
| 161 |
+
return "/v1/chat/completions", {"type": "v1_chat"}
|
| 162 |
+
elif "embedding" in path.lower():
|
| 163 |
+
return "/v1/embeddings", {"type": "v1_embeddings"}
|
| 164 |
+
elif "image" in path.lower():
|
| 165 |
+
return "/v1/images/generations", {"type": "v1_images"}
|
| 166 |
+
elif "audio" in path.lower():
|
| 167 |
+
return "/v1/audio/speech", {"type": "v1_audio"}
|
| 168 |
+
elif method == "GET":
|
| 169 |
+
if "model" in path.lower():
|
| 170 |
+
return "/v1/models", {"type": "v1_models"}
|
| 171 |
+
|
| 172 |
+
return path, None
|
| 173 |
+
|
| 174 |
+
def detect_stream_request(self, path: str, request: Request) -> bool:
|
| 175 |
+
"""检测是否为流式请求"""
|
| 176 |
+
# 1. 路径中包含stream关键词
|
| 177 |
+
if "stream" in path.lower():
|
| 178 |
+
return True
|
| 179 |
+
|
| 180 |
+
# 2. 查询参数
|
| 181 |
+
if request.query_params.get("stream") == "true":
|
| 182 |
+
return True
|
| 183 |
+
|
| 184 |
+
return False
|
| 185 |
+
|
| 186 |
+
def extract_model_name(self, path: str, request: Request) -> str:
|
| 187 |
+
"""从请求中提取模型名称,用于构建Gemini API URL"""
|
| 188 |
+
# 1. 从请求体中提取
|
| 189 |
+
try:
|
| 190 |
+
if hasattr(request, "_body") and request._body:
|
| 191 |
+
import json
|
| 192 |
+
|
| 193 |
+
body = json.loads(request._body.decode())
|
| 194 |
+
if "model" in body and body["model"]:
|
| 195 |
+
return body["model"]
|
| 196 |
+
except Exception:
|
| 197 |
+
pass
|
| 198 |
+
|
| 199 |
+
# 2. 从查询参数中提取
|
| 200 |
+
model_param = request.query_params.get("model")
|
| 201 |
+
if model_param:
|
| 202 |
+
return model_param
|
| 203 |
+
|
| 204 |
+
# 3. 从路径中提取(用于已包含模型名称的路径)
|
| 205 |
+
match = re.search(r"/models/([^/:]+)", path, re.IGNORECASE)
|
| 206 |
+
if match:
|
| 207 |
+
return match.group(1)
|
| 208 |
+
|
| 209 |
+
# 4. 如果无法提取模型名称,抛出异常
|
| 210 |
+
raise ValueError("Unable to extract model name from request")
|
app/router/config_routes.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
配置路由模块
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, List
|
| 6 |
+
|
| 7 |
+
from fastapi import APIRouter, HTTPException, Request
|
| 8 |
+
from fastapi.responses import RedirectResponse
|
| 9 |
+
from pydantic import BaseModel, Field
|
| 10 |
+
|
| 11 |
+
from app.core.security import verify_auth_token
|
| 12 |
+
from app.log.logger import Logger, get_config_routes_logger
|
| 13 |
+
from app.service.config.config_service import ConfigService
|
| 14 |
+
from app.service.proxy.proxy_check_service import get_proxy_check_service, ProxyCheckResult
|
| 15 |
+
from app.utils.helpers import redact_key_for_logging
|
| 16 |
+
|
| 17 |
+
router = APIRouter(prefix="/api/config", tags=["config"])
|
| 18 |
+
|
| 19 |
+
logger = get_config_routes_logger()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@router.get("", response_model=Dict[str, Any])
|
| 23 |
+
async def get_config(request: Request):
|
| 24 |
+
auth_token = request.cookies.get("auth_token")
|
| 25 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 26 |
+
logger.warning("Unauthorized access attempt to config page")
|
| 27 |
+
return RedirectResponse(url="/", status_code=302)
|
| 28 |
+
return await ConfigService.get_config()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@router.put("", response_model=Dict[str, Any])
|
| 32 |
+
async def update_config(config_data: Dict[str, Any], request: Request):
|
| 33 |
+
auth_token = request.cookies.get("auth_token")
|
| 34 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 35 |
+
logger.warning("Unauthorized access attempt to config page")
|
| 36 |
+
return RedirectResponse(url="/", status_code=302)
|
| 37 |
+
try:
|
| 38 |
+
result = await ConfigService.update_config(config_data)
|
| 39 |
+
# 配置更新成功后,立即更新所有 logger 的级别
|
| 40 |
+
Logger.update_log_levels(config_data["LOG_LEVEL"])
|
| 41 |
+
logger.info("Log levels updated after configuration change.")
|
| 42 |
+
return result
|
| 43 |
+
except Exception as e:
|
| 44 |
+
logger.error(f"Error updating config or log levels: {e}", exc_info=True)
|
| 45 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@router.post("/reset", response_model=Dict[str, Any])
|
| 49 |
+
async def reset_config(request: Request):
|
| 50 |
+
auth_token = request.cookies.get("auth_token")
|
| 51 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 52 |
+
logger.warning("Unauthorized access attempt to config page")
|
| 53 |
+
return RedirectResponse(url="/", status_code=302)
|
| 54 |
+
try:
|
| 55 |
+
return await ConfigService.reset_config()
|
| 56 |
+
except Exception as e:
|
| 57 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class DeleteKeysRequest(BaseModel):
|
| 61 |
+
keys: List[str] = Field(..., description="List of API keys to delete")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@router.delete("/keys/{key_to_delete}", response_model=Dict[str, Any])
|
| 65 |
+
async def delete_single_key(key_to_delete: str, request: Request):
|
| 66 |
+
auth_token = request.cookies.get("auth_token")
|
| 67 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 68 |
+
logger.warning(f"Unauthorized attempt to delete key: {redact_key_for_logging(key_to_delete)}")
|
| 69 |
+
return RedirectResponse(url="/", status_code=302)
|
| 70 |
+
try:
|
| 71 |
+
logger.info(f"Attempting to delete key: {redact_key_for_logging(key_to_delete)}")
|
| 72 |
+
result = await ConfigService.delete_key(key_to_delete)
|
| 73 |
+
if not result.get("success"):
|
| 74 |
+
raise HTTPException(
|
| 75 |
+
status_code=(
|
| 76 |
+
404 if "not found" in result.get("message", "").lower() else 400
|
| 77 |
+
),
|
| 78 |
+
detail=result.get("message"),
|
| 79 |
+
)
|
| 80 |
+
return result
|
| 81 |
+
except HTTPException as e:
|
| 82 |
+
raise e
|
| 83 |
+
except Exception as e:
|
| 84 |
+
logger.error(f"Error deleting key '{redact_key_for_logging(key_to_delete)}': {e}", exc_info=True)
|
| 85 |
+
raise HTTPException(status_code=500, detail=f"Error deleting key: {str(e)}")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@router.post("/keys/delete-selected", response_model=Dict[str, Any])
|
| 89 |
+
async def delete_selected_keys_route(
|
| 90 |
+
delete_request: DeleteKeysRequest, request: Request
|
| 91 |
+
):
|
| 92 |
+
auth_token = request.cookies.get("auth_token")
|
| 93 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 94 |
+
logger.warning("Unauthorized attempt to bulk delete keys")
|
| 95 |
+
return RedirectResponse(url="/", status_code=302)
|
| 96 |
+
|
| 97 |
+
if not delete_request.keys:
|
| 98 |
+
logger.warning("Attempt to bulk delete keys with an empty list.")
|
| 99 |
+
raise HTTPException(status_code=400, detail="No keys provided for deletion.")
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
logger.info(f"Attempting to bulk delete {len(delete_request.keys)} keys.")
|
| 103 |
+
result = await ConfigService.delete_selected_keys(delete_request.keys)
|
| 104 |
+
if not result.get("success") and result.get("deleted_count", 0) == 0:
|
| 105 |
+
raise HTTPException(
|
| 106 |
+
status_code=400, detail=result.get("message", "Failed to delete keys.")
|
| 107 |
+
)
|
| 108 |
+
return result
|
| 109 |
+
except HTTPException as e:
|
| 110 |
+
raise e
|
| 111 |
+
except Exception as e:
|
| 112 |
+
logger.error(f"Error bulk deleting keys: {e}", exc_info=True)
|
| 113 |
+
raise HTTPException(
|
| 114 |
+
status_code=500, detail=f"Error bulk deleting keys: {str(e)}"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@router.get("/ui/models")
|
| 119 |
+
async def get_ui_models(request: Request):
|
| 120 |
+
auth_token_cookie = request.cookies.get("auth_token")
|
| 121 |
+
if not auth_token_cookie or not verify_auth_token(auth_token_cookie):
|
| 122 |
+
logger.warning("Unauthorized access attempt to /api/config/ui/models")
|
| 123 |
+
raise HTTPException(status_code=403, detail="Not authenticated")
|
| 124 |
+
|
| 125 |
+
try:
|
| 126 |
+
models = await ConfigService.fetch_ui_models()
|
| 127 |
+
return models
|
| 128 |
+
except HTTPException as e:
|
| 129 |
+
raise e
|
| 130 |
+
except Exception as e:
|
| 131 |
+
logger.error(f"Unexpected error in /ui/models endpoint: {e}", exc_info=True)
|
| 132 |
+
raise HTTPException(
|
| 133 |
+
status_code=500,
|
| 134 |
+
detail=f"An unexpected error occurred while fetching UI models: {str(e)}",
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class ProxyCheckRequest(BaseModel):
|
| 139 |
+
"""Proxy check request"""
|
| 140 |
+
proxy: str = Field(..., description="Proxy address to check")
|
| 141 |
+
use_cache: bool = Field(True, description="Whether to use cached results")
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class ProxyBatchCheckRequest(BaseModel):
|
| 145 |
+
"""Batch proxy check request"""
|
| 146 |
+
proxies: List[str] = Field(..., description="List of proxy addresses to check")
|
| 147 |
+
use_cache: bool = Field(True, description="Whether to use cached results")
|
| 148 |
+
max_concurrent: int = Field(5, description="Maximum concurrent check count", ge=1, le=10)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
@router.post("/proxy/check", response_model=ProxyCheckResult)
|
| 152 |
+
async def check_single_proxy(proxy_request: ProxyCheckRequest, request: Request):
|
| 153 |
+
"""Check if a single proxy is available"""
|
| 154 |
+
auth_token = request.cookies.get("auth_token")
|
| 155 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 156 |
+
logger.warning("Unauthorized access attempt to proxy check")
|
| 157 |
+
return RedirectResponse(url="/", status_code=302)
|
| 158 |
+
|
| 159 |
+
try:
|
| 160 |
+
logger.info(f"Checking single proxy: {proxy_request.proxy}")
|
| 161 |
+
proxy_service = get_proxy_check_service()
|
| 162 |
+
result = await proxy_service.check_single_proxy(
|
| 163 |
+
proxy_request.proxy,
|
| 164 |
+
proxy_request.use_cache
|
| 165 |
+
)
|
| 166 |
+
return result
|
| 167 |
+
except Exception as e:
|
| 168 |
+
logger.error(f"Proxy check failed: {str(e)}", exc_info=True)
|
| 169 |
+
raise HTTPException(status_code=500, detail=f"Proxy check failed: {str(e)}")
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
@router.post("/proxy/check-all", response_model=List[ProxyCheckResult])
|
| 173 |
+
async def check_all_proxies(batch_request: ProxyBatchCheckRequest, request: Request):
|
| 174 |
+
"""Check multiple proxies availability"""
|
| 175 |
+
auth_token = request.cookies.get("auth_token")
|
| 176 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 177 |
+
logger.warning("Unauthorized access attempt to batch proxy check")
|
| 178 |
+
return RedirectResponse(url="/", status_code=302)
|
| 179 |
+
|
| 180 |
+
try:
|
| 181 |
+
logger.info(f"Batch checking {len(batch_request.proxies)} proxies")
|
| 182 |
+
proxy_service = get_proxy_check_service()
|
| 183 |
+
results = await proxy_service.check_multiple_proxies(
|
| 184 |
+
batch_request.proxies,
|
| 185 |
+
batch_request.use_cache,
|
| 186 |
+
batch_request.max_concurrent
|
| 187 |
+
)
|
| 188 |
+
return results
|
| 189 |
+
except Exception as e:
|
| 190 |
+
logger.error(f"Batch proxy check failed: {str(e)}", exc_info=True)
|
| 191 |
+
raise HTTPException(status_code=500, detail=f"Batch proxy check failed: {str(e)}")
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
@router.get("/proxy/cache-stats")
|
| 195 |
+
async def get_proxy_cache_stats(request: Request):
|
| 196 |
+
"""Get proxy check cache statistics"""
|
| 197 |
+
auth_token = request.cookies.get("auth_token")
|
| 198 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 199 |
+
logger.warning("Unauthorized access attempt to proxy cache stats")
|
| 200 |
+
return RedirectResponse(url="/", status_code=302)
|
| 201 |
+
|
| 202 |
+
try:
|
| 203 |
+
proxy_service = get_proxy_check_service()
|
| 204 |
+
stats = proxy_service.get_cache_stats()
|
| 205 |
+
return stats
|
| 206 |
+
except Exception as e:
|
| 207 |
+
logger.error(f"Get proxy cache stats failed: {str(e)}", exc_info=True)
|
| 208 |
+
raise HTTPException(status_code=500, detail=f"Get cache stats failed: {str(e)}")
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
@router.post("/proxy/clear-cache")
|
| 212 |
+
async def clear_proxy_cache(request: Request):
|
| 213 |
+
"""Clear proxy check cache"""
|
| 214 |
+
auth_token = request.cookies.get("auth_token")
|
| 215 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 216 |
+
logger.warning("Unauthorized access attempt to clear proxy cache")
|
| 217 |
+
return RedirectResponse(url="/", status_code=302)
|
| 218 |
+
|
| 219 |
+
try:
|
| 220 |
+
proxy_service = get_proxy_check_service()
|
| 221 |
+
proxy_service.clear_cache()
|
| 222 |
+
return {"success": True, "message": "Proxy check cache cleared"}
|
| 223 |
+
except Exception as e:
|
| 224 |
+
logger.error(f"Clear proxy cache failed: {str(e)}", exc_info=True)
|
| 225 |
+
raise HTTPException(status_code=500, detail=f"Clear cache failed: {str(e)}")
|
app/router/error_log_routes.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
日志路由模块
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from typing import Dict, List, Optional
|
| 7 |
+
|
| 8 |
+
from fastapi import (
|
| 9 |
+
APIRouter,
|
| 10 |
+
Body,
|
| 11 |
+
HTTPException,
|
| 12 |
+
Path,
|
| 13 |
+
Query,
|
| 14 |
+
Request,
|
| 15 |
+
Response,
|
| 16 |
+
status,
|
| 17 |
+
)
|
| 18 |
+
from pydantic import BaseModel
|
| 19 |
+
|
| 20 |
+
from app.core.security import verify_auth_token
|
| 21 |
+
from app.log.logger import get_log_routes_logger
|
| 22 |
+
from app.service.error_log import error_log_service
|
| 23 |
+
|
| 24 |
+
router = APIRouter(prefix="/api/logs", tags=["logs"])
|
| 25 |
+
|
| 26 |
+
logger = get_log_routes_logger()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ErrorLogListItem(BaseModel):
|
| 30 |
+
id: int
|
| 31 |
+
gemini_key: Optional[str] = None
|
| 32 |
+
error_type: Optional[str] = None
|
| 33 |
+
error_code: Optional[int] = None
|
| 34 |
+
model_name: Optional[str] = None
|
| 35 |
+
request_time: Optional[datetime] = None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class ErrorLogListResponse(BaseModel):
|
| 39 |
+
logs: List[ErrorLogListItem]
|
| 40 |
+
total: int
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@router.get("/errors", response_model=ErrorLogListResponse)
|
| 44 |
+
async def get_error_logs_api(
|
| 45 |
+
request: Request,
|
| 46 |
+
limit: int = Query(10, ge=1, le=1000),
|
| 47 |
+
offset: int = Query(0, ge=0),
|
| 48 |
+
key_search: Optional[str] = Query(
|
| 49 |
+
None, description="Search term for Gemini key (partial match)"
|
| 50 |
+
),
|
| 51 |
+
error_search: Optional[str] = Query(
|
| 52 |
+
None, description="Search term for error type or log message"
|
| 53 |
+
),
|
| 54 |
+
error_code_search: Optional[str] = Query(
|
| 55 |
+
None, description="Search term for error code"
|
| 56 |
+
),
|
| 57 |
+
start_date: Optional[datetime] = Query(
|
| 58 |
+
None, description="Start datetime for filtering"
|
| 59 |
+
),
|
| 60 |
+
end_date: Optional[datetime] = Query(
|
| 61 |
+
None, description="End datetime for filtering"
|
| 62 |
+
),
|
| 63 |
+
sort_by: str = Query(
|
| 64 |
+
"id", description="Field to sort by (e.g., 'id', 'request_time')"
|
| 65 |
+
),
|
| 66 |
+
sort_order: str = Query("desc", description="Sort order ('asc' or 'desc')"),
|
| 67 |
+
):
|
| 68 |
+
"""
|
| 69 |
+
获取错误日志列表 (返回错误码),支持过滤和排序
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
request: 请求对象
|
| 73 |
+
limit: 限制数量
|
| 74 |
+
offset: 偏移量
|
| 75 |
+
key_search: 密钥搜索
|
| 76 |
+
error_search: 错误搜索 (可能搜索类型或日志内容,由DB层决定)
|
| 77 |
+
error_code_search: 错误码搜索
|
| 78 |
+
start_date: 开始日期
|
| 79 |
+
end_date: 结束日期
|
| 80 |
+
sort_by: 排序字段
|
| 81 |
+
sort_order: 排序顺序
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
ErrorLogListResponse: An object containing the list of logs (with error_code) and the total count.
|
| 85 |
+
"""
|
| 86 |
+
auth_token = request.cookies.get("auth_token")
|
| 87 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 88 |
+
logger.warning("Unauthorized access attempt to error logs list")
|
| 89 |
+
raise HTTPException(status_code=401, detail="Not authenticated")
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
result = await error_log_service.process_get_error_logs(
|
| 93 |
+
limit=limit,
|
| 94 |
+
offset=offset,
|
| 95 |
+
key_search=key_search,
|
| 96 |
+
error_search=error_search,
|
| 97 |
+
error_code_search=error_code_search,
|
| 98 |
+
start_date=start_date,
|
| 99 |
+
end_date=end_date,
|
| 100 |
+
sort_by=sort_by,
|
| 101 |
+
sort_order=sort_order,
|
| 102 |
+
)
|
| 103 |
+
logs_data = result["logs"]
|
| 104 |
+
total_count = result["total"]
|
| 105 |
+
|
| 106 |
+
validated_logs = [ErrorLogListItem(**log) for log in logs_data]
|
| 107 |
+
return ErrorLogListResponse(logs=validated_logs, total=total_count)
|
| 108 |
+
except Exception as e:
|
| 109 |
+
logger.exception(f"Failed to get error logs list: {str(e)}")
|
| 110 |
+
raise HTTPException(
|
| 111 |
+
status_code=500, detail=f"Failed to get error logs list: {str(e)}"
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class ErrorLogDetailResponse(BaseModel):
|
| 116 |
+
id: int
|
| 117 |
+
gemini_key: Optional[str] = None
|
| 118 |
+
error_type: Optional[str] = None
|
| 119 |
+
error_log: Optional[str] = None
|
| 120 |
+
request_msg: Optional[str] = None
|
| 121 |
+
model_name: Optional[str] = None
|
| 122 |
+
request_time: Optional[datetime] = None
|
| 123 |
+
error_code: Optional[int] = None
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@router.get("/errors/{log_id}/details", response_model=ErrorLogDetailResponse)
|
| 127 |
+
async def get_error_log_detail_api(request: Request, log_id: int = Path(..., ge=1)):
|
| 128 |
+
"""
|
| 129 |
+
根据日志 ID 获取错误日志的详细信息 (包括 error_log 和 request_msg)
|
| 130 |
+
"""
|
| 131 |
+
auth_token = request.cookies.get("auth_token")
|
| 132 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 133 |
+
logger.warning(
|
| 134 |
+
f"Unauthorized access attempt to error log details for ID: {log_id}"
|
| 135 |
+
)
|
| 136 |
+
raise HTTPException(status_code=401, detail="Not authenticated")
|
| 137 |
+
|
| 138 |
+
try:
|
| 139 |
+
log_details = await error_log_service.process_get_error_log_details(
|
| 140 |
+
log_id=log_id
|
| 141 |
+
)
|
| 142 |
+
if not log_details:
|
| 143 |
+
raise HTTPException(status_code=404, detail="Error log not found")
|
| 144 |
+
|
| 145 |
+
return ErrorLogDetailResponse(**log_details)
|
| 146 |
+
except HTTPException as http_exc:
|
| 147 |
+
raise http_exc
|
| 148 |
+
except Exception as e:
|
| 149 |
+
logger.exception(f"Failed to get error log details for ID {log_id}: {str(e)}")
|
| 150 |
+
raise HTTPException(
|
| 151 |
+
status_code=500, detail=f"Failed to get error log details: {str(e)}"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@router.get("/errors/lookup", response_model=ErrorLogDetailResponse)
|
| 156 |
+
async def lookup_error_log_by_info(
|
| 157 |
+
request: Request,
|
| 158 |
+
gemini_key: str = Query(..., description="完整的 Gemini key"),
|
| 159 |
+
timestamp: datetime = Query(..., description="请求时间 (ISO8601)"),
|
| 160 |
+
status_code: Optional[int] = Query(None, description="错误码 (可选)"),
|
| 161 |
+
window_seconds: int = Query(
|
| 162 |
+
100, ge=1, le=300, description="时间窗口(秒), 默认100秒"
|
| 163 |
+
),
|
| 164 |
+
):
|
| 165 |
+
"""
|
| 166 |
+
通过 key / 错误码 / 时间窗口 查找最匹配的一条错误日志详情。
|
| 167 |
+
"""
|
| 168 |
+
auth_token = request.cookies.get("auth_token")
|
| 169 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 170 |
+
logger.warning("Unauthorized access attempt to lookup error log by info")
|
| 171 |
+
raise HTTPException(status_code=401, detail="Not authenticated")
|
| 172 |
+
|
| 173 |
+
try:
|
| 174 |
+
detail = await error_log_service.process_find_error_log_by_info(
|
| 175 |
+
gemini_key=gemini_key,
|
| 176 |
+
timestamp=timestamp,
|
| 177 |
+
status_code=status_code,
|
| 178 |
+
window_seconds=window_seconds,
|
| 179 |
+
)
|
| 180 |
+
if not detail:
|
| 181 |
+
raise HTTPException(status_code=404, detail="No matching error log found")
|
| 182 |
+
return ErrorLogDetailResponse(**detail)
|
| 183 |
+
except HTTPException as http_exc:
|
| 184 |
+
raise http_exc
|
| 185 |
+
except Exception as e:
|
| 186 |
+
logger.exception(
|
| 187 |
+
f"Failed to lookup error log by info for key=***{gemini_key[-4:] if gemini_key else ''}: {str(e)}"
|
| 188 |
+
)
|
| 189 |
+
raise HTTPException(status_code=500, detail="Internal server error")
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
@router.delete("/errors", status_code=status.HTTP_204_NO_CONTENT)
|
| 193 |
+
async def delete_error_logs_bulk_api(
|
| 194 |
+
request: Request, payload: Dict[str, List[int]] = Body(...)
|
| 195 |
+
):
|
| 196 |
+
"""
|
| 197 |
+
批量删除错误日志 (异步)
|
| 198 |
+
"""
|
| 199 |
+
auth_token = request.cookies.get("auth_token")
|
| 200 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 201 |
+
logger.warning("Unauthorized access attempt to bulk delete error logs")
|
| 202 |
+
raise HTTPException(status_code=401, detail="Not authenticated")
|
| 203 |
+
|
| 204 |
+
log_ids = payload.get("ids")
|
| 205 |
+
if not log_ids:
|
| 206 |
+
raise HTTPException(status_code=400, detail="No log IDs provided for deletion.")
|
| 207 |
+
|
| 208 |
+
try:
|
| 209 |
+
deleted_count = await error_log_service.process_delete_error_logs_by_ids(
|
| 210 |
+
log_ids
|
| 211 |
+
)
|
| 212 |
+
# 注意:异步函数返回的是尝试删除的数量,可能不是精确值
|
| 213 |
+
logger.info(
|
| 214 |
+
f"Attempted bulk deletion for {deleted_count} error logs with IDs: {log_ids}"
|
| 215 |
+
)
|
| 216 |
+
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
| 217 |
+
except Exception as e:
|
| 218 |
+
logger.exception(f"Error bulk deleting error logs with IDs {log_ids}: {str(e)}")
|
| 219 |
+
raise HTTPException(
|
| 220 |
+
status_code=500, detail="Internal server error during bulk deletion"
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
@router.delete("/errors/all", status_code=status.HTTP_204_NO_CONTENT)
|
| 225 |
+
async def delete_all_error_logs_api(request: Request):
|
| 226 |
+
"""
|
| 227 |
+
删除所有错误日志 (异步)
|
| 228 |
+
"""
|
| 229 |
+
auth_token = request.cookies.get("auth_token")
|
| 230 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 231 |
+
logger.warning("Unauthorized access attempt to delete all error logs")
|
| 232 |
+
raise HTTPException(status_code=401, detail="Not authenticated")
|
| 233 |
+
|
| 234 |
+
try:
|
| 235 |
+
await error_log_service.process_delete_all_error_logs()
|
| 236 |
+
logger.info("Successfully deleted all error logs.")
|
| 237 |
+
# No body needed for 204 response
|
| 238 |
+
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
| 239 |
+
except Exception as e:
|
| 240 |
+
logger.exception(f"Error deleting all error logs: {str(e)}")
|
| 241 |
+
raise HTTPException(
|
| 242 |
+
status_code=500, detail="Internal server error during deletion of all logs"
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
@router.delete("/errors/{log_id}", status_code=status.HTTP_204_NO_CONTENT)
|
| 247 |
+
async def delete_error_log_api(request: Request, log_id: int = Path(..., ge=1)):
|
| 248 |
+
"""
|
| 249 |
+
删除单个错误日志 (异步)
|
| 250 |
+
"""
|
| 251 |
+
auth_token = request.cookies.get("auth_token")
|
| 252 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 253 |
+
logger.warning(f"Unauthorized access attempt to delete error log ID: {log_id}")
|
| 254 |
+
raise HTTPException(status_code=401, detail="Not authenticated")
|
| 255 |
+
|
| 256 |
+
try:
|
| 257 |
+
success = await error_log_service.process_delete_error_log_by_id(log_id)
|
| 258 |
+
if not success:
|
| 259 |
+
# 服务层现在在未找到时返回 False,我们在这里转换为 404
|
| 260 |
+
raise HTTPException(
|
| 261 |
+
status_code=404, detail=f"Error log with ID {log_id} not found"
|
| 262 |
+
)
|
| 263 |
+
logger.info(f"Successfully deleted error log with ID: {log_id}")
|
| 264 |
+
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
| 265 |
+
except HTTPException as http_exc:
|
| 266 |
+
raise http_exc
|
| 267 |
+
except Exception as e:
|
| 268 |
+
logger.exception(f"Error deleting error log with ID {log_id}: {str(e)}")
|
| 269 |
+
raise HTTPException(
|
| 270 |
+
status_code=500, detail="Internal server error during deletion"
|
| 271 |
+
)
|
app/router/files_routes.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Files API 路由
|
| 3 |
+
"""
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from fastapi import APIRouter, Request, Query, Depends, Header, HTTPException
|
| 6 |
+
from fastapi.responses import JSONResponse
|
| 7 |
+
|
| 8 |
+
from app.config.config import settings
|
| 9 |
+
from app.domain.file_models import (
|
| 10 |
+
FileMetadata,
|
| 11 |
+
ListFilesResponse,
|
| 12 |
+
DeleteFileResponse
|
| 13 |
+
)
|
| 14 |
+
from app.log.logger import get_files_logger
|
| 15 |
+
from app.core.security import SecurityService
|
| 16 |
+
from app.service.files.files_service import get_files_service
|
| 17 |
+
from app.service.files.file_upload_handler import get_upload_handler
|
| 18 |
+
from app.utils.helpers import redact_key_for_logging
|
| 19 |
+
|
| 20 |
+
logger = get_files_logger()
|
| 21 |
+
|
| 22 |
+
router = APIRouter()
|
| 23 |
+
security_service = SecurityService()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@router.post("/upload/v1beta/files")
|
| 27 |
+
async def upload_file_init(
|
| 28 |
+
request: Request,
|
| 29 |
+
auth_token: str = Depends(security_service.verify_key_or_goog_api_key),
|
| 30 |
+
x_goog_upload_protocol: Optional[str] = Header(None),
|
| 31 |
+
x_goog_upload_command: Optional[str] = Header(None),
|
| 32 |
+
x_goog_upload_header_content_length: Optional[str] = Header(None),
|
| 33 |
+
x_goog_upload_header_content_type: Optional[str] = Header(None),
|
| 34 |
+
):
|
| 35 |
+
"""初始化文件上传"""
|
| 36 |
+
logger.debug(f"Upload file request: {request.method=}, {request.url=}, {auth_token=}, {x_goog_upload_protocol=}, {x_goog_upload_command=}, {x_goog_upload_header_content_length=}, {x_goog_upload_header_content_type=}")
|
| 37 |
+
|
| 38 |
+
# 檢查是否是實際的上傳請求(有 upload_id)
|
| 39 |
+
if request.query_params.get("upload_id") and x_goog_upload_command in ["upload", "upload, finalize"]:
|
| 40 |
+
logger.debug("This is an upload request, not initialization. Redirecting to handle_upload.")
|
| 41 |
+
return await handle_upload(
|
| 42 |
+
upload_path="v1beta/files",
|
| 43 |
+
request=request,
|
| 44 |
+
key=request.query_params.get("key"),
|
| 45 |
+
auth_token=auth_token
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
# 使用认证 token 作为 user_token
|
| 50 |
+
user_token = auth_token
|
| 51 |
+
# 获取请求体
|
| 52 |
+
body = await request.body()
|
| 53 |
+
|
| 54 |
+
# 构建请求主机 URL
|
| 55 |
+
request_host = f"{request.url.scheme}://{request.url.netloc}"
|
| 56 |
+
logger.info(f"Request host: {request_host}")
|
| 57 |
+
|
| 58 |
+
# 准备请求头
|
| 59 |
+
headers = {
|
| 60 |
+
"x-goog-upload-protocol": x_goog_upload_protocol or "resumable",
|
| 61 |
+
"x-goog-upload-command": x_goog_upload_command or "start",
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
if x_goog_upload_header_content_length:
|
| 65 |
+
headers["x-goog-upload-header-content-length"] = x_goog_upload_header_content_length
|
| 66 |
+
if x_goog_upload_header_content_type:
|
| 67 |
+
headers["x-goog-upload-header-content-type"] = x_goog_upload_header_content_type
|
| 68 |
+
|
| 69 |
+
# 调用服务
|
| 70 |
+
files_service = await get_files_service()
|
| 71 |
+
response_data, response_headers = await files_service.initialize_upload(
|
| 72 |
+
headers=headers,
|
| 73 |
+
body=body,
|
| 74 |
+
user_token=user_token,
|
| 75 |
+
request_host=request_host # 傳遞請求主機
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
logger.info(f"Upload initialization response: {response_data}")
|
| 79 |
+
logger.info(f"Upload initialization response headers: {response_headers}")
|
| 80 |
+
|
| 81 |
+
logger.info(f"Upload initialization response headers: {response_data}")
|
| 82 |
+
# 返回响应
|
| 83 |
+
return JSONResponse(
|
| 84 |
+
content=response_data,
|
| 85 |
+
headers=response_headers
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
except HTTPException as e:
|
| 89 |
+
logger.error(f"Upload initialization failed: {e.detail}")
|
| 90 |
+
return JSONResponse(
|
| 91 |
+
content={"error": {"message": e.detail}},
|
| 92 |
+
status_code=e.status_code
|
| 93 |
+
)
|
| 94 |
+
except Exception as e:
|
| 95 |
+
logger.error(f"Unexpected error in upload initialization: {str(e)}")
|
| 96 |
+
return JSONResponse(
|
| 97 |
+
content={"error": {"message": "Internal server error"}},
|
| 98 |
+
status_code=500
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@router.get("/v1beta/files")
|
| 103 |
+
async def list_files(
|
| 104 |
+
page_size: int = Query(10, ge=1, le=100, description="每页大小", alias="pageSize"),
|
| 105 |
+
page_token: Optional[str] = Query(None, description="分页标记", alias="pageToken"),
|
| 106 |
+
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
| 107 |
+
) -> ListFilesResponse:
|
| 108 |
+
"""列出文件"""
|
| 109 |
+
logger.debug(f"List files: {page_size=}, {page_token=}, {auth_token=}")
|
| 110 |
+
try:
|
| 111 |
+
# 使用认证 token 作为 user_token(如果启用用户隔离)
|
| 112 |
+
user_token = auth_token if settings.FILES_USER_ISOLATION_ENABLED else None
|
| 113 |
+
# 调用服务
|
| 114 |
+
files_service = await get_files_service()
|
| 115 |
+
return await files_service.list_files(
|
| 116 |
+
page_size=page_size,
|
| 117 |
+
page_token=page_token,
|
| 118 |
+
user_token=user_token
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
except HTTPException as e:
|
| 122 |
+
logger.error(f"List files failed: {e.detail}")
|
| 123 |
+
return JSONResponse(
|
| 124 |
+
content={"error": {"message": e.detail}},
|
| 125 |
+
status_code=e.status_code
|
| 126 |
+
)
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logger.error(f"Unexpected error in list files: {str(e)}")
|
| 129 |
+
return JSONResponse(
|
| 130 |
+
content={"error": {"message": "Internal server error"}},
|
| 131 |
+
status_code=500
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@router.get("/v1beta/files/{file_id:path}")
|
| 136 |
+
async def get_file(
|
| 137 |
+
file_id: str,
|
| 138 |
+
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
| 139 |
+
) -> FileMetadata:
|
| 140 |
+
"""获取文件信息"""
|
| 141 |
+
logger.debug(f"Get file request: {file_id=}, {auth_token=}")
|
| 142 |
+
try:
|
| 143 |
+
# 使用认证 token 作为 user_token
|
| 144 |
+
user_token = auth_token
|
| 145 |
+
# 调用服务
|
| 146 |
+
files_service = await get_files_service()
|
| 147 |
+
return await files_service.get_file(f"files/{file_id}", user_token)
|
| 148 |
+
|
| 149 |
+
except HTTPException as e:
|
| 150 |
+
logger.error(f"Get file failed: {e.detail}")
|
| 151 |
+
return JSONResponse(
|
| 152 |
+
content={"error": {"message": e.detail}},
|
| 153 |
+
status_code=e.status_code
|
| 154 |
+
)
|
| 155 |
+
except Exception as e:
|
| 156 |
+
logger.error(f"Unexpected error in get file: {str(e)}")
|
| 157 |
+
return JSONResponse(
|
| 158 |
+
content={"error": {"message": "Internal server error"}},
|
| 159 |
+
status_code=500
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@router.delete("/v1beta/files/{file_id:path}")
|
| 164 |
+
async def delete_file(
|
| 165 |
+
file_id: str,
|
| 166 |
+
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
| 167 |
+
) -> DeleteFileResponse:
|
| 168 |
+
"""删除文件"""
|
| 169 |
+
logger.info(f"Delete file: {file_id=}, {auth_token=}")
|
| 170 |
+
try:
|
| 171 |
+
# 使用认证 token 作为 user_token
|
| 172 |
+
user_token = auth_token
|
| 173 |
+
# 调用服务
|
| 174 |
+
files_service = await get_files_service()
|
| 175 |
+
success = await files_service.delete_file(f"files/{file_id}", user_token)
|
| 176 |
+
|
| 177 |
+
return DeleteFileResponse(
|
| 178 |
+
success=success,
|
| 179 |
+
message="File deleted successfully" if success else "Failed to delete file"
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
except HTTPException as e:
|
| 183 |
+
logger.error(f"Delete file failed: {e.detail}")
|
| 184 |
+
return JSONResponse(
|
| 185 |
+
content={"error": {"message": e.detail}},
|
| 186 |
+
status_code=e.status_code
|
| 187 |
+
)
|
| 188 |
+
except Exception as e:
|
| 189 |
+
logger.error(f"Unexpected error in delete file: {str(e)}")
|
| 190 |
+
return JSONResponse(
|
| 191 |
+
content={"error": {"message": "Internal server error"}},
|
| 192 |
+
status_code=500
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# 处理上传请求的通配符路由
|
| 197 |
+
@router.api_route("/upload/{upload_path:path}", methods=["GET", "POST", "PUT"])
|
| 198 |
+
async def handle_upload(
|
| 199 |
+
upload_path: str,
|
| 200 |
+
request: Request,
|
| 201 |
+
key: Optional[str] = Query(None), # 從查詢參數獲取 key
|
| 202 |
+
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
| 203 |
+
):
|
| 204 |
+
"""处理文件上传请求"""
|
| 205 |
+
try:
|
| 206 |
+
logger.info(f"Handling upload request: {request.method} {upload_path}, key={redact_key_for_logging(key)}")
|
| 207 |
+
|
| 208 |
+
# 從查詢參數獲取 upload_id
|
| 209 |
+
upload_id = request.query_params.get("upload_id")
|
| 210 |
+
if not upload_id:
|
| 211 |
+
raise HTTPException(status_code=400, detail="Missing upload_id")
|
| 212 |
+
|
| 213 |
+
# 從 session 獲取真實的 API key
|
| 214 |
+
files_service = await get_files_service()
|
| 215 |
+
session_info = await files_service.get_upload_session(upload_id)
|
| 216 |
+
if not session_info:
|
| 217 |
+
logger.error(f"No session found for upload_id: {upload_id}")
|
| 218 |
+
raise HTTPException(status_code=404, detail="Upload session not found")
|
| 219 |
+
|
| 220 |
+
real_api_key = session_info["api_key"]
|
| 221 |
+
original_upload_url = session_info["upload_url"]
|
| 222 |
+
|
| 223 |
+
# 使用真實的 API key 構建完整的 Google 上傳 URL
|
| 224 |
+
# 保留原始 URL 的所有參數,但使用真實的 API key
|
| 225 |
+
upload_url = original_upload_url
|
| 226 |
+
logger.info(f"Using real API key for upload: {redact_key_for_logging(real_api_key)}")
|
| 227 |
+
|
| 228 |
+
# 代理上传请求
|
| 229 |
+
upload_handler = get_upload_handler()
|
| 230 |
+
return await upload_handler.proxy_upload_request(
|
| 231 |
+
request=request,
|
| 232 |
+
upload_url=upload_url,
|
| 233 |
+
files_service=files_service
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
except HTTPException as e:
|
| 237 |
+
logger.error(f"Upload handling failed: {e.detail}")
|
| 238 |
+
return JSONResponse(
|
| 239 |
+
content={"error": {"message": e.detail}},
|
| 240 |
+
status_code=e.status_code
|
| 241 |
+
)
|
| 242 |
+
except Exception as e:
|
| 243 |
+
logger.error(f"Unexpected error in upload handling: {str(e)}")
|
| 244 |
+
return JSONResponse(
|
| 245 |
+
content={"error": {"message": "Internal server error"}},
|
| 246 |
+
status_code=500
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
# 为兼容性添加 /gemini 前缀的路由
|
| 251 |
+
@router.post("/gemini/upload/v1beta/files")
|
| 252 |
+
async def gemini_upload_file_init(
|
| 253 |
+
request: Request,
|
| 254 |
+
auth_token: str = Depends(security_service.verify_key_or_goog_api_key),
|
| 255 |
+
x_goog_upload_protocol: Optional[str] = Header(None),
|
| 256 |
+
x_goog_upload_command: Optional[str] = Header(None),
|
| 257 |
+
x_goog_upload_header_content_length: Optional[str] = Header(None),
|
| 258 |
+
x_goog_upload_header_content_type: Optional[str] = Header(None),
|
| 259 |
+
):
|
| 260 |
+
"""初始化文件上传(Gemini 前缀)"""
|
| 261 |
+
return await upload_file_init(
|
| 262 |
+
request,
|
| 263 |
+
auth_token,
|
| 264 |
+
x_goog_upload_protocol,
|
| 265 |
+
x_goog_upload_command,
|
| 266 |
+
x_goog_upload_header_content_length,
|
| 267 |
+
x_goog_upload_header_content_type
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
@router.get("/gemini/v1beta/files")
|
| 272 |
+
async def gemini_list_files(
|
| 273 |
+
page_size: int = Query(10, ge=1, le=100, alias="pageSize"),
|
| 274 |
+
page_token: Optional[str] = Query(None, alias="pageToken"),
|
| 275 |
+
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
| 276 |
+
) -> ListFilesResponse:
|
| 277 |
+
"""列出文件(Gemini 前缀)"""
|
| 278 |
+
return await list_files(page_size, page_token, auth_token)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
@router.get("/gemini/v1beta/files/{file_id:path}")
|
| 282 |
+
async def gemini_get_file(
|
| 283 |
+
file_id: str,
|
| 284 |
+
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
| 285 |
+
) -> FileMetadata:
|
| 286 |
+
"""获取文件信息(Gemini 前缀)"""
|
| 287 |
+
return await get_file(file_id, auth_token)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
@router.delete("/gemini/v1beta/files/{file_id:path}")
|
| 291 |
+
async def gemini_delete_file(
|
| 292 |
+
file_id: str,
|
| 293 |
+
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
| 294 |
+
) -> DeleteFileResponse:
|
| 295 |
+
"""删除文件(Gemini 前缀)"""
|
| 296 |
+
return await delete_file(file_id, auth_token)
|
app/router/gemini_routes.py
ADDED
|
@@ -0,0 +1,618 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
from copy import deepcopy
|
| 3 |
+
|
| 4 |
+
from fastapi import APIRouter, Depends, HTTPException
|
| 5 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
| 6 |
+
|
| 7 |
+
from app.config.config import settings
|
| 8 |
+
from app.core.constants import API_VERSION
|
| 9 |
+
from app.core.security import SecurityService
|
| 10 |
+
from app.domain.gemini_models import (
|
| 11 |
+
GeminiBatchEmbedRequest,
|
| 12 |
+
GeminiContent,
|
| 13 |
+
GeminiEmbedRequest,
|
| 14 |
+
GeminiRequest,
|
| 15 |
+
ResetSelectedKeysRequest,
|
| 16 |
+
VerifySelectedKeysRequest,
|
| 17 |
+
)
|
| 18 |
+
from app.handler.error_handler import handle_route_errors
|
| 19 |
+
from app.handler.retry_handler import RetryHandler
|
| 20 |
+
from app.log.logger import get_gemini_logger
|
| 21 |
+
from app.service.chat.gemini_chat_service import GeminiChatService
|
| 22 |
+
from app.service.embedding.gemini_embedding_service import GeminiEmbeddingService
|
| 23 |
+
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
| 24 |
+
from app.service.model.model_service import ModelService
|
| 25 |
+
from app.service.tts.native.tts_routes import get_tts_chat_service
|
| 26 |
+
from app.utils.helpers import redact_key_for_logging
|
| 27 |
+
|
| 28 |
+
router = APIRouter(prefix=f"/gemini/{API_VERSION}")
|
| 29 |
+
router_v1beta = APIRouter(prefix=f"/{API_VERSION}")
|
| 30 |
+
logger = get_gemini_logger()
|
| 31 |
+
|
| 32 |
+
security_service = SecurityService()
|
| 33 |
+
model_service = ModelService()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
async def get_key_manager():
|
| 37 |
+
"""获取密钥管理器实例"""
|
| 38 |
+
return await get_key_manager_instance()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
async def get_next_working_key(key_manager: KeyManager = Depends(get_key_manager)):
|
| 42 |
+
"""获取下一个可用的API密钥"""
|
| 43 |
+
return await key_manager.get_next_working_key()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
async def get_chat_service(key_manager: KeyManager = Depends(get_key_manager)):
|
| 47 |
+
"""获取Gemini聊天服务实例"""
|
| 48 |
+
return GeminiChatService(settings.BASE_URL, key_manager)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
async def get_embedding_service(key_manager: KeyManager = Depends(get_key_manager)):
|
| 52 |
+
"""获取Gemini嵌入服务实例"""
|
| 53 |
+
return GeminiEmbeddingService(settings.BASE_URL, key_manager)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@router.get("/models")
|
| 57 |
+
@router_v1beta.get("/models")
|
| 58 |
+
async def list_models(
|
| 59 |
+
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
| 60 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
| 61 |
+
):
|
| 62 |
+
"""获取可用的 Gemini 模型列表,并根据配置添加衍生模型(搜索、图像、非思考)。"""
|
| 63 |
+
operation_name = "list_gemini_models"
|
| 64 |
+
logger.info("-" * 50 + operation_name + "-" * 50)
|
| 65 |
+
logger.info("Handling Gemini models list request")
|
| 66 |
+
|
| 67 |
+
try:
|
| 68 |
+
api_key = await key_manager.get_random_valid_key()
|
| 69 |
+
if not api_key:
|
| 70 |
+
raise HTTPException(
|
| 71 |
+
status_code=503, detail="No valid API keys available to fetch models."
|
| 72 |
+
)
|
| 73 |
+
logger.info(f"Using allowed token: {allowed_token}")
|
| 74 |
+
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
| 75 |
+
|
| 76 |
+
models_data = await model_service.get_gemini_models(api_key)
|
| 77 |
+
if not models_data or "models" not in models_data:
|
| 78 |
+
raise HTTPException(
|
| 79 |
+
status_code=500, detail="Failed to fetch base models list."
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
models_json = deepcopy(models_data)
|
| 83 |
+
model_mapping = {
|
| 84 |
+
x.get("name", "").split("/", maxsplit=1)[-1]: x
|
| 85 |
+
for x in models_json.get("models", [])
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
def add_derived_model(base_name, suffix, display_suffix):
|
| 89 |
+
model = model_mapping.get(base_name)
|
| 90 |
+
if not model:
|
| 91 |
+
logger.warning(
|
| 92 |
+
f"Base model '{base_name}' not found for derived model '{suffix}'."
|
| 93 |
+
)
|
| 94 |
+
return
|
| 95 |
+
item = deepcopy(model)
|
| 96 |
+
item["name"] = f"models/{base_name}{suffix}"
|
| 97 |
+
display_name = f'{item.get("displayName", base_name)}{display_suffix}'
|
| 98 |
+
item["displayName"] = display_name
|
| 99 |
+
item["description"] = display_name
|
| 100 |
+
models_json["models"].append(item)
|
| 101 |
+
|
| 102 |
+
if settings.SEARCH_MODELS:
|
| 103 |
+
for name in settings.SEARCH_MODELS:
|
| 104 |
+
add_derived_model(name, "-search", " For Search")
|
| 105 |
+
if settings.IMAGE_MODELS:
|
| 106 |
+
for name in settings.IMAGE_MODELS:
|
| 107 |
+
add_derived_model(name, "-image", " For Image")
|
| 108 |
+
if settings.THINKING_MODELS:
|
| 109 |
+
for name in settings.THINKING_MODELS:
|
| 110 |
+
add_derived_model(name, "-non-thinking", " Non Thinking")
|
| 111 |
+
|
| 112 |
+
logger.info("Gemini models list request successful")
|
| 113 |
+
return models_json
|
| 114 |
+
except HTTPException as http_exc:
|
| 115 |
+
raise http_exc
|
| 116 |
+
except Exception as e:
|
| 117 |
+
logger.error(f"Error getting Gemini models list: {str(e)}")
|
| 118 |
+
raise HTTPException(
|
| 119 |
+
status_code=500,
|
| 120 |
+
detail="Internal server error while fetching Gemini models list",
|
| 121 |
+
) from e
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@router.post("/models/{model_name}:generateContent")
|
| 125 |
+
@router_v1beta.post("/models/{model_name}:generateContent")
|
| 126 |
+
@RetryHandler(key_arg="api_key")
|
| 127 |
+
async def generate_content(
|
| 128 |
+
model_name: str,
|
| 129 |
+
request: GeminiRequest,
|
| 130 |
+
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
| 131 |
+
api_key: str = Depends(get_next_working_key),
|
| 132 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
| 133 |
+
chat_service: GeminiChatService = Depends(get_chat_service),
|
| 134 |
+
):
|
| 135 |
+
"""处理 Gemini 非流式内容生成请求。"""
|
| 136 |
+
operation_name = "gemini_generate_content"
|
| 137 |
+
async with handle_route_errors(
|
| 138 |
+
logger, operation_name, failure_message="Content generation failed"
|
| 139 |
+
):
|
| 140 |
+
logger.info(
|
| 141 |
+
f"Handling Gemini content generation request for model: {model_name}"
|
| 142 |
+
)
|
| 143 |
+
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
| 144 |
+
|
| 145 |
+
# 检测是否为原生Gemini TTS请求
|
| 146 |
+
is_native_tts = False
|
| 147 |
+
if "tts" in model_name.lower() and request.generationConfig:
|
| 148 |
+
# 直接从解析后的request对象获取TTS配置
|
| 149 |
+
response_modalities = request.generationConfig.responseModalities or []
|
| 150 |
+
speech_config = request.generationConfig.speechConfig or {}
|
| 151 |
+
|
| 152 |
+
# 如果包含AUDIO模态和语音配置,则认为是原生TTS请求
|
| 153 |
+
if "AUDIO" in response_modalities and speech_config:
|
| 154 |
+
is_native_tts = True
|
| 155 |
+
logger.info("Detected native Gemini TTS request")
|
| 156 |
+
logger.info(f"TTS responseModalities: {response_modalities}")
|
| 157 |
+
logger.info(f"TTS speechConfig: {speech_config}")
|
| 158 |
+
|
| 159 |
+
logger.info(f"Using allowed token: {allowed_token}")
|
| 160 |
+
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
| 161 |
+
|
| 162 |
+
if not await model_service.check_model_support(model_name):
|
| 163 |
+
raise HTTPException(
|
| 164 |
+
status_code=400, detail=f"Model {model_name} is not supported"
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# 所有原生TTS请求都使用TTS增强服务
|
| 168 |
+
if is_native_tts:
|
| 169 |
+
try:
|
| 170 |
+
logger.info("Using native TTS enhanced service")
|
| 171 |
+
tts_service = await get_tts_chat_service(key_manager)
|
| 172 |
+
response = await tts_service.generate_content(
|
| 173 |
+
model=model_name, request=request, api_key=api_key
|
| 174 |
+
)
|
| 175 |
+
return response
|
| 176 |
+
except Exception as e:
|
| 177 |
+
logger.warning(
|
| 178 |
+
f"Native TTS processing failed, falling back to standard service: {e}"
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# 使用标准服务处理所有其他请求(非TTS)
|
| 182 |
+
response = await chat_service.generate_content(
|
| 183 |
+
model=model_name, request=request, api_key=api_key
|
| 184 |
+
)
|
| 185 |
+
return response
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
@router.post("/models/{model_name}:streamGenerateContent")
|
| 189 |
+
@router_v1beta.post("/models/{model_name}:streamGenerateContent")
|
| 190 |
+
@RetryHandler(key_arg="api_key")
|
| 191 |
+
async def stream_generate_content(
|
| 192 |
+
model_name: str,
|
| 193 |
+
request: GeminiRequest,
|
| 194 |
+
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
| 195 |
+
api_key: str = Depends(get_next_working_key),
|
| 196 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
| 197 |
+
chat_service: GeminiChatService = Depends(get_chat_service),
|
| 198 |
+
):
|
| 199 |
+
"""处理 Gemini 流式内容生成请求。"""
|
| 200 |
+
operation_name = "gemini_stream_generate_content"
|
| 201 |
+
async with handle_route_errors(
|
| 202 |
+
logger, operation_name, failure_message="Streaming request initiation failed"
|
| 203 |
+
):
|
| 204 |
+
logger.info(
|
| 205 |
+
f"Handling Gemini streaming content generation for model: {model_name}"
|
| 206 |
+
)
|
| 207 |
+
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
| 208 |
+
logger.info(f"Using allowed token: {allowed_token}")
|
| 209 |
+
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
| 210 |
+
|
| 211 |
+
if not await model_service.check_model_support(model_name):
|
| 212 |
+
raise HTTPException(
|
| 213 |
+
status_code=400, detail=f"Model {model_name} is not supported"
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
raw_stream = chat_service.stream_generate_content(
|
| 217 |
+
model=model_name, request=request, api_key=api_key
|
| 218 |
+
)
|
| 219 |
+
try:
|
| 220 |
+
# 尝试获取第一条数据,判断是正常 SSE(data: 前缀)还是错误 JSON
|
| 221 |
+
first_chunk = await raw_stream.__anext__()
|
| 222 |
+
except StopAsyncIteration:
|
| 223 |
+
# 如果流直接结束,退回标准 SSE 输出
|
| 224 |
+
return StreamingResponse(raw_stream, media_type="text/event-stream")
|
| 225 |
+
except Exception as e:
|
| 226 |
+
# 初始化流异常,直接返回 500 错误
|
| 227 |
+
return JSONResponse(
|
| 228 |
+
content={"error": {"code": e.args[0], "message": e.args[1]}},
|
| 229 |
+
status_code=e.args[0],
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# 如果以 "data:" 开头,代表正常 SSE,将首块和后续块一起发送
|
| 233 |
+
if isinstance(first_chunk, str) and first_chunk.startswith("data:"):
|
| 234 |
+
|
| 235 |
+
async def combined():
|
| 236 |
+
yield first_chunk
|
| 237 |
+
async for chunk in raw_stream:
|
| 238 |
+
yield chunk
|
| 239 |
+
|
| 240 |
+
return StreamingResponse(combined(), media_type="text/event-stream")
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
@router.post("/models/{model_name}:countTokens")
|
| 244 |
+
@router_v1beta.post("/models/{model_name}:countTokens")
|
| 245 |
+
@RetryHandler(key_arg="api_key")
|
| 246 |
+
async def count_tokens(
|
| 247 |
+
model_name: str,
|
| 248 |
+
request: GeminiRequest,
|
| 249 |
+
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
| 250 |
+
api_key: str = Depends(get_next_working_key),
|
| 251 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
| 252 |
+
chat_service: GeminiChatService = Depends(get_chat_service),
|
| 253 |
+
):
|
| 254 |
+
"""处理 Gemini token 计数请求。"""
|
| 255 |
+
operation_name = "gemini_count_tokens"
|
| 256 |
+
async with handle_route_errors(
|
| 257 |
+
logger, operation_name, failure_message="Token counting failed"
|
| 258 |
+
):
|
| 259 |
+
logger.info(f"Handling Gemini token count request for model: {model_name}")
|
| 260 |
+
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
| 261 |
+
logger.info(f"Using allowed token: {allowed_token}")
|
| 262 |
+
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
| 263 |
+
|
| 264 |
+
if not await model_service.check_model_support(model_name):
|
| 265 |
+
raise HTTPException(
|
| 266 |
+
status_code=400, detail=f"Model {model_name} is not supported"
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
response = await chat_service.count_tokens(
|
| 270 |
+
model=model_name, request=request, api_key=api_key
|
| 271 |
+
)
|
| 272 |
+
return response
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
@router.post("/models/{model_name}:embedContent")
|
| 276 |
+
@router_v1beta.post("/models/{model_name}:embedContent")
|
| 277 |
+
@RetryHandler(key_arg="api_key")
|
| 278 |
+
async def embed_content(
|
| 279 |
+
model_name: str,
|
| 280 |
+
request: GeminiEmbedRequest,
|
| 281 |
+
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
| 282 |
+
api_key: str = Depends(get_next_working_key),
|
| 283 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
| 284 |
+
embedding_service: GeminiEmbeddingService = Depends(get_embedding_service),
|
| 285 |
+
):
|
| 286 |
+
"""处理 Gemini 单一嵌入请求"""
|
| 287 |
+
operation_name = "gemini_embed_content"
|
| 288 |
+
async with handle_route_errors(
|
| 289 |
+
logger, operation_name, failure_message="Embedding content generation failed"
|
| 290 |
+
):
|
| 291 |
+
logger.info(f"Handling Gemini embedding request for model: {model_name}")
|
| 292 |
+
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
| 293 |
+
logger.info(f"Using allowed token: {allowed_token}")
|
| 294 |
+
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
| 295 |
+
|
| 296 |
+
if not await model_service.check_model_support(model_name):
|
| 297 |
+
raise HTTPException(
|
| 298 |
+
status_code=400, detail=f"Model {model_name} is not supported"
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
response = await embedding_service.embed_content(
|
| 302 |
+
model=model_name, request=request, api_key=api_key
|
| 303 |
+
)
|
| 304 |
+
return response
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
@router.post("/models/{model_name}:batchEmbedContents")
|
| 308 |
+
@router_v1beta.post("/models/{model_name}:batchEmbedContents")
|
| 309 |
+
@RetryHandler(key_arg="api_key")
|
| 310 |
+
async def batch_embed_contents(
|
| 311 |
+
model_name: str,
|
| 312 |
+
request: GeminiBatchEmbedRequest,
|
| 313 |
+
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
| 314 |
+
api_key: str = Depends(get_next_working_key),
|
| 315 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
| 316 |
+
embedding_service: GeminiEmbeddingService = Depends(get_embedding_service),
|
| 317 |
+
):
|
| 318 |
+
"""处理 Gemini 批量嵌入请求"""
|
| 319 |
+
operation_name = "gemini_batch_embed_contents"
|
| 320 |
+
async with handle_route_errors(
|
| 321 |
+
logger,
|
| 322 |
+
operation_name,
|
| 323 |
+
failure_message="Batch embedding content generation failed",
|
| 324 |
+
):
|
| 325 |
+
logger.info(f"Handling Gemini batch embedding request for model: {model_name}")
|
| 326 |
+
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
| 327 |
+
logger.info(f"Using allowed token: {allowed_token}")
|
| 328 |
+
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
| 329 |
+
|
| 330 |
+
if not await model_service.check_model_support(model_name):
|
| 331 |
+
raise HTTPException(
|
| 332 |
+
status_code=400, detail=f"Model {model_name} is not supported"
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
response = await embedding_service.batch_embed_contents(
|
| 336 |
+
model=model_name, request=request, api_key=api_key
|
| 337 |
+
)
|
| 338 |
+
return response
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
@router.post("/reset-all-fail-counts")
|
| 342 |
+
async def reset_all_key_fail_counts(
|
| 343 |
+
key_type: str = None, key_manager: KeyManager = Depends(get_key_manager)
|
| 344 |
+
):
|
| 345 |
+
"""批量重置Gemini API密钥的失败计数,可选择性地仅重置有效或无效密钥"""
|
| 346 |
+
logger.info("-" * 50 + "reset_all_gemini_key_fail_counts" + "-" * 50)
|
| 347 |
+
logger.info(f"Received reset request with key_type: {key_type}")
|
| 348 |
+
|
| 349 |
+
try:
|
| 350 |
+
# 获取分类后的密钥
|
| 351 |
+
keys_by_status = await key_manager.get_keys_by_status()
|
| 352 |
+
valid_keys = keys_by_status.get("valid_keys", {})
|
| 353 |
+
invalid_keys = keys_by_status.get("invalid_keys", {})
|
| 354 |
+
|
| 355 |
+
# 根据类型选择要重置的密钥
|
| 356 |
+
keys_to_reset = []
|
| 357 |
+
if key_type == "valid":
|
| 358 |
+
keys_to_reset = list(valid_keys.keys())
|
| 359 |
+
logger.info(f"Resetting only valid keys, count: {len(keys_to_reset)}")
|
| 360 |
+
elif key_type == "invalid":
|
| 361 |
+
keys_to_reset = list(invalid_keys.keys())
|
| 362 |
+
logger.info(f"Resetting only invalid keys, count: {len(keys_to_reset)}")
|
| 363 |
+
else:
|
| 364 |
+
# 重置所有密钥
|
| 365 |
+
await key_manager.reset_failure_counts()
|
| 366 |
+
return JSONResponse(
|
| 367 |
+
{"success": True, "message": "所有密钥的失败计数已重置"}
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
# 批量重置指定类型的密钥
|
| 371 |
+
for key in keys_to_reset:
|
| 372 |
+
await key_manager.reset_key_failure_count(key)
|
| 373 |
+
|
| 374 |
+
return JSONResponse(
|
| 375 |
+
{
|
| 376 |
+
"success": True,
|
| 377 |
+
"message": f"{key_type}密钥的失败计数已重置",
|
| 378 |
+
"reset_count": len(keys_to_reset),
|
| 379 |
+
}
|
| 380 |
+
)
|
| 381 |
+
except Exception as e:
|
| 382 |
+
logger.error(f"Failed to reset key failure counts: {str(e)}")
|
| 383 |
+
return JSONResponse(
|
| 384 |
+
{"success": False, "message": f"批量重置失败: {str(e)}"}, status_code=500
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
@router.post("/reset-selected-fail-counts")
|
| 389 |
+
async def reset_selected_key_fail_counts(
|
| 390 |
+
request: ResetSelectedKeysRequest,
|
| 391 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
| 392 |
+
):
|
| 393 |
+
"""批量重置选定Gemini API密钥的失败计数"""
|
| 394 |
+
logger.info("-" * 50 + "reset_selected_gemini_key_fail_counts" + "-" * 50)
|
| 395 |
+
keys_to_reset = request.keys
|
| 396 |
+
key_type = request.key_type
|
| 397 |
+
logger.info(
|
| 398 |
+
f"Received reset request for {len(keys_to_reset)} selected {key_type} keys."
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
if not keys_to_reset:
|
| 402 |
+
return JSONResponse(
|
| 403 |
+
{"success": False, "message": "没有提供需要重置的密钥"}, status_code=400
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
reset_count = 0
|
| 407 |
+
errors = []
|
| 408 |
+
|
| 409 |
+
try:
|
| 410 |
+
for key in keys_to_reset:
|
| 411 |
+
try:
|
| 412 |
+
result = await key_manager.reset_key_failure_count(key)
|
| 413 |
+
if result:
|
| 414 |
+
reset_count += 1
|
| 415 |
+
else:
|
| 416 |
+
logger.warning(
|
| 417 |
+
f"Key not found during selective reset: {redact_key_for_logging(key)}"
|
| 418 |
+
)
|
| 419 |
+
except Exception as key_error:
|
| 420 |
+
logger.error(
|
| 421 |
+
f"Error resetting key {redact_key_for_logging(key)}: {str(key_error)}"
|
| 422 |
+
)
|
| 423 |
+
errors.append(f"Key {key}: {str(key_error)}")
|
| 424 |
+
|
| 425 |
+
if errors:
|
| 426 |
+
error_message = f"批量重置完成,但出现错误: {'; '.join(errors)}"
|
| 427 |
+
final_success = reset_count > 0
|
| 428 |
+
status_code = 207 if final_success and errors else 500
|
| 429 |
+
return JSONResponse(
|
| 430 |
+
{
|
| 431 |
+
"success": final_success,
|
| 432 |
+
"message": error_message,
|
| 433 |
+
"reset_count": reset_count,
|
| 434 |
+
},
|
| 435 |
+
status_code=status_code,
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
return JSONResponse(
|
| 439 |
+
{
|
| 440 |
+
"success": True,
|
| 441 |
+
"message": f"成功重置 {reset_count} 个选定 {key_type} 密钥的失败计数",
|
| 442 |
+
"reset_count": reset_count,
|
| 443 |
+
}
|
| 444 |
+
)
|
| 445 |
+
except Exception as e:
|
| 446 |
+
logger.error(
|
| 447 |
+
f"Failed to process reset selected key failure counts request: {str(e)}"
|
| 448 |
+
)
|
| 449 |
+
return JSONResponse(
|
| 450 |
+
{"success": False, "message": f"批量重置处理失败: {str(e)}"},
|
| 451 |
+
status_code=500,
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
@router.post("/reset-fail-count/{api_key}")
|
| 456 |
+
async def reset_key_fail_count(
|
| 457 |
+
api_key: str, key_manager: KeyManager = Depends(get_key_manager)
|
| 458 |
+
):
|
| 459 |
+
"""重置指定Gemini API密钥的失败计数"""
|
| 460 |
+
logger.info("-" * 50 + "reset_gemini_key_fail_count" + "-" * 50)
|
| 461 |
+
logger.info(
|
| 462 |
+
f"Resetting failure count for API key: {redact_key_for_logging(api_key)}"
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
try:
|
| 466 |
+
result = await key_manager.reset_key_failure_count(api_key)
|
| 467 |
+
if result:
|
| 468 |
+
return JSONResponse({"success": True, "message": "失败计数已重置"})
|
| 469 |
+
return JSONResponse(
|
| 470 |
+
{"success": False, "message": "未找到指定密钥"}, status_code=404
|
| 471 |
+
)
|
| 472 |
+
except Exception as e:
|
| 473 |
+
logger.error(f"Failed to reset key failure count: {str(e)}")
|
| 474 |
+
return JSONResponse(
|
| 475 |
+
{"success": False, "message": f"重置失败: {str(e)}"}, status_code=500
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
@router.post("/verify-key/{api_key}")
|
| 480 |
+
async def verify_key(
|
| 481 |
+
api_key: str,
|
| 482 |
+
chat_service: GeminiChatService = Depends(get_chat_service),
|
| 483 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
| 484 |
+
):
|
| 485 |
+
"""验证Gemini API密钥的有效性"""
|
| 486 |
+
logger.info("-" * 50 + "verify_gemini_key" + "-" * 50)
|
| 487 |
+
logger.info("Verifying API key validity")
|
| 488 |
+
|
| 489 |
+
try:
|
| 490 |
+
gemini_request = GeminiRequest(
|
| 491 |
+
contents=[
|
| 492 |
+
GeminiContent(
|
| 493 |
+
role="user",
|
| 494 |
+
parts=[{"text": "hi"}],
|
| 495 |
+
)
|
| 496 |
+
],
|
| 497 |
+
generation_config={"temperature": 0.7, "topP": 1.0, "maxOutputTokens": 10},
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
response = await chat_service.generate_content(
|
| 501 |
+
settings.TEST_MODEL, gemini_request, api_key
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
if response:
|
| 505 |
+
# 如果密钥验证成功,则重置其失败计数
|
| 506 |
+
await key_manager.reset_key_failure_count(api_key)
|
| 507 |
+
return JSONResponse({"status": "valid"})
|
| 508 |
+
except Exception as e:
|
| 509 |
+
logger.error(f"Key verification failed: {str(e)}")
|
| 510 |
+
|
| 511 |
+
async with key_manager.failure_count_lock:
|
| 512 |
+
if api_key in key_manager.key_failure_counts:
|
| 513 |
+
key_manager.key_failure_counts[api_key] += 1
|
| 514 |
+
logger.warning(
|
| 515 |
+
f"Verification exception for key: {redact_key_for_logging(api_key)}, incrementing failure count"
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
return JSONResponse({"status": "invalid", "error": e.args[1]})
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
@router.post("/verify-selected-keys")
|
| 522 |
+
async def verify_selected_keys(
|
| 523 |
+
request: VerifySelectedKeysRequest,
|
| 524 |
+
chat_service: GeminiChatService = Depends(get_chat_service),
|
| 525 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
| 526 |
+
):
|
| 527 |
+
"""批量验证选定Gemini API密钥的有效性"""
|
| 528 |
+
logger.info("-" * 50 + "verify_selected_gemini_keys" + "-" * 50)
|
| 529 |
+
keys_to_verify = request.keys
|
| 530 |
+
logger.info(
|
| 531 |
+
f"Received verification request for {len(keys_to_verify)} selected keys."
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
if not keys_to_verify:
|
| 535 |
+
return JSONResponse(
|
| 536 |
+
{"success": False, "message": "没有提供需要验证的密钥"}, status_code=400
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
successful_keys = []
|
| 540 |
+
failed_keys = {}
|
| 541 |
+
|
| 542 |
+
async def _verify_single_key(api_key: str):
|
| 543 |
+
"""内部函数,用于验证单个密钥并处理异常"""
|
| 544 |
+
nonlocal successful_keys, failed_keys
|
| 545 |
+
try:
|
| 546 |
+
gemini_request = GeminiRequest(
|
| 547 |
+
contents=[GeminiContent(role="user", parts=[{"text": "hi"}])],
|
| 548 |
+
generation_config={
|
| 549 |
+
"temperature": 0.7,
|
| 550 |
+
"topP": 1.0,
|
| 551 |
+
"maxOutputTokens": 10,
|
| 552 |
+
},
|
| 553 |
+
)
|
| 554 |
+
await chat_service.generate_content(
|
| 555 |
+
settings.TEST_MODEL, gemini_request, api_key
|
| 556 |
+
)
|
| 557 |
+
successful_keys.append(api_key)
|
| 558 |
+
# 如果密钥验证成功,则重置其失败计数
|
| 559 |
+
await key_manager.reset_key_failure_count(api_key)
|
| 560 |
+
return api_key, "valid", None
|
| 561 |
+
except Exception as e:
|
| 562 |
+
error_message = e.args[1]
|
| 563 |
+
logger.warning(
|
| 564 |
+
f"Key verification failed for {redact_key_for_logging(api_key)}: {error_message}"
|
| 565 |
+
)
|
| 566 |
+
async with key_manager.failure_count_lock:
|
| 567 |
+
if api_key in key_manager.key_failure_counts:
|
| 568 |
+
key_manager.key_failure_counts[api_key] += 1
|
| 569 |
+
logger.warning(
|
| 570 |
+
f"Bulk verification exception for key: {redact_key_for_logging(api_key)}, incrementing failure count"
|
| 571 |
+
)
|
| 572 |
+
else:
|
| 573 |
+
key_manager.key_failure_counts[api_key] = 1
|
| 574 |
+
logger.warning(
|
| 575 |
+
f"Bulk verification exception for key: {redact_key_for_logging(api_key)}, initializing failure count to 1"
|
| 576 |
+
)
|
| 577 |
+
failed_keys[api_key] = {"error_message": e.args[1], "error_code": e.args[0]}
|
| 578 |
+
return api_key, "invalid", error_message
|
| 579 |
+
|
| 580 |
+
tasks = [_verify_single_key(key) for key in keys_to_verify]
|
| 581 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 582 |
+
|
| 583 |
+
for result in results:
|
| 584 |
+
if isinstance(result, Exception):
|
| 585 |
+
logger.error(
|
| 586 |
+
f"An unexpected error occurred during bulk verification task: {result}"
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
valid_count = len(successful_keys)
|
| 590 |
+
invalid_count = len(failed_keys)
|
| 591 |
+
logger.info(
|
| 592 |
+
f"Bulk verification finished. Valid: {valid_count}, Invalid: {invalid_count}"
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
if failed_keys:
|
| 596 |
+
message = f"批量验证完成。成功: {valid_count}, 失败: {invalid_count}。"
|
| 597 |
+
return JSONResponse(
|
| 598 |
+
{
|
| 599 |
+
"success": True,
|
| 600 |
+
"message": message,
|
| 601 |
+
"successful_keys": successful_keys,
|
| 602 |
+
"failed_keys": failed_keys,
|
| 603 |
+
"valid_count": valid_count,
|
| 604 |
+
"invalid_count": invalid_count,
|
| 605 |
+
}
|
| 606 |
+
)
|
| 607 |
+
else:
|
| 608 |
+
message = f"批量验证成功完成。所有 {valid_count} 个密钥均有效。"
|
| 609 |
+
return JSONResponse(
|
| 610 |
+
{
|
| 611 |
+
"success": True,
|
| 612 |
+
"message": message,
|
| 613 |
+
"successful_keys": successful_keys,
|
| 614 |
+
"failed_keys": {},
|
| 615 |
+
"valid_count": valid_count,
|
| 616 |
+
"invalid_count": 0,
|
| 617 |
+
}
|
| 618 |
+
)
|
app/router/key_routes.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, Request
|
| 2 |
+
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
| 3 |
+
from app.core.security import verify_auth_token
|
| 4 |
+
from fastapi.responses import JSONResponse
|
| 5 |
+
|
| 6 |
+
router = APIRouter()
|
| 7 |
+
|
| 8 |
+
@router.get("/api/keys")
|
| 9 |
+
async def get_keys_paginated(
|
| 10 |
+
request: Request,
|
| 11 |
+
page: int = 1,
|
| 12 |
+
limit: int = 10,
|
| 13 |
+
search: str = None,
|
| 14 |
+
fail_count_threshold: int = None,
|
| 15 |
+
status: str = "all", # 'valid', 'invalid', 'all'
|
| 16 |
+
key_manager: KeyManager = Depends(get_key_manager_instance),
|
| 17 |
+
):
|
| 18 |
+
"""
|
| 19 |
+
Get paginated, filtered, and searched keys.
|
| 20 |
+
"""
|
| 21 |
+
auth_token = request.cookies.get("auth_token")
|
| 22 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 23 |
+
return JSONResponse(status_code=401, content={"detail": "Unauthorized"})
|
| 24 |
+
|
| 25 |
+
all_keys_with_status = await key_manager.get_all_keys_with_fail_count()
|
| 26 |
+
|
| 27 |
+
# Filter by status
|
| 28 |
+
if status == "valid":
|
| 29 |
+
keys_to_filter = all_keys_with_status["valid_keys"]
|
| 30 |
+
elif status == "invalid":
|
| 31 |
+
keys_to_filter = all_keys_with_status["invalid_keys"]
|
| 32 |
+
else:
|
| 33 |
+
# Combine both for 'all' status, which might be useful for a unified view if ever needed
|
| 34 |
+
keys_to_filter = {**all_keys_with_status["valid_keys"], **all_keys_with_status["invalid_keys"]}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Further filtering (search and fail_count_threshold)
|
| 38 |
+
filtered_keys = {}
|
| 39 |
+
for key, fail_count in keys_to_filter.items():
|
| 40 |
+
search_match = True
|
| 41 |
+
if search:
|
| 42 |
+
search_match = search.lower() in key.lower()
|
| 43 |
+
|
| 44 |
+
fail_count_match = True
|
| 45 |
+
if fail_count_threshold is not None:
|
| 46 |
+
fail_count_match = fail_count >= fail_count_threshold
|
| 47 |
+
|
| 48 |
+
if search_match and fail_count_match:
|
| 49 |
+
filtered_keys[key] = fail_count
|
| 50 |
+
|
| 51 |
+
# Pagination
|
| 52 |
+
keys_list = list(filtered_keys.items())
|
| 53 |
+
total_items = len(keys_list)
|
| 54 |
+
start_index = (page - 1) * limit
|
| 55 |
+
end_index = start_index + limit
|
| 56 |
+
paginated_keys = dict(keys_list[start_index:end_index])
|
| 57 |
+
|
| 58 |
+
return {
|
| 59 |
+
"keys": paginated_keys,
|
| 60 |
+
"total_items": total_items,
|
| 61 |
+
"total_pages": (total_items + limit - 1) // limit,
|
| 62 |
+
"current_page": page,
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
@router.get("/api/keys/all")
|
| 66 |
+
async def get_all_keys(
|
| 67 |
+
request: Request,
|
| 68 |
+
key_manager: KeyManager = Depends(get_key_manager_instance),
|
| 69 |
+
):
|
| 70 |
+
"""
|
| 71 |
+
Get all keys (both valid and invalid) for bulk operations.
|
| 72 |
+
"""
|
| 73 |
+
auth_token = request.cookies.get("auth_token")
|
| 74 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 75 |
+
return JSONResponse(status_code=401, content={"detail": "Unauthorized"})
|
| 76 |
+
|
| 77 |
+
all_keys_with_status = await key_manager.get_all_keys_with_fail_count()
|
| 78 |
+
|
| 79 |
+
return {
|
| 80 |
+
"valid_keys": list(all_keys_with_status["valid_keys"].keys()),
|
| 81 |
+
"invalid_keys": list(all_keys_with_status["invalid_keys"].keys()),
|
| 82 |
+
"total_count": len(all_keys_with_status["valid_keys"]) + len(all_keys_with_status["invalid_keys"])
|
| 83 |
+
}
|
app/router/openai_compatiable_routes.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends
|
| 2 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
| 3 |
+
|
| 4 |
+
from app.config.config import settings
|
| 5 |
+
from app.core.security import SecurityService
|
| 6 |
+
from app.domain.openai_models import (
|
| 7 |
+
ChatRequest,
|
| 8 |
+
EmbeddingRequest,
|
| 9 |
+
ImageGenerationRequest,
|
| 10 |
+
)
|
| 11 |
+
from app.handler.error_handler import handle_route_errors
|
| 12 |
+
from app.handler.retry_handler import RetryHandler
|
| 13 |
+
from app.log.logger import get_openai_compatible_logger
|
| 14 |
+
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
| 15 |
+
from app.service.openai_compatiable.openai_compatiable_service import (
|
| 16 |
+
OpenAICompatiableService,
|
| 17 |
+
)
|
| 18 |
+
from app.utils.helpers import redact_key_for_logging
|
| 19 |
+
|
| 20 |
+
router = APIRouter()
|
| 21 |
+
logger = get_openai_compatible_logger()
|
| 22 |
+
|
| 23 |
+
security_service = SecurityService()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
async def get_key_manager():
|
| 27 |
+
return await get_key_manager_instance()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
async def get_next_working_key_wrapper(
|
| 31 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
| 32 |
+
):
|
| 33 |
+
return await key_manager.get_next_working_key()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
async def get_openai_service(key_manager: KeyManager = Depends(get_key_manager)):
|
| 37 |
+
"""获取OpenAI聊天服务实例"""
|
| 38 |
+
return OpenAICompatiableService(settings.BASE_URL, key_manager)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@router.get("/openai/v1/models")
|
| 42 |
+
async def list_models(
|
| 43 |
+
allowed_token=Depends(security_service.verify_authorization),
|
| 44 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
| 45 |
+
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
| 46 |
+
):
|
| 47 |
+
"""获取可用模型列表。"""
|
| 48 |
+
operation_name = "list_models"
|
| 49 |
+
async with handle_route_errors(logger, operation_name):
|
| 50 |
+
logger.info("Handling models list request")
|
| 51 |
+
api_key = await key_manager.get_random_valid_key()
|
| 52 |
+
logger.info(f"Using allowed token: {allowed_token}")
|
| 53 |
+
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
| 54 |
+
return await openai_service.get_models(api_key)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@router.post("/openai/v1/chat/completions")
|
| 58 |
+
@RetryHandler(key_arg="api_key")
|
| 59 |
+
async def chat_completion(
|
| 60 |
+
request: ChatRequest,
|
| 61 |
+
allowed_token=Depends(security_service.verify_authorization),
|
| 62 |
+
api_key: str = Depends(get_next_working_key_wrapper),
|
| 63 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
| 64 |
+
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
| 65 |
+
):
|
| 66 |
+
"""处理聊天补全请求,支持流式响应和特定模型切换。"""
|
| 67 |
+
operation_name = "chat_completion"
|
| 68 |
+
is_image_chat = request.model == f"{settings.CREATE_IMAGE_MODEL}-chat"
|
| 69 |
+
current_api_key = api_key
|
| 70 |
+
if is_image_chat:
|
| 71 |
+
current_api_key = await key_manager.get_paid_key()
|
| 72 |
+
|
| 73 |
+
async with handle_route_errors(logger, operation_name):
|
| 74 |
+
logger.info(f"Handling chat completion request for model: {request.model}")
|
| 75 |
+
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
| 76 |
+
logger.info(f"Using allowed token: {allowed_token}")
|
| 77 |
+
logger.info(f"Using API key: {redact_key_for_logging(current_api_key)}")
|
| 78 |
+
|
| 79 |
+
raw_response = None
|
| 80 |
+
if is_image_chat:
|
| 81 |
+
raw_response = await openai_service.create_image_chat_completion(
|
| 82 |
+
request, current_api_key
|
| 83 |
+
)
|
| 84 |
+
else:
|
| 85 |
+
raw_response = await openai_service.create_chat_completion(
|
| 86 |
+
request, current_api_key
|
| 87 |
+
)
|
| 88 |
+
if request.stream:
|
| 89 |
+
try:
|
| 90 |
+
# 尝试获取第一条数据,判断是正常 SSE(data: 前缀)还是错误 JSON
|
| 91 |
+
first_chunk = await raw_response.__anext__()
|
| 92 |
+
except StopAsyncIteration:
|
| 93 |
+
# 如果流直接结束,退回标准 SSE 输出
|
| 94 |
+
return StreamingResponse(raw_response, media_type="text/event-stream")
|
| 95 |
+
except Exception as e:
|
| 96 |
+
# 初始化流异常,直接返回 500 错误
|
| 97 |
+
return JSONResponse(
|
| 98 |
+
content={"error": {"code": e.args[0], "message": e.args[1]}},
|
| 99 |
+
status_code=e.args[0],
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# 如果以 "data:" 开头,代表正常 SSE,将首块和后续块一起发送
|
| 103 |
+
if isinstance(first_chunk, str) and first_chunk.startswith("data:"):
|
| 104 |
+
|
| 105 |
+
async def combined():
|
| 106 |
+
yield first_chunk
|
| 107 |
+
async for chunk in raw_response:
|
| 108 |
+
yield chunk
|
| 109 |
+
|
| 110 |
+
return StreamingResponse(combined(), media_type="text/event-stream")
|
| 111 |
+
else:
|
| 112 |
+
return raw_response
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@router.post("/openai/v1/images/generations")
|
| 116 |
+
async def generate_image(
|
| 117 |
+
request: ImageGenerationRequest,
|
| 118 |
+
allowed_token=Depends(security_service.verify_authorization),
|
| 119 |
+
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
| 120 |
+
):
|
| 121 |
+
"""处理图像生成请求。"""
|
| 122 |
+
operation_name = "generate_image"
|
| 123 |
+
async with handle_route_errors(logger, operation_name):
|
| 124 |
+
logger.info(f"Handling image generation request for prompt: {request.prompt}")
|
| 125 |
+
logger.info(f"Using allowed token: {allowed_token}")
|
| 126 |
+
request.model = settings.CREATE_IMAGE_MODEL
|
| 127 |
+
return await openai_service.generate_images(request)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
@router.post("/openai/v1/embeddings")
|
| 131 |
+
async def embedding(
|
| 132 |
+
request: EmbeddingRequest,
|
| 133 |
+
allowed_token=Depends(security_service.verify_authorization),
|
| 134 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
| 135 |
+
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
| 136 |
+
):
|
| 137 |
+
"""处理文本嵌入请求。"""
|
| 138 |
+
operation_name = "embedding"
|
| 139 |
+
async with handle_route_errors(logger, operation_name):
|
| 140 |
+
logger.info(f"Handling embedding request for model: {request.model}")
|
| 141 |
+
api_key = await key_manager.get_next_working_key()
|
| 142 |
+
logger.info(f"Using allowed token: {allowed_token}")
|
| 143 |
+
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
| 144 |
+
return await openai_service.create_embeddings(
|
| 145 |
+
input_text=request.input, model=request.model, api_key=api_key
|
| 146 |
+
)
|
app/router/openai_routes.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, HTTPException, Response
|
| 2 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
| 3 |
+
|
| 4 |
+
from app.config.config import settings
|
| 5 |
+
from app.core.security import SecurityService
|
| 6 |
+
from app.domain.openai_models import (
|
| 7 |
+
ChatRequest,
|
| 8 |
+
EmbeddingRequest,
|
| 9 |
+
ImageGenerationRequest,
|
| 10 |
+
TTSRequest,
|
| 11 |
+
)
|
| 12 |
+
from app.handler.error_handler import handle_route_errors
|
| 13 |
+
from app.handler.retry_handler import RetryHandler
|
| 14 |
+
from app.log.logger import get_openai_logger
|
| 15 |
+
from app.service.chat.openai_chat_service import OpenAIChatService
|
| 16 |
+
from app.service.embedding.embedding_service import EmbeddingService
|
| 17 |
+
from app.service.image.image_create_service import ImageCreateService
|
| 18 |
+
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
| 19 |
+
from app.service.model.model_service import ModelService
|
| 20 |
+
from app.service.tts.tts_service import TTSService
|
| 21 |
+
from app.utils.helpers import redact_key_for_logging
|
| 22 |
+
|
| 23 |
+
router = APIRouter()
|
| 24 |
+
logger = get_openai_logger()
|
| 25 |
+
|
| 26 |
+
security_service = SecurityService()
|
| 27 |
+
model_service = ModelService()
|
| 28 |
+
embedding_service = EmbeddingService()
|
| 29 |
+
image_create_service = ImageCreateService()
|
| 30 |
+
tts_service = TTSService()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
async def get_key_manager():
|
| 34 |
+
return await get_key_manager_instance()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
async def get_next_working_key_wrapper(
|
| 38 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
| 39 |
+
):
|
| 40 |
+
return await key_manager.get_next_working_key()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
async def get_openai_chat_service(key_manager: KeyManager = Depends(get_key_manager)):
|
| 44 |
+
"""获取OpenAI聊天服务实例"""
|
| 45 |
+
return OpenAIChatService(settings.BASE_URL, key_manager)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
async def get_tts_service():
|
| 49 |
+
"""获取TTS服务实例"""
|
| 50 |
+
return tts_service
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@router.get("/v1/models")
|
| 54 |
+
@router.get("/hf/v1/models")
|
| 55 |
+
async def list_models(
|
| 56 |
+
allowed_token=Depends(security_service.verify_authorization),
|
| 57 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
| 58 |
+
):
|
| 59 |
+
"""获取可用的 OpenAI 模型列表 (兼容 Gemini 和 OpenAI)。"""
|
| 60 |
+
operation_name = "list_models"
|
| 61 |
+
async with handle_route_errors(logger, operation_name):
|
| 62 |
+
logger.info("Handling models list request")
|
| 63 |
+
api_key = await key_manager.get_random_valid_key()
|
| 64 |
+
logger.info(f"Using allowed token: {allowed_token}")
|
| 65 |
+
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
| 66 |
+
return await model_service.get_gemini_openai_models(api_key)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@router.post("/v1/chat/completions")
|
| 70 |
+
@router.post("/hf/v1/chat/completions")
|
| 71 |
+
@RetryHandler(key_arg="api_key")
|
| 72 |
+
async def chat_completion(
|
| 73 |
+
request: ChatRequest,
|
| 74 |
+
allowed_token=Depends(security_service.verify_authorization),
|
| 75 |
+
api_key: str = Depends(get_next_working_key_wrapper),
|
| 76 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
| 77 |
+
chat_service: OpenAIChatService = Depends(get_openai_chat_service),
|
| 78 |
+
):
|
| 79 |
+
"""处理 OpenAI 聊天补全请求,支持流式响应和特定模型切换。"""
|
| 80 |
+
operation_name = "chat_completion"
|
| 81 |
+
is_image_chat = request.model == f"{settings.CREATE_IMAGE_MODEL}-chat"
|
| 82 |
+
current_api_key = api_key
|
| 83 |
+
if is_image_chat:
|
| 84 |
+
current_api_key = await key_manager.get_paid_key()
|
| 85 |
+
|
| 86 |
+
async with handle_route_errors(logger, operation_name):
|
| 87 |
+
logger.info(f"Handling chat completion request for model: {request.model}")
|
| 88 |
+
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
| 89 |
+
logger.info(f"Using allowed token: {allowed_token}")
|
| 90 |
+
logger.info(f"Using API key: {redact_key_for_logging(current_api_key)}")
|
| 91 |
+
|
| 92 |
+
if not await model_service.check_model_support(request.model):
|
| 93 |
+
raise HTTPException(
|
| 94 |
+
status_code=400, detail=f"Model {request.model} is not supported"
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
raw_response = None
|
| 98 |
+
if is_image_chat:
|
| 99 |
+
raw_response = await chat_service.create_image_chat_completion(
|
| 100 |
+
request, current_api_key
|
| 101 |
+
)
|
| 102 |
+
else:
|
| 103 |
+
raw_response = await chat_service.create_chat_completion(
|
| 104 |
+
request, current_api_key
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
if request.stream:
|
| 108 |
+
try:
|
| 109 |
+
# 尝试获取第一条数据,判断是正常 SSE(data: 前缀)还是错误 JSON
|
| 110 |
+
first_chunk = await raw_response.__anext__()
|
| 111 |
+
except StopAsyncIteration:
|
| 112 |
+
# 如果流直接结束,退回标准 SSE 输出
|
| 113 |
+
return StreamingResponse(raw_response, media_type="text/event-stream")
|
| 114 |
+
except Exception as e:
|
| 115 |
+
# 初始化流异常,直接返回 500 错误
|
| 116 |
+
return JSONResponse(
|
| 117 |
+
content={"error": {"code": e.args[0], "message": e.args[1]}},
|
| 118 |
+
status_code=e.args[0],
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# 如果以 "data:" 开头,代表正常 SSE,将首块和后续块一起发送
|
| 122 |
+
if isinstance(first_chunk, str) and first_chunk.startswith("data:"):
|
| 123 |
+
|
| 124 |
+
async def combined():
|
| 125 |
+
yield first_chunk
|
| 126 |
+
async for chunk in raw_response:
|
| 127 |
+
yield chunk
|
| 128 |
+
|
| 129 |
+
return StreamingResponse(combined(), media_type="text/event-stream")
|
| 130 |
+
else:
|
| 131 |
+
return raw_response
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@router.post("/v1/images/generations")
|
| 135 |
+
@router.post("/hf/v1/images/generations")
|
| 136 |
+
async def generate_image(
|
| 137 |
+
request: ImageGenerationRequest,
|
| 138 |
+
allowed_token=Depends(security_service.verify_authorization),
|
| 139 |
+
):
|
| 140 |
+
"""处理 OpenAI 图像生成请求。"""
|
| 141 |
+
operation_name = "generate_image"
|
| 142 |
+
async with handle_route_errors(logger, operation_name):
|
| 143 |
+
logger.info(f"Handling image generation request for prompt: {request.prompt}")
|
| 144 |
+
logger.info(f"Using allowed token: {allowed_token}")
|
| 145 |
+
response = image_create_service.generate_images(request)
|
| 146 |
+
return response
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
@router.post("/v1/embeddings")
|
| 150 |
+
@router.post("/hf/v1/embeddings")
|
| 151 |
+
async def embedding(
|
| 152 |
+
request: EmbeddingRequest,
|
| 153 |
+
allowed_token=Depends(security_service.verify_authorization),
|
| 154 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
| 155 |
+
):
|
| 156 |
+
"""处理 OpenAI 文本嵌入请求。"""
|
| 157 |
+
operation_name = "embedding"
|
| 158 |
+
async with handle_route_errors(logger, operation_name):
|
| 159 |
+
logger.info(f"Handling embedding request for model: {request.model}")
|
| 160 |
+
api_key = await key_manager.get_next_working_key()
|
| 161 |
+
logger.info(f"Using allowed token: {allowed_token}")
|
| 162 |
+
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
| 163 |
+
response = await embedding_service.create_embedding(
|
| 164 |
+
input_text=request.input, model=request.model, api_key=api_key
|
| 165 |
+
)
|
| 166 |
+
return response
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
@router.get("/v1/keys/list")
|
| 170 |
+
@router.get("/hf/v1/keys/list")
|
| 171 |
+
async def get_keys_list(
|
| 172 |
+
_=Depends(security_service.verify_auth_token),
|
| 173 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
| 174 |
+
):
|
| 175 |
+
"""获取有效和无效的API key列表 (需要管理 Token 认证)。"""
|
| 176 |
+
operation_name = "get_keys_list"
|
| 177 |
+
async with handle_route_errors(logger, operation_name):
|
| 178 |
+
logger.info("Handling keys list request")
|
| 179 |
+
keys_status = await key_manager.get_keys_by_status()
|
| 180 |
+
return {
|
| 181 |
+
"status": "success",
|
| 182 |
+
"data": {
|
| 183 |
+
"valid_keys": keys_status["valid_keys"],
|
| 184 |
+
"invalid_keys": keys_status["invalid_keys"],
|
| 185 |
+
},
|
| 186 |
+
"total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"]),
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
@router.post("/v1/audio/speech")
|
| 191 |
+
@router.post("/hf/v1/audio/speech")
|
| 192 |
+
async def text_to_speech(
|
| 193 |
+
request: TTSRequest,
|
| 194 |
+
allowed_token=Depends(security_service.verify_authorization),
|
| 195 |
+
api_key: str = Depends(get_next_working_key_wrapper),
|
| 196 |
+
tts_service: TTSService = Depends(get_tts_service),
|
| 197 |
+
):
|
| 198 |
+
"""处理 OpenAI TTS 请求。"""
|
| 199 |
+
operation_name = "text_to_speech"
|
| 200 |
+
async with handle_route_errors(logger, operation_name):
|
| 201 |
+
logger.info(f"Handling TTS request for model: {request.model}")
|
| 202 |
+
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
| 203 |
+
logger.info(f"Using allowed token: {allowed_token}")
|
| 204 |
+
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
| 205 |
+
audio_data = await tts_service.create_tts(request, api_key)
|
| 206 |
+
return Response(content=audio_data, media_type="audio/wav")
|
app/router/routes.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
路由配置模块,负责设置和配置应用程序的路由
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from fastapi import FastAPI, Request
|
| 6 |
+
from fastapi.responses import HTMLResponse, RedirectResponse
|
| 7 |
+
from fastapi.templating import Jinja2Templates
|
| 8 |
+
|
| 9 |
+
from app.config.config import settings
|
| 10 |
+
from app.core.security import verify_auth_token
|
| 11 |
+
from app.log.logger import get_routes_logger
|
| 12 |
+
from app.router import (
|
| 13 |
+
config_routes,
|
| 14 |
+
error_log_routes,
|
| 15 |
+
files_routes,
|
| 16 |
+
gemini_routes,
|
| 17 |
+
key_routes,
|
| 18 |
+
openai_compatiable_routes,
|
| 19 |
+
openai_routes,
|
| 20 |
+
scheduler_routes,
|
| 21 |
+
stats_routes,
|
| 22 |
+
version_routes,
|
| 23 |
+
vertex_express_routes,
|
| 24 |
+
)
|
| 25 |
+
from app.service.key.key_manager import get_key_manager_instance
|
| 26 |
+
from app.service.stats.stats_service import StatsService
|
| 27 |
+
from app.utils.static_version import get_static_url
|
| 28 |
+
|
| 29 |
+
logger = get_routes_logger()
|
| 30 |
+
|
| 31 |
+
templates = Jinja2Templates(directory="app/templates")
|
| 32 |
+
# 设置模板全局变量
|
| 33 |
+
templates.env.globals["static_url"] = get_static_url
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def setup_routers(app: FastAPI) -> None:
|
| 37 |
+
"""
|
| 38 |
+
设置应用程序的路由
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
app: FastAPI应用程序实例
|
| 42 |
+
"""
|
| 43 |
+
app.include_router(openai_routes.router)
|
| 44 |
+
app.include_router(gemini_routes.router)
|
| 45 |
+
app.include_router(gemini_routes.router_v1beta)
|
| 46 |
+
app.include_router(config_routes.router)
|
| 47 |
+
app.include_router(error_log_routes.router)
|
| 48 |
+
app.include_router(scheduler_routes.router)
|
| 49 |
+
app.include_router(stats_routes.router)
|
| 50 |
+
app.include_router(version_routes.router)
|
| 51 |
+
app.include_router(openai_compatiable_routes.router)
|
| 52 |
+
app.include_router(vertex_express_routes.router)
|
| 53 |
+
app.include_router(files_routes.router)
|
| 54 |
+
app.include_router(key_routes.router)
|
| 55 |
+
|
| 56 |
+
setup_page_routes(app)
|
| 57 |
+
|
| 58 |
+
setup_health_routes(app)
|
| 59 |
+
setup_api_stats_routes(app)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def setup_page_routes(app: FastAPI) -> None:
|
| 63 |
+
"""
|
| 64 |
+
设置页面相关的路由
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
app: FastAPI应用程序实例
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
@app.get("/", response_class=HTMLResponse)
|
| 71 |
+
async def auth_page(request: Request):
|
| 72 |
+
"""认证页面"""
|
| 73 |
+
return templates.TemplateResponse("auth.html", {"request": request})
|
| 74 |
+
|
| 75 |
+
@app.post("/auth")
|
| 76 |
+
async def authenticate(request: Request):
|
| 77 |
+
"""处理认证请求"""
|
| 78 |
+
try:
|
| 79 |
+
form = await request.form()
|
| 80 |
+
auth_token = form.get("auth_token")
|
| 81 |
+
if not auth_token:
|
| 82 |
+
logger.warning("Authentication attempt with empty token")
|
| 83 |
+
return RedirectResponse(url="/", status_code=302)
|
| 84 |
+
|
| 85 |
+
if verify_auth_token(auth_token):
|
| 86 |
+
logger.info("Successful authentication")
|
| 87 |
+
response = RedirectResponse(url="/keys", status_code=302)
|
| 88 |
+
response.set_cookie(
|
| 89 |
+
key="auth_token",
|
| 90 |
+
value=auth_token,
|
| 91 |
+
httponly=True,
|
| 92 |
+
max_age=settings.ADMIN_SESSION_EXPIRE,
|
| 93 |
+
)
|
| 94 |
+
return response
|
| 95 |
+
logger.warning("Failed authentication attempt with invalid token")
|
| 96 |
+
return RedirectResponse(url="/", status_code=302)
|
| 97 |
+
except Exception as e:
|
| 98 |
+
logger.error(f"Authentication error: {str(e)}")
|
| 99 |
+
return RedirectResponse(url="/", status_code=302)
|
| 100 |
+
|
| 101 |
+
@app.get("/keys", response_class=HTMLResponse)
|
| 102 |
+
async def keys_page(request: Request):
|
| 103 |
+
"""密钥管理页面"""
|
| 104 |
+
try:
|
| 105 |
+
auth_token = request.cookies.get("auth_token")
|
| 106 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 107 |
+
logger.warning("Unauthorized access attempt to keys page")
|
| 108 |
+
return RedirectResponse(url="/", status_code=302)
|
| 109 |
+
|
| 110 |
+
key_manager = await get_key_manager_instance()
|
| 111 |
+
keys_status = await key_manager.get_keys_by_status()
|
| 112 |
+
total_keys = len(keys_status["valid_keys"]) + len(
|
| 113 |
+
keys_status["invalid_keys"]
|
| 114 |
+
)
|
| 115 |
+
valid_key_count = len(keys_status["valid_keys"])
|
| 116 |
+
invalid_key_count = len(keys_status["invalid_keys"])
|
| 117 |
+
|
| 118 |
+
stats_service = StatsService()
|
| 119 |
+
api_stats = await stats_service.get_api_usage_stats()
|
| 120 |
+
logger.info(f"API stats retrieved: {api_stats}")
|
| 121 |
+
|
| 122 |
+
logger.info(f"Keys status retrieved successfully. Total keys: {total_keys}")
|
| 123 |
+
return templates.TemplateResponse(
|
| 124 |
+
"keys_status.html",
|
| 125 |
+
{
|
| 126 |
+
"request": request,
|
| 127 |
+
"valid_keys": {},
|
| 128 |
+
"invalid_keys": {},
|
| 129 |
+
"total_keys": total_keys,
|
| 130 |
+
"valid_key_count": valid_key_count,
|
| 131 |
+
"invalid_key_count": invalid_key_count,
|
| 132 |
+
"api_stats": api_stats,
|
| 133 |
+
},
|
| 134 |
+
)
|
| 135 |
+
except Exception as e:
|
| 136 |
+
logger.error(f"Error retrieving keys status or API stats: {str(e)}")
|
| 137 |
+
# Even if there's an error, render the page with whatever data is available
|
| 138 |
+
# or with empty/default values, so the frontend can still load.
|
| 139 |
+
return templates.TemplateResponse(
|
| 140 |
+
"keys_status.html",
|
| 141 |
+
{
|
| 142 |
+
"request": request,
|
| 143 |
+
"valid_keys": {},
|
| 144 |
+
"invalid_keys": {},
|
| 145 |
+
"total_keys": 0,
|
| 146 |
+
"valid_key_count": 0,
|
| 147 |
+
"invalid_key_count": 0,
|
| 148 |
+
"api_stats": { # Provide a default structure for api_stats
|
| 149 |
+
"calls_1m": {"total": 0, "success": 0, "failure": 0},
|
| 150 |
+
"calls_1h": {"total": 0, "success": 0, "failure": 0},
|
| 151 |
+
"calls_24h": {"total": 0, "success": 0, "failure": 0},
|
| 152 |
+
"calls_month": {"total": 0, "success": 0, "failure": 0},
|
| 153 |
+
},
|
| 154 |
+
},
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
@app.get("/config", response_class=HTMLResponse)
|
| 158 |
+
async def config_page(request: Request):
|
| 159 |
+
"""配置编辑页面"""
|
| 160 |
+
try:
|
| 161 |
+
auth_token = request.cookies.get("auth_token")
|
| 162 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 163 |
+
logger.warning("Unauthorized access attempt to config page")
|
| 164 |
+
return RedirectResponse(url="/", status_code=302)
|
| 165 |
+
|
| 166 |
+
logger.info("Config page accessed successfully")
|
| 167 |
+
return templates.TemplateResponse(
|
| 168 |
+
"config_editor.html", {"request": request}
|
| 169 |
+
)
|
| 170 |
+
except Exception as e:
|
| 171 |
+
logger.error(f"Error accessing config page: {str(e)}")
|
| 172 |
+
raise
|
| 173 |
+
|
| 174 |
+
@app.get("/logs", response_class=HTMLResponse)
|
| 175 |
+
async def logs_page(request: Request):
|
| 176 |
+
"""错误日志页面"""
|
| 177 |
+
try:
|
| 178 |
+
auth_token = request.cookies.get("auth_token")
|
| 179 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 180 |
+
logger.warning("Unauthorized access attempt to logs page")
|
| 181 |
+
return RedirectResponse(url="/", status_code=302)
|
| 182 |
+
|
| 183 |
+
logger.info("Logs page accessed successfully")
|
| 184 |
+
return templates.TemplateResponse("error_logs.html", {"request": request})
|
| 185 |
+
except Exception as e:
|
| 186 |
+
logger.error(f"Error accessing logs page: {str(e)}")
|
| 187 |
+
raise
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def setup_health_routes(app: FastAPI) -> None:
|
| 191 |
+
"""
|
| 192 |
+
设置健康检查相关的路由
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
app: FastAPI应用程序实例
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
@app.get("/health")
|
| 199 |
+
async def health_check(request: Request):
|
| 200 |
+
"""健康检查端点"""
|
| 201 |
+
logger.info("Health check endpoint called")
|
| 202 |
+
return {"status": "healthy"}
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def setup_api_stats_routes(app: FastAPI) -> None:
|
| 206 |
+
"""
|
| 207 |
+
设置 API 统计相关的路由
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
app: FastAPI应用程序实例
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
@app.get("/api/stats/details")
|
| 214 |
+
async def api_stats_details(request: Request, period: str):
|
| 215 |
+
"""获取指定时间段内的 API 调用详情"""
|
| 216 |
+
try:
|
| 217 |
+
auth_token = request.cookies.get("auth_token")
|
| 218 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 219 |
+
logger.warning("Unauthorized access attempt to API stats details")
|
| 220 |
+
return {"error": "Unauthorized"}, 401
|
| 221 |
+
|
| 222 |
+
logger.info(f"Fetching API call details for period: {period}")
|
| 223 |
+
stats_service = StatsService()
|
| 224 |
+
details = await stats_service.get_api_call_details(period)
|
| 225 |
+
return details
|
| 226 |
+
except ValueError as e:
|
| 227 |
+
logger.warning(
|
| 228 |
+
f"Invalid period requested for API stats details: {period} - {str(e)}"
|
| 229 |
+
)
|
| 230 |
+
return {"error": str(e)}, 400
|
| 231 |
+
except Exception as e:
|
| 232 |
+
logger.error(
|
| 233 |
+
f"Error fetching API stats details for period {period}: {str(e)}"
|
| 234 |
+
)
|
| 235 |
+
return {"error": "Internal server error"}, 500
|
| 236 |
+
|
| 237 |
+
@app.get("/api/stats/attention-keys")
|
| 238 |
+
async def api_stats_attention_keys(
|
| 239 |
+
request: Request, limit: int = 20, status_code: int = 429
|
| 240 |
+
):
|
| 241 |
+
"""返回最近24小时指定错误码次数最多的Key(仅包含内存Key列表中的)。默认错误码429。"""
|
| 242 |
+
try:
|
| 243 |
+
auth_token = request.cookies.get("auth_token")
|
| 244 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 245 |
+
logger.warning("Unauthorized access attempt to attention-keys")
|
| 246 |
+
return {"error": "Unauthorized"}, 401
|
| 247 |
+
|
| 248 |
+
# 支持所有标准HTTP状态码范围
|
| 249 |
+
# if not isinstance(status_code, int) or status_code < 100 or status_code > 599:
|
| 250 |
+
# return {"error": f"Unsupported status_code: {status_code}"}, 400
|
| 251 |
+
|
| 252 |
+
key_manager = await get_key_manager_instance()
|
| 253 |
+
keys_status = await key_manager.get_keys_by_status()
|
| 254 |
+
in_memory_keys = set(keys_status.get("valid_keys", [])) | set(
|
| 255 |
+
keys_status.get("invalid_keys", [])
|
| 256 |
+
)
|
| 257 |
+
stats_service = StatsService()
|
| 258 |
+
data = await stats_service.get_attention_keys_last_24h(
|
| 259 |
+
in_memory_keys, limit, status_code
|
| 260 |
+
)
|
| 261 |
+
return data
|
| 262 |
+
except Exception as e:
|
| 263 |
+
logger.error(f"Error fetching attention keys: {e}")
|
| 264 |
+
return {"error": "Internal server error"}, 500
|
| 265 |
+
|
| 266 |
+
@app.get("/api/stats/key-details")
|
| 267 |
+
async def api_stats_key_details(request: Request, key: str, period: str):
|
| 268 |
+
"""获取指定密钥在指定时间段内的调用详情"""
|
| 269 |
+
try:
|
| 270 |
+
auth_token = request.cookies.get("auth_token")
|
| 271 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 272 |
+
logger.warning("Unauthorized access attempt to API key stats details")
|
| 273 |
+
return {"error": "Unauthorized"}, 401
|
| 274 |
+
|
| 275 |
+
logger.info(
|
| 276 |
+
f"Fetching key call details for key=...{key[-4:] if key else ''}, period: {period}"
|
| 277 |
+
)
|
| 278 |
+
stats_service = StatsService()
|
| 279 |
+
details = await stats_service.get_key_call_details(key, period)
|
| 280 |
+
return details
|
| 281 |
+
except ValueError as e:
|
| 282 |
+
logger.warning(
|
| 283 |
+
f"Invalid period requested for key stats details: {period} - {str(e)}"
|
| 284 |
+
)
|
| 285 |
+
return {"error": str(e)}, 400
|
| 286 |
+
except Exception as e:
|
| 287 |
+
logger.error(
|
| 288 |
+
f"Error fetching key stats details for period {period}: {str(e)}"
|
| 289 |
+
)
|
| 290 |
+
return {"error": "Internal server error"}, 500
|
app/router/scheduler_routes.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
定时任务控制路由模块
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter, Request, HTTPException, status
|
| 6 |
+
from fastapi.responses import JSONResponse
|
| 7 |
+
|
| 8 |
+
from app.core.security import verify_auth_token
|
| 9 |
+
from app.scheduler.scheduled_tasks import start_scheduler, stop_scheduler
|
| 10 |
+
from app.log.logger import get_scheduler_routes
|
| 11 |
+
|
| 12 |
+
logger = get_scheduler_routes()
|
| 13 |
+
|
| 14 |
+
router = APIRouter(
|
| 15 |
+
prefix="/api/scheduler",
|
| 16 |
+
tags=["Scheduler"]
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
async def verify_token(request: Request):
|
| 20 |
+
auth_token = request.cookies.get("auth_token")
|
| 21 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 22 |
+
logger.warning("Unauthorized access attempt to scheduler API")
|
| 23 |
+
raise HTTPException(
|
| 24 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 25 |
+
detail="Not authenticated",
|
| 26 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
@router.post("/start", summary="启动定时任务")
|
| 30 |
+
async def start_scheduler_endpoint(request: Request):
|
| 31 |
+
"""Start the background scheduler task"""
|
| 32 |
+
await verify_token(request)
|
| 33 |
+
try:
|
| 34 |
+
logger.info("Received request to start scheduler.")
|
| 35 |
+
start_scheduler()
|
| 36 |
+
return JSONResponse(content={"message": "Scheduler started successfully."}, status_code=status.HTTP_200_OK)
|
| 37 |
+
except Exception as e:
|
| 38 |
+
logger.error(f"Error starting scheduler: {str(e)}", exc_info=True)
|
| 39 |
+
raise HTTPException(
|
| 40 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 41 |
+
detail=f"Failed to start scheduler: {str(e)}"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
@router.post("/stop", summary="停止定时任务")
|
| 45 |
+
async def stop_scheduler_endpoint(request: Request):
|
| 46 |
+
"""Stop the background scheduler task"""
|
| 47 |
+
await verify_token(request)
|
| 48 |
+
try:
|
| 49 |
+
logger.info("Received request to stop scheduler.")
|
| 50 |
+
stop_scheduler()
|
| 51 |
+
return JSONResponse(content={"message": "Scheduler stopped successfully."}, status_code=status.HTTP_200_OK)
|
| 52 |
+
except Exception as e:
|
| 53 |
+
logger.error(f"Error stopping scheduler: {str(e)}", exc_info=True)
|
| 54 |
+
raise HTTPException(
|
| 55 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 56 |
+
detail=f"Failed to stop scheduler: {str(e)}"
|
| 57 |
+
)
|
app/router/stats_routes.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, HTTPException, Request
|
| 2 |
+
from starlette import status
|
| 3 |
+
from app.core.security import verify_auth_token
|
| 4 |
+
from app.service.stats.stats_service import StatsService
|
| 5 |
+
from app.log.logger import get_stats_logger
|
| 6 |
+
from app.utils.helpers import redact_key_for_logging
|
| 7 |
+
|
| 8 |
+
logger = get_stats_logger()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
async def verify_token(request: Request):
|
| 12 |
+
auth_token = request.cookies.get("auth_token")
|
| 13 |
+
if not auth_token or not verify_auth_token(auth_token):
|
| 14 |
+
logger.warning("Unauthorized access attempt to scheduler API")
|
| 15 |
+
raise HTTPException(
|
| 16 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 17 |
+
detail="Not authenticated",
|
| 18 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
router = APIRouter(
|
| 22 |
+
prefix="/api",
|
| 23 |
+
tags=["stats"],
|
| 24 |
+
dependencies=[Depends(verify_token)]
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
stats_service = StatsService()
|
| 28 |
+
|
| 29 |
+
@router.get("/key-usage-details/{key}",
|
| 30 |
+
summary="获取指定密钥最近24小时的模型调用次数",
|
| 31 |
+
description="根据提供的 API 密钥,返回过去24小时内每个模型被调用的次数统计。")
|
| 32 |
+
async def get_key_usage_details(key: str):
|
| 33 |
+
"""
|
| 34 |
+
Retrieves the model usage count for a specific API key within the last 24 hours.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
key: The API key to get usage details for.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
A dictionary with model names as keys and their call counts as values.
|
| 41 |
+
Example: {"gemini-pro": 10, "gemini-1.5-pro-latest": 5}
|
| 42 |
+
|
| 43 |
+
Raises:
|
| 44 |
+
HTTPException: If an error occurs during data retrieval.
|
| 45 |
+
"""
|
| 46 |
+
try:
|
| 47 |
+
usage_details = await stats_service.get_key_usage_details_last_24h(key)
|
| 48 |
+
if usage_details is None:
|
| 49 |
+
return {}
|
| 50 |
+
return usage_details
|
| 51 |
+
except Exception as e:
|
| 52 |
+
logger.error(f"Error fetching key usage details for key {redact_key_for_logging(key)}: {e}")
|
| 53 |
+
raise HTTPException(
|
| 54 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 55 |
+
detail=f"获取密钥使用详情时出错: {e}"
|
| 56 |
+
)
|
app/router/version_routes.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException
|
| 2 |
+
from pydantic import BaseModel, Field
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
from app.service.update.update_service import check_for_updates
|
| 6 |
+
from app.utils.helpers import get_current_version
|
| 7 |
+
from app.log.logger import get_update_logger
|
| 8 |
+
|
| 9 |
+
router = APIRouter(prefix="/api/version", tags=["Version"])
|
| 10 |
+
logger = get_update_logger()
|
| 11 |
+
|
| 12 |
+
class VersionInfo(BaseModel):
|
| 13 |
+
current_version: str = Field(..., description="当前应用程序版本")
|
| 14 |
+
latest_version: Optional[str] = Field(None, description="可用的最新版本")
|
| 15 |
+
update_available: bool = Field(False, description="是否有可用更新")
|
| 16 |
+
error_message: Optional[str] = Field(None, description="检查更新时发生的错误信息")
|
| 17 |
+
|
| 18 |
+
@router.get("/check", response_model=VersionInfo, summary="检查应用程序更新")
|
| 19 |
+
async def get_version_info():
|
| 20 |
+
"""
|
| 21 |
+
检查当前应用程序版本与最新的 GitHub release 版本。
|
| 22 |
+
"""
|
| 23 |
+
try:
|
| 24 |
+
current_version = get_current_version()
|
| 25 |
+
update_available, latest_version, error_message = await check_for_updates()
|
| 26 |
+
|
| 27 |
+
logger.info(f"Version check API result: current={current_version}, latest={latest_version}, available={update_available}, error='{error_message}'")
|
| 28 |
+
|
| 29 |
+
return VersionInfo(
|
| 30 |
+
current_version=current_version,
|
| 31 |
+
latest_version=latest_version,
|
| 32 |
+
update_available=update_available,
|
| 33 |
+
error_message=error_message
|
| 34 |
+
)
|
| 35 |
+
except Exception as e:
|
| 36 |
+
logger.error(f"Error in /api/version/check endpoint: {e}", exc_info=True)
|
| 37 |
+
raise HTTPException(status_code=500, detail="检查版本信息时发生内部错误")
|
app/router/vertex_express_routes.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import deepcopy
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter, Depends, HTTPException
|
| 4 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
| 5 |
+
|
| 6 |
+
from app.config.config import settings
|
| 7 |
+
from app.core.constants import API_VERSION
|
| 8 |
+
from app.core.security import SecurityService
|
| 9 |
+
from app.domain.gemini_models import GeminiRequest
|
| 10 |
+
from app.handler.error_handler import handle_route_errors
|
| 11 |
+
from app.handler.retry_handler import RetryHandler
|
| 12 |
+
from app.log.logger import get_vertex_express_logger
|
| 13 |
+
from app.service.chat.vertex_express_chat_service import GeminiChatService
|
| 14 |
+
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
| 15 |
+
from app.service.model.model_service import ModelService
|
| 16 |
+
from app.utils.helpers import redact_key_for_logging
|
| 17 |
+
|
| 18 |
+
router = APIRouter(prefix=f"/vertex-express/{API_VERSION}")
|
| 19 |
+
logger = get_vertex_express_logger()
|
| 20 |
+
|
| 21 |
+
security_service = SecurityService()
|
| 22 |
+
model_service = ModelService()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
async def get_key_manager():
|
| 26 |
+
"""获取密钥管理器实例"""
|
| 27 |
+
return await get_key_manager_instance()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
async def get_next_working_key(key_manager: KeyManager = Depends(get_key_manager)):
|
| 31 |
+
"""获取下一个可用的API密钥"""
|
| 32 |
+
return await key_manager.get_next_working_vertex_key()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
async def get_chat_service(key_manager: KeyManager = Depends(get_key_manager)):
|
| 36 |
+
"""获取Gemini聊天服务实例"""
|
| 37 |
+
return GeminiChatService(settings.VERTEX_EXPRESS_BASE_URL, key_manager)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@router.get("/models")
|
| 41 |
+
async def list_models(
|
| 42 |
+
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
| 43 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
| 44 |
+
):
|
| 45 |
+
"""获取可用的 Gemini 模型列表,并根据配置添加衍生模型(搜索、图像、非思考)。"""
|
| 46 |
+
operation_name = "list_gemini_models"
|
| 47 |
+
logger.info("-" * 50 + operation_name + "-" * 50)
|
| 48 |
+
logger.info("Handling Gemini models list request")
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
api_key = await key_manager.get_random_valid_key()
|
| 52 |
+
if not api_key:
|
| 53 |
+
raise HTTPException(
|
| 54 |
+
status_code=503, detail="No valid API keys available to fetch models."
|
| 55 |
+
)
|
| 56 |
+
logger.info(f"Using allowed token: {allowed_token}")
|
| 57 |
+
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
| 58 |
+
|
| 59 |
+
models_data = await model_service.get_gemini_models(api_key)
|
| 60 |
+
if not models_data or "models" not in models_data:
|
| 61 |
+
raise HTTPException(
|
| 62 |
+
status_code=500, detail="Failed to fetch base models list."
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
models_json = deepcopy(models_data)
|
| 66 |
+
model_mapping = {
|
| 67 |
+
x.get("name", "").split("/", maxsplit=1)[-1]: x
|
| 68 |
+
for x in models_json.get("models", [])
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
def add_derived_model(base_name, suffix, display_suffix):
|
| 72 |
+
model = model_mapping.get(base_name)
|
| 73 |
+
if not model:
|
| 74 |
+
logger.warning(
|
| 75 |
+
f"Base model '{base_name}' not found for derived model '{suffix}'."
|
| 76 |
+
)
|
| 77 |
+
return
|
| 78 |
+
item = deepcopy(model)
|
| 79 |
+
item["name"] = f"models/{base_name}{suffix}"
|
| 80 |
+
display_name = f'{item.get("displayName", base_name)}{display_suffix}'
|
| 81 |
+
item["displayName"] = display_name
|
| 82 |
+
item["description"] = display_name
|
| 83 |
+
models_json["models"].append(item)
|
| 84 |
+
|
| 85 |
+
if settings.SEARCH_MODELS:
|
| 86 |
+
for name in settings.SEARCH_MODELS:
|
| 87 |
+
add_derived_model(name, "-search", " For Search")
|
| 88 |
+
if settings.IMAGE_MODELS:
|
| 89 |
+
for name in settings.IMAGE_MODELS:
|
| 90 |
+
add_derived_model(name, "-image", " For Image")
|
| 91 |
+
if settings.THINKING_MODELS:
|
| 92 |
+
for name in settings.THINKING_MODELS:
|
| 93 |
+
add_derived_model(name, "-non-thinking", " Non Thinking")
|
| 94 |
+
|
| 95 |
+
logger.info("Gemini models list request successful")
|
| 96 |
+
return models_json
|
| 97 |
+
except HTTPException as http_exc:
|
| 98 |
+
raise http_exc
|
| 99 |
+
except Exception as e:
|
| 100 |
+
logger.error(f"Error getting Gemini models list: {str(e)}")
|
| 101 |
+
raise HTTPException(
|
| 102 |
+
status_code=500,
|
| 103 |
+
detail="Internal server error while fetching Gemini models list",
|
| 104 |
+
) from e
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@router.post("/models/{model_name}:generateContent")
|
| 108 |
+
@RetryHandler(key_arg="api_key")
|
| 109 |
+
async def generate_content(
|
| 110 |
+
model_name: str,
|
| 111 |
+
request: GeminiRequest,
|
| 112 |
+
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
| 113 |
+
api_key: str = Depends(get_next_working_key),
|
| 114 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
| 115 |
+
chat_service: GeminiChatService = Depends(get_chat_service),
|
| 116 |
+
):
|
| 117 |
+
"""处理 Gemini 非流式内容生成请求。"""
|
| 118 |
+
operation_name = "gemini_generate_content"
|
| 119 |
+
async with handle_route_errors(
|
| 120 |
+
logger, operation_name, failure_message="Content generation failed"
|
| 121 |
+
):
|
| 122 |
+
logger.info(
|
| 123 |
+
f"Handling Gemini content generation request for model: {model_name}"
|
| 124 |
+
)
|
| 125 |
+
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
| 126 |
+
logger.info(f"Using allowed token: {allowed_token}")
|
| 127 |
+
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
| 128 |
+
|
| 129 |
+
if not await model_service.check_model_support(model_name):
|
| 130 |
+
raise HTTPException(
|
| 131 |
+
status_code=400, detail=f"Model {model_name} is not supported"
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
response = await chat_service.generate_content(
|
| 135 |
+
model=model_name, request=request, api_key=api_key
|
| 136 |
+
)
|
| 137 |
+
return response
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
@router.post("/models/{model_name}:streamGenerateContent")
|
| 141 |
+
@RetryHandler(key_arg="api_key")
|
| 142 |
+
async def stream_generate_content(
|
| 143 |
+
model_name: str,
|
| 144 |
+
request: GeminiRequest,
|
| 145 |
+
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
| 146 |
+
api_key: str = Depends(get_next_working_key),
|
| 147 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
| 148 |
+
chat_service: GeminiChatService = Depends(get_chat_service),
|
| 149 |
+
):
|
| 150 |
+
"""处理 Gemini 流式内容生成请求。"""
|
| 151 |
+
operation_name = "gemini_stream_generate_content"
|
| 152 |
+
async with handle_route_errors(
|
| 153 |
+
logger, operation_name, failure_message="Streaming request initiation failed"
|
| 154 |
+
):
|
| 155 |
+
logger.info(
|
| 156 |
+
f"Handling Gemini streaming content generation for model: {model_name}"
|
| 157 |
+
)
|
| 158 |
+
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
| 159 |
+
logger.info(f"Using allowed token: {allowed_token}")
|
| 160 |
+
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
| 161 |
+
|
| 162 |
+
if not await model_service.check_model_support(model_name):
|
| 163 |
+
raise HTTPException(
|
| 164 |
+
status_code=400, detail=f"Model {model_name} is not supported"
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
raw_stream = chat_service.stream_generate_content(
|
| 168 |
+
model=model_name, request=request, api_key=api_key
|
| 169 |
+
)
|
| 170 |
+
try:
|
| 171 |
+
# 尝试获取第一条数据,判断是正常 SSE(data: 前缀)还是错误 JSON
|
| 172 |
+
first_chunk = await raw_stream.__anext__()
|
| 173 |
+
except StopAsyncIteration:
|
| 174 |
+
# 如果流直接结束,退回标准 SSE 输出
|
| 175 |
+
return StreamingResponse(raw_stream, media_type="text/event-stream")
|
| 176 |
+
except Exception as e:
|
| 177 |
+
# 初始化流异常,直接返回 500 错误
|
| 178 |
+
return JSONResponse(
|
| 179 |
+
content={"error": {"code": e.args[0], "message": e.args[1]}},
|
| 180 |
+
status_code=e.args[0],
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# 如果以 "data:" 开头,代表正常 SSE,将首块和后续块一起发送
|
| 184 |
+
if isinstance(first_chunk, str) and first_chunk.startswith("data:"):
|
| 185 |
+
|
| 186 |
+
async def combined():
|
| 187 |
+
yield first_chunk
|
| 188 |
+
async for chunk in raw_stream:
|
| 189 |
+
yield chunk
|
| 190 |
+
|
| 191 |
+
return StreamingResponse(combined(), media_type="text/event-stream")
|
app/scheduler/scheduled_tasks.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
| 2 |
+
|
| 3 |
+
from app.config.config import settings
|
| 4 |
+
from app.domain.gemini_models import GeminiContent, GeminiRequest
|
| 5 |
+
from app.log.logger import Logger
|
| 6 |
+
from app.service.chat.gemini_chat_service import GeminiChatService
|
| 7 |
+
from app.service.error_log.error_log_service import delete_old_error_logs
|
| 8 |
+
from app.service.files.files_service import get_files_service
|
| 9 |
+
from app.service.key.key_manager import get_key_manager_instance
|
| 10 |
+
from app.service.request_log.request_log_service import delete_old_request_logs_task
|
| 11 |
+
from app.utils.helpers import redact_key_for_logging
|
| 12 |
+
|
| 13 |
+
logger = Logger.setup_logger("scheduler")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
async def check_failed_keys():
|
| 17 |
+
"""
|
| 18 |
+
定时检查失败次数大于0的API密钥,并尝试验证它们。
|
| 19 |
+
如果验证成功,重置失败计数;如果失败,增加失败计数。
|
| 20 |
+
"""
|
| 21 |
+
logger.info("Starting scheduled check for failed API keys...")
|
| 22 |
+
try:
|
| 23 |
+
key_manager = await get_key_manager_instance()
|
| 24 |
+
# 确保 KeyManager 已经初始化
|
| 25 |
+
if not key_manager or not hasattr(key_manager, "key_failure_counts"):
|
| 26 |
+
logger.warning(
|
| 27 |
+
"KeyManager instance not available or not initialized. Skipping check."
|
| 28 |
+
)
|
| 29 |
+
return
|
| 30 |
+
|
| 31 |
+
# 创建 GeminiChatService 实例用于验证
|
| 32 |
+
# 注意:这里直接创建实例,而不是通过依赖注入,因为这是后台任务
|
| 33 |
+
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
|
| 34 |
+
|
| 35 |
+
# 获取需要检查的 key 列表 (失败次数 > 0)
|
| 36 |
+
keys_to_check = []
|
| 37 |
+
async with key_manager.failure_count_lock: # 访问共享数据需要加锁
|
| 38 |
+
# 复制一份以避免在迭代时修改字典
|
| 39 |
+
failure_counts_copy = key_manager.key_failure_counts.copy()
|
| 40 |
+
keys_to_check = [
|
| 41 |
+
key for key, count in failure_counts_copy.items() if count > 0
|
| 42 |
+
] # 检查所有失败次数大于0的key
|
| 43 |
+
|
| 44 |
+
if not keys_to_check:
|
| 45 |
+
logger.info("No keys with failure count > 0 found. Skipping verification.")
|
| 46 |
+
return
|
| 47 |
+
|
| 48 |
+
logger.info(
|
| 49 |
+
f"Found {len(keys_to_check)} keys with failure count > 0 to verify."
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
for key in keys_to_check:
|
| 53 |
+
# 隐藏部分 key 用于日志记录
|
| 54 |
+
log_key = redact_key_for_logging(key)
|
| 55 |
+
logger.info(f"Verifying key: {log_key}...")
|
| 56 |
+
try:
|
| 57 |
+
# 构造测试请求
|
| 58 |
+
gemini_request = GeminiRequest(
|
| 59 |
+
contents=[
|
| 60 |
+
GeminiContent(
|
| 61 |
+
role="user",
|
| 62 |
+
parts=[{"text": "hi"}],
|
| 63 |
+
)
|
| 64 |
+
]
|
| 65 |
+
)
|
| 66 |
+
await chat_service.generate_content(
|
| 67 |
+
settings.TEST_MODEL, gemini_request, key
|
| 68 |
+
)
|
| 69 |
+
logger.info(
|
| 70 |
+
f"Key {log_key} verification successful. Resetting failure count."
|
| 71 |
+
)
|
| 72 |
+
await key_manager.reset_key_failure_count(key)
|
| 73 |
+
except Exception as e:
|
| 74 |
+
logger.warning(
|
| 75 |
+
f"Key {log_key} verification failed: {str(e)}. Incrementing failure count."
|
| 76 |
+
)
|
| 77 |
+
# 直接操作计数器,需要加锁
|
| 78 |
+
async with key_manager.failure_count_lock:
|
| 79 |
+
# 再次检查 key 是否存在且失败次数未达上限
|
| 80 |
+
if (
|
| 81 |
+
key in key_manager.key_failure_counts
|
| 82 |
+
and key_manager.key_failure_counts[key]
|
| 83 |
+
< key_manager.MAX_FAILURES
|
| 84 |
+
):
|
| 85 |
+
key_manager.key_failure_counts[key] += 1
|
| 86 |
+
logger.info(
|
| 87 |
+
f"Failure count for key {log_key} incremented to {key_manager.key_failure_counts[key]}."
|
| 88 |
+
)
|
| 89 |
+
elif key in key_manager.key_failure_counts:
|
| 90 |
+
logger.warning(
|
| 91 |
+
f"Key {log_key} reached MAX_FAILURES ({key_manager.MAX_FAILURES}). Not incrementing further."
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
except Exception as e:
|
| 95 |
+
logger.error(
|
| 96 |
+
f"An error occurred during the scheduled key check: {str(e)}", exc_info=True
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
async def cleanup_expired_files():
|
| 101 |
+
"""
|
| 102 |
+
定时清理过期的文件记录
|
| 103 |
+
"""
|
| 104 |
+
logger.info("Starting scheduled cleanup for expired files...")
|
| 105 |
+
try:
|
| 106 |
+
files_service = await get_files_service()
|
| 107 |
+
deleted_count = await files_service.cleanup_expired_files()
|
| 108 |
+
|
| 109 |
+
if deleted_count > 0:
|
| 110 |
+
logger.info(f"Successfully cleaned up {deleted_count} expired files.")
|
| 111 |
+
else:
|
| 112 |
+
logger.info("No expired files to clean up.")
|
| 113 |
+
|
| 114 |
+
except Exception as e:
|
| 115 |
+
logger.error(
|
| 116 |
+
f"An error occurred during the scheduled file cleanup: {str(e)}",
|
| 117 |
+
exc_info=True,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def setup_scheduler():
|
| 122 |
+
"""设置��启动 APScheduler"""
|
| 123 |
+
scheduler = AsyncIOScheduler(timezone=str(settings.TIMEZONE)) # 从配置读取时区
|
| 124 |
+
# 添加检查失败密钥的定时任务
|
| 125 |
+
if settings.CHECK_INTERVAL_HOURS != 0:
|
| 126 |
+
scheduler.add_job(
|
| 127 |
+
check_failed_keys,
|
| 128 |
+
"interval",
|
| 129 |
+
hours=settings.CHECK_INTERVAL_HOURS,
|
| 130 |
+
id="check_failed_keys_job",
|
| 131 |
+
name="Check Failed API Keys",
|
| 132 |
+
)
|
| 133 |
+
logger.info(
|
| 134 |
+
f"Key check job scheduled to run every {settings.CHECK_INTERVAL_HOURS} hour(s)."
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# 新增:添加自动删除错误日志的定时任务,每天凌晨0点执行
|
| 138 |
+
scheduler.add_job(
|
| 139 |
+
delete_old_error_logs,
|
| 140 |
+
"cron",
|
| 141 |
+
hour=0,
|
| 142 |
+
minute=0,
|
| 143 |
+
id="delete_old_error_logs_job",
|
| 144 |
+
name="Delete Old Error Logs",
|
| 145 |
+
)
|
| 146 |
+
logger.info("Auto-delete error logs job scheduled to run daily at 3:00 AM.")
|
| 147 |
+
|
| 148 |
+
# 新增:添加自动删除请求日志的定时任务,每天凌晨0点执行
|
| 149 |
+
scheduler.add_job(
|
| 150 |
+
delete_old_request_logs_task,
|
| 151 |
+
"cron",
|
| 152 |
+
hour=0,
|
| 153 |
+
minute=0,
|
| 154 |
+
id="delete_old_request_logs_job",
|
| 155 |
+
name="Delete Old Request Logs",
|
| 156 |
+
)
|
| 157 |
+
logger.info(
|
| 158 |
+
f"Auto-delete request logs job scheduled to run daily at 3:05 AM, if enabled and AUTO_DELETE_REQUEST_LOGS_DAYS is set to {settings.AUTO_DELETE_REQUEST_LOGS_DAYS} days."
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# 新增:添加文件过期清理的定时任务,每小时执行一次
|
| 162 |
+
if getattr(settings, "FILES_CLEANUP_ENABLED", True):
|
| 163 |
+
cleanup_interval = getattr(settings, "FILES_CLEANUP_INTERVAL_HOURS", 1)
|
| 164 |
+
scheduler.add_job(
|
| 165 |
+
cleanup_expired_files,
|
| 166 |
+
"interval",
|
| 167 |
+
hours=cleanup_interval,
|
| 168 |
+
id="cleanup_expired_files_job",
|
| 169 |
+
name="Cleanup Expired Files",
|
| 170 |
+
)
|
| 171 |
+
logger.info(
|
| 172 |
+
f"File cleanup job scheduled to run every {cleanup_interval} hour(s)."
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
scheduler.start()
|
| 176 |
+
logger.info("Scheduler started with all jobs.")
|
| 177 |
+
return scheduler
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# 可以在这里添加一个全局的 scheduler 实例,以便在应用关闭时优雅地停止
|
| 181 |
+
scheduler_instance = None
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def start_scheduler():
|
| 185 |
+
global scheduler_instance
|
| 186 |
+
if scheduler_instance is None or not scheduler_instance.running:
|
| 187 |
+
logger.info("Starting scheduler...")
|
| 188 |
+
scheduler_instance = setup_scheduler()
|
| 189 |
+
logger.info("Scheduler is already running.")
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def stop_scheduler():
|
| 193 |
+
global scheduler_instance
|
| 194 |
+
if scheduler_instance and scheduler_instance.running:
|
| 195 |
+
scheduler_instance.shutdown()
|
| 196 |
+
logger.info("Scheduler stopped.")
|
app/service/chat/gemini_chat_service.py
ADDED
|
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/services/chat_service.py
|
| 2 |
+
|
| 3 |
+
import datetime
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
+
import time
|
| 7 |
+
from typing import Any, AsyncGenerator, Dict, List
|
| 8 |
+
|
| 9 |
+
from app.config.config import settings
|
| 10 |
+
from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
| 11 |
+
from app.database.services import add_error_log, add_request_log, get_file_api_key
|
| 12 |
+
from app.domain.gemini_models import GeminiRequest
|
| 13 |
+
from app.handler.response_handler import GeminiResponseHandler
|
| 14 |
+
from app.handler.stream_optimizer import gemini_optimizer
|
| 15 |
+
from app.log.logger import get_gemini_logger
|
| 16 |
+
from app.service.client.api_client import GeminiApiClient
|
| 17 |
+
from app.service.key.key_manager import KeyManager
|
| 18 |
+
from app.utils.helpers import redact_key_for_logging
|
| 19 |
+
|
| 20 |
+
logger = get_gemini_logger()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
|
| 24 |
+
"""判断消息是否包含图片部分"""
|
| 25 |
+
for content in contents:
|
| 26 |
+
if "parts" in content:
|
| 27 |
+
for part in content["parts"]:
|
| 28 |
+
if "image_url" in part or "inline_data" in part:
|
| 29 |
+
return True
|
| 30 |
+
return False
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _extract_file_references(contents: List[Dict[str, Any]]) -> List[str]:
|
| 34 |
+
"""從內容中提取文件引用"""
|
| 35 |
+
file_names = []
|
| 36 |
+
for content in contents:
|
| 37 |
+
if "parts" in content:
|
| 38 |
+
for part in content["parts"]:
|
| 39 |
+
if not isinstance(part, dict) or "fileData" not in part:
|
| 40 |
+
continue
|
| 41 |
+
file_data = part["fileData"]
|
| 42 |
+
if "fileUri" not in file_data:
|
| 43 |
+
continue
|
| 44 |
+
file_uri = file_data["fileUri"]
|
| 45 |
+
# 從 URI 中提取文件名
|
| 46 |
+
# 1. https://generativelanguage.googleapis.com/v1beta/files/{file_id}
|
| 47 |
+
match = re.match(
|
| 48 |
+
rf"{re.escape(settings.BASE_URL)}/(files/.*)", file_uri
|
| 49 |
+
)
|
| 50 |
+
if not match:
|
| 51 |
+
logger.warning(f"Invalid file URI: {file_uri}")
|
| 52 |
+
continue
|
| 53 |
+
file_id = match.group(1)
|
| 54 |
+
file_names.append(file_id)
|
| 55 |
+
logger.info(f"Found file reference: {file_id}")
|
| 56 |
+
return file_names
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _clean_json_schema_properties(obj: Any) -> Any:
|
| 60 |
+
"""清理JSON Schema中Gemini API不支持的字段"""
|
| 61 |
+
if not isinstance(obj, dict):
|
| 62 |
+
return obj
|
| 63 |
+
|
| 64 |
+
# Gemini API不支持的JSON Schema字段
|
| 65 |
+
unsupported_fields = {
|
| 66 |
+
"exclusiveMaximum",
|
| 67 |
+
"exclusiveMinimum",
|
| 68 |
+
"const",
|
| 69 |
+
"examples",
|
| 70 |
+
"contentEncoding",
|
| 71 |
+
"contentMediaType",
|
| 72 |
+
"if",
|
| 73 |
+
"then",
|
| 74 |
+
"else",
|
| 75 |
+
"allOf",
|
| 76 |
+
"anyOf",
|
| 77 |
+
"oneOf",
|
| 78 |
+
"not",
|
| 79 |
+
"definitions",
|
| 80 |
+
"$schema",
|
| 81 |
+
"$id",
|
| 82 |
+
"$ref",
|
| 83 |
+
"$comment",
|
| 84 |
+
"readOnly",
|
| 85 |
+
"writeOnly",
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
cleaned = {}
|
| 89 |
+
for key, value in obj.items():
|
| 90 |
+
if key in unsupported_fields:
|
| 91 |
+
continue
|
| 92 |
+
if isinstance(value, dict):
|
| 93 |
+
cleaned[key] = _clean_json_schema_properties(value)
|
| 94 |
+
elif isinstance(value, list):
|
| 95 |
+
cleaned[key] = [_clean_json_schema_properties(item) for item in value]
|
| 96 |
+
else:
|
| 97 |
+
cleaned[key] = value
|
| 98 |
+
|
| 99 |
+
return cleaned
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 103 |
+
"""构建工具"""
|
| 104 |
+
|
| 105 |
+
def _has_function_call(contents: List[Dict[str, Any]]) -> bool:
|
| 106 |
+
"""检查内容中是否包含 functionCall"""
|
| 107 |
+
if not contents or not isinstance(contents, list):
|
| 108 |
+
return False
|
| 109 |
+
for content in contents:
|
| 110 |
+
if not content or not isinstance(content, dict) or "parts" not in content:
|
| 111 |
+
continue
|
| 112 |
+
parts = content.get("parts", [])
|
| 113 |
+
if not parts or not isinstance(parts, list):
|
| 114 |
+
continue
|
| 115 |
+
for part in parts:
|
| 116 |
+
if isinstance(part, dict) and "functionCall" in part:
|
| 117 |
+
return True
|
| 118 |
+
return False
|
| 119 |
+
|
| 120 |
+
def _merge_tools(tools: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 121 |
+
record = dict()
|
| 122 |
+
for item in tools:
|
| 123 |
+
if not item or not isinstance(item, dict):
|
| 124 |
+
continue
|
| 125 |
+
|
| 126 |
+
for k, v in item.items():
|
| 127 |
+
if k == "functionDeclarations" and v and isinstance(v, list):
|
| 128 |
+
functions = record.get("functionDeclarations", [])
|
| 129 |
+
# 清理每个函数声明中的不支持字段
|
| 130 |
+
cleaned_functions = []
|
| 131 |
+
for func in v:
|
| 132 |
+
if isinstance(func, dict):
|
| 133 |
+
cleaned_func = _clean_json_schema_properties(func)
|
| 134 |
+
cleaned_functions.append(cleaned_func)
|
| 135 |
+
else:
|
| 136 |
+
cleaned_functions.append(func)
|
| 137 |
+
functions.extend(cleaned_functions)
|
| 138 |
+
record["functionDeclarations"] = functions
|
| 139 |
+
else:
|
| 140 |
+
record[k] = v
|
| 141 |
+
return record
|
| 142 |
+
|
| 143 |
+
def _is_structured_output_request(payload: Dict[str, Any]) -> bool:
|
| 144 |
+
"""检查请求是否要求结构化JSON输出"""
|
| 145 |
+
try:
|
| 146 |
+
generation_config = payload.get("generationConfig", {})
|
| 147 |
+
return generation_config.get("responseMimeType") == "application/json"
|
| 148 |
+
except (AttributeError, TypeError):
|
| 149 |
+
return False
|
| 150 |
+
|
| 151 |
+
tool = dict()
|
| 152 |
+
if payload and isinstance(payload, dict) and "tools" in payload:
|
| 153 |
+
if payload.get("tools") and isinstance(payload.get("tools"), dict):
|
| 154 |
+
payload["tools"] = [payload.get("tools")]
|
| 155 |
+
items = payload.get("tools", [])
|
| 156 |
+
if items and isinstance(items, list):
|
| 157 |
+
tool.update(_merge_tools(items))
|
| 158 |
+
|
| 159 |
+
# "Tool use with a response mime type: 'application/json' is unsupported"
|
| 160 |
+
# Gemini API限制:不支持同时使用tools和结构化输出(response_mime_type='application/json')
|
| 161 |
+
# 当请求指定了JSON响应格式时,跳过所有工具的添加以避免API错误
|
| 162 |
+
has_structured_output = _is_structured_output_request(payload)
|
| 163 |
+
if not has_structured_output:
|
| 164 |
+
if (
|
| 165 |
+
settings.TOOLS_CODE_EXECUTION_ENABLED
|
| 166 |
+
and not (model.endswith("-search") or "-thinking" in model)
|
| 167 |
+
and not _has_image_parts(payload.get("contents", []))
|
| 168 |
+
):
|
| 169 |
+
tool["codeExecution"] = {}
|
| 170 |
+
|
| 171 |
+
if model.endswith("-search"):
|
| 172 |
+
tool["googleSearch"] = {}
|
| 173 |
+
|
| 174 |
+
real_model = _get_real_model(model)
|
| 175 |
+
if real_model in settings.URL_CONTEXT_MODELS and settings.URL_CONTEXT_ENABLED:
|
| 176 |
+
tool["urlContext"] = {}
|
| 177 |
+
|
| 178 |
+
# 解决 "Tool use with function calling is unsupported" 问题
|
| 179 |
+
if tool.get("functionDeclarations") or _has_function_call(
|
| 180 |
+
payload.get("contents", [])
|
| 181 |
+
):
|
| 182 |
+
tool.pop("googleSearch", None)
|
| 183 |
+
tool.pop("codeExecution", None)
|
| 184 |
+
tool.pop("urlContext", None)
|
| 185 |
+
|
| 186 |
+
return [tool] if tool else []
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def _get_real_model(model: str) -> str:
|
| 190 |
+
if model.endswith("-search"):
|
| 191 |
+
model = model[:-7]
|
| 192 |
+
if model.endswith("-image"):
|
| 193 |
+
model = model[:-6]
|
| 194 |
+
if model.endswith("-non-thinking"):
|
| 195 |
+
model = model[:-13]
|
| 196 |
+
if "-search" in model and "-non-thinking" in model:
|
| 197 |
+
model = model[:-20]
|
| 198 |
+
return model
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
| 202 |
+
"""获取安全设置"""
|
| 203 |
+
if model == "gemini-2.0-flash-exp":
|
| 204 |
+
return GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
| 205 |
+
return settings.SAFETY_SETTINGS
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def _filter_empty_parts(contents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 209 |
+
"""Filters out contents with empty or invalid parts."""
|
| 210 |
+
if not contents:
|
| 211 |
+
return []
|
| 212 |
+
|
| 213 |
+
filtered_contents = []
|
| 214 |
+
for content in contents:
|
| 215 |
+
if (
|
| 216 |
+
not content
|
| 217 |
+
or "parts" not in content
|
| 218 |
+
or not isinstance(content.get("parts"), list)
|
| 219 |
+
):
|
| 220 |
+
continue
|
| 221 |
+
|
| 222 |
+
valid_parts = [
|
| 223 |
+
part for part in content["parts"] if isinstance(part, dict) and part
|
| 224 |
+
]
|
| 225 |
+
|
| 226 |
+
if valid_parts:
|
| 227 |
+
new_content = content.copy()
|
| 228 |
+
new_content["parts"] = valid_parts
|
| 229 |
+
filtered_contents.append(new_content)
|
| 230 |
+
|
| 231 |
+
return filtered_contents
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
|
| 235 |
+
"""构建请求payload"""
|
| 236 |
+
request_dict = request.model_dump(exclude_none=False)
|
| 237 |
+
if request.generationConfig:
|
| 238 |
+
if request.generationConfig.maxOutputTokens is None:
|
| 239 |
+
# 如果未指定最大输出长度,则不传递该字段,解决截断的问题
|
| 240 |
+
if "maxOutputTokens" in request_dict["generationConfig"]:
|
| 241 |
+
request_dict["generationConfig"].pop("maxOutputTokens")
|
| 242 |
+
|
| 243 |
+
# 检查是否为TTS模型
|
| 244 |
+
is_tts_model = "tts" in model.lower()
|
| 245 |
+
|
| 246 |
+
if is_tts_model:
|
| 247 |
+
# TTS模型使用简化的payload,不包含tools和safetySettings
|
| 248 |
+
payload = {
|
| 249 |
+
"contents": _filter_empty_parts(request_dict.get("contents", [])),
|
| 250 |
+
"generationConfig": request_dict.get("generationConfig"),
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
# 只在有systemInstruction时才添加
|
| 254 |
+
if request_dict.get("systemInstruction"):
|
| 255 |
+
payload["systemInstruction"] = request_dict.get("systemInstruction")
|
| 256 |
+
else:
|
| 257 |
+
# 非TTS模型使用完整的payload
|
| 258 |
+
payload = {
|
| 259 |
+
"contents": _filter_empty_parts(request_dict.get("contents", [])),
|
| 260 |
+
"tools": _build_tools(model, request_dict),
|
| 261 |
+
"safetySettings": _get_safety_settings(model),
|
| 262 |
+
"generationConfig": request_dict.get("generationConfig"),
|
| 263 |
+
"systemInstruction": request_dict.get("systemInstruction"),
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
# 确保 generationConfig 不为 None
|
| 267 |
+
if payload["generationConfig"] is None:
|
| 268 |
+
payload["generationConfig"] = {}
|
| 269 |
+
|
| 270 |
+
if model.endswith("-image") or model.endswith("-image-generation"):
|
| 271 |
+
payload.pop("systemInstruction")
|
| 272 |
+
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
|
| 273 |
+
|
| 274 |
+
# 处理思考配置:优先使用客户端提供的配置,否则使���默认配置
|
| 275 |
+
client_thinking_config = None
|
| 276 |
+
if request.generationConfig and request.generationConfig.thinkingConfig:
|
| 277 |
+
client_thinking_config = request.generationConfig.thinkingConfig
|
| 278 |
+
|
| 279 |
+
if client_thinking_config is not None:
|
| 280 |
+
# 客户端提供了思考配置,直接使用
|
| 281 |
+
payload["generationConfig"]["thinkingConfig"] = client_thinking_config
|
| 282 |
+
else:
|
| 283 |
+
# 客户端没有提供思考配置,使用默认配置
|
| 284 |
+
if model.endswith("-non-thinking"):
|
| 285 |
+
if "gemini-2.5-pro" in model:
|
| 286 |
+
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 128}
|
| 287 |
+
else:
|
| 288 |
+
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
| 289 |
+
elif _get_real_model(model) in settings.THINKING_BUDGET_MAP:
|
| 290 |
+
if settings.SHOW_THINKING_PROCESS:
|
| 291 |
+
payload["generationConfig"]["thinkingConfig"] = {
|
| 292 |
+
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model, 1000),
|
| 293 |
+
"includeThoughts": True,
|
| 294 |
+
}
|
| 295 |
+
else:
|
| 296 |
+
payload["generationConfig"]["thinkingConfig"] = {
|
| 297 |
+
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model, 1000)
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
return payload
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class GeminiChatService:
|
| 304 |
+
"""聊天服务"""
|
| 305 |
+
|
| 306 |
+
def __init__(self, base_url: str, key_manager: KeyManager):
|
| 307 |
+
self.api_client = GeminiApiClient(base_url, settings.TIME_OUT)
|
| 308 |
+
self.key_manager = key_manager
|
| 309 |
+
self.response_handler = GeminiResponseHandler()
|
| 310 |
+
|
| 311 |
+
def _extract_text_from_response(self, response: Dict[str, Any]) -> str:
|
| 312 |
+
"""从响应中提取文本内容"""
|
| 313 |
+
if not response.get("candidates"):
|
| 314 |
+
return ""
|
| 315 |
+
|
| 316 |
+
candidate = response["candidates"][0]
|
| 317 |
+
content = candidate.get("content", {})
|
| 318 |
+
parts = content.get("parts", [])
|
| 319 |
+
|
| 320 |
+
if parts and "text" in parts[0]:
|
| 321 |
+
return parts[0].get("text", "")
|
| 322 |
+
return ""
|
| 323 |
+
|
| 324 |
+
def _create_char_response(
|
| 325 |
+
self, original_response: Dict[str, Any], text: str
|
| 326 |
+
) -> Dict[str, Any]:
|
| 327 |
+
"""创建包含指定文本的响应"""
|
| 328 |
+
response_copy = json.loads(json.dumps(original_response))
|
| 329 |
+
if response_copy.get("candidates") and response_copy["candidates"][0].get(
|
| 330 |
+
"content", {}
|
| 331 |
+
).get("parts"):
|
| 332 |
+
response_copy["candidates"][0]["content"]["parts"][0]["text"] = text
|
| 333 |
+
return response_copy
|
| 334 |
+
|
| 335 |
+
async def generate_content(
|
| 336 |
+
self, model: str, request: GeminiRequest, api_key: str
|
| 337 |
+
) -> Dict[str, Any]:
|
| 338 |
+
"""生成内容"""
|
| 339 |
+
# 檢查並獲取文件專用的 API key(如果有文件)
|
| 340 |
+
file_names = _extract_file_references(request.model_dump().get("contents", []))
|
| 341 |
+
if file_names:
|
| 342 |
+
logger.info(f"Request contains file references: {file_names}")
|
| 343 |
+
file_api_key = await get_file_api_key(file_names[0])
|
| 344 |
+
if file_api_key:
|
| 345 |
+
logger.info(
|
| 346 |
+
f"Found API key for file {file_names[0]}: {redact_key_for_logging(file_api_key)}"
|
| 347 |
+
)
|
| 348 |
+
api_key = file_api_key # 使用文件的 API key
|
| 349 |
+
else:
|
| 350 |
+
logger.warning(
|
| 351 |
+
f"No API key found for file {file_names[0]}, using default key: {redact_key_for_logging(api_key)}"
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
payload = _build_payload(model, request)
|
| 355 |
+
start_time = time.perf_counter()
|
| 356 |
+
request_datetime = datetime.datetime.now()
|
| 357 |
+
is_success = False
|
| 358 |
+
status_code = None
|
| 359 |
+
response = None
|
| 360 |
+
|
| 361 |
+
try:
|
| 362 |
+
response = await self.api_client.generate_content(payload, model, api_key)
|
| 363 |
+
is_success = True
|
| 364 |
+
status_code = 200
|
| 365 |
+
return self.response_handler.handle_response(response, model, stream=False)
|
| 366 |
+
except Exception as e:
|
| 367 |
+
is_success = False
|
| 368 |
+
status_code = e.args[0]
|
| 369 |
+
error_log_msg = e.args[1]
|
| 370 |
+
logger.error(f"Normal API call failed with error: {error_log_msg}")
|
| 371 |
+
|
| 372 |
+
await add_error_log(
|
| 373 |
+
gemini_key=api_key,
|
| 374 |
+
model_name=model,
|
| 375 |
+
error_type="gemini-chat-non-stream",
|
| 376 |
+
error_log=error_log_msg,
|
| 377 |
+
error_code=status_code,
|
| 378 |
+
request_msg=payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None,
|
| 379 |
+
request_datetime=request_datetime,
|
| 380 |
+
)
|
| 381 |
+
raise e
|
| 382 |
+
finally:
|
| 383 |
+
end_time = time.perf_counter()
|
| 384 |
+
latency_ms = int((end_time - start_time) * 1000)
|
| 385 |
+
await add_request_log(
|
| 386 |
+
model_name=model,
|
| 387 |
+
api_key=api_key,
|
| 388 |
+
is_success=is_success,
|
| 389 |
+
status_code=status_code,
|
| 390 |
+
latency_ms=latency_ms,
|
| 391 |
+
request_time=request_datetime,
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
async def count_tokens(
|
| 395 |
+
self, model: str, request: GeminiRequest, api_key: str
|
| 396 |
+
) -> Dict[str, Any]:
|
| 397 |
+
"""计算token数量"""
|
| 398 |
+
# countTokens API只需要contents
|
| 399 |
+
payload = {
|
| 400 |
+
"contents": _filter_empty_parts(request.model_dump().get("contents", []))
|
| 401 |
+
}
|
| 402 |
+
start_time = time.perf_counter()
|
| 403 |
+
request_datetime = datetime.datetime.now()
|
| 404 |
+
is_success = False
|
| 405 |
+
status_code = None
|
| 406 |
+
response = None
|
| 407 |
+
|
| 408 |
+
try:
|
| 409 |
+
response = await self.api_client.count_tokens(payload, model, api_key)
|
| 410 |
+
is_success = True
|
| 411 |
+
status_code = 200
|
| 412 |
+
return response
|
| 413 |
+
except Exception as e:
|
| 414 |
+
is_success = False
|
| 415 |
+
status_code = e.args[0]
|
| 416 |
+
error_log_msg = e.args[1]
|
| 417 |
+
logger.error(f"Count tokens API call failed with error: {error_log_msg}")
|
| 418 |
+
|
| 419 |
+
await add_error_log(
|
| 420 |
+
gemini_key=api_key,
|
| 421 |
+
model_name=model,
|
| 422 |
+
error_type="gemini-count-tokens",
|
| 423 |
+
error_log=error_log_msg,
|
| 424 |
+
error_code=status_code,
|
| 425 |
+
request_msg=payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None,
|
| 426 |
+
)
|
| 427 |
+
raise e
|
| 428 |
+
finally:
|
| 429 |
+
end_time = time.perf_counter()
|
| 430 |
+
latency_ms = int((end_time - start_time) * 1000)
|
| 431 |
+
await add_request_log(
|
| 432 |
+
model_name=model,
|
| 433 |
+
api_key=api_key,
|
| 434 |
+
is_success=is_success,
|
| 435 |
+
status_code=status_code,
|
| 436 |
+
latency_ms=latency_ms,
|
| 437 |
+
request_time=request_datetime,
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
async def stream_generate_content(
|
| 441 |
+
self, model: str, request: GeminiRequest, api_key: str
|
| 442 |
+
) -> AsyncGenerator[str, None]:
|
| 443 |
+
"""流式生成内容"""
|
| 444 |
+
# 檢查並獲取文件專用的 API key(如果有文件)
|
| 445 |
+
file_names = _extract_file_references(request.model_dump().get("contents", []))
|
| 446 |
+
if file_names:
|
| 447 |
+
logger.info(f"Request contains file references: {file_names}")
|
| 448 |
+
file_api_key = await get_file_api_key(file_names[0])
|
| 449 |
+
if file_api_key:
|
| 450 |
+
logger.info(
|
| 451 |
+
f"Found API key for file {file_names[0]}: {redact_key_for_logging(file_api_key)}"
|
| 452 |
+
)
|
| 453 |
+
api_key = file_api_key # 使用文件的 API key
|
| 454 |
+
else:
|
| 455 |
+
logger.warning(
|
| 456 |
+
f"No API key found for file {file_names[0]}, using default key: {redact_key_for_logging(api_key)}"
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
retries = 0
|
| 460 |
+
max_retries = settings.MAX_RETRIES
|
| 461 |
+
payload = _build_payload(model, request)
|
| 462 |
+
is_success = False
|
| 463 |
+
status_code = None
|
| 464 |
+
final_api_key = api_key
|
| 465 |
+
|
| 466 |
+
while retries < max_retries:
|
| 467 |
+
request_datetime = datetime.datetime.now()
|
| 468 |
+
start_time = time.perf_counter()
|
| 469 |
+
current_attempt_key = api_key
|
| 470 |
+
final_api_key = current_attempt_key
|
| 471 |
+
try:
|
| 472 |
+
async for line in self.api_client.stream_generate_content(
|
| 473 |
+
payload, model, current_attempt_key
|
| 474 |
+
):
|
| 475 |
+
# print(line)
|
| 476 |
+
if line.startswith("data:"):
|
| 477 |
+
line = line[6:]
|
| 478 |
+
response_data = self.response_handler.handle_response(
|
| 479 |
+
json.loads(line), model, stream=True
|
| 480 |
+
)
|
| 481 |
+
text = self._extract_text_from_response(response_data)
|
| 482 |
+
# 如果有文本内容,且开启了流式输出优化器,则使用流式输出优化器处理
|
| 483 |
+
if text and settings.STREAM_OPTIMIZER_ENABLED:
|
| 484 |
+
# 使用流式输出优化器处理文本输出
|
| 485 |
+
async for (
|
| 486 |
+
optimized_chunk
|
| 487 |
+
) in gemini_optimizer.optimize_stream_output(
|
| 488 |
+
text,
|
| 489 |
+
lambda t: self._create_char_response(response_data, t),
|
| 490 |
+
lambda c: "data: " + json.dumps(c) + "\n\n",
|
| 491 |
+
):
|
| 492 |
+
yield optimized_chunk
|
| 493 |
+
else:
|
| 494 |
+
# 如果没有文本内容(如工具调用等),整块输出
|
| 495 |
+
yield "data: " + json.dumps(response_data) + "\n\n"
|
| 496 |
+
logger.info("Streaming completed successfully")
|
| 497 |
+
is_success = True
|
| 498 |
+
status_code = 200
|
| 499 |
+
break
|
| 500 |
+
except Exception as e:
|
| 501 |
+
retries += 1
|
| 502 |
+
is_success = False
|
| 503 |
+
status_code = e.args[0]
|
| 504 |
+
error_log_msg = e.args[1]
|
| 505 |
+
logger.warning(
|
| 506 |
+
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
await add_error_log(
|
| 510 |
+
gemini_key=current_attempt_key,
|
| 511 |
+
model_name=model,
|
| 512 |
+
error_type="gemini-chat-stream",
|
| 513 |
+
error_log=error_log_msg,
|
| 514 |
+
error_code=status_code,
|
| 515 |
+
request_msg=(
|
| 516 |
+
payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None
|
| 517 |
+
),
|
| 518 |
+
request_datetime=request_datetime,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
api_key = await self.key_manager.handle_api_failure(
|
| 522 |
+
current_attempt_key, retries
|
| 523 |
+
)
|
| 524 |
+
if api_key:
|
| 525 |
+
logger.info(
|
| 526 |
+
f"Switched to new API key: {redact_key_for_logging(api_key)}"
|
| 527 |
+
)
|
| 528 |
+
else:
|
| 529 |
+
logger.error(f"No valid API key available after {retries} retries.")
|
| 530 |
+
raise
|
| 531 |
+
|
| 532 |
+
if retries >= max_retries:
|
| 533 |
+
logger.error(f"Max retries ({max_retries}) reached for streaming.")
|
| 534 |
+
raise
|
| 535 |
+
finally:
|
| 536 |
+
end_time = time.perf_counter()
|
| 537 |
+
latency_ms = int((end_time - start_time) * 1000)
|
| 538 |
+
await add_request_log(
|
| 539 |
+
model_name=model,
|
| 540 |
+
api_key=final_api_key,
|
| 541 |
+
is_success=is_success,
|
| 542 |
+
status_code=status_code,
|
| 543 |
+
latency_ms=latency_ms,
|
| 544 |
+
request_time=request_datetime,
|
| 545 |
+
)
|